diff --git a/js/ai/src/model-types.ts b/js/ai/src/model-types.ts index 511359ade1..cf20755ac8 100644 --- a/js/ai/src/model-types.ts +++ b/js/ai/src/model-types.ts @@ -201,6 +201,12 @@ export const GenerationCommonConfigSchema = z 'Set of character sequences (up to 5) that will stop output generation.' ) .optional(), + apiKey: z + .string() + .describe( + 'API Key to use for the model call, overrides API key provided in plugin config.' + ) + .optional(), }) .passthrough(); diff --git a/js/genkit/src/embedder.ts b/js/genkit/src/embedder.ts index ba4e5d3991..d05fc77b34 100644 --- a/js/genkit/src/embedder.ts +++ b/js/genkit/src/embedder.ts @@ -17,6 +17,7 @@ export { EmbedderInfoSchema, embedderRef, + type EmbedRequest, type EmbedderAction, type EmbedderArgument, type EmbedderInfo, diff --git a/js/plugins/compat-oai/src/audio.ts b/js/plugins/compat-oai/src/audio.ts index 873ed9405c..7657509cc6 100644 --- a/js/plugins/compat-oai/src/audio.ts +++ b/js/plugins/compat-oai/src/audio.ts @@ -22,13 +22,15 @@ import type { import { GenerationCommonConfigSchema, Message, modelRef, z } from 'genkit'; import type { ModelAction, ModelInfo } from 'genkit/model'; import { model } from 'genkit/plugin'; -import type OpenAI from 'openai'; +import OpenAI from 'openai'; import { Response } from 'openai/core.mjs'; import type { SpeechCreateParams, Transcription, TranscriptionCreateParams, } from 'openai/resources/audio/index.mjs'; +import { PluginOptions } from './index.js'; +import { maybeCreateRequestScopedOpenAIClient } from './utils.js'; export type SpeechRequestBuilder = ( req: GenerateRequest, @@ -185,10 +187,16 @@ export function defineCompatOpenAISpeechModel< client: OpenAI; modelRef?: ModelReference; requestBuilder?: SpeechRequestBuilder; + pluginOptions: PluginOptions; }): ModelAction { - const { name, client, modelRef, requestBuilder } = params; + const { + name, + client: defaultClient, + pluginOptions, + modelRef, + requestBuilder, + } = params; const modelName = name.substring(name.indexOf('/') + 1); - return model( { name, @@ -197,6 +205,11 @@ export function defineCompatOpenAISpeechModel< }, async (request, { abortSignal }) => { const ttsRequest = toTTSRequest(modelName!, request, requestBuilder); + const client = maybeCreateRequestScopedOpenAIClient( + pluginOptions, + request, + defaultClient + ); const result = await client.audio.speech.create(ttsRequest, { signal: abortSignal, }); @@ -338,11 +351,17 @@ export function defineCompatOpenAITranscriptionModel< >(params: { name: string; client: OpenAI; + pluginOptions?: PluginOptions; modelRef?: ModelReference; requestBuilder?: TranscriptionRequestBuilder; }): ModelAction { - const { name, client, modelRef, requestBuilder } = params; - + const { + name, + pluginOptions, + client: defaultClient, + modelRef, + requestBuilder, + } = params; return model( { name, @@ -353,6 +372,11 @@ export function defineCompatOpenAITranscriptionModel< const modelName = name.substring(name.indexOf('/') + 1); const params = toSttRequest(modelName!, request, requestBuilder); + const client = maybeCreateRequestScopedOpenAIClient( + pluginOptions, + request, + defaultClient + ); // Explicitly setting stream to false ensures we use the non-streaming overload const result = await client.audio.transcriptions.create( { diff --git a/js/plugins/compat-oai/src/deepseek/index.ts b/js/plugins/compat-oai/src/deepseek/index.ts index a76680b3be..522ab77c04 100644 --- a/js/plugins/compat-oai/src/deepseek/index.ts +++ b/js/plugins/compat-oai/src/deepseek/index.ts @@ -36,26 +36,25 @@ import { export type DeepSeekPluginOptions = Omit; -const resolver = async ( - client: OpenAI, - actionType: ActionType, - actionName: string -) => { - if (actionType === 'model') { - const modelRef = deepSeekModelRef({ - name: actionName, - }); - return defineCompatOpenAIModel({ - name: modelRef.name, - client, - modelRef, - requestBuilder: deepSeekRequestBuilder, - }); - } else { - logger.warn('Only model actions are supported by the DeepSeek plugin'); - return undefined; - } -}; +function createResolver(pluginOptions: PluginOptions) { + return async (client: OpenAI, actionType: ActionType, actionName: string) => { + if (actionType === 'model') { + const modelRef = deepSeekModelRef({ + name: actionName, + }); + return defineCompatOpenAIModel({ + name: modelRef.name, + client, + pluginOptions, + modelRef, + requestBuilder: deepSeekRequestBuilder, + }); + } else { + logger.warn('Only model actions are supported by the DeepSeek plugin'); + return undefined; + } + }; +} const listActions = async (client: OpenAI): Promise => { return await client.models.list().then((response) => @@ -87,6 +86,7 @@ export function deepSeekPlugin( 'Please pass in the API key or set the DEEPSEEK_API_KEY environment variable.', }); } + const pluginOptions = { name: 'deepseek', ...options }; return openAICompatible({ name: 'deepseek', baseURL: 'https://api.deepseek.com', @@ -97,12 +97,13 @@ export function deepSeekPlugin( defineCompatOpenAIModel({ name: modelRef.name, client, + pluginOptions, modelRef, requestBuilder: deepSeekRequestBuilder, }) ); }, - resolver, + resolver: createResolver(pluginOptions), listActions, }); } diff --git a/js/plugins/compat-oai/src/embedder.ts b/js/plugins/compat-oai/src/embedder.ts index f4499eccad..e19eb980b5 100644 --- a/js/plugins/compat-oai/src/embedder.ts +++ b/js/plugins/compat-oai/src/embedder.ts @@ -20,6 +20,8 @@ import type { EmbedderAction, EmbedderReference } from 'genkit'; import { embedder } from 'genkit/plugin'; import OpenAI from 'openai'; +import { PluginOptions } from './index.js'; +import { maybeCreateRequestScopedOpenAIClient } from './utils.js'; /** * Method to define a new Genkit Embedder that is compatibale with the Open AI @@ -37,11 +39,11 @@ import OpenAI from 'openai'; export function defineCompatOpenAIEmbedder(params: { name: string; client: OpenAI; + pluginOptions?: PluginOptions; embedderRef?: EmbedderReference; }): EmbedderAction { - const { name, client, embedderRef } = params; + const { name, client: defaultClient, pluginOptions, embedderRef } = params; const modelName = name.substring(name.indexOf('/') + 1); - return embedder( { name, @@ -50,6 +52,11 @@ export function defineCompatOpenAIEmbedder(params: { }, async (req) => { const { encodingFormat: encoding_format, ...restOfConfig } = req.options; + const client = maybeCreateRequestScopedOpenAIClient( + pluginOptions, + req, + defaultClient + ); const embeddings = await client.embeddings.create({ model: modelName!, input: req.input.map((d) => d.text), diff --git a/js/plugins/compat-oai/src/image.ts b/js/plugins/compat-oai/src/image.ts index d90d2c5320..1915041dcd 100644 --- a/js/plugins/compat-oai/src/image.ts +++ b/js/plugins/compat-oai/src/image.ts @@ -27,6 +27,8 @@ import type { ImageGenerateParams, ImagesResponse, } from 'openai/resources/images.mjs'; +import { PluginOptions } from './index.js'; +import { maybeCreateRequestScopedOpenAIClient } from './utils.js'; export type ImageRequestBuilder = ( req: GenerateRequest, @@ -122,10 +124,17 @@ export function defineCompatOpenAIImageModel< >(params: { name: string; client: OpenAI; + pluginOptions?: PluginOptions; modelRef?: ModelReference; requestBuilder?: ImageRequestBuilder; }): ModelAction { - const { name, client, modelRef, requestBuilder } = params; + const { + name, + client: defaultClient, + pluginOptions, + modelRef, + requestBuilder, + } = params; const modelName = name.substring(name.indexOf('/') + 1); return model( @@ -135,6 +144,11 @@ export function defineCompatOpenAIImageModel< configSchema: modelRef?.configSchema, }, async (request, { abortSignal }) => { + const client = maybeCreateRequestScopedOpenAIClient( + pluginOptions, + request, + defaultClient + ); const result = await client.images.generate( toImageGenerateParams(modelName!, request, requestBuilder), { signal: abortSignal } diff --git a/js/plugins/compat-oai/src/index.ts b/js/plugins/compat-oai/src/index.ts index 1c5a5d5ec1..82eea786b5 100644 --- a/js/plugins/compat-oai/src/index.ts +++ b/js/plugins/compat-oai/src/index.ts @@ -17,7 +17,7 @@ import { ActionMetadata } from 'genkit'; import { ResolvableAction, genkitPluginV2 } from 'genkit/plugin'; import { ActionType } from 'genkit/registry'; -import { OpenAI, type ClientOptions } from 'openai'; +import OpenAI, { type ClientOptions } from 'openai'; import { compatOaiModelRef, defineCompatOpenAIModel } from './model.js'; export { @@ -45,7 +45,8 @@ export { type ModelRequestBuilder, } from './model.js'; -export interface PluginOptions extends Partial { +export interface PluginOptions extends Partial> { + apiKey?: ClientOptions['apiKey'] | false; name: string; initializer?: (client: OpenAI) => Promise; resolver?: ( @@ -110,24 +111,33 @@ export interface PluginOptions extends Partial { */ export const openAICompatible = (options: PluginOptions) => { let listActionsCache; + var client: OpenAI; + function createClient() { + if (client) return client; + const { apiKey, ...restofOptions } = options; + client = new OpenAI({ + ...restofOptions, + apiKey: apiKey === false ? 'placeholder' : apiKey, + }); + return client; + } return genkitPluginV2({ name: options.name, async init() { if (!options.initializer) { return []; } - const client = new OpenAI(options); - return await options.initializer(client); + return await options.initializer(createClient()); }, async resolve(actionType: ActionType, actionName: string) { - const client = new OpenAI(options); if (options.resolver) { - return await options.resolver(client, actionType, actionName); + return await options.resolver(createClient(), actionType, actionName); } else { if (actionType === 'model') { return defineCompatOpenAIModel({ name: actionName, - client, + client: createClient(), + pluginOptions: options, modelRef: compatOaiModelRef({ name: actionName, }), @@ -136,14 +146,15 @@ export const openAICompatible = (options: PluginOptions) => { return undefined; } }, - list: options.listActions - ? async () => { - if (listActionsCache) return listActionsCache; - const client = new OpenAI(options); - listActionsCache = await options.listActions!(client); - return listActionsCache; - } - : undefined, + list: + // Don't attempt to list models if apiKey set to false + options.listActions && options.apiKey !== false + ? async () => { + if (listActionsCache) return listActionsCache; + listActionsCache = await options.listActions!(createClient()); + return listActionsCache; + } + : undefined, }); }; diff --git a/js/plugins/compat-oai/src/model.ts b/js/plugins/compat-oai/src/model.ts index 5d18a382e5..d2ce7cce99 100644 --- a/js/plugins/compat-oai/src/model.ts +++ b/js/plugins/compat-oai/src/model.ts @@ -49,6 +49,8 @@ import type { ChatCompletionTool, CompletionChoice, } from 'openai/resources/index.mjs'; +import { PluginOptions } from './index.js'; +import { maybeCreateRequestScopedOpenAIClient } from './utils.js'; const VisualDetailLevelSchema = z.enum(['auto', 'low', 'high']).optional(); @@ -479,6 +481,7 @@ export function toOpenAIRequestBody( stopSequences: stop, version: modelVersion, tools: toolsFromConfig, + apiKey, ...restOfConfig } = request.config ?? {}; @@ -541,8 +544,9 @@ export function toOpenAIRequestBody( */ export function openAIModelRunner( name: string, - client: OpenAI, - requestBuilder?: ModelRequestBuilder + defaultClient: OpenAI, + requestBuilder?: ModelRequestBuilder, + pluginOptions?: Omit ) { return async ( request: GenerateRequest, @@ -552,6 +556,11 @@ export function openAIModelRunner( abortSignal?: AbortSignal; } ): Promise => { + const client = maybeCreateRequestScopedOpenAIClient( + pluginOptions, + request, + defaultClient + ); try { let response: ChatCompletion; const body = toOpenAIRequestBody(name, request, requestBuilder); @@ -605,6 +614,12 @@ export function openAIModelRunner( case 429: status = 'RESOURCE_EXHAUSTED'; break; + case 401: + status = 'PERMISSION_DENIED'; + break; + case 403: + status = 'UNAUTHENTICATED'; + break; case 400: status = 'INVALID_ARGUMENT'; break; @@ -648,8 +663,9 @@ export function defineCompatOpenAIModel< client: OpenAI; modelRef?: ModelReference; requestBuilder?: ModelRequestBuilder; + pluginOptions?: PluginOptions; }): ModelAction { - const { name, client, modelRef, requestBuilder } = params; + const { name, client, pluginOptions, modelRef, requestBuilder } = params; const modelName = name.substring(name.indexOf('/') + 1); return model( @@ -658,7 +674,7 @@ export function defineCompatOpenAIModel< ...modelRef?.info, configSchema: modelRef?.configSchema, }, - openAIModelRunner(modelName!, client, requestBuilder) + openAIModelRunner(modelName!, client, requestBuilder, pluginOptions) ); } diff --git a/js/plugins/compat-oai/src/openai/index.ts b/js/plugins/compat-oai/src/openai/index.ts index 753cd02405..ae7593a65b 100644 --- a/js/plugins/compat-oai/src/openai/index.ts +++ b/js/plugins/compat-oai/src/openai/index.ts @@ -64,51 +64,57 @@ export type OpenAIPluginOptions = Omit; const UNSUPPORTED_MODEL_MATCHERS = ['babbage', 'davinci', 'codex']; -const resolver = async ( - client: OpenAI, - actionType: ActionType, - actionName: string -) => { - if (actionType === 'embedder') { - return defineCompatOpenAIEmbedder({ name: actionName, client }); - } else if ( - actionName.includes('gpt-image-1') || - actionName.includes('dall-e') - ) { - const modelRef = openAIImageModelRef({ name: actionName }); - return defineCompatOpenAIImageModel({ - name: modelRef.name, - client, - modelRef, - }); - } else if (actionName.includes('tts')) { - const modelRef = openAISpeechModelRef({ name: actionName }); - return defineCompatOpenAISpeechModel({ - name: modelRef.name, - client, - modelRef, - }); - } else if ( - actionName.includes('whisper') || - actionName.includes('transcribe') - ) { - const modelRef = openAITranscriptionModelRef({ - name: actionName, - }); - return defineCompatOpenAITranscriptionModel({ - name: modelRef.name, - client, - modelRef, - }); - } else { - const modelRef = openAIModelRef({ name: actionName }); - return defineCompatOpenAIModel({ - name: modelRef.name, - client, - modelRef, - }); - } -}; +function createResolver(pluginOptions: PluginOptions) { + return async (client: OpenAI, actionType: ActionType, actionName: string) => { + if (actionType === 'embedder') { + return defineCompatOpenAIEmbedder({ + name: actionName, + client, + pluginOptions, + }); + } else if ( + actionName.includes('gpt-image-1') || + actionName.includes('dall-e') + ) { + const modelRef = openAIImageModelRef({ name: actionName }); + return defineCompatOpenAIImageModel({ + name: modelRef.name, + client, + pluginOptions, + modelRef, + }); + } else if (actionName.includes('tts')) { + const modelRef = openAISpeechModelRef({ name: actionName }); + return defineCompatOpenAISpeechModel({ + name: modelRef.name, + client, + pluginOptions, + modelRef, + }); + } else if ( + actionName.includes('whisper') || + actionName.includes('transcribe') + ) { + const modelRef = openAITranscriptionModelRef({ + name: actionName, + }); + return defineCompatOpenAITranscriptionModel({ + name: modelRef.name, + client, + pluginOptions, + modelRef, + }); + } else { + const modelRef = openAIModelRef({ name: actionName }); + return defineCompatOpenAIModel({ + name: modelRef.name, + client, + pluginOptions, + modelRef, + }); + } + }; +} function filterOpenAiModels(model: OpenAI.Model): boolean { return !UNSUPPORTED_MODEL_MATCHERS.some((m) => model.id.includes(m)); @@ -170,6 +176,7 @@ const listActions = async (client: OpenAI): Promise => { }; export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPluginV2 { + const pluginOptions = { name: 'openai', ...options }; return openAICompatible({ name: 'openai', ...options, @@ -177,7 +184,12 @@ export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPluginV2 { const models = [] as ResolvableAction[]; models.push( ...Object.values(SUPPORTED_GPT_MODELS).map((modelRef) => - defineCompatOpenAIModel({ name: modelRef.name, client, modelRef }) + defineCompatOpenAIModel({ + name: modelRef.name, + client, + pluginOptions, + modelRef, + }) ) ); models.push( @@ -185,6 +197,7 @@ export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPluginV2 { defineCompatOpenAIEmbedder({ name: embedderRef.name, client, + pluginOptions, embedderRef, }) ) @@ -194,6 +207,7 @@ export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPluginV2 { defineCompatOpenAISpeechModel({ name: modelRef.name, client, + pluginOptions, modelRef, }) ) @@ -203,6 +217,7 @@ export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPluginV2 { defineCompatOpenAITranscriptionModel({ name: modelRef.name, client, + pluginOptions, modelRef, }) ) @@ -212,6 +227,7 @@ export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPluginV2 { defineCompatOpenAIImageModel({ name: modelRef.name, client, + pluginOptions, modelRef, requestBuilder: modelRef.name.includes('gpt-image-1') ? gptImage1RequestBuilder @@ -221,7 +237,7 @@ export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPluginV2 { ); return models; }, - resolver, + resolver: createResolver(pluginOptions), listActions, }); } diff --git a/js/plugins/compat-oai/src/utils.ts b/js/plugins/compat-oai/src/utils.ts new file mode 100644 index 0000000000..31c576fcba --- /dev/null +++ b/js/plugins/compat-oai/src/utils.ts @@ -0,0 +1,42 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { GenerateRequest } from 'genkit'; +import { EmbedRequest } from 'genkit/embedder'; +import OpenAI from 'openai'; +import { PluginOptions } from '.'; + +/** + * Inspects the request and if apiKey is provided in config, creates a new client. + * Otherwise falls back on the `defaultClient`. + */ +export function maybeCreateRequestScopedOpenAIClient( + pluginOptions: PluginOptions | undefined, + request: GenerateRequest | EmbedRequest, + defaultClient: OpenAI +): OpenAI { + const requestApiKey = + (request as GenerateRequest)?.config?.apiKey ?? + (request as EmbedRequest)?.options?.apiKey; + if (!requestApiKey) { + return defaultClient; + } + return new OpenAI({ + // if pluginOptions are not passed in we attempt to get options from the default client. + ...(pluginOptions ?? defaultClient['_options']), + apiKey: requestApiKey, + }); +} diff --git a/js/plugins/compat-oai/src/xai/index.ts b/js/plugins/compat-oai/src/xai/index.ts index 97655577ef..8a979d82ef 100644 --- a/js/plugins/compat-oai/src/xai/index.ts +++ b/js/plugins/compat-oai/src/xai/index.ts @@ -41,24 +41,23 @@ import { export type XAIPluginOptions = Omit; -const resolver = async ( - client: OpenAI, - actionType: ActionType, - actionName: string -) => { - if (actionType === 'model') { - const modelRef = xaiModelRef({ name: actionName }); - return defineCompatOpenAIModel({ - name: modelRef.name, - client, - modelRef, - requestBuilder: grokRequestBuilder, - }); - } else { - logger.warn('Only model actions are supported by the XAI plugin'); - } - return undefined; -}; +function createResolver(pluginOptions: PluginOptions) { + return async (client: OpenAI, actionType: ActionType, actionName: string) => { + if (actionType === 'model') { + const modelRef = xaiModelRef({ name: actionName }); + return defineCompatOpenAIModel({ + name: modelRef.name, + client, + pluginOptions, + modelRef, + requestBuilder: grokRequestBuilder, + }); + } else { + logger.warn('Only model actions are supported by the XAI plugin'); + } + return undefined; + }; +} const listActions = async (client: OpenAI): Promise => { return await client.models.list().then((response) => @@ -97,6 +96,7 @@ export function xAIPlugin(options?: XAIPluginOptions): GenkitPluginV2 { 'Please pass in the API key or set the XAI_API_KEY environment variable.', }); } + const pluginOptions = { name: 'xai', ...options }; return openAICompatible({ name: 'xai', baseURL: 'https://api.x.ai/v1', @@ -109,6 +109,7 @@ export function xAIPlugin(options?: XAIPluginOptions): GenkitPluginV2 { defineCompatOpenAIModel({ name: modelRef.name, client, + pluginOptions, modelRef, requestBuilder: grokRequestBuilder, }) @@ -119,13 +120,14 @@ export function xAIPlugin(options?: XAIPluginOptions): GenkitPluginV2 { defineCompatOpenAIImageModel({ name: modelRef.name, client, + pluginOptions, modelRef, }) ) ); return models; }, - resolver, + resolver: createResolver(pluginOptions), listActions, }); } diff --git a/js/plugins/compat-oai/tests/compat_oai_test.ts b/js/plugins/compat-oai/tests/compat_oai_test.ts index ed2c6ee76b..93844e3fba 100644 --- a/js/plugins/compat-oai/tests/compat_oai_test.ts +++ b/js/plugins/compat-oai/tests/compat_oai_test.ts @@ -15,7 +15,14 @@ * limitations under the License. */ -import { describe, expect, it, jest } from '@jest/globals'; +import { + afterEach, + beforeEach, + describe, + expect, + it, + jest, +} from '@jest/globals'; import type { GenerateRequest, GenerateResponseData, @@ -23,26 +30,25 @@ import type { Part, Role, } from 'genkit'; -import type OpenAI from 'openai'; +import OpenAI, { APIError } from 'openai'; import type { ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall, ChatCompletionRole, } from 'openai/resources/index.mjs'; - -import { APIError } from 'openai'; import { + ModelRequestBuilder, fromOpenAIChoice, fromOpenAIChunkChoice, fromOpenAIToolCall, - ModelRequestBuilder, openAIModelRunner, toOpenAIMessages, toOpenAIRequestBody, toOpenAIRole, toOpenAITextAndMedia, } from '../src/model'; +import { FakeOpenAIServer } from './fake_openai_server'; jest.mock('genkit/model', () => { const originalModule = @@ -1574,6 +1580,142 @@ describe('openAIModelRunner', () => { ); }); + describe('request scoping with fake server', () => { + let server: FakeOpenAIServer; + + beforeEach(async () => { + server = new FakeOpenAIServer('scoped-key'); + await server.start(); + }); + + afterEach(() => { + server.stop(); + }); + + it('should use request scoped client when apiKey is provided', async () => { + server.setNextResponse({ + body: { + choices: [ + { + message: { role: 'assistant', content: 'scoped response' }, + finish_reason: 'stop', + }, + ], + }, + }); + + const defaultClient = new OpenAI({ apiKey: 'default-key' }); + const runner = openAIModelRunner('gpt-4o', defaultClient, undefined, { + name: 'openai', + baseURL: server.baseUrl, + }); + + const result = await runner({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + config: { apiKey: 'scoped-key' }, + }); + + expect(result.message?.content[0].text).toBe('scoped response'); + // Verify server received correct key + expect(server.requests.length).toBe(1); + expect(server.requests[0].headers['authorization']).toBe( + 'Bearer scoped-key' + ); + }); + + it('should handle streaming response with scoped client', async () => { + server.setNextResponse({ + stream: true, + chunks: [ + { + id: '1', + choices: [ + { + index: 0, + delta: { role: 'assistant', content: 'chunk1' }, + finish_reason: null, + }, + ], + }, + { + id: '2', + choices: [ + { index: 0, delta: { content: 'chunk2' }, finish_reason: null }, + ], + }, + { + id: '3', + choices: [{ index: 0, delta: {}, finish_reason: 'stop' }], + }, + ], + }); + + const defaultClient = new OpenAI({ apiKey: 'default-key' }); + const runner = openAIModelRunner('gpt-4o', defaultClient, undefined, { + name: 'openai', + baseURL: server.baseUrl, + }); + + let streamedContent = ''; + const result = await runner( + { + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + config: { apiKey: 'scoped-key' }, + }, + { + streamingRequested: true, + sendChunk: (chunk) => { + if (chunk.content.length > 0) { + streamedContent += chunk.content[0].text; + } + }, + } + ); + + expect(streamedContent).toBe('chunk1chunk2'); + expect(result.message?.content[0].text).toBe('chunk1chunk2'); + }); + + it('should fail when invalid apiKey is provided in request', async () => { + const defaultClient = new OpenAI({ apiKey: 'default-key' }); + const runner = openAIModelRunner('gpt-4o', defaultClient, undefined, { + name: 'openai', + baseURL: server.baseUrl, + }); + + await expect( + runner({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + config: { apiKey: 'wrong-key' }, + }) + ).rejects.toThrow( + expect.objectContaining({ + status: 'PERMISSION_DENIED', + message: expect.stringContaining('Incorrect API key provided'), + }) + ); + }); + + it('should fail when invalid apiKey is provided in plugin options', async () => { + const defaultClient = new OpenAI({ apiKey: 'default-key' }); + const runner = openAIModelRunner('gpt-4o', defaultClient, undefined, { + name: 'openai', + baseURL: server.baseUrl, + }); + + await expect( + runner({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + }) + ).rejects.toThrow( + expect.objectContaining({ + status: 'PERMISSION_DENIED', + message: expect.stringContaining('Incorrect API key provided'), + }) + ); + }); + }); + describe('error handling', () => { const testCases = [ { diff --git a/js/plugins/compat-oai/tests/fake_openai_server.ts b/js/plugins/compat-oai/tests/fake_openai_server.ts new file mode 100644 index 0000000000..ceb209351c --- /dev/null +++ b/js/plugins/compat-oai/tests/fake_openai_server.ts @@ -0,0 +1,128 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as http from 'http'; +import { AddressInfo } from 'net'; + +export interface MockResponse { + statusCode?: number; + body?: any; + headers?: http.OutgoingHttpHeaders; + stream?: boolean; + chunks?: any[]; // For streaming +} + +export class FakeOpenAIServer { + private server: http.Server; + private port: number = 0; + private responses: MockResponse[] = []; + public requests: { headers: http.IncomingHttpHeaders; body: any }[] = []; + private expectedApiKey?: string; + + constructor(expectedApiKey?: string) { + this.expectedApiKey = expectedApiKey; + this.server = http.createServer(async (req, res) => { + let body = ''; + req.on('data', (chunk) => { + body += chunk.toString(); + }); + + await new Promise((resolve) => req.on('end', resolve)); + + const parsedBody = body ? JSON.parse(body) : {}; + this.requests.push({ headers: req.headers, body: parsedBody }); + + if (this.expectedApiKey) { + const authHeader = req.headers['authorization']; + if (!authHeader || authHeader !== `Bearer ${this.expectedApiKey}`) { + res.writeHead(401, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + error: { + message: 'Incorrect API key provided', + type: 'invalid_request_error', + param: null, + code: 'invalid_api_key', + }, + }) + ); + return; + } + } + + const response = this.responses.shift(); + if (!response) { + // Default response if nothing queued + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + choices: [ + { + message: { role: 'assistant', content: 'default response' }, + finish_reason: 'stop', + }, + ], + }) + ); + return; + } + + const statusCode = response.statusCode || 200; + const headers = response.headers || { + 'Content-Type': 'application/json', + }; + + if (response.stream) { + headers['Content-Type'] = 'text/event-stream'; + headers['Cache-Control'] = 'no-cache'; + headers['Connection'] = 'keep-alive'; + res.writeHead(statusCode, headers); + + if (response.chunks) { + for (const chunk of response.chunks) { + res.write(`data: ${JSON.stringify(chunk)}\n\n`); + } + } + res.write('data: [DONE]\n\n'); + res.end(); + } else { + res.writeHead(statusCode, headers); + res.end(JSON.stringify(response.body)); + } + }); + } + + async start() { + await new Promise((resolve) => { + this.server.listen(0, () => { + this.port = (this.server.address() as AddressInfo).port; + resolve(); + }); + }); + } + + stop() { + this.server.close(); + } + + get baseUrl() { + return `http://localhost:${this.port}/v1`; + } + + setNextResponse(response: MockResponse) { + this.responses.push(response); + } +} diff --git a/js/plugins/compat-oai/tests/openai_test.ts b/js/plugins/compat-oai/tests/openai_test.ts index a5ee506fec..fab97afda1 100644 --- a/js/plugins/compat-oai/tests/openai_test.ts +++ b/js/plugins/compat-oai/tests/openai_test.ts @@ -18,7 +18,6 @@ import { afterEach, describe, expect, it, jest } from '@jest/globals'; import { modelRef, type GenerateRequest } from 'genkit/model'; import type OpenAI from 'openai'; - import { ChatCompletionCommonConfigSchema, defineCompatOpenAIModel, diff --git a/js/plugins/compat-oai/tests/utils_test.ts b/js/plugins/compat-oai/tests/utils_test.ts new file mode 100644 index 0000000000..a4736c762e --- /dev/null +++ b/js/plugins/compat-oai/tests/utils_test.ts @@ -0,0 +1,82 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { describe, expect, it } from '@jest/globals'; +import { GenerateRequest } from 'genkit'; +import OpenAI from 'openai'; +import { maybeCreateRequestScopedOpenAIClient } from '../src/utils'; + +describe('maybeCreateRequestScopedOpenAIClient', () => { + it('should copy options from defaultClient when pluginOptions is undefined', () => { + const defaultClient = new OpenAI({ + apiKey: 'default-key', + baseURL: 'https://example.com/v1', + timeout: 12345, + }); + + const request = { + config: { apiKey: 'scoped-key' }, + } as GenerateRequest; + + const newClient = maybeCreateRequestScopedOpenAIClient( + undefined, + request, + defaultClient + ); + + expect(newClient).not.toBe(defaultClient); + expect(newClient.apiKey).toBe('scoped-key'); + expect(newClient.baseURL).toBe('https://example.com/v1'); + expect(newClient.timeout).toBe(12345); + }); + + it('should prioritize pluginOptions over defaultClient options', () => { + const defaultClient = new OpenAI({ + apiKey: 'default-key', + baseURL: 'https://example.com/v1', + }); + + const pluginOptions = { + name: 'foo', + baseURL: 'https://plugin-override.com/v1', + }; + + const request = { + config: { apiKey: 'scoped-key' }, + } as GenerateRequest; + + const newClient = maybeCreateRequestScopedOpenAIClient( + pluginOptions, + request, + defaultClient + ); + + expect(newClient.apiKey).toBe('scoped-key'); + expect(newClient.baseURL).toBe('https://plugin-override.com/v1'); + }); + + it('should verify that _options property exists on defaultClient (library compatibility check)', () => { + // This test ensures that the OpenAI library version being used still has the private '_options' property + // that we rely on for copying configuration. + const defaultClient = new OpenAI({ apiKey: 'test' }); + expect(defaultClient).toHaveProperty('_options'); + expect((defaultClient as any)['_options']).toEqual( + expect.objectContaining({ + apiKey: 'test', + }) + ); + }); +}); diff --git a/js/plugins/google-genai/src/googleai/index.ts b/js/plugins/google-genai/src/googleai/index.ts index 8aedfb293f..27dd7dde30 100644 --- a/js/plugins/google-genai/src/googleai/index.ts +++ b/js/plugins/google-genai/src/googleai/index.ts @@ -80,6 +80,10 @@ async function resolver( async function listActions( options?: GoogleAIPluginOptions ): Promise { + // Don't attempt to list models if apiKey is set to false. + if (options?.apiKey === false) { + return []; + } try { const apiKey = calculateApiKey(options?.apiKey, undefined); const allModels = await listModels(apiKey, { diff --git a/js/plugins/google-genai/tests/googleai/index_test.ts b/js/plugins/google-genai/tests/googleai/index_test.ts index bc2019057b..431e258e5e 100644 --- a/js/plugins/google-genai/tests/googleai/index_test.ts +++ b/js/plugins/google-genai/tests/googleai/index_test.ts @@ -588,7 +588,7 @@ describe('GoogleAI Plugin', () => { ); }); - it('should still call list if API key is missing for listActions', async () => { + it('should return empty list if API key is missing for listActions', async () => { delete process.env.GOOGLE_API_KEY; delete process.env.GEMINI_API_KEY; delete process.env.GOOGLE_GENAI_API_KEY; @@ -600,7 +600,7 @@ describe('GoogleAI Plugin', () => { [], 'Should return empty array if API key is not found' ); - assert.strictEqual(fetchMock.mock.callCount(), 1); + assert.strictEqual(fetchMock.mock.callCount(), 0); }); it('should use listActions cache', async () => {