diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index dd3b90c10b..34670ceabc 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -16,9 +16,9 @@ import { FunctionCallingMode, + GenerateContentCandidate, GoogleGenerativeAI, SchemaType, - type EnhancedGenerateContentResponse, type FileDataPart, type FunctionCallPart, type FunctionDeclaration, @@ -1045,14 +1045,16 @@ export function defineGoogleAIModel({ } const callGemini = async () => { - let response: EnhancedGenerateContentResponse; + let response: GenerateContentResponse; if (sendChunk) { const result = await genModel .startChat(updatedChatRequest) .sendMessageStream(msg.parts, options); + const chunks = [] as GenerateContentResponse[]; for await (const item of result.stream) { + chunks.push(item as GenerateContentResponse); (item as GenerateContentResponse).candidates?.forEach( (candidate) => { const c = fromJSONModeScopedGeminiCandidate(candidate); @@ -1064,7 +1066,7 @@ export function defineGoogleAIModel({ ); } - response = await result.response; + response = aggregateResponses(chunks); } else { const result = await genModel .startChat(updatedChatRequest) @@ -1176,3 +1178,83 @@ function toGeminiFunctionModeEnum( throw new Error(`unsupported function calling mode: ${genkitMode}`); } } + +/** + * Aggregates an array of `GenerateContentResponse`s into a single GenerateContentResponse. + * + * This code is copy-pasted from https://github.com/google-gemini/deprecated-generative-ai-js/blob/8b14949a5e8f1f3dfc35c394ebf5b19e68f92a22/src/requests/stream-reader.ts#L153 + * with a small (but critical) bug fix. + */ +export function aggregateResponses( + responses: GenerateContentResponse[] +): GenerateContentResponse { + const lastResponse = responses[responses.length - 1]; + const aggregatedResponse: GenerateContentResponse = { + promptFeedback: lastResponse?.promptFeedback, + }; + for (const response of responses) { + if (response.candidates) { + let candidateIndex = 0; + for (const candidate of response.candidates) { + if (!aggregatedResponse.candidates) { + aggregatedResponse.candidates = []; + } + if (!aggregatedResponse.candidates[candidateIndex]) { + aggregatedResponse.candidates[candidateIndex] = { + index: candidateIndex, + } as GenerateContentCandidate; + } + // Keep overwriting, the last one will be final + aggregatedResponse.candidates[candidateIndex].citationMetadata = + candidate.citationMetadata; + aggregatedResponse.candidates[candidateIndex].groundingMetadata = + candidate.groundingMetadata; + aggregatedResponse.candidates[candidateIndex].finishReason = + candidate.finishReason; + aggregatedResponse.candidates[candidateIndex].finishMessage = + candidate.finishMessage; + aggregatedResponse.candidates[candidateIndex].safetyRatings = + candidate.safetyRatings; + + /** + * Candidates should always have content and parts, but this handles + * possible malformed responses. + */ + if (candidate.content && candidate.content.parts) { + if (!aggregatedResponse.candidates[candidateIndex].content) { + aggregatedResponse.candidates[candidateIndex].content = { + role: candidate.content.role || 'user', + parts: [], + }; + } + for (const part of candidate.content.parts) { + const newPart: Partial = {}; + if (part.text) { + newPart.text = part.text; + } + if (part.functionCall) { + newPart.functionCall = part.functionCall; + } + if (part.executableCode) { + newPart.executableCode = part.executableCode; + } + if (part.codeExecutionResult) { + newPart.codeExecutionResult = part.codeExecutionResult; + } + if (Object.keys(newPart).length === 0) { + newPart.text = ''; + } + aggregatedResponse.candidates[candidateIndex].content.parts.push( + newPart as GeminiPart + ); + } + } + } + candidateIndex++; + } + if (response.usageMetadata) { + aggregatedResponse.usageMetadata = response.usageMetadata; + } + } + return aggregatedResponse; +}