From ddb864d14a56876771f46bd4b3db8ccd5e7242cd Mon Sep 17 00:00:00 2001 From: Pankaj Bhojwani Date: Thu, 21 Nov 2024 16:17:34 -0800 Subject: [PATCH 1/7] add 5s timeout --- src/cascadia/QueryExtension/ExtensionPalette.cpp | 10 +++++++++- .../QueryExtension/Resources/en-US/Resources.resw | 2 +- src/cascadia/QueryExtension/pch.h | 2 ++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/cascadia/QueryExtension/ExtensionPalette.cpp b/src/cascadia/QueryExtension/ExtensionPalette.cpp index be788e1ffd9..52922ffb7b9 100644 --- a/src/cascadia/QueryExtension/ExtensionPalette.cpp +++ b/src/cascadia/QueryExtension/ExtensionPalette.cpp @@ -159,7 +159,15 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation if (_lmProvider) { - result = _lmProvider.GetResponseAsync(promptCopy).get(); + const auto asyncOperation = _lmProvider.GetResponseAsync(promptCopy); + if (asyncOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed) + { + result = asyncOperation.GetResults(); + } + else + { + result = winrt::make(RS_(L"UnknownErrorMessage"), ErrorTypes::Unknown, winrt::hstring{}); + } } else { diff --git a/src/cascadia/QueryExtension/Resources/en-US/Resources.resw b/src/cascadia/QueryExtension/Resources/en-US/Resources.resw index e93fcbcd159..323394e23eb 100644 --- a/src/cascadia/QueryExtension/Resources/en-US/Resources.resw +++ b/src/cascadia/QueryExtension/Resources/en-US/Resources.resw @@ -126,7 +126,7 @@ The message presented to the user when they attempt to use the AI chat feature without providing an AI endpoint and key. - An error occurred. Your AI provider might not be correctly configured, or the service might be temporarily unavailable. + An error occurred. The service might be temporarily unavailable or there might be network connectivity issues. The error message presented to the user when we were unable to query the provided endpoint. diff --git a/src/cascadia/QueryExtension/pch.h b/src/cascadia/QueryExtension/pch.h index c2745e48e79..ba7085ba859 100644 --- a/src/cascadia/QueryExtension/pch.h +++ b/src/cascadia/QueryExtension/pch.h @@ -53,6 +53,8 @@ TRACELOGGING_DECLARE_PROVIDER(g_hQueryExtensionProvider); #include +#include + // Manually include til after we include Windows.Foundation to give it winrt superpowers #include "til.h" From 48debd94636af35292f0056e0b7a860fb8017d90 Mon Sep 17 00:00:00 2001 From: Pankaj Bhojwani Date: Fri, 22 Nov 2024 11:18:15 -0800 Subject: [PATCH 2/7] cancel the http request too --- src/cascadia/QueryExtension/AzureLLMProvider.cpp | 12 +++++++++++- src/cascadia/QueryExtension/AzureLLMProvider.h | 1 + src/cascadia/QueryExtension/ExtensionPalette.cpp | 1 + .../QueryExtension/GithubCopilotLLMProvider.cpp | 12 +++++++++++- .../QueryExtension/GithubCopilotLLMProvider.h | 1 + src/cascadia/QueryExtension/OpenAILLMProvider.cpp | 12 +++++++++++- src/cascadia/QueryExtension/OpenAILLMProvider.h | 1 + 7 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/cascadia/QueryExtension/AzureLLMProvider.cpp b/src/cascadia/QueryExtension/AzureLLMProvider.cpp index 5a6a2fd55c9..bc0cabb5e53 100644 --- a/src/cascadia/QueryExtension/AzureLLMProvider.cpp +++ b/src/cascadia/QueryExtension/AzureLLMProvider.cpp @@ -79,6 +79,14 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Foundation::IAsyncOperation AzureLLMProvider::GetResponseAsync(const winrt::hstring& userPrompt) { + auto cancelation_token{ co_await winrt::get_cancellation_token() }; + cancelation_token.callback([=] { + if (_lastRequest) + { + _lastRequest.Cancel(); + } + }); + // Use the ErrorTypes enum to flag whether the response the user receives is an error message // we pass this enum back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event) ErrorTypes errorType{ ErrorTypes::None }; @@ -145,7 +153,9 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation // Send the request try { - const auto response = _httpClient.SendRequestAsync(request).get(); + const auto sendRequestOperation = _httpClient.SendRequestAsync(request); + const auto response{ co_await sendRequestOperation }; + _lastRequest = sendRequestOperation; // Parse out the suggestion from the response const auto string{ response.Content().ReadAsStringAsync().get() }; const auto jsonResult{ WDJ::JsonObject::Parse(string) }; diff --git a/src/cascadia/QueryExtension/AzureLLMProvider.h b/src/cascadia/QueryExtension/AzureLLMProvider.h index 1899bb93099..99139769f82 100644 --- a/src/cascadia/QueryExtension/AzureLLMProvider.h +++ b/src/cascadia/QueryExtension/AzureLLMProvider.h @@ -39,6 +39,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::hstring _azureKey; winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr }; IBrandingData _brandingData{ winrt::make() }; + winrt::Windows::Foundation::IAsyncOperationWithProgress _lastRequest{ nullptr }; Extension::IContext _context; diff --git a/src/cascadia/QueryExtension/ExtensionPalette.cpp b/src/cascadia/QueryExtension/ExtensionPalette.cpp index 52922ffb7b9..77b80e9b142 100644 --- a/src/cascadia/QueryExtension/ExtensionPalette.cpp +++ b/src/cascadia/QueryExtension/ExtensionPalette.cpp @@ -166,6 +166,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation } else { + asyncOperation.Cancel(); result = winrt::make(RS_(L"UnknownErrorMessage"), ErrorTypes::Unknown, winrt::hstring{}); } } diff --git a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp index b72ceba2b2a..bfaa4f29f6b 100644 --- a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp +++ b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp @@ -237,6 +237,14 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Foundation::IAsyncOperation GithubCopilotLLMProvider::GetResponseAsync(const winrt::hstring& userPrompt) { + auto cancelation_token{ co_await winrt::get_cancellation_token() }; + cancelation_token.callback([=] { + if (_lastRequest) + { + _lastRequest.Cancel(); + } + }); + // Use the ErrorTypes enum to flag whether the response the user receives is an error message // we pass this enum back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event) ErrorTypes errorType{ ErrorTypes::None }; @@ -360,7 +368,9 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation WWH::HttpRequestMessage request{ method, Uri{ uri } }; request.Content(content); - const auto response{ co_await _httpClient.SendRequestAsync(request) }; + const auto sendRequestOperation = _httpClient.SendRequestAsync(request); + const auto response{ co_await sendRequestOperation }; + _lastRequest = sendRequestOperation; const auto string{ co_await response.Content().ReadAsStringAsync() }; _lastResponse = string; const auto jsonResult{ WDJ::JsonObject::Parse(string) }; diff --git a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.h b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.h index 98f69cd6fcc..d711607c131 100644 --- a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.h +++ b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.h @@ -51,6 +51,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr }; IBrandingData _brandingData{ winrt::make() }; winrt::hstring _lastResponse; + winrt::Windows::Foundation::IAsyncOperationWithProgress _lastRequest{ nullptr }; Extension::IContext _context; diff --git a/src/cascadia/QueryExtension/OpenAILLMProvider.cpp b/src/cascadia/QueryExtension/OpenAILLMProvider.cpp index a8184f72593..e7e25d26333 100644 --- a/src/cascadia/QueryExtension/OpenAILLMProvider.cpp +++ b/src/cascadia/QueryExtension/OpenAILLMProvider.cpp @@ -62,6 +62,14 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Foundation::IAsyncOperation OpenAILLMProvider::GetResponseAsync(const winrt::hstring userPrompt) { + auto cancelation_token{ co_await winrt::get_cancellation_token() }; + cancelation_token.callback([=] { + if (_lastRequest) + { + _lastRequest.Cancel(); + } + }); + // Use the ErrorTypes enum to flag whether the response the user receives is an error message // we pass this enum back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event) ErrorTypes errorType{ ErrorTypes::None }; @@ -100,7 +108,9 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation // Send the request try { - const auto response = co_await _httpClient.SendRequestAsync(request); + const auto sendRequestOperation = _httpClient.SendRequestAsync(request); + const auto response{ co_await sendRequestOperation }; + _lastRequest = sendRequestOperation; // Parse out the suggestion from the response const auto string{ co_await response.Content().ReadAsStringAsync() }; const auto jsonResult{ WDJ::JsonObject::Parse(string) }; diff --git a/src/cascadia/QueryExtension/OpenAILLMProvider.h b/src/cascadia/QueryExtension/OpenAILLMProvider.h index c1f489d310c..5f4f770e97b 100644 --- a/src/cascadia/QueryExtension/OpenAILLMProvider.h +++ b/src/cascadia/QueryExtension/OpenAILLMProvider.h @@ -38,6 +38,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::hstring _AIKey; winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr }; IBrandingData _brandingData{ winrt::make() }; + winrt::Windows::Foundation::IAsyncOperationWithProgress _lastRequest{ nullptr }; Extension::IContext _context; From 2fece1350cc166457f6834f01b74c64a963cb0cb Mon Sep 17 00:00:00 2001 From: Pankaj Bhojwani Date: Fri, 22 Nov 2024 12:00:29 -0800 Subject: [PATCH 3/7] camel --- src/cascadia/QueryExtension/AzureLLMProvider.cpp | 4 ++-- src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp | 4 ++-- src/cascadia/QueryExtension/OpenAILLMProvider.cpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/cascadia/QueryExtension/AzureLLMProvider.cpp b/src/cascadia/QueryExtension/AzureLLMProvider.cpp index bc0cabb5e53..0708b0c2361 100644 --- a/src/cascadia/QueryExtension/AzureLLMProvider.cpp +++ b/src/cascadia/QueryExtension/AzureLLMProvider.cpp @@ -79,8 +79,8 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Foundation::IAsyncOperation AzureLLMProvider::GetResponseAsync(const winrt::hstring& userPrompt) { - auto cancelation_token{ co_await winrt::get_cancellation_token() }; - cancelation_token.callback([=] { + auto cancellationToken{ co_await winrt::get_cancellation_token() }; + cancellationToken.callback([=] { if (_lastRequest) { _lastRequest.Cancel(); diff --git a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp index bfaa4f29f6b..ca11ba3bb38 100644 --- a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp +++ b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp @@ -237,8 +237,8 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Foundation::IAsyncOperation GithubCopilotLLMProvider::GetResponseAsync(const winrt::hstring& userPrompt) { - auto cancelation_token{ co_await winrt::get_cancellation_token() }; - cancelation_token.callback([=] { + auto cancellationToken{ co_await winrt::get_cancellation_token() }; + cancellationToken.callback([=] { if (_lastRequest) { _lastRequest.Cancel(); diff --git a/src/cascadia/QueryExtension/OpenAILLMProvider.cpp b/src/cascadia/QueryExtension/OpenAILLMProvider.cpp index e7e25d26333..d511c0c67d5 100644 --- a/src/cascadia/QueryExtension/OpenAILLMProvider.cpp +++ b/src/cascadia/QueryExtension/OpenAILLMProvider.cpp @@ -62,8 +62,8 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Foundation::IAsyncOperation OpenAILLMProvider::GetResponseAsync(const winrt::hstring userPrompt) { - auto cancelation_token{ co_await winrt::get_cancellation_token() }; - cancelation_token.callback([=] { + auto cancellationToken{ co_await winrt::get_cancellation_token() }; + cancellationToken.callback([=] { if (_lastRequest) { _lastRequest.Cancel(); From 7affcb6f76552c8d57545dd8a2d28c28d0157054 Mon Sep 17 00:00:00 2001 From: Pankaj Bhojwani Date: Mon, 25 Nov 2024 15:59:14 -0800 Subject: [PATCH 4/7] remove as member --- .../QueryExtension/AzureLLMProvider.cpp | 18 ++++++++---------- src/cascadia/QueryExtension/AzureLLMProvider.h | 1 - .../QueryExtension/OpenAILLMProvider.cpp | 16 +++++++--------- .../QueryExtension/OpenAILLMProvider.h | 1 - 4 files changed, 15 insertions(+), 21 deletions(-) diff --git a/src/cascadia/QueryExtension/AzureLLMProvider.cpp b/src/cascadia/QueryExtension/AzureLLMProvider.cpp index 0708b0c2361..e34886fc5e3 100644 --- a/src/cascadia/QueryExtension/AzureLLMProvider.cpp +++ b/src/cascadia/QueryExtension/AzureLLMProvider.cpp @@ -79,14 +79,6 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Foundation::IAsyncOperation AzureLLMProvider::GetResponseAsync(const winrt::hstring& userPrompt) { - auto cancellationToken{ co_await winrt::get_cancellation_token() }; - cancellationToken.callback([=] { - if (_lastRequest) - { - _lastRequest.Cancel(); - } - }); - // Use the ErrorTypes enum to flag whether the response the user receives is an error message // we pass this enum back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event) ErrorTypes errorType{ ErrorTypes::None }; @@ -154,10 +146,16 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation try { const auto sendRequestOperation = _httpClient.SendRequestAsync(request); + auto cancellationToken{ co_await winrt::get_cancellation_token() }; + cancellationToken.callback([sendRequestOperation] { + if (sendRequestOperation) + { + sendRequestOperation.Cancel(); + } + }); const auto response{ co_await sendRequestOperation }; - _lastRequest = sendRequestOperation; // Parse out the suggestion from the response - const auto string{ response.Content().ReadAsStringAsync().get() }; + const auto string{ co_await response.Content().ReadAsStringAsync() }; const auto jsonResult{ WDJ::JsonObject::Parse(string) }; if (jsonResult.HasKey(errorString)) { diff --git a/src/cascadia/QueryExtension/AzureLLMProvider.h b/src/cascadia/QueryExtension/AzureLLMProvider.h index 99139769f82..1899bb93099 100644 --- a/src/cascadia/QueryExtension/AzureLLMProvider.h +++ b/src/cascadia/QueryExtension/AzureLLMProvider.h @@ -39,7 +39,6 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::hstring _azureKey; winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr }; IBrandingData _brandingData{ winrt::make() }; - winrt::Windows::Foundation::IAsyncOperationWithProgress _lastRequest{ nullptr }; Extension::IContext _context; diff --git a/src/cascadia/QueryExtension/OpenAILLMProvider.cpp b/src/cascadia/QueryExtension/OpenAILLMProvider.cpp index d511c0c67d5..9f957cf46d3 100644 --- a/src/cascadia/QueryExtension/OpenAILLMProvider.cpp +++ b/src/cascadia/QueryExtension/OpenAILLMProvider.cpp @@ -62,14 +62,6 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Foundation::IAsyncOperation OpenAILLMProvider::GetResponseAsync(const winrt::hstring userPrompt) { - auto cancellationToken{ co_await winrt::get_cancellation_token() }; - cancellationToken.callback([=] { - if (_lastRequest) - { - _lastRequest.Cancel(); - } - }); - // Use the ErrorTypes enum to flag whether the response the user receives is an error message // we pass this enum back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event) ErrorTypes errorType{ ErrorTypes::None }; @@ -109,8 +101,14 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation try { const auto sendRequestOperation = _httpClient.SendRequestAsync(request); + auto cancellationToken{ co_await winrt::get_cancellation_token() }; + cancellationToken.callback([sendRequestOperation] { + if (sendRequestOperation) + { + sendRequestOperation.Cancel(); + } + }); const auto response{ co_await sendRequestOperation }; - _lastRequest = sendRequestOperation; // Parse out the suggestion from the response const auto string{ co_await response.Content().ReadAsStringAsync() }; const auto jsonResult{ WDJ::JsonObject::Parse(string) }; diff --git a/src/cascadia/QueryExtension/OpenAILLMProvider.h b/src/cascadia/QueryExtension/OpenAILLMProvider.h index 5f4f770e97b..c1f489d310c 100644 --- a/src/cascadia/QueryExtension/OpenAILLMProvider.h +++ b/src/cascadia/QueryExtension/OpenAILLMProvider.h @@ -38,7 +38,6 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::hstring _AIKey; winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr }; IBrandingData _brandingData{ winrt::make() }; - winrt::Windows::Foundation::IAsyncOperationWithProgress _lastRequest{ nullptr }; Extension::IContext _context; From f1019334e6e245ac2eb0e63fe17a98356ee7bdae Mon Sep 17 00:00:00 2001 From: Pankaj Bhojwani Date: Mon, 25 Nov 2024 16:21:59 -0800 Subject: [PATCH 5/7] same for capi --- .../QueryExtension/AzureLLMProvider.cpp | 5 +---- .../GithubCopilotLLMProvider.cpp | 19 ++++++++++--------- .../QueryExtension/OpenAILLMProvider.cpp | 5 +---- 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/src/cascadia/QueryExtension/AzureLLMProvider.cpp b/src/cascadia/QueryExtension/AzureLLMProvider.cpp index e34886fc5e3..e935a325593 100644 --- a/src/cascadia/QueryExtension/AzureLLMProvider.cpp +++ b/src/cascadia/QueryExtension/AzureLLMProvider.cpp @@ -148,10 +148,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation const auto sendRequestOperation = _httpClient.SendRequestAsync(request); auto cancellationToken{ co_await winrt::get_cancellation_token() }; cancellationToken.callback([sendRequestOperation] { - if (sendRequestOperation) - { - sendRequestOperation.Cancel(); - } + sendRequestOperation.Cancel(); }); const auto response{ co_await sendRequestOperation }; // Parse out the suggestion from the response diff --git a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp index ca11ba3bb38..ddd37009743 100644 --- a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp +++ b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp @@ -237,14 +237,6 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Foundation::IAsyncOperation GithubCopilotLLMProvider::GetResponseAsync(const winrt::hstring& userPrompt) { - auto cancellationToken{ co_await winrt::get_cancellation_token() }; - cancellationToken.callback([=] { - if (_lastRequest) - { - _lastRequest.Cancel(); - } - }); - // Use the ErrorTypes enum to flag whether the response the user receives is an error message // we pass this enum back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event) ErrorTypes errorType{ ErrorTypes::None }; @@ -284,7 +276,12 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation }; // Send the request - const auto jsonResult = co_await _SendRequestReturningJson(_endpointUri, requestContent, WWH::HttpMethod::Post()); + const auto jsonResultOperation = _SendRequestReturningJson(_endpointUri, requestContent, WWH::HttpMethod::Post()); + auto cancellationToken{ co_await winrt::get_cancellation_token() }; + cancellationToken.callback([jsonResultOperation] { + jsonResultOperation.Cancel(); + }); + const auto jsonResult = co_await jsonResultOperation; if (jsonResult.HasKey(errorKey)) { const auto errorObject = jsonResult.GetNamedObject(errorKey); @@ -369,6 +366,10 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation request.Content(content); const auto sendRequestOperation = _httpClient.SendRequestAsync(request); + auto cancellationToken{ co_await winrt::get_cancellation_token() }; + cancellationToken.callback([sendRequestOperation] { + sendRequestOperation.Cancel(); + }); const auto response{ co_await sendRequestOperation }; _lastRequest = sendRequestOperation; const auto string{ co_await response.Content().ReadAsStringAsync() }; diff --git a/src/cascadia/QueryExtension/OpenAILLMProvider.cpp b/src/cascadia/QueryExtension/OpenAILLMProvider.cpp index 9f957cf46d3..da8fe7a5220 100644 --- a/src/cascadia/QueryExtension/OpenAILLMProvider.cpp +++ b/src/cascadia/QueryExtension/OpenAILLMProvider.cpp @@ -103,10 +103,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation const auto sendRequestOperation = _httpClient.SendRequestAsync(request); auto cancellationToken{ co_await winrt::get_cancellation_token() }; cancellationToken.callback([sendRequestOperation] { - if (sendRequestOperation) - { - sendRequestOperation.Cancel(); - } + sendRequestOperation.Cancel(); }); const auto response{ co_await sendRequestOperation }; // Parse out the suggestion from the response From 733e12362b3db44af4143193e26bc4559f9db53a Mon Sep 17 00:00:00 2001 From: Pankaj Bhojwani Date: Tue, 26 Nov 2024 16:17:49 -0800 Subject: [PATCH 6/7] lmproviders timeout as well --- .../QueryExtension/AzureLLMProvider.cpp | 49 ++++++++----- .../QueryExtension/ExtensionPalette.cpp | 2 +- .../GithubCopilotLLMProvider.cpp | 68 ++++++++++++++----- .../QueryExtension/GithubCopilotLLMProvider.h | 1 - .../QueryExtension/OpenAILLMProvider.cpp | 37 ++++++---- 5 files changed, 108 insertions(+), 49 deletions(-) diff --git a/src/cascadia/QueryExtension/AzureLLMProvider.cpp b/src/cascadia/QueryExtension/AzureLLMProvider.cpp index e935a325593..d4b8d643ca8 100644 --- a/src/cascadia/QueryExtension/AzureLLMProvider.cpp +++ b/src/cascadia/QueryExtension/AzureLLMProvider.cpp @@ -146,35 +146,48 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation try { const auto sendRequestOperation = _httpClient.SendRequestAsync(request); + + // if the caller cancels this operation, make sure to cancel the http request as well auto cancellationToken{ co_await winrt::get_cancellation_token() }; cancellationToken.callback([sendRequestOperation] { sendRequestOperation.Cancel(); }); - const auto response{ co_await sendRequestOperation }; - // Parse out the suggestion from the response - const auto string{ co_await response.Content().ReadAsStringAsync() }; - const auto jsonResult{ WDJ::JsonObject::Parse(string) }; - if (jsonResult.HasKey(errorString)) - { - const auto errorObject = jsonResult.GetNamedObject(errorString); - message = errorObject.GetNamedString(messageString); - errorType = ErrorTypes::FromProvider; - } - else + + if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed) { - if (_verifyModelIsValidHelper(jsonResult)) + // Parse out the suggestion from the response + const auto response = sendRequestOperation.GetResults(); + const auto string{ co_await response.Content().ReadAsStringAsync() }; + const auto jsonResult{ WDJ::JsonObject::Parse(string) }; + if (jsonResult.HasKey(errorString)) { - const auto choices = jsonResult.GetNamedArray(L"choices"); - const auto firstChoice = choices.GetAt(0).GetObject(); - const auto messageObject = firstChoice.GetNamedObject(messageString); - message = messageObject.GetNamedString(contentString); + const auto errorObject = jsonResult.GetNamedObject(errorString); + message = errorObject.GetNamedString(messageString); + errorType = ErrorTypes::FromProvider; } else { - message = RS_(L"InvalidModelMessage"); - errorType = ErrorTypes::InvalidModel; + if (_verifyModelIsValidHelper(jsonResult)) + { + const auto choices = jsonResult.GetNamedArray(L"choices"); + const auto firstChoice = choices.GetAt(0).GetObject(); + const auto messageObject = firstChoice.GetNamedObject(messageString); + message = messageObject.GetNamedString(contentString); + } + else + { + message = RS_(L"InvalidModelMessage"); + errorType = ErrorTypes::InvalidModel; + } } } + else + { + // if the http request takes too long, cancel the http request and return an error + sendRequestOperation.Cancel(); + message = RS_(L"UnknownErrorMessage"); + errorType = ErrorTypes::Unknown; + } } catch (...) { diff --git a/src/cascadia/QueryExtension/ExtensionPalette.cpp b/src/cascadia/QueryExtension/ExtensionPalette.cpp index 77b80e9b142..7b2da1bbc54 100644 --- a/src/cascadia/QueryExtension/ExtensionPalette.cpp +++ b/src/cascadia/QueryExtension/ExtensionPalette.cpp @@ -160,7 +160,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation if (_lmProvider) { const auto asyncOperation = _lmProvider.GetResponseAsync(promptCopy); - if (asyncOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed) + if (asyncOperation.wait_for(std::chrono::seconds(15)) == AsyncStatus::Completed) { result = asyncOperation.GetResults(); } diff --git a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp index ddd37009743..a7e6bfc3abd 100644 --- a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp +++ b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp @@ -247,7 +247,9 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation // Make sure we are on the background thread for the http request auto strongThis = get_strong(); + co_await winrt::resume_background(); + auto cancellationToken{ co_await winrt::get_cancellation_token() }; for (bool refreshAttempted = false;;) { @@ -276,24 +278,37 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation }; // Send the request - const auto jsonResultOperation = _SendRequestReturningJson(_endpointUri, requestContent, WWH::HttpMethod::Post()); - auto cancellationToken{ co_await winrt::get_cancellation_token() }; - cancellationToken.callback([jsonResultOperation] { - jsonResultOperation.Cancel(); + const auto sendRequestOperation = _SendRequestReturningJson(_endpointUri, requestContent, WWH::HttpMethod::Post()); + + // if the caller cancels this operation, make sure to cancel the http request as well + cancellationToken.callback([sendRequestOperation] { + sendRequestOperation.Cancel(); }); - const auto jsonResult = co_await jsonResultOperation; - if (jsonResult.HasKey(errorKey)) + + if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed) { - const auto errorObject = jsonResult.GetNamedObject(errorKey); - message = errorObject.GetNamedString(messageKey); - errorType = ErrorTypes::FromProvider; + // Parse out the suggestion from the response + const auto jsonResult = sendRequestOperation.GetResults(); + if (jsonResult.HasKey(errorKey)) + { + const auto errorObject = jsonResult.GetNamedObject(errorKey); + message = errorObject.GetNamedString(messageKey); + errorType = ErrorTypes::FromProvider; + } + else + { + const auto choices = jsonResult.GetNamedArray(L"ayy"); + const auto firstChoice = choices.GetAt(0).GetObject(); + const auto messageObject = firstChoice.GetNamedObject(messageKey); + message = messageObject.GetNamedString(contentKey); + } } else { - const auto choices = jsonResult.GetNamedArray(choicesKey); - const auto firstChoice = choices.GetAt(0).GetObject(); - const auto messageObject = firstChoice.GetNamedObject(messageKey); - message = messageObject.GetNamedString(contentKey); + // if the http request takes too long, cancel the http request and return an error + sendRequestOperation.Cancel(); + message = RS_(L"UnknownErrorMessage"); + errorType = ErrorTypes::Unknown; } break; } @@ -310,8 +325,23 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation break; } - co_await _refreshAuthTokens(); - refreshAttempted = true; + const auto refreshTokensAction = _refreshAuthTokens(); + cancellationToken.callback([refreshTokensAction] { + refreshTokensAction.Cancel(); + }); + // allow up to 10 seconds for reauthentication + if (refreshTokensAction.wait_for(std::chrono::seconds(10)) == AsyncStatus::Completed) + { + refreshAttempted = true; + } + else + { + // if the refresh action takes too long, cancel it and return an error + refreshTokensAction.Cancel(); + message = RS_(L"UnknownErrorMessage"); + errorType = ErrorTypes::Unknown; + break; + } } // Also make a new entry in our jsonMessages list, so the AI knows the full conversation so far @@ -339,7 +369,12 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation try { - const auto jsonResult = co_await _SendRequestReturningJson(accessTokenEndpoint, requestContent, WWH::HttpMethod::Post()); + const auto reAuthOperation = _SendRequestReturningJson(accessTokenEndpoint, requestContent, WWH::HttpMethod::Post()); + auto cancellationToken{ co_await winrt::get_cancellation_token() }; + cancellationToken.callback([reAuthOperation] { + reAuthOperation.Cancel(); + }); + const auto jsonResult{ co_await reAuthOperation }; _authToken = jsonResult.GetNamedString(accessTokenKey); _refreshToken = jsonResult.GetNamedString(refreshTokenKey); @@ -371,7 +406,6 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation sendRequestOperation.Cancel(); }); const auto response{ co_await sendRequestOperation }; - _lastRequest = sendRequestOperation; const auto string{ co_await response.Content().ReadAsStringAsync() }; _lastResponse = string; const auto jsonResult{ WDJ::JsonObject::Parse(string) }; diff --git a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.h b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.h index d711607c131..98f69cd6fcc 100644 --- a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.h +++ b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.h @@ -51,7 +51,6 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr }; IBrandingData _brandingData{ winrt::make() }; winrt::hstring _lastResponse; - winrt::Windows::Foundation::IAsyncOperationWithProgress _lastRequest{ nullptr }; Extension::IContext _context; diff --git a/src/cascadia/QueryExtension/OpenAILLMProvider.cpp b/src/cascadia/QueryExtension/OpenAILLMProvider.cpp index da8fe7a5220..84eb18e0e71 100644 --- a/src/cascadia/QueryExtension/OpenAILLMProvider.cpp +++ b/src/cascadia/QueryExtension/OpenAILLMProvider.cpp @@ -101,26 +101,39 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation try { const auto sendRequestOperation = _httpClient.SendRequestAsync(request); + + // if the caller cancels this operation, make sure to cancel the http request as well auto cancellationToken{ co_await winrt::get_cancellation_token() }; cancellationToken.callback([sendRequestOperation] { sendRequestOperation.Cancel(); }); - const auto response{ co_await sendRequestOperation }; - // Parse out the suggestion from the response - const auto string{ co_await response.Content().ReadAsStringAsync() }; - const auto jsonResult{ WDJ::JsonObject::Parse(string) }; - if (jsonResult.HasKey(L"error")) + + if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed) { - const auto errorObject = jsonResult.GetNamedObject(L"error"); - message = errorObject.GetNamedString(L"message"); - errorType = ErrorTypes::FromProvider; + // Parse out the suggestion from the response + const auto response = sendRequestOperation.GetResults(); + const auto string{ co_await response.Content().ReadAsStringAsync() }; + const auto jsonResult{ WDJ::JsonObject::Parse(string) }; + if (jsonResult.HasKey(L"error")) + { + const auto errorObject = jsonResult.GetNamedObject(L"error"); + message = errorObject.GetNamedString(L"message"); + errorType = ErrorTypes::FromProvider; + } + else + { + const auto choices = jsonResult.GetNamedArray(L"choices"); + const auto firstChoice = choices.GetAt(0).GetObject(); + const auto messageObject = firstChoice.GetNamedObject(L"message"); + message = messageObject.GetNamedString(L"content"); + } } else { - const auto choices = jsonResult.GetNamedArray(L"choices"); - const auto firstChoice = choices.GetAt(0).GetObject(); - const auto messageObject = firstChoice.GetNamedObject(L"message"); - message = messageObject.GetNamedString(L"content"); + // if the http request takes too long, cancel the http request and return an error + sendRequestOperation.Cancel(); + message = RS_(L"UnknownErrorMessage"); + errorType = ErrorTypes::Unknown; } } catch (...) From 2a90fb94f2949b4a7fe35c38a39f27a7bc3a4b03 Mon Sep 17 00:00:00 2001 From: Pankaj Bhojwani Date: Tue, 26 Nov 2024 16:23:29 -0800 Subject: [PATCH 7/7] things --- src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp index a7e6bfc3abd..a90276d8d38 100644 --- a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp +++ b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp @@ -247,7 +247,6 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation // Make sure we are on the background thread for the http request auto strongThis = get_strong(); - co_await winrt::resume_background(); auto cancellationToken{ co_await winrt::get_cancellation_token() }; @@ -297,7 +296,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation } else { - const auto choices = jsonResult.GetNamedArray(L"ayy"); + const auto choices = jsonResult.GetNamedArray(choicesKey); const auto firstChoice = choices.GetAt(0).GetObject(); const auto messageObject = firstChoice.GetNamedObject(messageKey); message = messageObject.GetNamedString(contentKey);