Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions js/ai/src/model-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ export const GenerationCommonConfigSchema = z
'Set of character sequences (up to 5) that will stop output generation.'
)
.optional(),
apiKey: z
.string()
.describe(
'API Key to use for the model call, overrides API key provided in plugin config.'
)
.optional(),
})
.passthrough();

Expand Down
1 change: 1 addition & 0 deletions js/genkit/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
export {
EmbedderInfoSchema,
embedderRef,
type EmbedRequest,
type EmbedderAction,
type EmbedderArgument,
type EmbedderInfo,
Expand Down
34 changes: 29 additions & 5 deletions js/plugins/compat-oai/src/audio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ import type {
import { GenerationCommonConfigSchema, Message, modelRef, z } from 'genkit';
import type { ModelAction, ModelInfo } from 'genkit/model';
import { model } from 'genkit/plugin';
import type OpenAI from 'openai';
import OpenAI from 'openai';
import { Response } from 'openai/core.mjs';
import type {
SpeechCreateParams,
Transcription,
TranscriptionCreateParams,
} from 'openai/resources/audio/index.mjs';
import { PluginOptions } from './index.js';
import { maybeCreateRequestScopedOpenAIClient } from './utils.js';

export type SpeechRequestBuilder = (
req: GenerateRequest,
Expand Down Expand Up @@ -185,10 +187,16 @@ export function defineCompatOpenAISpeechModel<
client: OpenAI;
modelRef?: ModelReference<CustomOptions>;
requestBuilder?: SpeechRequestBuilder;
pluginOptions: PluginOptions;
}): ModelAction {
const { name, client, modelRef, requestBuilder } = params;
const {
name,
client: defaultClient,
pluginOptions,
modelRef,
requestBuilder,
} = params;
const modelName = name.substring(name.indexOf('/') + 1);

return model(
{
name,
Expand All @@ -197,6 +205,11 @@ export function defineCompatOpenAISpeechModel<
},
async (request, { abortSignal }) => {
const ttsRequest = toTTSRequest(modelName!, request, requestBuilder);
const client = maybeCreateRequestScopedOpenAIClient(
pluginOptions,
request,
defaultClient
);
const result = await client.audio.speech.create(ttsRequest, {
signal: abortSignal,
});
Expand Down Expand Up @@ -338,11 +351,17 @@ export function defineCompatOpenAITranscriptionModel<
>(params: {
name: string;
client: OpenAI;
pluginOptions?: PluginOptions;
modelRef?: ModelReference<CustomOptions>;
requestBuilder?: TranscriptionRequestBuilder;
}): ModelAction {
const { name, client, modelRef, requestBuilder } = params;

const {
name,
pluginOptions,
client: defaultClient,
modelRef,
requestBuilder,
} = params;
return model(
{
name,
Expand All @@ -353,6 +372,11 @@ export function defineCompatOpenAITranscriptionModel<
const modelName = name.substring(name.indexOf('/') + 1);

const params = toSttRequest(modelName!, request, requestBuilder);
const client = maybeCreateRequestScopedOpenAIClient(
pluginOptions,
request,
defaultClient
);
// Explicitly setting stream to false ensures we use the non-streaming overload
const result = await client.audio.transcriptions.create(
{
Expand Down
43 changes: 22 additions & 21 deletions js/plugins/compat-oai/src/deepseek/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,25 @@ import {

export type DeepSeekPluginOptions = Omit<PluginOptions, 'name' | 'baseURL'>;

const resolver = async (
client: OpenAI,
actionType: ActionType,
actionName: string
) => {
if (actionType === 'model') {
const modelRef = deepSeekModelRef({
name: actionName,
});
return defineCompatOpenAIModel({
name: modelRef.name,
client,
modelRef,
requestBuilder: deepSeekRequestBuilder,
});
} else {
logger.warn('Only model actions are supported by the DeepSeek plugin');
return undefined;
}
};
function createResolver(pluginOptions: PluginOptions) {
return async (client: OpenAI, actionType: ActionType, actionName: string) => {
if (actionType === 'model') {
const modelRef = deepSeekModelRef({
name: actionName,
});
return defineCompatOpenAIModel({
name: modelRef.name,
client,
pluginOptions,
modelRef,
requestBuilder: deepSeekRequestBuilder,
});
} else {
logger.warn('Only model actions are supported by the DeepSeek plugin');
return undefined;
}
};
}

const listActions = async (client: OpenAI): Promise<ActionMetadata[]> => {
return await client.models.list().then((response) =>
Expand Down Expand Up @@ -87,6 +86,7 @@ export function deepSeekPlugin(
'Please pass in the API key or set the DEEPSEEK_API_KEY environment variable.',
});
}
const pluginOptions = { name: 'deepseek', ...options };
return openAICompatible({
name: 'deepseek',
baseURL: 'https://api.deepseek.com',
Expand All @@ -97,12 +97,13 @@ export function deepSeekPlugin(
defineCompatOpenAIModel({
name: modelRef.name,
client,
pluginOptions,
modelRef,
requestBuilder: deepSeekRequestBuilder,
})
);
},
resolver,
resolver: createResolver(pluginOptions),
listActions,
});
}
Expand Down
11 changes: 9 additions & 2 deletions js/plugins/compat-oai/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import type { EmbedderAction, EmbedderReference } from 'genkit';
import { embedder } from 'genkit/plugin';
import OpenAI from 'openai';
import { PluginOptions } from './index.js';
import { maybeCreateRequestScopedOpenAIClient } from './utils.js';

/**
* Method to define a new Genkit Embedder that is compatibale with the Open AI
Expand All @@ -37,11 +39,11 @@ import OpenAI from 'openai';
export function defineCompatOpenAIEmbedder(params: {
name: string;
client: OpenAI;
pluginOptions?: PluginOptions;
embedderRef?: EmbedderReference;
}): EmbedderAction {
const { name, client, embedderRef } = params;
const { name, client: defaultClient, pluginOptions, embedderRef } = params;
const modelName = name.substring(name.indexOf('/') + 1);

return embedder(
{
name,
Expand All @@ -50,6 +52,11 @@ export function defineCompatOpenAIEmbedder(params: {
},
async (req) => {
const { encodingFormat: encoding_format, ...restOfConfig } = req.options;
const client = maybeCreateRequestScopedOpenAIClient(
pluginOptions,
req,
defaultClient
);
const embeddings = await client.embeddings.create({
model: modelName!,
input: req.input.map((d) => d.text),
Expand Down
16 changes: 15 additions & 1 deletion js/plugins/compat-oai/src/image.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import type {
ImageGenerateParams,
ImagesResponse,
} from 'openai/resources/images.mjs';
import { PluginOptions } from './index.js';
import { maybeCreateRequestScopedOpenAIClient } from './utils.js';

export type ImageRequestBuilder = (
req: GenerateRequest,
Expand Down Expand Up @@ -122,10 +124,17 @@ export function defineCompatOpenAIImageModel<
>(params: {
name: string;
client: OpenAI;
pluginOptions?: PluginOptions;
modelRef?: ModelReference<CustomOptions>;
requestBuilder?: ImageRequestBuilder;
}): ModelAction<CustomOptions> {
const { name, client, modelRef, requestBuilder } = params;
const {
name,
client: defaultClient,
pluginOptions,
modelRef,
requestBuilder,
} = params;
const modelName = name.substring(name.indexOf('/') + 1);

return model(
Expand All @@ -135,6 +144,11 @@ export function defineCompatOpenAIImageModel<
configSchema: modelRef?.configSchema,
},
async (request, { abortSignal }) => {
const client = maybeCreateRequestScopedOpenAIClient(
pluginOptions,
request,
defaultClient
);
const result = await client.images.generate(
toImageGenerateParams(modelName!, request, requestBuilder),
{ signal: abortSignal }
Expand Down
41 changes: 26 additions & 15 deletions js/plugins/compat-oai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import { ActionMetadata } from 'genkit';
import { ResolvableAction, genkitPluginV2 } from 'genkit/plugin';
import { ActionType } from 'genkit/registry';
import { OpenAI, type ClientOptions } from 'openai';
import OpenAI, { type ClientOptions } from 'openai';
import { compatOaiModelRef, defineCompatOpenAIModel } from './model.js';

export {
Expand Down Expand Up @@ -45,7 +45,8 @@ export {
type ModelRequestBuilder,
} from './model.js';

export interface PluginOptions extends Partial<ClientOptions> {
export interface PluginOptions extends Partial<Omit<ClientOptions, 'apiKey'>> {
apiKey?: ClientOptions['apiKey'] | false;
name: string;
initializer?: (client: OpenAI) => Promise<ResolvableAction[]>;
resolver?: (
Expand Down Expand Up @@ -110,24 +111,33 @@ export interface PluginOptions extends Partial<ClientOptions> {
*/
export const openAICompatible = (options: PluginOptions) => {
let listActionsCache;
var client: OpenAI;
function createClient() {
if (client) return client;
const { apiKey, ...restofOptions } = options;
client = new OpenAI({
...restofOptions,
apiKey: apiKey === false ? 'placeholder' : apiKey,
});
return client;
}
return genkitPluginV2({
name: options.name,
async init() {
if (!options.initializer) {
return [];
}
const client = new OpenAI(options);
return await options.initializer(client);
return await options.initializer(createClient());
},
async resolve(actionType: ActionType, actionName: string) {
const client = new OpenAI(options);
if (options.resolver) {
return await options.resolver(client, actionType, actionName);
return await options.resolver(createClient(), actionType, actionName);
} else {
if (actionType === 'model') {
return defineCompatOpenAIModel({
name: actionName,
client,
client: createClient(),
pluginOptions: options,
modelRef: compatOaiModelRef({
name: actionName,
}),
Expand All @@ -136,14 +146,15 @@ export const openAICompatible = (options: PluginOptions) => {
return undefined;
}
},
list: options.listActions
? async () => {
if (listActionsCache) return listActionsCache;
const client = new OpenAI(options);
listActionsCache = await options.listActions!(client);
return listActionsCache;
}
: undefined,
list:
// Don't attempt to list models if apiKey set to false
options.listActions && options.apiKey !== false
? async () => {
if (listActionsCache) return listActionsCache;
listActionsCache = await options.listActions!(createClient());
return listActionsCache;
}
: undefined,
});
};

Expand Down
Loading