From 97a81a22eb34e49bb1b2fb4088d75dc49d7ba197 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 29 May 2024 10:24:06 -0400 Subject: [PATCH] Send `GenerateContentRequest` in `CountTokensRequest` (#175) --- Sources/GoogleAI/CountTokensRequest.swift | 4 +- Sources/GoogleAI/GenerateContentRequest.swift | 1 + Sources/GoogleAI/GenerativeModel.swift | 13 +- .../GenerateContentRequestTests.swift | 144 ++++++++++++++++++ 4 files changed, 158 insertions(+), 4 deletions(-) create mode 100644 Tests/GoogleAITests/GenerateContentRequestTests.swift diff --git a/Sources/GoogleAI/CountTokensRequest.swift b/Sources/GoogleAI/CountTokensRequest.swift index de852ae..d8bfc0e 100644 --- a/Sources/GoogleAI/CountTokensRequest.swift +++ b/Sources/GoogleAI/CountTokensRequest.swift @@ -17,7 +17,7 @@ import Foundation @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) struct CountTokensRequest { let model: String - let contents: [ModelContent] + let generateContentRequest: GenerateContentRequest let options: RequestOptions } @@ -42,7 +42,7 @@ public struct CountTokensResponse { @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) extension CountTokensRequest: Encodable { enum CodingKeys: CodingKey { - case contents + case generateContentRequest } } diff --git a/Sources/GoogleAI/GenerateContentRequest.swift b/Sources/GoogleAI/GenerateContentRequest.swift index 05abadf..c360583 100644 --- a/Sources/GoogleAI/GenerateContentRequest.swift +++ b/Sources/GoogleAI/GenerateContentRequest.swift @@ -31,6 +31,7 @@ struct GenerateContentRequest { @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) extension GenerateContentRequest: Encodable { enum CodingKeys: String, CodingKey { + case model case contents case generationConfig case safetySettings diff --git a/Sources/GoogleAI/GenerativeModel.swift b/Sources/GoogleAI/GenerativeModel.swift index ed1aecd..fc9c985 100644 --- a/Sources/GoogleAI/GenerativeModel.swift +++ b/Sources/GoogleAI/GenerativeModel.swift @@ -325,9 +325,18 @@ public final class GenerativeModel { public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws -> CountTokensResponse { do { - let countTokensRequest = try CountTokensRequest( + let generateContentRequest = try GenerateContentRequest(model: modelResourceName, + contents: content(), + generationConfig: generationConfig, + safetySettings: safetySettings, + tools: tools, + toolConfig: toolConfig, + systemInstruction: systemInstruction, + isStreaming: false, + options: requestOptions) + let countTokensRequest = CountTokensRequest( model: modelResourceName, - contents: content(), + generateContentRequest: generateContentRequest, options: requestOptions ) return try await generativeAIService.loadRequest(request: countTokensRequest) diff --git a/Tests/GoogleAITests/GenerateContentRequestTests.swift b/Tests/GoogleAITests/GenerateContentRequestTests.swift new file mode 100644 index 0000000..a808799 --- /dev/null +++ b/Tests/GoogleAITests/GenerateContentRequestTests.swift @@ -0,0 +1,144 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation +import XCTest + +@testable import GoogleGenerativeAI + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +final class GenerateContentRequestTests: XCTestCase { + let encoder = JSONEncoder() + let role = "test-role" + let prompt = "test-prompt" + let modelName = "test-model-name" + + override func setUp() { + encoder.outputFormatting = .init( + arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes + ) + } + + // MARK: GenerateContentRequest Encoding + + func testEncodeRequest_allFieldsIncluded() throws { + let content = [ModelContent(role: role, parts: prompt)] + let request = GenerateContentRequest( + model: modelName, + contents: content, + generationConfig: GenerationConfig(temperature: 0.5), + safetySettings: [SafetySetting( + harmCategory: .dangerousContent, + threshold: .blockLowAndAbove + )], + tools: [Tool(functionDeclarations: [FunctionDeclaration( + name: "test-function-name", + description: "test-function-description", + parameters: nil + )])], + toolConfig: ToolConfig(functionCallingConfig: FunctionCallingConfig(mode: .auto)), + systemInstruction: ModelContent(role: "system", parts: "test-system-instruction"), + isStreaming: false, + options: RequestOptions() + ) + + let jsonData = try encoder.encode(request) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "contents" : [ + { + "parts" : [ + { + "text" : "\(prompt)" + } + ], + "role" : "\(role)" + } + ], + "generationConfig" : { + "temperature" : 0.5 + }, + "model" : "\(modelName)", + "safetySettings" : [ + { + "category" : "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold" : "BLOCK_LOW_AND_ABOVE" + } + ], + "systemInstruction" : { + "parts" : [ + { + "text" : "test-system-instruction" + } + ], + "role" : "system" + }, + "toolConfig" : { + "functionCallingConfig" : { + "mode" : "AUTO" + } + }, + "tools" : [ + { + "functionDeclarations" : [ + { + "description" : "test-function-description", + "name" : "test-function-name", + "parameters" : { + "type" : "OBJECT" + } + } + ] + } + ] + } + """) + } + + func testEncodeRequest_optionalFieldsOmitted() throws { + let content = [ModelContent(role: role, parts: prompt)] + let request = GenerateContentRequest( + model: modelName, + contents: content, + generationConfig: nil, + safetySettings: nil, + tools: nil, + toolConfig: nil, + systemInstruction: nil, + isStreaming: false, + options: RequestOptions() + ) + + let jsonData = try encoder.encode(request) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "contents" : [ + { + "parts" : [ + { + "text" : "\(prompt)" + } + ], + "role" : "\(role)" + } + ], + "model" : "\(modelName)" + } + """) + } +}