From 3d67fd50c8bf63dec531a5919db14059d1d47a5a Mon Sep 17 00:00:00 2001 From: Tristan Chin <23557893+maxijonson@users.noreply.github.com> Date: Sun, 18 Jun 2023 02:33:09 -0400 Subject: [PATCH] Function Calling and pricing changes (#24) --- README.md | 4 +- packages/cli/src/components/Message.tsx | 4 + packages/cli/src/hooks/usePagedMessages.ts | 20 +- .../src/managers/ConversationManager.ts | 3 + packages/lib/README.md | 109 +++++++++- packages/lib/package.json | 1 - packages/lib/src/classes/Conversation.ts | 193 ++++++++++++++++-- .../lib/src/classes/ConversationConfig.ts | 10 + packages/lib/src/classes/Message.ts | 99 ++++++++- packages/lib/src/config/constants.ts | 6 +- .../src/schemas/conversationConfig.schema.ts | 19 ++ packages/lib/src/schemas/index.ts | 1 + packages/lib/src/schemas/jsonSchema.schema.ts | 185 +++++++++++++++++ packages/lib/src/schemas/message.schema.ts | 10 +- packages/lib/src/utils/getPricing.ts | 3 +- packages/lib/src/utils/types.ts | 100 ++++++++- packages/nest/package.json | 3 +- packages/web/package.json | 2 +- packages/web/src/components/Message.tsx | 42 +++- 19 files changed, 758 insertions(+), 56 deletions(-) create mode 100644 packages/lib/src/schemas/jsonSchema.schema.ts diff --git a/README.md b/README.md index 6f2139a..22f82ef 100644 --- a/README.md +++ b/README.md @@ -18,9 +18,11 @@ GPT Turbo is a JavaScript library for seamless integration with OpenAI's Chat Co ## Features +> ✨ New (June 2023): Added support for **Function calling** + 🤖 Supports all Chat Completion models, including **GPT-4**. (full list [here](https://platform.openai.com/docs/models/model-endpoint-compatibility)) -💬 Supports both single and streamed completions, just like ChatGPT. +💬 Supports both single, streamed and function completions, just like ChatGPT. ⚙ Tune chat completion parameters, such as temperature, top-p, and frequency penalty. diff --git a/packages/cli/src/components/Message.tsx b/packages/cli/src/components/Message.tsx index 2ff99b1..5e56d0a 100644 --- a/packages/cli/src/components/Message.tsx +++ b/packages/cli/src/components/Message.tsx @@ -5,11 +5,13 @@ import React from "react"; export const SENDER_USER = "You"; export const SENDER_ASSISTANT = "GPT"; export const SENDER_SYSTEM = "SYS"; +export const SENDER_FUNCTION = "FUN"; export const SENDER_SUFFIX = ": "; export const SENDER_WIDTH = [ SENDER_USER, SENDER_ASSISTANT, SENDER_SYSTEM, + SENDER_FUNCTION, ].reduce( (max, sender) => Math.max(max, sender.length + SENDER_SUFFIX.length), 0 @@ -28,6 +30,8 @@ export default ({ message }: MessageProps) => { return SENDER_ASSISTANT; case "system": return SENDER_SYSTEM; + case "function": + return SENDER_FUNCTION; case "user": default: return SENDER_USER; diff --git a/packages/cli/src/hooks/usePagedMessages.ts b/packages/cli/src/hooks/usePagedMessages.ts index 3748120..6163747 100644 --- a/packages/cli/src/hooks/usePagedMessages.ts +++ b/packages/cli/src/hooks/usePagedMessages.ts @@ -29,7 +29,23 @@ export default (messages: Message[], maxWidth: number, maxHeight: number) => { const msgs = messages.slice(); for (let i = 0; i < msgs.length; i++) { const message = msgs[i]; - const messageHeight = getMessageHeight(message.content, maxWidth); + const messageContent = (() => { + if (message.isCompletion()) { + return message.content; + } + if (message.isFunction()) { + return `${message.name}() => ${message.content}`; + } + if (message.isFunctionCall()) { + const { name, arguments: args } = message.functionCall; + const parameters = Object.entries(args) + .map(([param, value]) => `${param}=${value}`) + .join(", "); + return `${name}(${parameters})`; + } + return "[Unknown message type]"; + })(); + const messageHeight = getMessageHeight(messageContent, maxWidth); const isHuge = messageHeight > maxHeight; const isOverflowing = pageHeight + messageHeight > maxHeight; @@ -37,7 +53,7 @@ export default (messages: Message[], maxWidth: number, maxHeight: number) => { if (isHuge) { const remainingHeight = maxHeight - pageHeight; const [firstMessageContent, secondMessageContent] = - splitMessage(message.content, maxWidth, remainingHeight); + splitMessage(messageContent, maxWidth, remainingHeight); if (firstMessageContent.length && secondMessageContent.length) { msgs[i] = new Message( diff --git a/packages/discord/src/managers/ConversationManager.ts b/packages/discord/src/managers/ConversationManager.ts index f296990..d5d0ba1 100644 --- a/packages/discord/src/managers/ConversationManager.ts +++ b/packages/discord/src/managers/ConversationManager.ts @@ -39,6 +39,9 @@ export default class ConversationManager { max_tokens: maxTokens, }); + // Should never happen, since we're not using functions. But this check provides type guards. + if (!response.isCompletion()) throw new Error("Not a completion"); + try { await conversation.addAssistantMessage(response.content); } finally { diff --git a/packages/lib/README.md b/packages/lib/README.md index d36bf1c..36f9b8c 100644 --- a/packages/lib/README.md +++ b/packages/lib/README.md @@ -47,13 +47,13 @@ const conversation = new Conversation({ }); const response = await conversation.prompt("How can I make my code more efficient than a droid army?"); -process.stdout.write(`Response: `); -const unsubscribe = response.onMessageUpdate((content) => { - process.stdout.write(content); +const unsubscribeUpdate = response.onMessageUpdate((content) => { + console.log(content); }); -response.onStreamingStop(() => { - unsubscribe(); +const unsubscribeStop = response.onStreamingStop(() => { + unsubscribeUpdate(); + unsubscribeStop(); }); ``` @@ -147,6 +147,105 @@ const second = await conversation.prompt("Take a seat, young Skywalker."); // "I const edit = await conversation.reprompt(first, "We grant you the rank of Master."); ``` +### Function Calling + +> ⚠ Function calling is relatively new and the implementation in this library may change as more is discovered about it. +> +> Limitations (of the GPT Turbo library) with function calling: +> - Token count is not currently calculated for assistant function calls and context. This means the cost of function calls are not taken into account at the moment. This will be fixed in a future release, as I learn more about how function call tokens are calculated by OpenAI. +> - Function calls are not currently supported in dry mode. There is no planned support for this in the near future. +> - While this feature is typed, it may not be as strongly typed as you'd expect. In other words, there's no strict type checking against the function name and arguments against the definition you gave to the configuration's `functions` property. This may or may not be improved in the future, depending on how relevant strong typing is for this feature without sacrificing usability. + +You can use OpenAI's Function Calling feature with GPT Turbo through the `functionPrompt` method. Just define your functions in the conversation configuration (or during prompting) just like you would normally with the Chat Completion API. + +⚠ Unless you configure `functions_call` to explicitly call a function by name (which by default does not, it uses `auto`), make sure you also plan for standard chat completions in your code. To help with detecting which type of response you got, the `Message` class exposes two (type-guarded!) functions: `isFunctionCall` and `isCompletion`. + +> At the time of writing, Function Calling is not supported on the latest version of the GPT model. In this example, we'll use the `gpt-3.5-turbo-0613` model, but the standard `gpt-3.5-turbo` model might work at the time you're reading this. + +```ts +const locateJedi = (jedi, locationType = "planet") => { + return { + name: jedi, + location: locationType === "planet" ? "Tatooine" : "Mos Eisley", + }; +}; + +const conversation = new Conversation({ + apiKey: /** Your API key */, + model: "gpt-3.5-turbo-0613", + functions: [ + { + name: "locateJedi", + description: "Returns the current location of a Jedi", + parameters: { + type: "object", + properties: { + jedi: { + type: "string", + description: "The name of the Jedi to locate", + }, + locationType: { + type: "string", + enum: ["planet", "city"], + }, + }, + required: ["jedi"], + }, + }, + ], +}); + +const r1 = await conversation.prompt("Where can I find Obi-Wan Kenobi?"); + +if (r1.isCompletion()) { + console.info(r1.content); +} else if (r1.isFunctionCall()) { + const { jedi, locationType } = r1.functionCall.arguments; + const r2 = await conversation.functionPrompt( + r1.functionCall.name, + locateJedi(jedi, locationType) + ); + console.info(r2.content); // "Obi-Wan Kenobi can be found on Tatooine." +} +``` + +For streamed completions and function calls, it gets a bit more complicated, but still supported! Hopefully, a better flow will be implemented in the future. + +```ts +const conversation = new Conversation({ /* ... */, stream: true }); + +const r1 = await conversation.prompt("In which city is Obi-Wan Kenobi?"); + +const unsubscribeUpdates = r1.onMessageUpdate((_, message) => { + if (!message.isCompletion()) { + return; + } + console.info(message.content); +}); + +const unsubscribeStop = r1.onMessageStreamingStop(async (message) => { + if (message.isFunctionCall()) { + const { jedi, locationType } = message.functionCall.arguments; + const r2 = await conversation.functionPrompt( + message.functionCall.name, + locateJedi(jedi, locationType) + ); + + const unsubscribeFunctionUpdate = r2.onMessageUpdate((content) => { + console.info(content); // "Obi-Wan Kenobi is located in the city of Mos Eisley." + }); + + const unsubscribeFunctionStop = r2.onMessageStreamingStop(() => { + unsubscribeFunctionUpdate(); + unsubscribeFunctionStop(); + }); + } + + unsubscribeUpdates(); + unsubscribeStop(); +}); +``` + ## Documentation View the full documentation [here](https://gpt-turbo.chintristan.io/). The documentation website is auto-generated based on the TSdoc comments in the source code for the latest version of the library. diff --git a/packages/lib/package.json b/packages/lib/package.json index d107a06..2bafd59 100644 --- a/packages/lib/package.json +++ b/packages/lib/package.json @@ -9,7 +9,6 @@ "lint:strict": "npm run lint -- --max-warnings 0", "lint:fix": "npm run lint -- --fix", "build": "npm run lint:strict && rimraf dist && tsc -p tsconfig.build.json && copyfiles -u 1 -e \"src/**/*.ts\" \"src/**/*\" dist", - "sandbox": "ts-node-esm sandbox/index.ts", "docs": "typedoc" }, "keywords": [ diff --git a/packages/lib/src/classes/Conversation.ts b/packages/lib/src/classes/Conversation.ts index 5c9183a..e03c36b 100644 --- a/packages/lib/src/classes/Conversation.ts +++ b/packages/lib/src/classes/Conversation.ts @@ -2,7 +2,6 @@ import { ConversationConfig } from "./ConversationConfig.js"; import { createChatCompletion, createDryChatCompletion, - getMessageSize, } from "../utils/index.js"; import { ModerationException } from "../exceptions/ModerationException.js"; import { Message } from "./Message.js"; @@ -11,6 +10,8 @@ import { v4 as uuid } from "uuid"; import { AddMessageListener, ConversationConfigParameters, + CreateChatCompletionFunctionCallMessage, + CreateChatCompletionFunctionMessage, HandleChatCompletionOptions, PromptOptions, RemoveMessageListener, @@ -81,14 +82,38 @@ export class Conversation { for (const message of messages) { switch (message.role) { case "user": + if (message.content === null) + throw new Error("User message content cannot be null."); await conversation.addUserMessage(message.content); break; case "assistant": - await conversation.addAssistantMessage(message.content); + if (message.content === null) { + if (!message.function_call) + throw new Error("Function call must be provided."); + await conversation.addFunctionCallMessage( + message.function_call + ); + } else { + await conversation.addAssistantMessage(message.content); + } break; case "system": + if (message.content === null) + throw new Error("Context cannot be null."); conversation.setContext(message.content); break; + case "function": + if (!message.name) + throw new Error("Function name must be specified."); + if (message.content === null) + throw new Error( + "Function message content cannot be null." + ); + await conversation.addFunctionMessage( + message.content, + message.name + ); + break; } } @@ -164,6 +189,9 @@ export class Conversation { message, this.config.model ); + if (!assistantMessage.isCompletion()) { + throw new Error("Not a completion message."); + } return this.addMessage(assistantMessage); } @@ -175,9 +203,41 @@ export class Conversation { */ public addUserMessage(message: string) { const userMessage = new Message("user", message, this.config.model); + if (!userMessage.isCompletion()) { + throw new Error("Not a completion message."); + } return this.addMessage(userMessage); } + public addFunctionCallMessage(functionCall: { + name: string; + arguments: Record; + }) { + const functionCallMessage = new Message( + "assistant", + null, + this.config.model + ); + functionCallMessage.functionCall = functionCall; + if (!functionCallMessage.isFunctionCall()) { + throw new Error("Not a function call message."); + } + return this.addMessage(functionCallMessage); + } + + public addFunctionMessage(message: string, name: string) { + const functionMessage = new Message( + "function", + message, + this.config.model + ); + functionMessage.name = name; + if (!functionMessage.isFunction()) { + throw new Error("Not a function message."); + } + return this.addMessage(functionMessage); + } + /** * Get the messages in the conversation. * @@ -335,13 +395,16 @@ export class Conversation { if (fromIndex === -1) { throw new Error(`Message with ID "${id}" not found.`); } - const from = this.messages[fromIndex]; // Find the previous user message - const previousUserMessageIndex = - from.role === "user" ? fromIndex : fromIndex - 1; - const previousUserMessage = this.messages[previousUserMessageIndex]; - if (!previousUserMessage) { + let previousUserMessageIndex = fromIndex; + let previousUserMessage = this.messages[previousUserMessageIndex]; + while (previousUserMessage.role !== "user") { + previousUserMessageIndex--; + if (previousUserMessageIndex < 0) break; + previousUserMessage = this.messages[previousUserMessageIndex]; + } + if (previousUserMessage?.role !== "user") { throw new Error( `Could not find a previous user message to reprompt from (${id}).` ); @@ -370,6 +433,39 @@ export class Conversation { } } + /** + * Sends the result of a user-evaluated function call to the GPT model and gets the assistant's response. + * This method should usually be called after receiving a function_call message from the assistant (using `getChatCompletionResponse()` or `prompt()`) and evaluating your own function with the provided arguments from that message. + * + * @param name The name of the function used to generate the result. This function must be defined in the `functions` config option. + * @param result The result of the function call. If the result is anything other than a string, it will be JSON stringified. Since `result` can be anything, the `T` template is provided for your typing convenience, but is not used internally + * @param options Additional options to pass to the Create Chat Completion API endpoint. This overrides the config passed to the constructor. + * @param requestOptions Additional options to pass for the HTTP request. This overrides the config passed to the constructor. + * @returns The assistant's response as a [`Message`](./Message.js) instance. + */ + public async functionPrompt( + name: string, + result: T, + options?: PromptOptions, + requestOptions?: RequestOptions + ) { + const functionMessage = await this.addFunctionMessage( + typeof result === "string" ? result : JSON.stringify(result), + name + ); + + try { + const assistantMessage = await this.getAssistantResponse( + options, + requestOptions + ); + return assistantMessage; + } catch (e) { + this.removeMessage(functionMessage); + throw e; + } + } + /** * Returns the sum of the token count of each message in the conversation's current messages. The next time `getChatCompletionResponse()` is called, this is the minimum amount of tokens that will be sent to the OpenAI API (estimated). */ @@ -440,7 +536,9 @@ export class Conversation { } private async addMessage(message: Message) { - message.content = message.content.trim(); + if (message.isCompletion() || message.isFunction()) { + message.content = message.content.trim(); + } if (!message.content && message.role === "user") { throw new Error("User message content cannot be empty."); @@ -457,7 +555,7 @@ export class Conversation { } if (message.role === "system") { - this.config.context = message.content; + this.config.context = message.content ?? ""; if (message.content) { // Update the system message or add it if it doesn't exist. if (this.messages[0]?.role === "system") { @@ -492,14 +590,14 @@ export class Conversation { requestOptions: RequestOptions = {} ) { const message = new Message("assistant", "", this.config.model); - const messages = this.messages.map(({ role, content }) => ({ - role, - content, - })); + const messages = this.getCreateChatCompletionMessages(); const unsubscribeStreaming = message.onMessageStreamingStop((m) => { - this.cumulativeSize += this.getSize() + getMessageSize(m.content); - this.cumulativeCost += this.getCost() + m.cost; + // FIXME: Find out how the size is calculated for messages with function calls, fix in Message class and remove this condition + if (message.isFunctionCall()) { + this.cumulativeSize += this.getSize() + m.size; + this.cumulativeCost += this.getCost() + m.cost; + } unsubscribeStreaming(); }); @@ -538,10 +636,7 @@ export class Conversation { requestOptions: RequestOptions = {} ) { const message = new Message("assistant", "", this.config.model); - const messages = this.messages.map(({ role, content }) => ({ - role, - content, - })); + const messages = this.getCreateChatCompletionMessages(); if (this.config.dry) { await new Promise((resolve) => setTimeout(resolve, 1000)); @@ -559,11 +654,29 @@ export class Conversation { ...requestOptions, } ); - message.content = response.choices[0].message?.content ?? ""; + const responseMessage = response.choices[0].message; + message.content = responseMessage.content; + if (responseMessage.function_call) { + try { + message.functionCall = { + name: responseMessage.function_call.name, + arguments: JSON.parse( + responseMessage.function_call.arguments + ), + }; + } catch { + throw new Error( + "Assistant did not generate valid JSON arguments." + ); + } + } } - this.cumulativeSize += this.getSize() + getMessageSize(message.content); - this.cumulativeCost += this.getCost() + message.cost; + // FIXME: Find out how the size is calculated for messages with function calls, fix in Message class and remove this condition + if (message.isFunctionCall()) { + this.cumulativeSize += this.getSize() + message.size; + this.cumulativeCost += this.getCost() + message.cost; + } return message; } @@ -579,4 +692,40 @@ export class Conversation { const assistantMessage = await this.addMessage(completion); return assistantMessage; } + + private getCreateChatCompletionMessages() { + return this.messages.map((message) => { + if (message.isFunctionCall()) { + const m: CreateChatCompletionFunctionCallMessage = { + role: message.role, + content: message.content, + function_call: { + name: message.functionCall.name, + arguments: JSON.stringify( + message.functionCall.arguments + ), + }, + }; + return m; + } + + if (message.isFunction()) { + const m: CreateChatCompletionFunctionMessage = { + content: message.content, + name: message.name, + role: message.role, + }; + return m; + } + + if (message.isCompletion()) { + return { + content: message.content, + role: message.role, + }; + } + + throw new Error("Message type not recognized."); + }); + } } diff --git a/packages/lib/src/classes/ConversationConfig.ts b/packages/lib/src/classes/ConversationConfig.ts index 06c02f2..4d42ad7 100644 --- a/packages/lib/src/classes/ConversationConfig.ts +++ b/packages/lib/src/classes/ConversationConfig.ts @@ -41,6 +41,8 @@ export class ConversationConfig { public temperature: ConversationConfigChatCompletionOptions["temperature"]; public topP: ConversationConfigChatCompletionOptions["top_p"]; public user: ConversationConfigChatCompletionOptions["user"]; + public functions: ConversationConfigChatCompletionOptions["functions"]; + public functionCall: ConversationConfigChatCompletionOptions["function_call"]; private _stop: ConversationConfigChatCompletionOptions["stop"]; private _apiKey!: ConversationConfigChatCompletionRequiredOption<"apiKey">; @@ -66,6 +68,8 @@ export class ConversationConfig { temperature, top_p, user, + functions, + function_call, } = chatCompletionConfig; this.apiKey = apiKey; @@ -83,6 +87,8 @@ export class ConversationConfig { this.temperature = temperature; this.topP = top_p; this.user = user; + this.functions = functions; + this.functionCall = function_call; } /** @@ -116,6 +122,8 @@ export class ConversationConfig { top_p: this.topP, user: this.user, stop: this.stop, + functions: this.functions, + function_call: this.functionCall, }; return conversationConfigSchema.parse(json); } @@ -147,6 +155,8 @@ export class ConversationConfig { temperature: this.temperature, top_p: this.topP, user: this.user, + functions: this.functions, + function_call: this.functionCall, }; } diff --git a/packages/lib/src/classes/Message.ts b/packages/lib/src/classes/Message.ts index e5ca875..13e751e 100644 --- a/packages/lib/src/classes/Message.ts +++ b/packages/lib/src/classes/Message.ts @@ -8,6 +8,9 @@ import { MessageStreamingStartListener, MessageStreamingStopListener, MessageUpdateListener, + FunctionCallMessage, + FunctionMessage, + CompletionMessage, } from "../utils/types.js"; import createModeration from "../utils/createModeration.js"; import { MessageModel, messageSchema } from "../schemas/message.schema.js"; @@ -23,7 +26,14 @@ export class Message { private _role!: ChatCompletionRequestMessageRoleEnum; private _model!: string; - private _content!: string; + private _content!: string | null; + private _name: string | undefined; + private _functionCall: + | { + name: string; + arguments: Record; + } + | undefined; private _flags: string[] | null = null; private _size: number | null = null; private _cost: number | null = null; @@ -41,7 +51,7 @@ export class Message { */ constructor( role: ChatCompletionRequestMessageRoleEnum = "user", - content = "", + content: string | null = "", model = "" ) { this.role = role; @@ -78,6 +88,8 @@ export class Message { role: this.role, content: this.content, model: this.model, + name: this._name, + function_call: this._functionCall, flags: this.flags, }; return messageSchema.parse(json); @@ -207,6 +219,9 @@ export class Message { const reader = response.getReader(); const decoder = new TextDecoder(); + let functionCallName = ""; + let functionCallArguments = ""; + while (true) { const { done, value } = await reader.read(); if (done) break; @@ -219,9 +234,29 @@ export class Message { JSON.parse(chunk); const content = json?.choices?.[0]?.delta?.content ?? null; - - if (!content) continue; - this.content += json.choices[0].delta.content; + const functionCall = + json?.choices?.[0]?.delta?.function_call ?? null; + + if (functionCall) { + const { name, arguments: argsStr } = functionCall; + functionCallName += name ?? ""; + functionCallArguments += argsStr ?? ""; + + let args = this.functionCall + ? { ...this.functionCall.arguments } + : {}; + try { + args = JSON.parse(functionCallArguments); + } finally { + this.functionCall = { + name: functionCallName, + arguments: args, + }; + } + } else { + if (!content) continue; + this.content += content; + } } catch { continue; } @@ -234,6 +269,36 @@ export class Message { } } + /** + * Whether the message is a function call by the assistant + */ + public isFunctionCall(): this is FunctionCallMessage { + return ( + this.role === "assistant" && + this.content === null && + this.functionCall !== undefined + ); + } + + /** + * Whether the message is a function call result by the user + */ + public isFunction(): this is FunctionMessage { + return this.role === "function" && this.name !== undefined; + } + + /** + * Whether the message is a regular chat completion message + */ + public isCompletion(): this is CompletionMessage { + return ( + this.role !== "function" && + this.content !== null && + this.functionCall === undefined && + this.name === undefined + ); + } + private notifyMessageUpdate() { const content = this.content; this.messageUpdateListeners.forEach((listener) => @@ -279,6 +344,26 @@ export class Message { this.notifyMessageUpdate(); } + get name() { + return this._name; + } + + set name(name) { + this._name = name; + } + + get functionCall() { + return this._functionCall; + } + + set functionCall(functionCall) { + this._functionCall = functionCall; + this.flags = null; + this.size = null; + this.cost = null; + this.content = null; // also call notifyMessageUpdate() to notify listeners + } + /** The flags detected by OpenAI's moderation API. Only set after calling `moderate`. */ get flags() { return this._flags?.slice() ?? null; @@ -298,6 +383,10 @@ export class Message { if (this._size) { return this._size; } + // FIXME: Find out how the size is calculated for messages with function calls + if (this._content === null) { + return 0; + } const s = getMessageSize(this._content); this.size = s; return this._size as typeof s; diff --git a/packages/lib/src/config/constants.ts b/packages/lib/src/config/constants.ts index 52fa350..45f5c43 100644 --- a/packages/lib/src/config/constants.ts +++ b/packages/lib/src/config/constants.ts @@ -20,9 +20,13 @@ export const PRICING_TABLE = { completion: 0, }, "3.5": { - prompt: 0.000002, + prompt: 0.0000015, completion: 0.000002, }, + "3.5-16k": { + prompt: 0.000003, + completion: 0.000004, + }, "4": { prompt: 0.00003, completion: 0.00006, diff --git a/packages/lib/src/schemas/conversationConfig.schema.ts b/packages/lib/src/schemas/conversationConfig.schema.ts index 351affa..817efc0 100644 --- a/packages/lib/src/schemas/conversationConfig.schema.ts +++ b/packages/lib/src/schemas/conversationConfig.schema.ts @@ -1,4 +1,5 @@ import { z } from "zod"; +import { jsonSchemaObjectSchema } from "./jsonSchema.schema.js"; /** * A JSON representation of a ConversationConfig instance. @@ -23,6 +24,24 @@ export const conversationConfigSchema = z.object({ ) .optional(), user: z.string().optional(), + functions: z + .array( + z.object({ + name: z.string().min(1), + description: z.string().min(1).optional(), + properties: jsonSchemaObjectSchema.optional(), + }) + ) + .optional(), + function_call: z + .union([ + z.literal("none"), + z.literal("auto"), + z.object({ + name: z.string().min(1), + }), + ]) + .optional(), }); /** diff --git a/packages/lib/src/schemas/index.ts b/packages/lib/src/schemas/index.ts index cbd443c..4b4623f 100644 --- a/packages/lib/src/schemas/index.ts +++ b/packages/lib/src/schemas/index.ts @@ -1,3 +1,4 @@ export * from "./conversation.schema.js"; export * from "./conversationConfig.schema.js"; export * from "./message.schema.js"; +export * from "./jsonSchema.schema.js"; diff --git a/packages/lib/src/schemas/jsonSchema.schema.ts b/packages/lib/src/schemas/jsonSchema.schema.ts new file mode 100644 index 0000000..bb289df --- /dev/null +++ b/packages/lib/src/schemas/jsonSchema.schema.ts @@ -0,0 +1,185 @@ +import { z } from "zod"; + +// The following Zod schemas are all based on the JSON Schema specification and may not be complete: +// https://json-schema.org/understanding-json-schema/reference/index.html + +// -------------------- Base -------------------- +export const jsonSchemaBaseSchema = z.object({ + title: z.string().optional(), + description: z.string().optional(), + default: z.any().optional(), + examples: z.array(z.any()).optional(), + readOnly: z.boolean().optional(), + writeOnly: z.boolean().optional(), + deprecated: z.boolean().optional(), + $comment: z.string().optional(), +}); + +export type JsonSchemaBase = z.infer; + +// -------------------- String -------------------- + +export const jsonSchemaStringSchema = jsonSchemaBaseSchema.extend({ + type: z.literal("string"), + minLength: z.number().optional(), + maxLength: z.number().optional(), + pattern: z.string().optional(), + enum: z.array(z.string()).optional(), + format: z.union([ + z.literal("date-time"), + z.literal("time"), + z.literal("date"), + z.literal("duration"), + z.literal("email"), + z.literal("idn-email"), + z.literal("hostname"), + z.literal("idn-hostname"), + z.literal("ipv4"), + z.literal("ipv6"), + z.literal("uuid"), + z.literal("uri"), + z.literal("uri-reference"), + z.literal("iri"), + z.literal("iri-reference"), + z.literal("uri-template"), + z.literal("json-pointer"), + z.literal("relative-json-pointer"), + z.literal("regex"), + ]), +}); + +export type JsonSchemaString = z.infer; + +// -------------------- Number -------------------- + +export const jsonSchemaNumberSchema = jsonSchemaBaseSchema.extend({ + type: z.literal("number"), + multipleOf: z.number().optional(), + minimum: z.number().optional(), + exclusiveMinimum: z.union([z.number(), z.boolean()]).optional(), + maximum: z.number().optional(), + exclusiveMaximum: z.union([z.number(), z.boolean()]).optional(), +}); + +export type JsonSchemaNumber = z.infer; + +// -------------------- Boolean -------------------- + +export const jsonSchemaBooleanSchema = jsonSchemaBaseSchema.extend({ + type: z.literal("boolean"), +}); + +export type JsonSchemaBoolean = z.infer; + +// -------------------- Null -------------------- + +export const jsonSchemaNullSchema = jsonSchemaBaseSchema.extend({ + type: z.literal("null"), +}); + +export type JsonSchemaNull = z.infer; + +// -------------------- Enum -------------------- + +export const jsonSchemaEnumSchema = jsonSchemaBaseSchema.extend({ + enum: z.array(z.union([z.string(), z.number(), z.boolean(), z.null()])), +}); + +export type JsonSchemaEnum = z.infer; + +// -------------------- Const -------------------- + +export const jsonSchemaConstSchema = jsonSchemaBaseSchema.extend({ + const: z.union([z.string(), z.number(), z.boolean(), z.null()]), +}); + +export type JsonSchemaConst = z.infer; + +// -------------------- Object (without recursive types) -------------------- + +export const baseJsonSchemaObjectSchema = jsonSchemaBaseSchema.extend({ + type: z.literal("object"), + required: z.array(z.string()).optional(), + unevaluatedProperties: z.boolean().optional(), + minProperties: z.number().optional(), + maxProperties: z.number().optional(), +}); + +export type BaseJsonSchemaObject = z.infer; + +// -------------------- Array (without recursive types) -------------------- + +export const baseJsonSchemaArraySchema = jsonSchemaBaseSchema.extend({ + type: z.literal("array"), + minContains: z.number().optional(), + maxContains: z.number().optional(), + minItems: z.number().optional(), + maxItems: z.number().optional(), + uniqueItems: z.boolean().optional(), +}); + +export type BaseJsonSchemaArray = z.infer; + +// -------------------- JSON -------------------- + +export const baseJsonSchemaSchema = z.union([ + jsonSchemaStringSchema, + jsonSchemaNumberSchema, + jsonSchemaBooleanSchema, + jsonSchemaNullSchema, + jsonSchemaEnumSchema, + jsonSchemaConstSchema, +]); + +// Manually define the recursive types +export type JsonSchema = + | z.infer + | (BaseJsonSchemaObject & { + properties?: Record; + patternProperties?: Record; + additionalProperties?: boolean | JsonSchema; + }) + | (BaseJsonSchemaArray & { + items?: JsonSchema | boolean; + prefixItems?: JsonSchema[]; + contains?: JsonSchema; + }); + +// Define the recursive types using zod's "lazy" function +export const jsonSchemaSchema: z.ZodType = z.union([ + baseJsonSchemaSchema, + baseJsonSchemaObjectSchema.extend({ + properties: z.record(z.lazy(() => jsonSchemaSchema)).optional(), + patternProperties: z.record(z.lazy(() => jsonSchemaSchema)).optional(), + additionalProperties: z + .union([z.boolean(), z.lazy(() => jsonSchemaSchema)]) + .optional(), + }), + baseJsonSchemaArraySchema.extend({ + items: z + .union([z.lazy(() => jsonSchemaSchema), z.boolean()]) + .optional(), + prefixItems: z.array(z.lazy(() => jsonSchemaSchema)).optional(), + contains: z.lazy(() => jsonSchemaSchema).optional(), + }), +]); + +// -------------------- Object -------------------- + +export const jsonSchemaObjectSchema = baseJsonSchemaObjectSchema.extend({ + properties: z.record(jsonSchemaSchema).optional(), + patternProperties: z.record(jsonSchemaSchema).optional(), + additionalProperties: z.union([z.boolean(), jsonSchemaSchema]).optional(), +}); + +export type JsonSchemaObject = z.infer; + +// -------------------- Array -------------------- + +export const jsonSchemaArraySchema = baseJsonSchemaArraySchema.extend({ + items: z.union([jsonSchemaSchema, z.boolean()]).optional(), + prefixItems: z.array(jsonSchemaSchema).optional(), + contains: jsonSchemaSchema.optional(), +}); + +export type JsonSchemaArray = z.infer; diff --git a/packages/lib/src/schemas/message.schema.ts b/packages/lib/src/schemas/message.schema.ts index afa1108..5899ecd 100644 --- a/packages/lib/src/schemas/message.schema.ts +++ b/packages/lib/src/schemas/message.schema.ts @@ -9,8 +9,16 @@ export const messageSchema = z.object({ z.literal("user"), z.literal("assistant"), z.literal("system"), + z.literal("function"), ]), - content: z.string(), + content: z.string().nullable(), + name: z.string().min(1).optional(), + function_call: z + .object({ + name: z.string().min(1), + arguments: z.record(z.any()), + }) + .optional(), model: z.string().optional(), flags: z.array(z.string()).or(z.null()).optional(), }); diff --git a/packages/lib/src/utils/getPricing.ts b/packages/lib/src/utils/getPricing.ts index 6754d12..4cd7597 100644 --- a/packages/lib/src/utils/getPricing.ts +++ b/packages/lib/src/utils/getPricing.ts @@ -7,7 +7,8 @@ import { PRICING_TABLE } from "../config/constants.js"; */ export default (model: string) => { if (!model) return PRICING_TABLE["unknown"]; - if (model.startsWith("gpt-3.5")) return PRICING_TABLE["3.5"]; + if (model.startsWith("gpt-3.5-turbo-16k")) return PRICING_TABLE["3.5-16k"]; + if (model.startsWith("gpt-3.5-turbo")) return PRICING_TABLE["3.5"]; if (model.startsWith("gpt-4-32k")) return PRICING_TABLE["4-32k"]; if (model.startsWith("gpt-4")) return PRICING_TABLE["4"]; return PRICING_TABLE["unknown"]; diff --git a/packages/lib/src/utils/types.ts b/packages/lib/src/utils/types.ts index b472ff0..ae5e825 100644 --- a/packages/lib/src/utils/types.ts +++ b/packages/lib/src/utils/types.ts @@ -2,6 +2,7 @@ import { Conversation } from "../classes/Conversation.js"; import { ConversationConfig } from "../classes/ConversationConfig.js"; import { Message } from "../classes/Message.js"; +import { JsonSchemaObject } from "../schemas/jsonSchema.schema.js"; /** * Supported values for the `role` property of a message. @@ -9,7 +10,8 @@ import { Message } from "../classes/Message.js"; export type ChatCompletionRequestMessageRoleEnum = | "user" | "system" - | "assistant"; + | "assistant" + | "function"; /** * Overridable {@link CreateChatCompletionRequest} properties of a {@link Conversation}'s config for a single prompt. @@ -120,12 +122,75 @@ export interface CreateChatCompletionMessage { * * @see {@link ChatCompletionRequestMessageRoleEnum} */ - role: ChatCompletionRequestMessageRoleEnum; + role: Exclude; /** * The message content. */ content: string; + + function_call?: undefined; + name?: undefined; +} + +export type CompletionMessage = Message & { + role: Exclude; + content: string; + functionCall: undefined; + name: undefined; +}; + +/** + * A function_call-related message in OpenAI's chat format. + * + * @see {@link https://platform.openai.com/docs/api-reference/chat/create#chat/create-messages Create Chat Completion Request Body - messages} + */ +export interface CreateChatCompletionFunctionCallMessage { + role: Extract; + content: null; + function_call: { + name: string; + arguments: string; + }; + + name?: undefined; +} + +export type FunctionCallMessage = Message & { + role: "assistant"; + content: null; + functionCall: { + name: string; + arguments: Record; + }; +}; + +/** + * A function-related message in OpenAI's chat format. + * + * @see {@link https://platform.openai.com/docs/api-reference/chat/create#chat/create-messages Create Chat Completion Request Body - messages} + */ +export interface CreateChatCompletionFunctionMessage { + role: Extract; + content: string; + name: string; + + function_call?: undefined; +} + +export type FunctionMessage = Message & { + role: "function"; + name: string; + content: string; +}; + +/** + * A function that can be called by the model. + */ +export interface CreateChatCompletionFunction { + name: string; + description?: string; + parameters?: JsonSchemaObject; } /** @@ -157,7 +222,25 @@ export interface CreateChatCompletionRequest { * * @see {@link https://platform.openai.com/docs/api-reference/chat/create#chat/create-messages Create Chat Completion Request Body - messages} */ - messages: CreateChatCompletionMessage[]; + messages: ( + | CreateChatCompletionMessage + | CreateChatCompletionFunctionCallMessage + | CreateChatCompletionFunctionMessage + )[]; + + /** + * A list of functions the model may generate JSON inputs for. + * + * @see {@link https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions Create Chat Completion Request Body - functions} + */ + functions?: CreateChatCompletionFunction[]; + + /** + * Controls how the model responds to function calls. "none" means the model does not call a function, and responds to the end-user. "auto" means the model can pick between an end-user or calling a function. Specifying a particular function via `{"name":\ "my_function"}` forces the model to call that function. "none" is the default when no functions are present. "auto" is the default if functions are present. + * + * @see {@link https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call Create Chat Completion Request Body - function_call} + */ + function_call?: "none" | "auto" | { name: string }; /** * What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. @@ -300,7 +383,10 @@ export interface CreateChatCompletionResponse { /** * The assistant's response to the prompt. */ - message: CreateChatCompletionMessage; + message: + | CreateChatCompletionMessage + | CreateChatCompletionFunctionCallMessage + | CreateChatCompletionFunctionMessage; /** * The reason the chat completion ended. Always "stop" for non-streamed completions. @@ -347,7 +433,9 @@ export interface CreateChatCompletionStreamResponse { /** * A token representing part of the assistant's response to the prompt. */ - delta: CreateChatCompletionMessage; + delta: + | CreateChatCompletionMessage + | CreateChatCompletionFunctionCallMessage; /** * The reason the chat completion ended. `null` when streaming, otherwise "stop". @@ -517,7 +605,7 @@ export type MessageUpdateListener = ( /** * The new content of the message. */ - content: string, + content: string | null, /** * The {@link Message message} instance that was updated. diff --git a/packages/nest/package.json b/packages/nest/package.json index 36874eb..2a68468 100644 --- a/packages/nest/package.json +++ b/packages/nest/package.json @@ -9,7 +9,8 @@ "lint": "eslint --ext .ts src", "lint:strict": "npm run lint -- --max-warnings 0", "lint:fix": "npm run lint -- --fix", - "build": "nest build", + "tscheck": "tsc --noEmit", + "build": "npm run lint:strict && npm run tscheck && nest build", "start": "nest start", "dev": "nest start --watch", "debug": "nest start --debug --watch", diff --git a/packages/web/package.json b/packages/web/package.json index fbfa002..1985f5c 100644 --- a/packages/web/package.json +++ b/packages/web/package.json @@ -10,7 +10,7 @@ "lint:strict": "npm run lint -- --max-warnings 0", "lint:fix": "npm run lint -- --fix", "tscheck": "tsc --noEmit", - "build": "npm run lint:strict && rimraf dist && vite build", + "build": "npm run lint:strict && npm run tscheck && rimraf dist && vite build", "dev": "vite", "preview": "npm run build && vite preview" }, diff --git a/packages/web/src/components/Message.tsx b/packages/web/src/components/Message.tsx index 78bea91..10a103e 100644 --- a/packages/web/src/components/Message.tsx +++ b/packages/web/src/components/Message.tsx @@ -20,6 +20,7 @@ import { BiUser, BiX, } from "react-icons/bi"; +import { AiOutlineFunction } from "react-icons/ai"; import React from "react"; import TippedActionIcon from "./TippedActionIcon"; import { useForm } from "@mantine/form"; @@ -56,7 +57,7 @@ export default ({ message }: MessageProps) => { const [isEditing, setIsEditing] = React.useState(false); const form = useForm({ initialValues: { - content: message.content, + content: message.content ?? "", }, }); const editFormRef = React.useRef(null); @@ -95,6 +96,8 @@ export default ({ message }: MessageProps) => { return SiOpenai; case "user": return BiUser; + case "function": + return AiOutlineFunction; case "system": default: return BiCog; @@ -107,14 +110,33 @@ export default ({ message }: MessageProps) => { return "teal"; case "user": return "blue"; + case "function": + return "cyan"; case "system": default: return "gray"; } })(); + const messageContent = (() => { + if (message.isCompletion()) { + return message.content; + } + if (message.isFunction()) { + return `${message.name}() => ${message.content}`; + } + if (message.isFunctionCall()) { + const { name, arguments: args } = message.functionCall; + const parameters = Object.entries(args) + .map(([param, value]) => `${param}=${value}`) + .join(", "); + return `${name}(${parameters})`; + } + return "[Unknown message type]"; + })(); + const MessageContent = React.useMemo(() => { - const lines = message.content.split("\n"); + const lines = messageContent.split("\n"); const output: JSX.Element[] = []; let isCode = false; @@ -168,12 +190,12 @@ export default ({ message }: MessageProps) => { } return output; - }, [message.content, message.isFlagged]); + }, [message, messageContent]); const Actions = React.useMemo(() => { if (message.role === "system") return null; - if (message.role === "assistant") { + if (message.role === "assistant" || message.role === "function") { return ( reprompt()}> @@ -253,11 +275,13 @@ export default ({ message }: MessageProps) => { title="Save Prompt" centered > - + {message.content && ( + + )} );