Skip to content

Commit 2c6539e

Browse files
authored
feat(js/plugins/compat-oai): allow passing api key at runtime instead of config time (#3946)
1 parent 45defaf commit 2c6539e

File tree

17 files changed

+617
-122
lines changed

17 files changed

+617
-122
lines changed

js/ai/src/model-types.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,12 @@ export const GenerationCommonConfigSchema = z
201201
'Set of character sequences (up to 5) that will stop output generation.'
202202
)
203203
.optional(),
204+
apiKey: z
205+
.string()
206+
.describe(
207+
'API Key to use for the model call, overrides API key provided in plugin config.'
208+
)
209+
.optional(),
204210
})
205211
.passthrough();
206212

js/genkit/src/embedder.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
export {
1818
EmbedderInfoSchema,
1919
embedderRef,
20+
type EmbedRequest,
2021
type EmbedderAction,
2122
type EmbedderArgument,
2223
type EmbedderInfo,

js/plugins/compat-oai/src/audio.ts

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ import type {
2222
import { GenerationCommonConfigSchema, Message, modelRef, z } from 'genkit';
2323
import type { ModelAction, ModelInfo } from 'genkit/model';
2424
import { model } from 'genkit/plugin';
25-
import type OpenAI from 'openai';
25+
import OpenAI from 'openai';
2626
import { Response } from 'openai/core.mjs';
2727
import type {
2828
SpeechCreateParams,
2929
Transcription,
3030
TranscriptionCreateParams,
3131
} from 'openai/resources/audio/index.mjs';
32+
import { PluginOptions } from './index.js';
33+
import { maybeCreateRequestScopedOpenAIClient } from './utils.js';
3234

3335
export type SpeechRequestBuilder = (
3436
req: GenerateRequest,
@@ -185,10 +187,16 @@ export function defineCompatOpenAISpeechModel<
185187
client: OpenAI;
186188
modelRef?: ModelReference<CustomOptions>;
187189
requestBuilder?: SpeechRequestBuilder;
190+
pluginOptions: PluginOptions;
188191
}): ModelAction {
189-
const { name, client, modelRef, requestBuilder } = params;
192+
const {
193+
name,
194+
client: defaultClient,
195+
pluginOptions,
196+
modelRef,
197+
requestBuilder,
198+
} = params;
190199
const modelName = name.substring(name.indexOf('/') + 1);
191-
192200
return model(
193201
{
194202
name,
@@ -197,6 +205,11 @@ export function defineCompatOpenAISpeechModel<
197205
},
198206
async (request, { abortSignal }) => {
199207
const ttsRequest = toTTSRequest(modelName!, request, requestBuilder);
208+
const client = maybeCreateRequestScopedOpenAIClient(
209+
pluginOptions,
210+
request,
211+
defaultClient
212+
);
200213
const result = await client.audio.speech.create(ttsRequest, {
201214
signal: abortSignal,
202215
});
@@ -338,11 +351,17 @@ export function defineCompatOpenAITranscriptionModel<
338351
>(params: {
339352
name: string;
340353
client: OpenAI;
354+
pluginOptions?: PluginOptions;
341355
modelRef?: ModelReference<CustomOptions>;
342356
requestBuilder?: TranscriptionRequestBuilder;
343357
}): ModelAction {
344-
const { name, client, modelRef, requestBuilder } = params;
345-
358+
const {
359+
name,
360+
pluginOptions,
361+
client: defaultClient,
362+
modelRef,
363+
requestBuilder,
364+
} = params;
346365
return model(
347366
{
348367
name,
@@ -353,6 +372,11 @@ export function defineCompatOpenAITranscriptionModel<
353372
const modelName = name.substring(name.indexOf('/') + 1);
354373

355374
const params = toSttRequest(modelName!, request, requestBuilder);
375+
const client = maybeCreateRequestScopedOpenAIClient(
376+
pluginOptions,
377+
request,
378+
defaultClient
379+
);
356380
// Explicitly setting stream to false ensures we use the non-streaming overload
357381
const result = await client.audio.transcriptions.create(
358382
{

js/plugins/compat-oai/src/deepseek/index.ts

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,25 @@ import {
3636

3737
export type DeepSeekPluginOptions = Omit<PluginOptions, 'name' | 'baseURL'>;
3838

39-
const resolver = async (
40-
client: OpenAI,
41-
actionType: ActionType,
42-
actionName: string
43-
) => {
44-
if (actionType === 'model') {
45-
const modelRef = deepSeekModelRef({
46-
name: actionName,
47-
});
48-
return defineCompatOpenAIModel({
49-
name: modelRef.name,
50-
client,
51-
modelRef,
52-
requestBuilder: deepSeekRequestBuilder,
53-
});
54-
} else {
55-
logger.warn('Only model actions are supported by the DeepSeek plugin');
56-
return undefined;
57-
}
58-
};
39+
function createResolver(pluginOptions: PluginOptions) {
40+
return async (client: OpenAI, actionType: ActionType, actionName: string) => {
41+
if (actionType === 'model') {
42+
const modelRef = deepSeekModelRef({
43+
name: actionName,
44+
});
45+
return defineCompatOpenAIModel({
46+
name: modelRef.name,
47+
client,
48+
pluginOptions,
49+
modelRef,
50+
requestBuilder: deepSeekRequestBuilder,
51+
});
52+
} else {
53+
logger.warn('Only model actions are supported by the DeepSeek plugin');
54+
return undefined;
55+
}
56+
};
57+
}
5958

6059
const listActions = async (client: OpenAI): Promise<ActionMetadata[]> => {
6160
return await client.models.list().then((response) =>
@@ -87,6 +86,7 @@ export function deepSeekPlugin(
8786
'Please pass in the API key or set the DEEPSEEK_API_KEY environment variable.',
8887
});
8988
}
89+
const pluginOptions = { name: 'deepseek', ...options };
9090
return openAICompatible({
9191
name: 'deepseek',
9292
baseURL: 'https://api.deepseek.com',
@@ -97,12 +97,13 @@ export function deepSeekPlugin(
9797
defineCompatOpenAIModel({
9898
name: modelRef.name,
9999
client,
100+
pluginOptions,
100101
modelRef,
101102
requestBuilder: deepSeekRequestBuilder,
102103
})
103104
);
104105
},
105-
resolver,
106+
resolver: createResolver(pluginOptions),
106107
listActions,
107108
});
108109
}

js/plugins/compat-oai/src/embedder.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import type { EmbedderAction, EmbedderReference } from 'genkit';
2121
import { embedder } from 'genkit/plugin';
2222
import OpenAI from 'openai';
23+
import { PluginOptions } from './index.js';
24+
import { maybeCreateRequestScopedOpenAIClient } from './utils.js';
2325

2426
/**
2527
* Method to define a new Genkit Embedder that is compatibale with the Open AI
@@ -37,11 +39,11 @@ import OpenAI from 'openai';
3739
export function defineCompatOpenAIEmbedder(params: {
3840
name: string;
3941
client: OpenAI;
42+
pluginOptions?: PluginOptions;
4043
embedderRef?: EmbedderReference;
4144
}): EmbedderAction {
42-
const { name, client, embedderRef } = params;
45+
const { name, client: defaultClient, pluginOptions, embedderRef } = params;
4346
const modelName = name.substring(name.indexOf('/') + 1);
44-
4547
return embedder(
4648
{
4749
name,
@@ -50,6 +52,11 @@ export function defineCompatOpenAIEmbedder(params: {
5052
},
5153
async (req) => {
5254
const { encodingFormat: encoding_format, ...restOfConfig } = req.options;
55+
const client = maybeCreateRequestScopedOpenAIClient(
56+
pluginOptions,
57+
req,
58+
defaultClient
59+
);
5360
const embeddings = await client.embeddings.create({
5461
model: modelName!,
5562
input: req.input.map((d) => d.text),

js/plugins/compat-oai/src/image.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import type {
2727
ImageGenerateParams,
2828
ImagesResponse,
2929
} from 'openai/resources/images.mjs';
30+
import { PluginOptions } from './index.js';
31+
import { maybeCreateRequestScopedOpenAIClient } from './utils.js';
3032

3133
export type ImageRequestBuilder = (
3234
req: GenerateRequest,
@@ -122,10 +124,17 @@ export function defineCompatOpenAIImageModel<
122124
>(params: {
123125
name: string;
124126
client: OpenAI;
127+
pluginOptions?: PluginOptions;
125128
modelRef?: ModelReference<CustomOptions>;
126129
requestBuilder?: ImageRequestBuilder;
127130
}): ModelAction<CustomOptions> {
128-
const { name, client, modelRef, requestBuilder } = params;
131+
const {
132+
name,
133+
client: defaultClient,
134+
pluginOptions,
135+
modelRef,
136+
requestBuilder,
137+
} = params;
129138
const modelName = name.substring(name.indexOf('/') + 1);
130139

131140
return model(
@@ -135,6 +144,11 @@ export function defineCompatOpenAIImageModel<
135144
configSchema: modelRef?.configSchema,
136145
},
137146
async (request, { abortSignal }) => {
147+
const client = maybeCreateRequestScopedOpenAIClient(
148+
pluginOptions,
149+
request,
150+
defaultClient
151+
);
138152
const result = await client.images.generate(
139153
toImageGenerateParams(modelName!, request, requestBuilder),
140154
{ signal: abortSignal }

js/plugins/compat-oai/src/index.ts

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import { ActionMetadata } from 'genkit';
1818
import { ResolvableAction, genkitPluginV2 } from 'genkit/plugin';
1919
import { ActionType } from 'genkit/registry';
20-
import { OpenAI, type ClientOptions } from 'openai';
20+
import OpenAI, { type ClientOptions } from 'openai';
2121
import { compatOaiModelRef, defineCompatOpenAIModel } from './model.js';
2222

2323
export {
@@ -45,7 +45,8 @@ export {
4545
type ModelRequestBuilder,
4646
} from './model.js';
4747

48-
export interface PluginOptions extends Partial<ClientOptions> {
48+
export interface PluginOptions extends Partial<Omit<ClientOptions, 'apiKey'>> {
49+
apiKey?: ClientOptions['apiKey'] | false;
4950
name: string;
5051
initializer?: (client: OpenAI) => Promise<ResolvableAction[]>;
5152
resolver?: (
@@ -110,24 +111,33 @@ export interface PluginOptions extends Partial<ClientOptions> {
110111
*/
111112
export const openAICompatible = (options: PluginOptions) => {
112113
let listActionsCache;
114+
var client: OpenAI;
115+
function createClient() {
116+
if (client) return client;
117+
const { apiKey, ...restofOptions } = options;
118+
client = new OpenAI({
119+
...restofOptions,
120+
apiKey: apiKey === false ? 'placeholder' : apiKey,
121+
});
122+
return client;
123+
}
113124
return genkitPluginV2({
114125
name: options.name,
115126
async init() {
116127
if (!options.initializer) {
117128
return [];
118129
}
119-
const client = new OpenAI(options);
120-
return await options.initializer(client);
130+
return await options.initializer(createClient());
121131
},
122132
async resolve(actionType: ActionType, actionName: string) {
123-
const client = new OpenAI(options);
124133
if (options.resolver) {
125-
return await options.resolver(client, actionType, actionName);
134+
return await options.resolver(createClient(), actionType, actionName);
126135
} else {
127136
if (actionType === 'model') {
128137
return defineCompatOpenAIModel({
129138
name: actionName,
130-
client,
139+
client: createClient(),
140+
pluginOptions: options,
131141
modelRef: compatOaiModelRef({
132142
name: actionName,
133143
}),
@@ -136,14 +146,15 @@ export const openAICompatible = (options: PluginOptions) => {
136146
return undefined;
137147
}
138148
},
139-
list: options.listActions
140-
? async () => {
141-
if (listActionsCache) return listActionsCache;
142-
const client = new OpenAI(options);
143-
listActionsCache = await options.listActions!(client);
144-
return listActionsCache;
145-
}
146-
: undefined,
149+
list:
150+
// Don't attempt to list models if apiKey set to false
151+
options.listActions && options.apiKey !== false
152+
? async () => {
153+
if (listActionsCache) return listActionsCache;
154+
listActionsCache = await options.listActions!(createClient());
155+
return listActionsCache;
156+
}
157+
: undefined,
147158
});
148159
};
149160

0 commit comments

Comments
 (0)