diff --git a/package-lock.json b/package-lock.json index 7f81ff2..443c7ac 100644 --- a/package-lock.json +++ b/package-lock.json @@ -20394,7 +20394,8 @@ "license": "MIT", "dependencies": { "gpt-token-utils": "^1.2.0", - "uuid": "^9.0.0" + "uuid": "^9.0.0", + "zod": "^3.21.4" }, "devDependencies": { "@types/uuid": "^9.0.1", diff --git a/packages/cli/src/contexts/providers/PersistenceProvider.tsx b/packages/cli/src/contexts/providers/PersistenceProvider.tsx index cafbdbe..277860e 100644 --- a/packages/cli/src/contexts/providers/PersistenceProvider.tsx +++ b/packages/cli/src/contexts/providers/PersistenceProvider.tsx @@ -1,7 +1,5 @@ import React from "react"; import fs from "fs"; -import { Persistence, persistenceSchema } from "../../entities/persistence.js"; -import { PersistenceMessage } from "../../entities/persistenceMessage.js"; import useConversationManager from "../../hooks/useConversationManager.js"; import { PersistenceContext, @@ -20,19 +18,7 @@ export default ({ children }: PersistenceProviderProps) => { const save = React.useCallback(() => { if (!conversation || !saveFile) return; - const persistedConversations: Persistence = { - conversation: { - ...conversation.getConfig(), - messages: conversation.getMessages().map( - (message): PersistenceMessage => ({ - content: message.content, - role: message.role, - }) - ), - }, - }; - const parsed = persistenceSchema.parse(persistedConversations); - const json = JSON.stringify(parsed); + const json = JSON.stringify(conversation.toJSON()); fs.writeFileSync(saveFile, json); }, [conversation, saveFile]); @@ -58,45 +44,7 @@ export default ({ children }: PersistenceProviderProps) => { try { const data = fs.readFileSync(loadFile, "utf8"); const json = JSON.parse(data); - const { conversation: persistedConversation } = - persistenceSchema.parse(json); - const { messages, disableModeration, ...config } = - persistedConversation; - const newConversation = new Conversation({ - ...config, - disableModeration: true, - }); - - for (const message of messages) { - try { - switch (message.role) { - case "user": - await newConversation.addUserMessage( - message.content - ); - break; - case "assistant": - await newConversation.addAssistantMessage( - message.content - ); - break; - case "system": - newConversation.setContext(message.content); - } - } catch (e) { - console.error( - "Error while loading message", - (e as Error).message - ); - } - } - - newConversation.setConfig( - { - disableModeration, - }, - true - ); + const newConversation = await Conversation.fromJSON(json); setConversation(newConversation); } catch (e) { console.error(e); @@ -117,10 +65,7 @@ export default ({ children }: PersistenceProviderProps) => { messageOffs.push( message.onMessageStreamingStop(() => { save(); - }) - ); - - messageOffs.push( + }), message.onMessageUpdate((_, m) => { if (m.isStreaming) return; save(); diff --git a/packages/cli/src/entities/persistence.ts b/packages/cli/src/entities/persistence.ts deleted file mode 100644 index a66d933..0000000 --- a/packages/cli/src/entities/persistence.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { z } from "zod"; -import { persistenceConversationSchema } from "./persistenceConversation.js"; - -export const persistenceSchema = z.object({ - conversation: persistenceConversationSchema, -}); - -export type Persistence = z.infer; diff --git a/packages/cli/src/entities/persistenceConversation.ts b/packages/cli/src/entities/persistenceConversation.ts deleted file mode 100644 index 664c2b2..0000000 --- a/packages/cli/src/entities/persistenceConversation.ts +++ /dev/null @@ -1,26 +0,0 @@ -import { z } from "zod"; -import { persistenceMessageSchema } from "./persistenceMessage.js"; - -export const persistenceConversationSchema = z.object({ - messages: z.array(persistenceMessageSchema), - - context: z.string(), - dry: z.boolean(), - disableModeration: z.boolean().or(z.literal("soft")), - - apiKey: z.string(), - model: z.string(), - stream: z.boolean(), - frequency_penalty: z.number().optional(), - presence_penalty: z.number().optional(), - max_tokens: z.number().optional(), - logit_bias: z.record(z.number(), z.number()).optional(), - stop: z.string().or(z.array(z.string())).or(z.null()).optional(), - temperature: z.number().optional(), - top_p: z.number().optional(), - user: z.string().optional(), -}); - -export type PersistenceConversation = z.infer< - typeof persistenceConversationSchema ->; diff --git a/packages/cli/src/entities/persistenceMessage.ts b/packages/cli/src/entities/persistenceMessage.ts deleted file mode 100644 index 56839fa..0000000 --- a/packages/cli/src/entities/persistenceMessage.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { z } from "zod"; - -export const persistenceMessageSchema = z.object({ - content: z.string(), - role: z.enum(["user", "system", "assistant"]), -}); - -export type PersistenceMessage = z.infer; diff --git a/packages/cli/src/index.tsx b/packages/cli/src/index.tsx index 57a2398..f2ade23 100644 --- a/packages/cli/src/index.tsx +++ b/packages/cli/src/index.tsx @@ -132,7 +132,9 @@ const saveFile = (() => { const savePath = path.normalize( typeof save === "string" ? save - : `gptturbo-conversation-${new Date().toISOString()}` + : `gptturbo-conversation-${new Date() + .toISOString() + .replace(/:/g, "-")}` ); return savePath.endsWith(".json") ? savePath : `${savePath}.json`; })(); diff --git a/packages/discord/src/managers/ConversationManager.ts b/packages/discord/src/managers/ConversationManager.ts index 1031693..f296990 100644 --- a/packages/discord/src/managers/ConversationManager.ts +++ b/packages/discord/src/managers/ConversationManager.ts @@ -1,4 +1,4 @@ -import { ChatCompletionRequestMessageRoleEnum, Conversation } from "gpt-turbo"; +import { Conversation, MessageModel } from "gpt-turbo"; import getConversationConfig from "../utils/getConversationConfig.js"; import BotException from "../exceptions/BotException.js"; import QuotaManager from "./QuotaManager.js"; @@ -8,10 +8,7 @@ export default class ConversationManager { constructor(private readonly quotaManager: QuotaManager) {} public async getChatCompletion( - messages: ( - | { content: string; role: ChatCompletionRequestMessageRoleEnum } - | string - )[], + messages: (MessageModel | string)[], userId: string ) { const user: ConversationUser = `discord-${userId}`; @@ -52,15 +49,9 @@ export default class ConversationManager { } private getAlternatedMessages( - messages: ( - | { content: string; role: ChatCompletionRequestMessageRoleEnum } - | string - )[] - ): { content: string; role: ChatCompletionRequestMessageRoleEnum }[] { - const alternatedMessages: { - content: string; - role: ChatCompletionRequestMessageRoleEnum; - }[] = []; + messages: (MessageModel | string)[] + ): MessageModel[] { + const alternatedMessages: MessageModel[] = []; messages.forEach((message, index) => { alternatedMessages.push({ diff --git a/packages/discord/src/message-handlers/DMReplyHandler.ts b/packages/discord/src/message-handlers/DMReplyHandler.ts index 3bad21b..16d0467 100644 --- a/packages/discord/src/message-handlers/DMReplyHandler.ts +++ b/packages/discord/src/message-handlers/DMReplyHandler.ts @@ -3,7 +3,7 @@ import MessageHandler from "./MessageHandler.js"; import getPromptAndReplyMessages from "../utils/getPromptAndReplyMessages.js"; import BotException from "../exceptions/BotException.js"; import getCleanContent from "../utils/getCleanContent.js"; -import { ChatCompletionRequestMessageRoleEnum } from "gpt-turbo"; +import { MessageModel } from "gpt-turbo"; export default class DMReplyHandler extends MessageHandler { public get name(): string { @@ -41,12 +41,7 @@ export default class DMReplyHandler extends MessageHandler { const [conversationMessages, prompt] = await Promise.all([ Promise.all( messages.map( - async ( - m - ): Promise<{ - content: string; - role: ChatCompletionRequestMessageRoleEnum; - }> => ({ + async (m): Promise => ({ content: (await getCleanContent(m)) || "Hello", role: m.author.id === message.author.id diff --git a/packages/lib/package.json b/packages/lib/package.json index d66d568..9ea2585 100644 --- a/packages/lib/package.json +++ b/packages/lib/package.json @@ -71,7 +71,8 @@ }, "dependencies": { "gpt-token-utils": "^1.2.0", - "uuid": "^9.0.0" + "uuid": "^9.0.0", + "zod": "^3.21.4" }, "gitHead": "328d04dc73d86236dba158a181581a53caf37720" } diff --git a/packages/lib/src/classes/Conversation.ts b/packages/lib/src/classes/Conversation.ts index 28a0f46..5c9183a 100644 --- a/packages/lib/src/classes/Conversation.ts +++ b/packages/lib/src/classes/Conversation.ts @@ -10,7 +10,6 @@ import { MessageRoleException } from "../exceptions/index.js"; import { v4 as uuid } from "uuid"; import { AddMessageListener, - ChatCompletionRequestMessageRoleEnum, ConversationConfigParameters, HandleChatCompletionOptions, PromptOptions, @@ -18,6 +17,11 @@ import { RequestOptions, } from "../utils/types.js"; import { DEFAULT_DISABLEMODERATION } from "../index.js"; +import { + ConversationModel, + conversationSchema, +} from "../schemas/conversation.schema.js"; +import { MessageModel } from "../schemas/message.schema.js"; /** * A Conversation manages the messages sent to and from the OpenAI API and handles the logic for providing the message history to the API for each prompt. @@ -40,94 +44,30 @@ export class Conversation { * Creates a new Conversation instance. * * @param config The configuration for this conversation. See {@link ConversationConfigParameters} - * @param options The HTTP request options. See {@link RequestOptions} + * @param requestOptions The HTTP request options. See {@link RequestOptions} */ constructor( config: ConversationConfigParameters = {}, - options: RequestOptions = {} + requestOptions: RequestOptions = {} ) { this.config = new ConversationConfig(config); - this.requestOptions = options; + this.requestOptions = requestOptions; this.clearMessages(); } - private notifyMessageAdded(message: Message) { - this.addMessageListeners.forEach((listener) => listener(message)); - } - - private notifyMessageRemoved(message: Message) { - this.removeMessageListeners.forEach((listener) => listener(message)); - } - - private async addMessage(message: Message) { - message.content = message.content.trim(); - - if (!message.content && message.role === "user") { - throw new Error("User message content cannot be empty."); - } - - if (this.config.isModerationEnabled) { - const flags = await message.moderate( - this.config.apiKey, - this.requestOptions - ); - if (this.config.isModerationStrict && flags.length > 0) { - throw new ModerationException(flags); - } - } - - if (message.role === "system") { - this.config.context = message.content; - if (message.content) { - // Update the system message or add it if it doesn't exist. - if (this.messages[0]?.role === "system") { - this.messages[0] = message; - } else { - this.messages.unshift(message); - } - } else if (this.messages[0]?.role === "system") { - // Remove the system message if it exists and the new content is empty. - this.messages.shift(); - } - } else { - if ( - this.messages[this.messages.length - 1]?.role === message.role - ) { - throw new MessageRoleException(); - } - this.messages.push(message); - } - - this.notifyMessageAdded(message); - return message; - } - - private addSystemMessage(message: string) { - const systemMessage = new Message("system", message, this.config.model); - return this.addMessage(systemMessage); - } - /** - * Creates a new Conversation instance from a list of messages. Useful for loading a conversation from a database. + * Creates a new Conversation instance from a list of messages. * - * @param messages The messages to add to the conversation. - * Can be an array of strings (starting from the user prompt and alternating between user/assistant) or an array of objects with a `role` and `content` property. - * Regardless of the format, the messages should alternate between user and assistant messages. (or an error will be thrown) + * @param messages The messages (as JSON) to add to the conversation. See {@link MessageModel} * @param config The configuration for this conversation. See {@link ConversationConfigParameters} - * @param options The HTTP request options. See {@link RequestOptions} + * @param requestOptions The HTTP request options. See {@link RequestOptions} * @param disableInitialModeration Whether to disable moderation for the initial messages. Defaults to `true` to prevent multiple API calls in a short period of time. * @returns The new Conversation instance. */ public static async fromMessages( - messages: ( - | { - role: ChatCompletionRequestMessageRoleEnum; - content: string; - } - | string - )[], + messages: MessageModel[], config: ConversationConfigParameters = {}, - options: RequestOptions = {}, + requestOptions: RequestOptions = {}, disableInitialModeration: ConversationConfigParameters["disableModeration"] = true ) { const conversation = new Conversation( @@ -135,20 +75,10 @@ export class Conversation { ...config, disableModeration: disableInitialModeration, }, - options + requestOptions ); - let isUserMessage = true; - const initialMessages = messages.map((message) => { - if (typeof message === "string") { - const role = isUserMessage ? "user" : "assistant"; - isUserMessage = !isUserMessage; - return { role, content: message }; - } - return message; - }); - - for (const message of initialMessages) { + for (const message of messages) { switch (message.role) { case "user": await conversation.addUserMessage(message.content); @@ -158,6 +88,7 @@ export class Conversation { break; case "system": conversation.setContext(message.content); + break; } } @@ -172,6 +103,55 @@ export class Conversation { return conversation; } + /** + * Creates a new Conversation instance from a JSON object. + * + * @param json The JSON object of the Conversation instance. + * @param config Overrides for the loaded configuration of this conversation. See {@link ConversationConfigParameters} + * @param requestOptions Overrides for the loaded HTTP request options. See {@link RequestOptions} + * @param disableInitialModeration Whether to disable moderation for the initial messages. Defaults to `true` to prevent multiple API calls in a short period of time. + * @returns The new Conversation instance. + */ + public static async fromJSON( + json: ConversationModel, + config: ConversationConfigParameters = {}, + requestOptions: RequestOptions = {}, + disableInitialModeration: ConversationConfigParameters["disableModeration"] = true + ) { + const conversationJson = conversationSchema.parse(json); + const conversation = await Conversation.fromMessages( + conversationJson.messages, + { + ...(conversationJson.config ?? {}), + ...config, + }, + { + ...(conversationJson.requestOptions ?? {}), + ...requestOptions, + }, + disableInitialModeration + ); + + if (conversationJson.id) conversation.id = conversationJson.id; + + return conversation; + } + + /** + * Serializes the `Conversation` to JSON. + * + * @returns A JSON representation of the `Conversation` instance. + */ + public toJSON(): ConversationModel { + const json: ConversationModel = { + id: this.id, + config: this.config.toJSON(), + messages: this.messages.map((message) => message.toJSON()), + requestOptions: this.requestOptions, + }; + return conversationSchema.parse(json); + } + /** * Adds a message with the role of `"assistant"` to the conversation. * @@ -274,87 +254,6 @@ export class Conversation { this.notifyMessageRemoved(removedMessage); } - private async handleStreamedResponse( - options: HandleChatCompletionOptions = {}, - requestOptions: RequestOptions = {} - ) { - const message = new Message("assistant", "", this.config.model); - const messages = this.messages.map(({ role, content }) => ({ - role, - content, - })); - - const unsubscribeStreaming = message.onMessageStreamingStop((m) => { - this.cumulativeSize += this.getSize() + getMessageSize(m.content); - this.cumulativeCost += this.getCost() + m.cost; - unsubscribeStreaming(); - }); - - if (this.config.dry) { - const response = createDryChatCompletion( - this.messages[messages.length - 1]?.content ?? "", - { - model: this.config.model, - } - ); - message.readContentFromStream(response); - } else { - createChatCompletion( - { - ...this.config.chatCompletionConfig, - ...options, - stream: true, - messages, - }, - { - ...this.requestOptions, - ...requestOptions, - } - ).then((response) => { - // Using .then() to get the message out as soon as possible, since the content is known to be empty at first. - // This gives time for client code to subscribe to the streaming events. - message.readContentFromStream(response); - }); - } - - return message; - } - - private async handleNonStreamedResponse( - options: HandleChatCompletionOptions = {}, - requestOptions: RequestOptions = {} - ) { - const message = new Message("assistant", "", this.config.model); - const messages = this.messages.map(({ role, content }) => ({ - role, - content, - })); - - if (this.config.dry) { - await new Promise((resolve) => setTimeout(resolve, 1000)); - message.content = messages[messages.length - 1]?.content ?? null; - } else { - const response = await createChatCompletion( - { - ...this.config.chatCompletionConfig, - ...options, - stream: false, - messages, - }, - { - ...this.requestOptions, - ...requestOptions, - } - ); - message.content = response.choices[0].message?.content ?? ""; - } - - this.cumulativeSize += this.getSize() + getMessageSize(message.content); - this.cumulativeCost += this.getCost() + message.cost; - - return message; - } - /** * Sends a Create Chat Completion request to the OpenAI API using the current messages stored in the conversation's history. * @@ -372,18 +271,6 @@ export class Conversation { : this.handleNonStreamedResponse(options, requestOptions); } - private async getAssistantResponse( - options?: PromptOptions, - requestOptions?: RequestOptions - ) { - const completion = await this.getChatCompletionResponse( - options, - requestOptions - ); - const assistantMessage = await this.addMessage(completion); - return assistantMessage; - } - /** * This is the **recommended** way to interact with the GPT model. It's a wrapper method around other public methods that handles the logic of adding a user message, sending a request to the OpenAI API, and adding the assistant's response. * @@ -543,4 +430,153 @@ export class Conversation { this.setContext(config.context); } } + + private notifyMessageAdded(message: Message) { + this.addMessageListeners.forEach((listener) => listener(message)); + } + + private notifyMessageRemoved(message: Message) { + this.removeMessageListeners.forEach((listener) => listener(message)); + } + + private async addMessage(message: Message) { + message.content = message.content.trim(); + + if (!message.content && message.role === "user") { + throw new Error("User message content cannot be empty."); + } + + if (this.config.isModerationEnabled) { + const flags = await message.moderate( + this.config.apiKey, + this.requestOptions + ); + if (this.config.isModerationStrict && flags.length > 0) { + throw new ModerationException(flags); + } + } + + if (message.role === "system") { + this.config.context = message.content; + if (message.content) { + // Update the system message or add it if it doesn't exist. + if (this.messages[0]?.role === "system") { + this.messages[0] = message; + } else { + this.messages.unshift(message); + } + } else if (this.messages[0]?.role === "system") { + // Remove the system message if it exists and the new content is empty. + this.messages.shift(); + } + } else { + if ( + this.messages[this.messages.length - 1]?.role === message.role + ) { + throw new MessageRoleException(); + } + this.messages.push(message); + } + + this.notifyMessageAdded(message); + return message; + } + + private addSystemMessage(message: string) { + const systemMessage = new Message("system", message, this.config.model); + return this.addMessage(systemMessage); + } + + private async handleStreamedResponse( + options: HandleChatCompletionOptions = {}, + requestOptions: RequestOptions = {} + ) { + const message = new Message("assistant", "", this.config.model); + const messages = this.messages.map(({ role, content }) => ({ + role, + content, + })); + + const unsubscribeStreaming = message.onMessageStreamingStop((m) => { + this.cumulativeSize += this.getSize() + getMessageSize(m.content); + this.cumulativeCost += this.getCost() + m.cost; + unsubscribeStreaming(); + }); + + if (this.config.dry) { + const response = createDryChatCompletion( + this.messages[messages.length - 1]?.content ?? "", + { + model: this.config.model, + } + ); + message.readContentFromStream(response); + } else { + createChatCompletion( + { + ...this.config.chatCompletionConfig, + ...options, + stream: true, + messages, + }, + { + ...this.requestOptions, + ...requestOptions, + } + ).then((response) => { + // Using .then() to get the message out as soon as possible, since the content is known to be empty at first. + // This gives time for client code to subscribe to the streaming events. + message.readContentFromStream(response); + }); + } + + return message; + } + + private async handleNonStreamedResponse( + options: HandleChatCompletionOptions = {}, + requestOptions: RequestOptions = {} + ) { + const message = new Message("assistant", "", this.config.model); + const messages = this.messages.map(({ role, content }) => ({ + role, + content, + })); + + if (this.config.dry) { + await new Promise((resolve) => setTimeout(resolve, 1000)); + message.content = messages[messages.length - 1]?.content ?? null; + } else { + const response = await createChatCompletion( + { + ...this.config.chatCompletionConfig, + ...options, + stream: false, + messages, + }, + { + ...this.requestOptions, + ...requestOptions, + } + ); + message.content = response.choices[0].message?.content ?? ""; + } + + this.cumulativeSize += this.getSize() + getMessageSize(message.content); + this.cumulativeCost += this.getCost() + message.cost; + + return message; + } + + private async getAssistantResponse( + options?: PromptOptions, + requestOptions?: RequestOptions + ) { + const completion = await this.getChatCompletionResponse( + options, + requestOptions + ); + const assistantMessage = await this.addMessage(completion); + return assistantMessage; + } } diff --git a/packages/lib/src/classes/ConversationConfig.ts b/packages/lib/src/classes/ConversationConfig.ts index 1ba141d..06c02f2 100644 --- a/packages/lib/src/classes/ConversationConfig.ts +++ b/packages/lib/src/classes/ConversationConfig.ts @@ -5,6 +5,10 @@ import { DEFAULT_MODEL, DEFAULT_STREAM, } from "../config/constants.js"; +import { + ConversationConfigModel, + conversationConfigSchema, +} from "../schemas/conversationConfig.schema.js"; import { ConversationConfigChatCompletionOptions, ConversationConfigOptions, @@ -81,6 +85,41 @@ export class ConversationConfig { this.user = user; } + /** + * Creates a new `ConversationConfig` instance from a serialized config. + * + * @param json The JSON object of a ConversationConfig instance. + * @returns A new `ConversationConfig` instance + */ + public static fromJSON(json: ConversationConfigModel) { + return new ConversationConfig(conversationConfigSchema.parse(json)); + } + + /** + * Serializes the ConversationConfig to JSON. + * + * @returns The `ConversationConfig` as a JSON object. + */ + public toJSON(): ConversationConfigModel { + const json: ConversationConfigModel = { + dry: this.dry, + apiKey: this.apiKey, + model: this.model, + stream: this.stream, + disableModeration: this.disableModeration, + context: this.context, + frequency_penalty: this.frequencyPenalty, + presence_penalty: this.presencePenalty, + max_tokens: this.maxTokens, + logit_bias: this.logitBias, + temperature: this.temperature, + top_p: this.topP, + user: this.user, + stop: this.stop, + }; + return conversationConfigSchema.parse(json); + } + public get config(): Required { return { context: this.context, diff --git a/packages/lib/src/classes/Message.ts b/packages/lib/src/classes/Message.ts index d595c31..e5ca875 100644 --- a/packages/lib/src/classes/Message.ts +++ b/packages/lib/src/classes/Message.ts @@ -10,6 +10,7 @@ import { MessageUpdateListener, } from "../utils/types.js"; import createModeration from "../utils/createModeration.js"; +import { MessageModel, messageSchema } from "../schemas/message.schema.js"; /** * A message in a Conversation. @@ -48,17 +49,38 @@ export class Message { this.content = content; } - private notifyMessageUpdate() { - const content = this.content; - this.messageUpdateListeners.forEach((listener) => - listener(content, this) + /** + * Creates a new `Message` instance from a serialized message. + * + * @param json The JSON object of the Message instance. + * @returns A new `Message` instance + */ + public static fromJSON(json: MessageModel): Message { + const messageJson = messageSchema.parse(json); + const message = new Message( + messageJson.role, + messageJson.content, + messageJson.model ); + if (messageJson.id) message.id = messageJson.id; + if (messageJson.flags) message.flags = messageJson.flags; + return message; } - private notifyMessageStreaming() { - this.messageStreamingListeners.forEach((listener) => { - listener(this.isStreaming, this); - }); + /** + * Serializes the message to JSON. + * + * @returns The `Message` as a JSON object. + */ + public toJSON(): MessageModel { + const json: MessageModel = { + id: this.id, + role: this.role, + content: this.content, + model: this.model, + flags: this.flags, + }; + return messageSchema.parse(json); } /** @@ -212,6 +234,19 @@ export class Message { } } + private notifyMessageUpdate() { + const content = this.content; + this.messageUpdateListeners.forEach((listener) => + listener(content, this) + ); + } + + private notifyMessageStreaming() { + this.messageStreamingListeners.forEach((listener) => { + listener(this.isStreaming, this); + }); + } + /** The role of who this message is from. */ get role() { return this._role; diff --git a/packages/lib/src/index.ts b/packages/lib/src/index.ts index effd6d0..50e6c94 100644 --- a/packages/lib/src/index.ts +++ b/packages/lib/src/index.ts @@ -2,3 +2,4 @@ export * from "./classes/index.js"; export * from "./config/constants.js"; export * from "./utils/index.js"; export * from "./exceptions/index.js"; +export * from "./schemas/index.js"; diff --git a/packages/lib/src/schemas/conversation.schema.ts b/packages/lib/src/schemas/conversation.schema.ts new file mode 100644 index 0000000..04d72c8 --- /dev/null +++ b/packages/lib/src/schemas/conversation.schema.ts @@ -0,0 +1,35 @@ +import { z } from "zod"; +import { messageSchema } from "./message.schema.js"; +import { conversationConfigSchema } from "./conversationConfig.schema.js"; + +/** + * A JSON representation of a Conversation instance. + */ +export const conversationSchema = z.object({ + id: z.string().uuid().optional(), + messages: z.array(messageSchema), + config: conversationConfigSchema.optional(), + requestOptions: z + .object({ + headers: z.record(z.string(), z.string()).optional(), + proxy: z + .object({ + host: z.string(), + port: z.number().optional(), + protocol: z.enum(["http", "https"]).optional(), + auth: z + .object({ + username: z.string(), + password: z.string(), + }) + .optional(), + }) + .optional(), + }) + .optional(), +}); + +/** + * A JSON representation of a Conversation instance. + */ +export type ConversationModel = z.infer; diff --git a/packages/nest/src/db/schemas/persistenceConversation.ts b/packages/lib/src/schemas/conversationConfig.schema.ts similarity index 62% rename from packages/nest/src/db/schemas/persistenceConversation.ts rename to packages/lib/src/schemas/conversationConfig.schema.ts index 538cc8f..073b50a 100644 --- a/packages/nest/src/db/schemas/persistenceConversation.ts +++ b/packages/lib/src/schemas/conversationConfig.schema.ts @@ -1,26 +1,26 @@ -import { z } from "nestjs-zod/z"; -import { persistenceMessageSchema } from "./persistenceMessage.js"; - -export const persistenceConversationSchema = z.object({ - messages: z.array(persistenceMessageSchema).default([]), +import { z } from "zod"; +/** + * A JSON representation of a ConversationConfig instance. + */ +export const conversationConfigSchema = z.object({ context: z.string().optional(), dry: z.boolean().optional(), disableModeration: z.boolean().or(z.literal("soft")).optional(), - - apiKey: z.string().optional(), + apiKey: z.string().optional().optional(), model: z.string().optional(), + temperature: z.number().optional(), + top_p: z.number().optional(), stream: z.boolean().optional(), - frequency_penalty: z.number().optional(), - presence_penalty: z.number().optional(), + stop: z.string().or(z.array(z.string())).or(z.null()).optional(), max_tokens: z.number().optional(), + presence_penalty: z.number().optional(), + frequency_penalty: z.number().optional(), logit_bias: z.record(z.number(), z.number()).optional(), - stop: z.string().or(z.array(z.string())).or(z.null()).optional(), - temperature: z.number().optional(), - top_p: z.number().optional(), user: z.string().optional(), }); -export type PersistenceConversation = z.infer< - typeof persistenceConversationSchema ->; +/** + * A JSON representation of a ConversationConfig instance. + */ +export type ConversationConfigModel = z.infer; diff --git a/packages/lib/src/schemas/index.ts b/packages/lib/src/schemas/index.ts new file mode 100644 index 0000000..cbd443c --- /dev/null +++ b/packages/lib/src/schemas/index.ts @@ -0,0 +1,3 @@ +export * from "./conversation.schema.js"; +export * from "./conversationConfig.schema.js"; +export * from "./message.schema.js"; diff --git a/packages/lib/src/schemas/message.schema.ts b/packages/lib/src/schemas/message.schema.ts new file mode 100644 index 0000000..afa1108 --- /dev/null +++ b/packages/lib/src/schemas/message.schema.ts @@ -0,0 +1,21 @@ +import { z } from "zod"; + +/** + * A JSON schema for a Message instance. + */ +export const messageSchema = z.object({ + id: z.string().uuid().optional(), + role: z.union([ + z.literal("user"), + z.literal("assistant"), + z.literal("system"), + ]), + content: z.string(), + model: z.string().optional(), + flags: z.array(z.string()).or(z.null()).optional(), +}); + +/** + * A JSON representation of a Message instance. + */ +export type MessageModel = z.infer; diff --git a/packages/nest/src/db/db.service.ts b/packages/nest/src/db/db.service.ts index 2065430..c9fc21f 100644 --- a/packages/nest/src/db/db.service.ts +++ b/packages/nest/src/db/db.service.ts @@ -1,29 +1,23 @@ import { Low } from "lowdb"; import { JSONFile } from "lowdb/node"; import { Injectable, OnModuleDestroy, OnModuleInit } from "@nestjs/common"; -import { Persistence, persistenceSchema } from "./schemas/persistence.js"; -import { Conversation, DEFAULT_DISABLEMODERATION } from "gpt-turbo"; -import { - PersistenceConversation, - persistenceConversationSchema, -} from "./schemas/persistenceConversation.js"; +import { Conversation, ConversationModel } from "gpt-turbo"; import { ConversationsService } from "../conversations/conversations.service.js"; -import { PersistenceMessage } from "./schemas/persistenceMessage.js"; @Injectable() export class DbService implements OnModuleInit, OnModuleDestroy { - private readonly db!: Low; + private readonly db!: Low; constructor(private readonly conversationsService: ConversationsService) {} async onModuleInit() { - const adapter = new JSONFile("db.json"); + const adapter = new JSONFile("db.json"); // @ts-ignore this.db = new Low(adapter); await this.db.read(); - this.db.data ||= { conversations: [] }; + this.db.data ||= []; const conversations = await Promise.all( - this.db.data.conversations.map((c) => this.loadConversation(c)) + this.db.data.map((c) => Conversation.fromJSON(c)) ); conversations.forEach((c) => this.conversationsService.addConversation(c) @@ -36,65 +30,7 @@ export class DbService implements OnModuleInit, OnModuleDestroy { async save() { const conversations = this.conversationsService.getConversations(); - - const persistedConversations: Persistence = { - conversations: conversations.map((conversation) => ({ - ...conversation.getConfig(), - messages: conversation.getMessages().map( - (message): PersistenceMessage => ({ - content: message.content, - role: message.role, - }) - ), - })), - }; - const parsed = persistenceSchema.parse(persistedConversations); - - this.db.data = parsed; + this.db.data = conversations.map((c) => c.toJSON()); await this.db.write(); } - - private async loadConversation(dbConversation: PersistenceConversation) { - const persistedConversation = - persistenceConversationSchema.parse(dbConversation); - const { - messages, - disableModeration = DEFAULT_DISABLEMODERATION, - ...config - } = persistedConversation; - const newConversation = new Conversation({ - ...config, - disableModeration: true, - }); - - for (const message of messages) { - try { - switch (message.role) { - case "user": - await newConversation.addUserMessage(message.content); - break; - case "assistant": - await newConversation.addAssistantMessage( - message.content - ); - break; - case "system": - newConversation.setContext(message.content); - } - } catch (e) { - console.error( - "Error while loading message", - (e as Error).message - ); - } - } - - newConversation.setConfig( - { - disableModeration, - }, - true - ); - return newConversation; - } } diff --git a/packages/nest/src/db/schemas/persistence.ts b/packages/nest/src/db/schemas/persistence.ts deleted file mode 100644 index fc0eabc..0000000 --- a/packages/nest/src/db/schemas/persistence.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { z } from "nestjs-zod/z"; -import { persistenceConversationSchema } from "./persistenceConversation.js"; - -export const persistenceSchema = z.object({ - conversations: z.array(persistenceConversationSchema), -}); - -export type Persistence = z.infer; diff --git a/packages/nest/src/db/schemas/persistenceMessage.ts b/packages/nest/src/db/schemas/persistenceMessage.ts deleted file mode 100644 index ebaaa74..0000000 --- a/packages/nest/src/db/schemas/persistenceMessage.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { z } from "nestjs-zod/z"; - -export const persistenceMessageSchema = z.object({ - content: z.string(), - role: z.enum(["user", "system", "assistant"]), -}); - -export type PersistenceMessage = z.infer; diff --git a/packages/web/src/contexts/ConversationManagerContext.ts b/packages/web/src/contexts/ConversationManagerContext.ts index df6b56e..fd68ee2 100644 --- a/packages/web/src/contexts/ConversationManagerContext.ts +++ b/packages/web/src/contexts/ConversationManagerContext.ts @@ -8,7 +8,7 @@ export interface ConversationManagerContextValue { activeConversation: Conversation | null; showUsage: boolean; addConversation: ( - conversationConfig: ConversationConfigParameters + conversationConfig: ConversationConfigParameters | Conversation ) => Conversation; removeConversation: (id: string) => void; removeAllConversations: () => void; diff --git a/packages/web/src/contexts/providers/ConversationManagerProvider.tsx b/packages/web/src/contexts/providers/ConversationManagerProvider.tsx index 85d38da..9354f66 100644 --- a/packages/web/src/contexts/providers/ConversationManagerProvider.tsx +++ b/packages/web/src/contexts/providers/ConversationManagerProvider.tsx @@ -26,8 +26,11 @@ export default ({ children }: ConversationManagerProviderProps) => { ); const addConversation = React.useCallback( - (conversation: ConversationConfigParameters) => { - const newConversation = new Conversation(conversation); + (conversation: ConversationConfigParameters | Conversation) => { + const newConversation = + conversation instanceof Conversation + ? conversation + : new Conversation(conversation); setConversations((c) => [...c, newConversation]); return newConversation; }, diff --git a/packages/web/src/contexts/providers/PersistenceProvider.tsx b/packages/web/src/contexts/providers/PersistenceProvider.tsx index 79162d4..0333ee4 100644 --- a/packages/web/src/contexts/providers/PersistenceProvider.tsx +++ b/packages/web/src/contexts/providers/PersistenceProvider.tsx @@ -7,8 +7,7 @@ import useStorage from "../../hooks/useStorage"; import useConversationManager from "../../hooks/useConversationManager"; import { Persistence, persistenceSchema } from "../../entities/persistence"; import { PersistenceConversation } from "../../entities/persistenceConversation"; -import { PersistenceMessage } from "../../entities/persistenceMessage"; -import { Message } from "gpt-turbo"; +import { Conversation, Message } from "gpt-turbo"; interface PersistenceProviderProps { children?: React.ReactNode; @@ -56,14 +55,8 @@ export default ({ children }: PersistenceProviderProps) => { conversation.getMessages().length ) .map((conversation) => ({ - ...conversation.getConfig(), + ...conversation.toJSON(), name: getConversationName(conversation.id), - messages: conversation.getMessages().map( - (message): PersistenceMessage => ({ - content: message.content, - role: message.role, - }) - ), })); setPersistence({ @@ -102,48 +95,15 @@ export default ({ children }: PersistenceProviderProps) => { const load = async () => { let i = -1; for (const { - messages, - disableModeration, - ...config + name, + ...conversationJson } of persistence.conversations) { - const newConversation = addConversation({ - ...config, - disableModeration: true, - }); + const newConversation = addConversation( + await Conversation.fromJSON(conversationJson) + ); if (++i === 0) setActiveConversation(newConversation.id, true); addPersistedConversationId(newConversation.id); - setConversationName(newConversation.id, config.name); - - for (const message of messages) { - try { - switch (message.role) { - case "user": - await newConversation.addUserMessage( - message.content - ); - break; - case "assistant": - await newConversation.addAssistantMessage( - message.content - ); - break; - case "system": - newConversation.setContext(message.content); - } - } catch (e) { - console.error( - "Error while loading message", - (e as Error).message - ); - } - } - - newConversation.setConfig( - { - disableModeration, - }, - true - ); + setConversationName(newConversation.id, name); } }; load().then(() => { diff --git a/packages/web/src/entities/persistenceConversation.ts b/packages/web/src/entities/persistenceConversation.ts index 7ce2758..e46576d 100644 --- a/packages/web/src/entities/persistenceConversation.ts +++ b/packages/web/src/entities/persistenceConversation.ts @@ -1,25 +1,8 @@ import { z } from "zod"; -import { persistenceMessageSchema } from "./persistenceMessage"; +import { conversationSchema } from "gpt-turbo"; -export const persistenceConversationSchema = z.object({ +export const persistenceConversationSchema = conversationSchema.extend({ name: z.string(), - messages: z.array(persistenceMessageSchema), - - context: z.string(), - dry: z.boolean(), - disableModeration: z.boolean().or(z.literal("soft")), - - apiKey: z.string(), - model: z.string(), - stream: z.boolean(), - frequency_penalty: z.number().optional(), - presence_penalty: z.number().optional(), - max_tokens: z.number().optional(), - logit_bias: z.record(z.number(), z.number()).optional(), - stop: z.string().or(z.array(z.string())).or(z.null()).optional(), - temperature: z.number().optional(), - top_p: z.number().optional(), - user: z.string().optional(), }); export type PersistenceConversation = z.infer< diff --git a/packages/web/src/entities/persistenceMessage.ts b/packages/web/src/entities/persistenceMessage.ts deleted file mode 100644 index 56839fa..0000000 --- a/packages/web/src/entities/persistenceMessage.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { z } from "zod"; - -export const persistenceMessageSchema = z.object({ - content: z.string(), - role: z.enum(["user", "system", "assistant"]), -}); - -export type PersistenceMessage = z.infer;