diff --git a/js/plugins/compat-oai/src/openai/gpt.ts b/js/plugins/compat-oai/src/openai/gpt.ts index 53b01501f..fa2168a7e 100644 --- a/js/plugins/compat-oai/src/openai/gpt.ts +++ b/js/plugins/compat-oai/src/openai/gpt.ts @@ -15,7 +15,7 @@ * limitations under the License. */ -import { z, type ModelReference } from 'genkit'; +import { z } from 'genkit'; import { GenerationCommonConfigSchema, modelRef } from 'genkit/model'; export const ChatCompletionConfigSchema = GenerationCommonConfigSchema.extend({ @@ -291,10 +291,7 @@ export const gpt35Turbo = modelRef({ configSchema: ChatCompletionConfigSchema, }); -export const SUPPORTED_GPT_MODELS: Record< - string, - ModelReference -> = { +export const SUPPORTED_GPT_MODELS = { 'gpt-4.5': gpt45, 'gpt-4o': gpt4o, 'gpt-4o-mini': gpt4oMini, diff --git a/js/plugins/compat-oai/src/openai/index.ts b/js/plugins/compat-oai/src/openai/index.ts index 97f8680dc..a5ad892dc 100644 --- a/js/plugins/compat-oai/src/openai/index.ts +++ b/js/plugins/compat-oai/src/openai/index.ts @@ -212,16 +212,45 @@ export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPlugin { export type OpenAIPlugin = { (params?: OpenAIPluginOptions): GenkitPlugin; + model( + name: + | keyof typeof SUPPORTED_GPT_MODELS + | (`gpt-${string}` & {}) + | (`o${number}` & {}), + config?: z.infer + ): ModelReference; + model( + name: + | keyof typeof SUPPORTED_IMAGE_MODELS + | (`dall-e${string}` & {}) + | (`gpt-image-${string}` & {}), + config?: z.infer + ): ModelReference; + model( + name: + | keyof typeof SUPPORTED_TTS_MODELS + | (`tts-${string}` & {}) + | (`${string}-tts` & {}), + config?: z.infer + ): ModelReference; + model( + name: + | keyof typeof SUPPORTED_STT_MODELS + | (`whisper-${string}` & {}) + | (`${string}-transcribe` & {}), + config?: z.infer + ): ModelReference; model(name: string, config?: any): ModelReference; + embedder( + name: + | keyof typeof SUPPORTED_EMBEDDING_MODELS + | (`${string}-embedding-${string}` & {}), + config?: z.infer + ): EmbedderReference; embedder(name: string, config?: any): EmbedderReference; }; -export const openAI = openAIPlugin as OpenAIPlugin; -// provide generic implementation for the model function overloads. -(openAI as any).model = ( - name: string, - config?: any -): ModelReference => { +const model = ((name: string, config?: any): ModelReference => { if (name.includes('gpt-image-1') || name.includes('dall-e')) { return modelRef({ name: `openai/${name}`, @@ -248,8 +277,9 @@ export const openAI = openAIPlugin as OpenAIPlugin; config, configSchema: ChatCompletionConfigSchema, }); -}; -openAI.embedder = ( +}) as OpenAIPlugin['model']; + +const embedder = (( name: string, config?: any ): EmbedderReference => { @@ -258,6 +288,11 @@ openAI.embedder = ( config, configSchema: TextEmbeddingConfigSchema, }); -}; +}) as OpenAIPlugin['embedder']; + +export const openAI: OpenAIPlugin = Object.assign(openAIPlugin, { + model, + embedder, +}); export default openAI; diff --git a/js/testapps/compat-oai/src/index.ts b/js/testapps/compat-oai/src/index.ts index 9cbf4f347..31f642c1f 100644 --- a/js/testapps/compat-oai/src/index.ts +++ b/js/testapps/compat-oai/src/index.ts @@ -23,7 +23,7 @@ import { genkit, z } from 'genkit'; dotenv.config(); const ai = genkit({ - plugins: [openAI({ apiKey: process.env.OPENAI_API_KEY, name: 'openai' })], + plugins: [openAI()], }); export const jokeFlow = ai.defineFlow( @@ -51,7 +51,7 @@ export const embedFlow = ai.defineFlow( }, async (text) => { const embedding = await ai.embed({ - embedder: 'openai/text-embedding-ada-002', + embedder: openAI.embedder('text-embedding-ada-002'), content: text, });