diff --git a/src/cascadia/QueryExtension/AzureLLMProvider.cpp b/src/cascadia/QueryExtension/AzureLLMProvider.cpp index 5a6a2fd55c9..d4b8d643ca8 100644 --- a/src/cascadia/QueryExtension/AzureLLMProvider.cpp +++ b/src/cascadia/QueryExtension/AzureLLMProvider.cpp @@ -145,31 +145,49 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation // Send the request try { - const auto response = _httpClient.SendRequestAsync(request).get(); - // Parse out the suggestion from the response - const auto string{ response.Content().ReadAsStringAsync().get() }; - const auto jsonResult{ WDJ::JsonObject::Parse(string) }; - if (jsonResult.HasKey(errorString)) - { - const auto errorObject = jsonResult.GetNamedObject(errorString); - message = errorObject.GetNamedString(messageString); - errorType = ErrorTypes::FromProvider; - } - else + const auto sendRequestOperation = _httpClient.SendRequestAsync(request); + + // if the caller cancels this operation, make sure to cancel the http request as well + auto cancellationToken{ co_await winrt::get_cancellation_token() }; + cancellationToken.callback([sendRequestOperation] { + sendRequestOperation.Cancel(); + }); + + if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed) { - if (_verifyModelIsValidHelper(jsonResult)) + // Parse out the suggestion from the response + const auto response = sendRequestOperation.GetResults(); + const auto string{ co_await response.Content().ReadAsStringAsync() }; + const auto jsonResult{ WDJ::JsonObject::Parse(string) }; + if (jsonResult.HasKey(errorString)) { - const auto choices = jsonResult.GetNamedArray(L"choices"); - const auto firstChoice = choices.GetAt(0).GetObject(); - const auto messageObject = firstChoice.GetNamedObject(messageString); - message = messageObject.GetNamedString(contentString); + const auto errorObject = jsonResult.GetNamedObject(errorString); + message = errorObject.GetNamedString(messageString); + errorType = ErrorTypes::FromProvider; } else { - message = RS_(L"InvalidModelMessage"); - errorType = ErrorTypes::InvalidModel; + if (_verifyModelIsValidHelper(jsonResult)) + { + const auto choices = jsonResult.GetNamedArray(L"choices"); + const auto firstChoice = choices.GetAt(0).GetObject(); + const auto messageObject = firstChoice.GetNamedObject(messageString); + message = messageObject.GetNamedString(contentString); + } + else + { + message = RS_(L"InvalidModelMessage"); + errorType = ErrorTypes::InvalidModel; + } } } + else + { + // if the http request takes too long, cancel the http request and return an error + sendRequestOperation.Cancel(); + message = RS_(L"UnknownErrorMessage"); + errorType = ErrorTypes::Unknown; + } } catch (...) { diff --git a/src/cascadia/QueryExtension/ExtensionPalette.cpp b/src/cascadia/QueryExtension/ExtensionPalette.cpp index be788e1ffd9..7b2da1bbc54 100644 --- a/src/cascadia/QueryExtension/ExtensionPalette.cpp +++ b/src/cascadia/QueryExtension/ExtensionPalette.cpp @@ -159,7 +159,16 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation if (_lmProvider) { - result = _lmProvider.GetResponseAsync(promptCopy).get(); + const auto asyncOperation = _lmProvider.GetResponseAsync(promptCopy); + if (asyncOperation.wait_for(std::chrono::seconds(15)) == AsyncStatus::Completed) + { + result = asyncOperation.GetResults(); + } + else + { + asyncOperation.Cancel(); + result = winrt::make(RS_(L"UnknownErrorMessage"), ErrorTypes::Unknown, winrt::hstring{}); + } } else { diff --git a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp index b72ceba2b2a..a90276d8d38 100644 --- a/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp +++ b/src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp @@ -248,6 +248,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation // Make sure we are on the background thread for the http request auto strongThis = get_strong(); co_await winrt::resume_background(); + auto cancellationToken{ co_await winrt::get_cancellation_token() }; for (bool refreshAttempted = false;;) { @@ -276,19 +277,37 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation }; // Send the request - const auto jsonResult = co_await _SendRequestReturningJson(_endpointUri, requestContent, WWH::HttpMethod::Post()); - if (jsonResult.HasKey(errorKey)) + const auto sendRequestOperation = _SendRequestReturningJson(_endpointUri, requestContent, WWH::HttpMethod::Post()); + + // if the caller cancels this operation, make sure to cancel the http request as well + cancellationToken.callback([sendRequestOperation] { + sendRequestOperation.Cancel(); + }); + + if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed) { - const auto errorObject = jsonResult.GetNamedObject(errorKey); - message = errorObject.GetNamedString(messageKey); - errorType = ErrorTypes::FromProvider; + // Parse out the suggestion from the response + const auto jsonResult = sendRequestOperation.GetResults(); + if (jsonResult.HasKey(errorKey)) + { + const auto errorObject = jsonResult.GetNamedObject(errorKey); + message = errorObject.GetNamedString(messageKey); + errorType = ErrorTypes::FromProvider; + } + else + { + const auto choices = jsonResult.GetNamedArray(choicesKey); + const auto firstChoice = choices.GetAt(0).GetObject(); + const auto messageObject = firstChoice.GetNamedObject(messageKey); + message = messageObject.GetNamedString(contentKey); + } } else { - const auto choices = jsonResult.GetNamedArray(choicesKey); - const auto firstChoice = choices.GetAt(0).GetObject(); - const auto messageObject = firstChoice.GetNamedObject(messageKey); - message = messageObject.GetNamedString(contentKey); + // if the http request takes too long, cancel the http request and return an error + sendRequestOperation.Cancel(); + message = RS_(L"UnknownErrorMessage"); + errorType = ErrorTypes::Unknown; } break; } @@ -305,8 +324,23 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation break; } - co_await _refreshAuthTokens(); - refreshAttempted = true; + const auto refreshTokensAction = _refreshAuthTokens(); + cancellationToken.callback([refreshTokensAction] { + refreshTokensAction.Cancel(); + }); + // allow up to 10 seconds for reauthentication + if (refreshTokensAction.wait_for(std::chrono::seconds(10)) == AsyncStatus::Completed) + { + refreshAttempted = true; + } + else + { + // if the refresh action takes too long, cancel it and return an error + refreshTokensAction.Cancel(); + message = RS_(L"UnknownErrorMessage"); + errorType = ErrorTypes::Unknown; + break; + } } // Also make a new entry in our jsonMessages list, so the AI knows the full conversation so far @@ -334,7 +368,12 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation try { - const auto jsonResult = co_await _SendRequestReturningJson(accessTokenEndpoint, requestContent, WWH::HttpMethod::Post()); + const auto reAuthOperation = _SendRequestReturningJson(accessTokenEndpoint, requestContent, WWH::HttpMethod::Post()); + auto cancellationToken{ co_await winrt::get_cancellation_token() }; + cancellationToken.callback([reAuthOperation] { + reAuthOperation.Cancel(); + }); + const auto jsonResult{ co_await reAuthOperation }; _authToken = jsonResult.GetNamedString(accessTokenKey); _refreshToken = jsonResult.GetNamedString(refreshTokenKey); @@ -360,7 +399,12 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation WWH::HttpRequestMessage request{ method, Uri{ uri } }; request.Content(content); - const auto response{ co_await _httpClient.SendRequestAsync(request) }; + const auto sendRequestOperation = _httpClient.SendRequestAsync(request); + auto cancellationToken{ co_await winrt::get_cancellation_token() }; + cancellationToken.callback([sendRequestOperation] { + sendRequestOperation.Cancel(); + }); + const auto response{ co_await sendRequestOperation }; const auto string{ co_await response.Content().ReadAsStringAsync() }; _lastResponse = string; const auto jsonResult{ WDJ::JsonObject::Parse(string) }; diff --git a/src/cascadia/QueryExtension/OpenAILLMProvider.cpp b/src/cascadia/QueryExtension/OpenAILLMProvider.cpp index a8184f72593..84eb18e0e71 100644 --- a/src/cascadia/QueryExtension/OpenAILLMProvider.cpp +++ b/src/cascadia/QueryExtension/OpenAILLMProvider.cpp @@ -100,22 +100,40 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation // Send the request try { - const auto response = co_await _httpClient.SendRequestAsync(request); - // Parse out the suggestion from the response - const auto string{ co_await response.Content().ReadAsStringAsync() }; - const auto jsonResult{ WDJ::JsonObject::Parse(string) }; - if (jsonResult.HasKey(L"error")) + const auto sendRequestOperation = _httpClient.SendRequestAsync(request); + + // if the caller cancels this operation, make sure to cancel the http request as well + auto cancellationToken{ co_await winrt::get_cancellation_token() }; + cancellationToken.callback([sendRequestOperation] { + sendRequestOperation.Cancel(); + }); + + if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed) { - const auto errorObject = jsonResult.GetNamedObject(L"error"); - message = errorObject.GetNamedString(L"message"); - errorType = ErrorTypes::FromProvider; + // Parse out the suggestion from the response + const auto response = sendRequestOperation.GetResults(); + const auto string{ co_await response.Content().ReadAsStringAsync() }; + const auto jsonResult{ WDJ::JsonObject::Parse(string) }; + if (jsonResult.HasKey(L"error")) + { + const auto errorObject = jsonResult.GetNamedObject(L"error"); + message = errorObject.GetNamedString(L"message"); + errorType = ErrorTypes::FromProvider; + } + else + { + const auto choices = jsonResult.GetNamedArray(L"choices"); + const auto firstChoice = choices.GetAt(0).GetObject(); + const auto messageObject = firstChoice.GetNamedObject(L"message"); + message = messageObject.GetNamedString(L"content"); + } } else { - const auto choices = jsonResult.GetNamedArray(L"choices"); - const auto firstChoice = choices.GetAt(0).GetObject(); - const auto messageObject = firstChoice.GetNamedObject(L"message"); - message = messageObject.GetNamedString(L"content"); + // if the http request takes too long, cancel the http request and return an error + sendRequestOperation.Cancel(); + message = RS_(L"UnknownErrorMessage"); + errorType = ErrorTypes::Unknown; } } catch (...) diff --git a/src/cascadia/QueryExtension/Resources/en-US/Resources.resw b/src/cascadia/QueryExtension/Resources/en-US/Resources.resw index e93fcbcd159..323394e23eb 100644 --- a/src/cascadia/QueryExtension/Resources/en-US/Resources.resw +++ b/src/cascadia/QueryExtension/Resources/en-US/Resources.resw @@ -126,7 +126,7 @@ The message presented to the user when they attempt to use the AI chat feature without providing an AI endpoint and key. - An error occurred. Your AI provider might not be correctly configured, or the service might be temporarily unavailable. + An error occurred. The service might be temporarily unavailable or there might be network connectivity issues. The error message presented to the user when we were unable to query the provided endpoint. diff --git a/src/cascadia/QueryExtension/pch.h b/src/cascadia/QueryExtension/pch.h index c2745e48e79..ba7085ba859 100644 --- a/src/cascadia/QueryExtension/pch.h +++ b/src/cascadia/QueryExtension/pch.h @@ -53,6 +53,8 @@ TRACELOGGING_DECLARE_PROVIDER(g_hQueryExtensionProvider); #include +#include + // Manually include til after we include Windows.Foundation to give it winrt superpowers #include "til.h"