Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOM] Handle Invalid Response Format #27345

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions components/ai_chat/core/browser/constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
#include "brave/components/ai_chat/core/browser/constants.h"

#include <array>
#include <functional>
#include <string>
#include <utility>

#include "base/containers/flat_tree.h"
#include "base/strings/string_util.h"
#include "components/grit/brave_components_strings.h"
#include "mojo/public/cpp/bindings/struct_ptr.h"
Expand Down Expand Up @@ -40,6 +37,8 @@ base::span<const webui::LocalizedString> GetLocalizedStrings() {
{"errorInvalidAPIKey", IDS_CHAT_UI_ERROR_INVALID_API_KEY},
{"errorOAIRateLimit", IDS_CHAT_UI_ERROR_OAI_RATE_LIMIT},
{"errorServiceOverloaded", IDS_CHAT_UI_ERROR_SERVICE_OVERLOADED},
{"errorInvalidResponseFormat",
IDS_CHAT_UI_ERROR_INVALID_RESPONSE_FORMAT},
{"retryButtonLabel", IDS_CHAT_UI_RETRY_BUTTON_LABEL},
{"introMessage-chat-basic", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_BASIC},
{"introMessage-chat-leo-expanded",
Expand Down
5 changes: 5 additions & 0 deletions components/ai_chat/core/browser/conversation_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1529,6 +1529,11 @@ void ConversationHandler::OnSuggestedQuestionsResponse(
mojom::SuggestionGenerationStatus::HasGenerated;
DVLOG(2) << "Got questions:" << base::JoinString(result.value(), "\n");
} else {
// handle failure
if (result.error() != mojom::APIError::None) {
DVLOG(2) << __func__ << ": With error";
SetAPIError(std::move(result.error()));
}
// TODO(nullhook): Set a specialized error state generated questions
suggestion_generation_status_ =
mojom::SuggestionGenerationStatus::CanGenerate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ void EngineConsumerOAIRemote::OnGenerateQuestionSuggestionsResponse(
GenerationResult result) {
if (!result.has_value() || result->empty()) {
// Query resulted in error
DVLOG(2) << "Failed to generate question suggestions.";
std::move(callback).Run(base::unexpected(std::move(result.error())));
return;
}
Expand Down
46 changes: 22 additions & 24 deletions components/ai_chat/core/browser/engine/oai_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@

#include "brave/components/ai_chat/core/browser/engine/oai_api_client.h"

#include <ios>
#include <optional>
#include <ostream>
#include <string>
#include <string_view>
#include <type_traits>

#include "base/containers/flat_map.h"
#include "base/functional/bind.h"
Expand All @@ -19,7 +15,6 @@
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
#include "base/metrics/field_trial_params.h"
#include "base/strings/strcat.h"
#include "base/types/expected.h"
#include "brave/components/ai_chat/core/common/features.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
Expand Down Expand Up @@ -134,27 +129,30 @@ void OAIAPIClient::OnQueryCompleted(GenerationCompletedCallback callback,
const bool success = result.Is2XXResponseCode();
// Handle successful request
if (success) {
std::string completion = "";
// We're checking for a value body in case for non-streaming API results.
if (result.value_body().is_dict()) {
const base::Value::List* choices =
result.value_body().GetDict().FindList("choices");
if (!choices) {
DVLOG(2) << "No choices list found in response.";
return;
}
if (choices->front().is_dict()) {
const base::Value::Dict* message =
choices->front().GetDict().FindDict("message");
if (!message) {
DVLOG(2) << "No message dict found in response.";
return;
}
completion = *message->FindString("content");
}
// We expect to find a list of completion choices in the response.
const base::Value::List* choices =
result.value_body().is_dict()
? result.value_body().GetDict().FindList("choices")
: nullptr;

// Retrieve the first completion choice's content, if any.
const auto* message =
(choices && !choices->empty() && choices->front().is_dict())
? choices->front().GetDict().FindStringByDottedPath(
"message.content")
: nullptr;

// Return the completion message.
if (message) {
std::move(callback).Run(base::ok(std::move(*message)));
return;
}

std::move(callback).Run(base::ok(std::move(completion)));
// If no completion was found, log and return an error.
// This situation may occur when the response format is unexpected.
DVLOG(2) << "No completion was found in response.";
std::move(callback).Run(
base::unexpected(mojom::APIError::InvalidResponseFormat));
return;
}

Expand Down
64 changes: 62 additions & 2 deletions components/ai_chat/core/browser/engine/oai_api_client_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@

#include "brave/components/ai_chat/core/browser/engine/oai_api_client.h"

#include <list>
#include <optional>
#include <string>
#include <string_view>
#include <type_traits>
#include <vector>

#include "base/containers/flat_map.h"
Expand All @@ -23,6 +21,7 @@
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
#include "brave/components/api_request_helper/api_request_helper.h"
#include "gmock/gmock.h"
#include "mojo/public/cpp/bindings/struct_ptr.h"
#include "net/base/net_errors.h"
#include "net/http/http_request_headers.h"
Expand Down Expand Up @@ -195,4 +194,65 @@ TEST_F(OAIAPIUnitTest, PerformRequest) {
testing::Mock::VerifyAndClearExpectations(mock_request_helper);
}

TEST_F(OAIAPIUnitTest, HandleInavlidResponseFormat) {
mojom::CustomModelOptionsPtr model_options = mojom::CustomModelOptions::New(
"test_api_key", 0, 0, 0, "test_system_prompt", GURL("https://test.com"),
"test_model");

// Simulate an invalid response format from the server
std::string invalid_server_response =
R"({"choices":[{"message":{"not_content":"value"}}]})";

MockAPIRequestHelper* mock_request_helper =
client_->GetMockAPIRequestHelper();
testing::StrictMock<MockCallbacks> mock_callbacks;
base::RunLoop run_loop;

// Intercept API Request Helper call and verify the request is as expected
EXPECT_CALL(*mock_request_helper, RequestSSE(_, _, _, _, _, _, _, _))
.WillOnce([&](const std::string& method, const GURL& url,
const std::string& body, const std::string& content_type,
DataReceivedCallback data_received_callback,
ResultCallback result_callback,
const base::flat_map<std::string, std::string>& headers,
const api_request_helper::APIRequestOptions& options) {
auto invalid_response = base::JSONReader::Read(invalid_server_response);
EXPECT_TRUE(invalid_response.has_value());

std::move(result_callback)
.Run(api_request_helper::APIRequestResult(
200, std::move(invalid_response.value()), {}, net::OK, GURL()));

run_loop.Quit();
return Ticket();
});

EXPECT_CALL(mock_callbacks, OnCompleted(_))
.WillOnce([&](GenerationResult result) {
EXPECT_FALSE(result.has_value());
EXPECT_EQ(result.error(), mojom::APIError::InvalidResponseFormat);
});

// Begin request
base::Value::List messages;

{
base::Value::Dict message;
message.Set("role", "user");
message.Set("content", "Hello, World.");
messages.Append(std::move(message));
}

client_->PerformRequest(
*model_options, std::move(messages),
base::BindRepeating(&MockCallbacks::OnDataReceived,
base::Unretained(&mock_callbacks)),
base::BindOnce(&MockCallbacks::OnCompleted,
base::Unretained(&mock_callbacks)));

run_loop.Run();
testing::Mock::VerifyAndClearExpectations(client_.get());
testing::Mock::VerifyAndClearExpectations(mock_request_helper);
}

} // namespace ai_chat
3 changes: 2 additions & 1 deletion components/ai_chat/core/common/mojom/ai_chat.mojom
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ enum APIError {
ContextLimitReached,
InvalidEndpointURL,
InvalidAPIKey,
ServiceOverloaded
ServiceOverloaded,
InvalidResponseFormat,
};

enum ModelEngineType {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/* Copyright (c) 2025 The Brave Authors. All rights reserved.
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
* You can obtain one at https://mozilla.org/MPL/2.0/. */

import * as React from 'react'
import Alert from '@brave/leo/react/alert'
import { getLocale } from '$web-common/locale'
import styles from './alerts.module.scss'

export default function ErrorInvalidResponseFormat() {
return (
<div className={styles.alert}>
<Alert type='error'>
{getLocale('errorInvalidResponseFormat')}
</Alert>
</div>
)
}
4 changes: 4 additions & 0 deletions components/ai_chat/resources/page/components/main/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import ErrorInvalidAPIKey from '../alerts/error_invalid_api_key'
import ErrorInvalidEndpointURL from '../alerts/error_invalid_endpoint_url'
import ErrorRateLimit from '../alerts/error_rate_limit'
import ErrorServiceOverloaded from '../alerts/error_service_overloaded'
import ErrorInvalidResponseFormat from '../alerts/error_invalid_response_format'
import LongConversationInfo from '../alerts/long_conversation_info'
import NoticeConversationStorage from '../notices/notice_conversation_storage'
import WarningPremiumDisconnected from '../alerts/warning_premium_disconnected'
Expand Down Expand Up @@ -104,6 +105,9 @@ function Main() {
currentErrorElement = <ErrorServiceOverloaded
onRetry={conversationContext.retryAPIRequest} />
break
case Mojom.APIError.InvalidResponseFormat:
currentErrorElement = <ErrorInvalidResponseFormat />
break
case Mojom.APIError.RateLimitReached:
currentErrorElement = <ErrorRateLimit />
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import ErrorInvalidAPIKey from '../components/alerts/error_invalid_api_key'
import ErrorInvalidEndpointURL from '../components/alerts/error_invalid_endpoint_url'
import ErrorRateLimit from '../components/alerts/error_rate_limit'
import ErrorServiceOverloaded from '../components/alerts/error_service_overloaded'
import ErrorInvalidResponseFormat from '../components/alerts/error_invalid_response_format'
import LongConversationInfo from '../components/alerts/long_conversation_info'
import WarningPremiumDisconnected from '../components/alerts/warning_premium_disconnected'

Expand Down Expand Up @@ -665,6 +666,7 @@ export const _Alerts = {
<ErrorRateLimit />
<ErrorRateLimit _testIsCurrentModelLeo={false} />
<ErrorServiceOverloaded />
<ErrorInvalidResponseFormat />
<LongConversationInfo />
<WarningPremiumDisconnected />
</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ provideStrings({
errorInvalidAPIKey: 'The API key configured for this model is invalid. Please check your configuration and try again.',
errorRateLimit: 'You\'ve reached the premium rate limit. Please try again in a few hours.',
errorServiceOverloaded: 'The endpoint is currently overloaded. Please try again later.',
errorInvalidResponseFormat: 'The response from the model is in an invalid format. Please check your configuration and try again.',
braveLeoAssistantEndpointInvalidError: 'The endpoint URL is invalid. Please check the URL and try again.',
braveLeoAssistantEndpointValidAsPrivateIp: 'If you would like to use a private IP address, you must first enable "Private IP Addresses for Custom Model Enpoints" via brave://flags/#brave-ai-chat-allow-private-ips',
retryButtonLabel: 'Retry',
Expand Down
3 changes: 3 additions & 0 deletions components/resources/ai_chat_ui_strings.grdp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
<message name="IDS_CHAT_UI_ERROR_SERVICE_OVERLOADED" desc="An error presented when the API service is overloaded">
The endpoint is currently overloaded. Please try again later.
</message>
<message name="IDS_CHAT_UI_ERROR_INVALID_RESPONSE_FORMAT" desc="An error presented when the API response is in an unexpected format">
The response from the model is in an invalid format. Please check your configuration and try again.
</message>
<message name="IDS_CHAT_UI_RETRY_BUTTON_LABEL" desc="A button label to retry API again">
Retry
</message>
Expand Down
Loading