From 491131cb4d088189aa93450da0ce229f87fd4b99 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Tue, 19 Dec 2023 23:31:27 +0000 Subject: [PATCH 01/31] Add embedding models configurable, from both Xenova and TEI --- .env | 13 ++++ .env.template | 22 ++++++ .../embeddingEndpoints/embeddingEndpoints.ts | 33 +++++++++ .../tei/embeddingEndpoints.ts | 31 ++++++++ .../xenova/embeddingEndpoints.ts | 51 +++++++++++++ src/lib/server/embeddingModels.ts | 74 +++++++++++++++++++ src/lib/server/models.ts | 1 + src/lib/server/sentenceSimilarity.ts | 38 ++++++++++ src/lib/server/websearch/runWebSearch.ts | 16 ++-- .../server/websearch/sentenceSimilarity.ts | 52 ------------- src/lib/types/Conversation.ts | 1 + src/lib/types/SharedConversation.ts | 2 + src/routes/conversation/+server.ts | 6 ++ 13 files changed, 282 insertions(+), 58 deletions(-) create mode 100644 src/lib/server/embeddingEndpoints/embeddingEndpoints.ts create mode 100644 src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts create mode 100644 src/lib/server/embeddingEndpoints/xenova/embeddingEndpoints.ts create mode 100644 src/lib/server/embeddingModels.ts create mode 100644 src/lib/server/sentenceSimilarity.ts delete mode 100644 src/lib/server/websearch/sentenceSimilarity.ts diff --git a/.env b/.env index c4d23840998..d9c743775d8 100644 --- a/.env +++ b/.env @@ -42,6 +42,19 @@ CA_PATH=# CLIENT_KEY_PASSWORD=# REJECT_UNAUTHORIZED=true +TEXT_EMBEDDING_MODELS = `[ + { + "name": "Xenova/gte-small", + "displayName": "Xenova/gte-small", + "description": "Local embedding model running on the server.", + "maxSequenceLength": 512, + "endpoints": [ + { "type": "xenova" } + ] + } +]` + + # 'name', 'userMessageToken', 'assistantMessageToken' are required MODELS=`[ { diff --git a/.env.template b/.env.template index c9d7e248b83..6101175e68f 100644 --- a/.env.template +++ b/.env.template @@ -196,6 +196,7 @@ MODELS=`[ "max_new_tokens" : 8192, "stop" : [""] }, + "embeddingModelName": "thenlper/gte-base", "promptExamples" : [ { "title": "Write an email from bullet list", @@ -215,6 +216,27 @@ OLD_MODELS=`[{"name":"bigcode/starcoder"}, {"name":"OpenAssistant/oasst-sft-6-ll TASK_MODEL='mistralai/Mistral-7B-Instruct-v0.2' +# Default to using the first text embedding model when not specifying 'embeddingModelName' in the model itself. +TEXT_EMBEDDING_MODELS = `[ + { + "name": "Xenova/gte-small", + "displayName": "Xenova/gte-small", + "description": "Local embedding model running on the server.", + "maxSequenceLength": 512, + "endpoints": [ + { "type": "xenova" } + ] + }, + { + "name": "thenlper/gte-base", + "displayName": "thenlper/gte-base", + "description": "Hosted embedding model running on the cloud somewhere.", + "maxSequenceLength": 512, + "endpoints": [ + { "type": "tei", "http://localhost:8080/" } # Make sure the MAX_CLIENT_BATCH_SIZE env in TEI is large enough to accommodate web search chunks. + ] + } +]` APP_BASE="/chat" PUBLIC_ORIGIN=https://huggingface.co diff --git a/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts new file mode 100644 index 00000000000..840c18ecc20 --- /dev/null +++ b/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts @@ -0,0 +1,33 @@ +import { embeddingEndpointTei, embeddingEndpointTeiParametersSchema } from "./tei/embeddingEndpoints"; +import { z } from "zod"; +import embeddingEndpointXenova, { embeddingEndpointXenovaParametersSchema } from "./xenova/embeddingEndpoints"; + +// parameters passed when generating text +interface EmbeddingEndpointParameters { + inputs: string[] +} + +interface CommonEmbeddingEndpoint { + weight: number; +} + +// type signature for the endpoint +export type EmbeddingEndpoint = ( + params: EmbeddingEndpointParameters +) => Promise; // TODO: type + +// generator function that takes in parameters for defining the endpoint and return the endpoint +export type EmbeddingEndpointGenerator = (parameters: T) => EmbeddingEndpoint; + +// list of all endpoint generators +export const embeddingEndpoints = { + tei: embeddingEndpointTei, + xenova: embeddingEndpointXenova +}; + +export const embeddingEndpointSchema = z.discriminatedUnion("type", [ + embeddingEndpointTeiParametersSchema, + embeddingEndpointXenovaParametersSchema +]); + +export default embeddingEndpoints; diff --git a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts new file mode 100644 index 00000000000..be944fe02ee --- /dev/null +++ b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts @@ -0,0 +1,31 @@ +import { HF_ACCESS_TOKEN, HF_TOKEN } from "$env/static/private"; +import { featureExtraction } from "@huggingface/inference"; +import { z } from "zod"; +import type { EmbeddingEndpoint } from "../embeddingEndpoints"; + +export const embeddingEndpointTeiParametersSchema = z.object({ + weight: z.number().int().positive().default(1), + model: z.any(), + type: z.literal("tei"), + url: z.string().url(), +}); + +export function embeddingEndpointTei(input: z.input): EmbeddingEndpoint { + const { url } = embeddingEndpointTeiParametersSchema.parse(input); + return async ({ inputs }) => { + const { origin } = new URL(url) + + const response = await fetch(`${origin}/embed`, { + method: 'POST', + headers: { + 'Accept': 'application/json', + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ inputs, normalize: true, truncate: true }) + }); + + return response.json(); + }; +} + +export default embeddingEndpointTei; diff --git a/src/lib/server/embeddingEndpoints/xenova/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/xenova/embeddingEndpoints.ts new file mode 100644 index 00000000000..25e4861a16e --- /dev/null +++ b/src/lib/server/embeddingEndpoints/xenova/embeddingEndpoints.ts @@ -0,0 +1,51 @@ +import { z } from "zod"; +import type { EmbeddingEndpoint } from "../embeddingEndpoints"; +import type { Tensor, Pipeline } from "@xenova/transformers"; +import { pipeline } from "@xenova/transformers"; + +export const embeddingEndpointXenovaParametersSchema = z.object({ + weight: z.number().int().positive().default(1), + model: z.any(), + type: z.literal("xenova"), +}); + + +// Use the Singleton pattern to enable lazy construction of the pipeline. +class XenovaModelsSingleton { + static instances: Array<[string, Promise]> = []; + + static async getInstance(modelName: string): Promise { + const modelPipeline = this.instances.find(([name]) => name === modelName) + + if (modelPipeline) { + return modelPipeline[1]; + } + + const newModelPipeline = pipeline("feature-extraction", modelName) + this.instances.push([modelName, newModelPipeline]) + + return newModelPipeline; + } +} + + +export async function calculateEmbedding( + modelName: string, + inputs: string[] +) { + const extractor = await XenovaModelsSingleton.getInstance(modelName); + const output: Tensor = await extractor(inputs, { pooling: "mean", normalize: true }); + + return output.tolist(); +} + + +export function embeddingEndpointXenova(input: z.input): EmbeddingEndpoint { + const { model } = embeddingEndpointXenovaParametersSchema.parse(input); + + return async ({ inputs }) => { + return calculateEmbedding(model.name, inputs); + }; +} + +export default embeddingEndpointXenova; diff --git a/src/lib/server/embeddingModels.ts b/src/lib/server/embeddingModels.ts new file mode 100644 index 00000000000..5284c5284d3 --- /dev/null +++ b/src/lib/server/embeddingModels.ts @@ -0,0 +1,74 @@ +import { + TEXT_EMBEDDING_MODELS, +} from "$env/static/private"; + +import { z } from "zod"; +import { sum } from "$lib/utils/sum"; +import embeddingEndpoints, { embeddingEndpointSchema, type EmbeddingEndpoint } from "./embeddingEndpoints/embeddingEndpoints"; +import embeddingEndpointXenova from "./embeddingEndpoints/xenova/embeddingEndpoints"; + +const modelConfig = z.object({ + /** Used as an identifier in DB */ + id: z.string().optional(), + /** Used to link to the model page, and for inference */ + name: z.string().min(1), + displayName: z.string().min(1).optional(), + description: z.string().min(1).optional(), + websiteUrl: z.string().url().optional(), + modelUrl: z.string().url().optional(), + endpoints: z.array(embeddingEndpointSchema).optional(), + maxSequenceLength: z.number().positive(), +}); + +const embeddingModelsRaw = z.array(modelConfig).parse(JSON.parse(TEXT_EMBEDDING_MODELS)); + +const processEmbeddingModel = async (m: z.infer) => ({ + ...m, + id: m.id || m.name, +}); + +const addEndpoint = (m: Awaited>) => ({ + ...m, + getEndpoint: async (): Promise => { + if (!m.endpoints) { + return embeddingEndpointXenova({ + type: "xenova", + weight: 1, + model: m, + }); + } + + const totalWeight = sum(m.endpoints.map((e) => e.weight)); + + let random = Math.random() * totalWeight; + + for (const endpoint of m.endpoints) { + if (random < endpoint.weight) { + const args = { ...endpoint, model: m }; + + switch (args.type) { + case "tei": + return embeddingEndpoints.tei(args); + case "xenova": + return embeddingEndpoints.xenova(args); + } + } + + random -= endpoint.weight; + } + + throw new Error(`Failed to select endpoint`); + }, +}); + +export const embeddingModels = await Promise.all(embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint))); + +export const defaultEmbeddingModel = embeddingModels[0]; + +export const validateEmbeddingModel = (_models: EmbeddingBackendModel[]) => { + // Zod enum function requires 2 parameters + return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]); +}; + + +export type EmbeddingBackendModel = typeof defaultEmbeddingModel; diff --git a/src/lib/server/models.ts b/src/lib/server/models.ts index 58d05bd7a9b..c0921ff903e 100644 --- a/src/lib/server/models.ts +++ b/src/lib/server/models.ts @@ -66,6 +66,7 @@ const modelConfig = z.object({ .optional(), multimodal: z.boolean().default(false), unlisted: z.boolean().default(false), + embeddingModelName: z.string().optional() }); const modelsRaw = z.array(modelConfig).parse(JSON.parse(MODELS)); diff --git a/src/lib/server/sentenceSimilarity.ts b/src/lib/server/sentenceSimilarity.ts new file mode 100644 index 00000000000..b0b42fc780e --- /dev/null +++ b/src/lib/server/sentenceSimilarity.ts @@ -0,0 +1,38 @@ +import { dot } from "@xenova/transformers"; +import type { EmbeddingBackendModel } from "./embeddingModels"; + +// see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34 +function innerProduct(embeddingA: number[], embeddingB: number[]) { + return 1.0 - dot(embeddingA, embeddingB); +} + +export async function findSimilarSentences( + embeddingModel: EmbeddingBackendModel, + query: string, + sentences: string[], + { topK = 5 }: { topK: number } +): Promise { + const inputs = [query, ...sentences]; + + const embeddingEndpoint = await embeddingModel.getEndpoint(); + const output = await embeddingEndpoint({ inputs }) + + const queryEmbedding: number[] = output[0]; + const sentencesEmbeddings: number[][] = output.slice([1, inputs.length - 1]); + + const distancesFromQuery: { distance: number; index: number }[] = [...sentencesEmbeddings].map( + (sentenceEmbedding: number[], index: number) => { + return { + distance: innerProduct(queryEmbedding, sentenceEmbedding), + index: index, + }; + } + ); + + distancesFromQuery.sort((a, b) => { + return a.distance - b.distance; + }); + + // Return the indexes of the closest topK sentences + return distancesFromQuery.slice(0, topK).map((item) => item.index); +} diff --git a/src/lib/server/websearch/runWebSearch.ts b/src/lib/server/websearch/runWebSearch.ts index 0869ea8b494..22ead137dc6 100644 --- a/src/lib/server/websearch/runWebSearch.ts +++ b/src/lib/server/websearch/runWebSearch.ts @@ -4,13 +4,11 @@ import type { WebSearch, WebSearchSource } from "$lib/types/WebSearch"; import { generateQuery } from "$lib/server/websearch/generateQuery"; import { parseWeb } from "$lib/server/websearch/parseWeb"; import { chunk } from "$lib/utils/chunk"; -import { - MAX_SEQ_LEN as CHUNK_CAR_LEN, - findSimilarSentences, -} from "$lib/server/websearch/sentenceSimilarity"; +import { findSimilarSentences } from "$lib/server/sentenceSimilarity"; import type { Conversation } from "$lib/types/Conversation"; import type { MessageUpdate } from "$lib/types/MessageUpdate"; import { getWebSearchProvider } from "./searchWeb"; +import { embeddingModels } from "../embeddingModels"; const MAX_N_PAGES_SCRAPE = 10 as const; const MAX_N_PAGES_EMBED = 5 as const; @@ -57,6 +55,12 @@ export async function runWebSearch( .filter(({ link }) => !DOMAIN_BLOCKLIST.some((el) => link.includes(el))) // filter out blocklist links .slice(0, MAX_N_PAGES_SCRAPE); // limit to first 10 links only + // fetch the model + const embeddingModel = embeddingModels.find((m) => m.id === conv.embeddingModel); + if (!embeddingModel) { + throw new Error(`Embedding model ${conv.embeddingModel} not available anymore`); + } + let paragraphChunks: { source: WebSearchSource; text: string }[] = []; if (webSearch.results.length > 0) { appendUpdate("Browsing results"); @@ -72,7 +76,7 @@ export async function runWebSearch( } } const MAX_N_CHUNKS = 100; - const texts = chunk(text, CHUNK_CAR_LEN).slice(0, MAX_N_CHUNKS); + const texts = chunk(text, embeddingModel.maxSequenceLength).slice(0, MAX_N_CHUNKS); return texts.map((t) => ({ source: result, text: t })); }); const nestedParagraphChunks = (await Promise.all(promises)).slice(0, MAX_N_PAGES_EMBED); @@ -87,7 +91,7 @@ export async function runWebSearch( appendUpdate("Extracting relevant information"); const topKClosestParagraphs = 8; const texts = paragraphChunks.map(({ text }) => text); - const indices = await findSimilarSentences(prompt, texts, { + const indices = await findSimilarSentences(embeddingModel, prompt, texts, { topK: topKClosestParagraphs, }); webSearch.context = indices.map((idx) => texts[idx]).join(""); diff --git a/src/lib/server/websearch/sentenceSimilarity.ts b/src/lib/server/websearch/sentenceSimilarity.ts deleted file mode 100644 index a877f8e0cd6..00000000000 --- a/src/lib/server/websearch/sentenceSimilarity.ts +++ /dev/null @@ -1,52 +0,0 @@ -import type { Tensor, Pipeline } from "@xenova/transformers"; -import { pipeline, dot } from "@xenova/transformers"; - -// see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34 -function innerProduct(tensor1: Tensor, tensor2: Tensor) { - return 1.0 - dot(tensor1.data, tensor2.data); -} - -// Use the Singleton pattern to enable lazy construction of the pipeline. -class PipelineSingleton { - static modelId = "Xenova/gte-small"; - static instance: Promise | null = null; - static async getInstance() { - if (this.instance === null) { - this.instance = pipeline("feature-extraction", this.modelId); - } - return this.instance; - } -} - -// see https://huggingface.co/thenlper/gte-small/blob/d8e2604cadbeeda029847d19759d219e0ce2e6d8/README.md?code=true#L2625 -export const MAX_SEQ_LEN = 512 as const; - -export async function findSimilarSentences( - query: string, - sentences: string[], - { topK = 5 }: { topK: number } -) { - const input = [query, ...sentences]; - - const extractor = await PipelineSingleton.getInstance(); - const output: Tensor = await extractor(input, { pooling: "mean", normalize: true }); - - const queryTensor: Tensor = output[0]; - const sentencesTensor: Tensor = output.slice([1, input.length - 1]); - - const distancesFromQuery: { distance: number; index: number }[] = [...sentencesTensor].map( - (sentenceTensor: Tensor, index: number) => { - return { - distance: innerProduct(queryTensor, sentenceTensor), - index: index, - }; - } - ); - - distancesFromQuery.sort((a, b) => { - return a.distance - b.distance; - }); - - // Return the indexes of the closest topK sentences - return distancesFromQuery.slice(0, topK).map((item) => item.index); -} diff --git a/src/lib/types/Conversation.ts b/src/lib/types/Conversation.ts index 5788ce63fd8..665a688f6b4 100644 --- a/src/lib/types/Conversation.ts +++ b/src/lib/types/Conversation.ts @@ -10,6 +10,7 @@ export interface Conversation extends Timestamps { userId?: User["_id"]; model: string; + embeddingModel: string; title: string; messages: Message[]; diff --git a/src/lib/types/SharedConversation.ts b/src/lib/types/SharedConversation.ts index 8571f2c3f3a..1996bcc6ff9 100644 --- a/src/lib/types/SharedConversation.ts +++ b/src/lib/types/SharedConversation.ts @@ -7,6 +7,8 @@ export interface SharedConversation extends Timestamps { hash: string; model: string; + embeddingModel: string; + title: string; messages: Message[]; preprompt?: string; diff --git a/src/routes/conversation/+server.ts b/src/routes/conversation/+server.ts index 6452e985d67..6b79d46e13e 100644 --- a/src/routes/conversation/+server.ts +++ b/src/routes/conversation/+server.ts @@ -6,6 +6,7 @@ import { base } from "$app/paths"; import { z } from "zod"; import type { Message } from "$lib/types/Message"; import { models, validateModel } from "$lib/server/models"; +import { defaultEmbeddingModel, embeddingModels, validateEmbeddingModel } from "$lib/server/embeddingModels"; export const POST: RequestHandler = async ({ locals, request }) => { const body = await request.text(); @@ -17,6 +18,7 @@ export const POST: RequestHandler = async ({ locals, request }) => { .object({ fromShare: z.string().optional(), model: validateModel(models), + embeddingModel: validateEmbeddingModel(embeddingModels).optional(), preprompt: z.string().optional(), }) .parse(JSON.parse(body)); @@ -35,11 +37,14 @@ export const POST: RequestHandler = async ({ locals, request }) => { title = conversation.title; messages = conversation.messages; values.model = conversation.model; + values.embeddingModel = conversation.embeddingModel; preprompt = conversation.preprompt; } const model = models.find((m) => m.name === values.model); + values.embeddingModel = values.embeddingModel ?? model?.embeddingModelName ?? defaultEmbeddingModel.name + if (!model) { throw error(400, "Invalid model"); } @@ -59,6 +64,7 @@ export const POST: RequestHandler = async ({ locals, request }) => { preprompt: preprompt === model?.preprompt ? model?.preprompt : preprompt, createdAt: new Date(), updatedAt: new Date(), + embeddingModel: values.embeddingModel, ...(locals.user ? { userId: locals.user._id } : { sessionId: locals.sessionId }), ...(values.fromShare ? { meta: { fromShareId: values.fromShare } } : {}), }); From 3473fc22aefa2e6d66e80a22f9c80bc549f4b1f0 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Wed, 20 Dec 2023 20:27:14 +0000 Subject: [PATCH 02/31] fix lint and format --- .../embeddingEndpoints/embeddingEndpoints.ts | 23 +++++++++++-------- .../tei/embeddingEndpoints.ts | 16 ++++++------- .../xenova/embeddingEndpoints.ts | 18 ++++++--------- src/lib/server/embeddingModels.ts | 14 ++++++----- src/lib/server/models.ts | 2 +- src/lib/server/sentenceSimilarity.ts | 2 +- src/routes/conversation/+server.ts | 9 ++++++-- 7 files changed, 46 insertions(+), 38 deletions(-) diff --git a/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts index 840c18ecc20..289e00135e4 100644 --- a/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts @@ -1,10 +1,15 @@ -import { embeddingEndpointTei, embeddingEndpointTeiParametersSchema } from "./tei/embeddingEndpoints"; +import { + embeddingEndpointTei, + embeddingEndpointTeiParametersSchema, +} from "./tei/embeddingEndpoints"; import { z } from "zod"; -import embeddingEndpointXenova, { embeddingEndpointXenovaParametersSchema } from "./xenova/embeddingEndpoints"; +import embeddingEndpointXenova, { + embeddingEndpointXenovaParametersSchema, +} from "./xenova/embeddingEndpoints"; // parameters passed when generating text interface EmbeddingEndpointParameters { - inputs: string[] + inputs: string[]; } interface CommonEmbeddingEndpoint { @@ -12,22 +17,22 @@ interface CommonEmbeddingEndpoint { } // type signature for the endpoint -export type EmbeddingEndpoint = ( - params: EmbeddingEndpointParameters -) => Promise; // TODO: type +export type EmbeddingEndpoint = (params: EmbeddingEndpointParameters) => Promise; // generator function that takes in parameters for defining the endpoint and return the endpoint -export type EmbeddingEndpointGenerator = (parameters: T) => EmbeddingEndpoint; +export type EmbeddingEndpointGenerator = ( + parameters: T +) => EmbeddingEndpoint; // list of all endpoint generators export const embeddingEndpoints = { tei: embeddingEndpointTei, - xenova: embeddingEndpointXenova + xenova: embeddingEndpointXenova, }; export const embeddingEndpointSchema = z.discriminatedUnion("type", [ embeddingEndpointTeiParametersSchema, - embeddingEndpointXenovaParametersSchema + embeddingEndpointXenovaParametersSchema, ]); export default embeddingEndpoints; diff --git a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts index be944fe02ee..161ee7088e8 100644 --- a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts @@ -1,5 +1,3 @@ -import { HF_ACCESS_TOKEN, HF_TOKEN } from "$env/static/private"; -import { featureExtraction } from "@huggingface/inference"; import { z } from "zod"; import type { EmbeddingEndpoint } from "../embeddingEndpoints"; @@ -10,18 +8,20 @@ export const embeddingEndpointTeiParametersSchema = z.object({ url: z.string().url(), }); -export function embeddingEndpointTei(input: z.input): EmbeddingEndpoint { +export function embeddingEndpointTei( + input: z.input +): EmbeddingEndpoint { const { url } = embeddingEndpointTeiParametersSchema.parse(input); return async ({ inputs }) => { - const { origin } = new URL(url) + const { origin } = new URL(url); const response = await fetch(`${origin}/embed`, { - method: 'POST', + method: "POST", headers: { - 'Accept': 'application/json', - 'Content-Type': 'application/json' + Accept: "application/json", + "Content-Type": "application/json", }, - body: JSON.stringify({ inputs, normalize: true, truncate: true }) + body: JSON.stringify({ inputs, normalize: true, truncate: true }), }); return response.json(); diff --git a/src/lib/server/embeddingEndpoints/xenova/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/xenova/embeddingEndpoints.ts index 25e4861a16e..ba580290d39 100644 --- a/src/lib/server/embeddingEndpoints/xenova/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/xenova/embeddingEndpoints.ts @@ -9,38 +9,34 @@ export const embeddingEndpointXenovaParametersSchema = z.object({ type: z.literal("xenova"), }); - // Use the Singleton pattern to enable lazy construction of the pipeline. class XenovaModelsSingleton { static instances: Array<[string, Promise]> = []; static async getInstance(modelName: string): Promise { - const modelPipeline = this.instances.find(([name]) => name === modelName) + const modelPipeline = this.instances.find(([name]) => name === modelName); if (modelPipeline) { return modelPipeline[1]; } - const newModelPipeline = pipeline("feature-extraction", modelName) - this.instances.push([modelName, newModelPipeline]) + const newModelPipeline = pipeline("feature-extraction", modelName); + this.instances.push([modelName, newModelPipeline]); return newModelPipeline; } } - -export async function calculateEmbedding( - modelName: string, - inputs: string[] -) { +export async function calculateEmbedding(modelName: string, inputs: string[]) { const extractor = await XenovaModelsSingleton.getInstance(modelName); const output: Tensor = await extractor(inputs, { pooling: "mean", normalize: true }); return output.tolist(); } - -export function embeddingEndpointXenova(input: z.input): EmbeddingEndpoint { +export function embeddingEndpointXenova( + input: z.input +): EmbeddingEndpoint { const { model } = embeddingEndpointXenovaParametersSchema.parse(input); return async ({ inputs }) => { diff --git a/src/lib/server/embeddingModels.ts b/src/lib/server/embeddingModels.ts index 5284c5284d3..78f6a773694 100644 --- a/src/lib/server/embeddingModels.ts +++ b/src/lib/server/embeddingModels.ts @@ -1,10 +1,11 @@ -import { - TEXT_EMBEDDING_MODELS, -} from "$env/static/private"; +import { TEXT_EMBEDDING_MODELS } from "$env/static/private"; import { z } from "zod"; import { sum } from "$lib/utils/sum"; -import embeddingEndpoints, { embeddingEndpointSchema, type EmbeddingEndpoint } from "./embeddingEndpoints/embeddingEndpoints"; +import embeddingEndpoints, { + embeddingEndpointSchema, + type EmbeddingEndpoint, +} from "./embeddingEndpoints/embeddingEndpoints"; import embeddingEndpointXenova from "./embeddingEndpoints/xenova/embeddingEndpoints"; const modelConfig = z.object({ @@ -61,7 +62,9 @@ const addEndpoint = (m: Awaited>) => ({ }, }); -export const embeddingModels = await Promise.all(embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint))); +export const embeddingModels = await Promise.all( + embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint)) +); export const defaultEmbeddingModel = embeddingModels[0]; @@ -70,5 +73,4 @@ export const validateEmbeddingModel = (_models: EmbeddingBackendModel[]) => { return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]); }; - export type EmbeddingBackendModel = typeof defaultEmbeddingModel; diff --git a/src/lib/server/models.ts b/src/lib/server/models.ts index c0921ff903e..fa2952b2827 100644 --- a/src/lib/server/models.ts +++ b/src/lib/server/models.ts @@ -66,7 +66,7 @@ const modelConfig = z.object({ .optional(), multimodal: z.boolean().default(false), unlisted: z.boolean().default(false), - embeddingModelName: z.string().optional() + embeddingModelName: z.string().optional(), }); const modelsRaw = z.array(modelConfig).parse(JSON.parse(MODELS)); diff --git a/src/lib/server/sentenceSimilarity.ts b/src/lib/server/sentenceSimilarity.ts index b0b42fc780e..b4d9caf950b 100644 --- a/src/lib/server/sentenceSimilarity.ts +++ b/src/lib/server/sentenceSimilarity.ts @@ -15,7 +15,7 @@ export async function findSimilarSentences( const inputs = [query, ...sentences]; const embeddingEndpoint = await embeddingModel.getEndpoint(); - const output = await embeddingEndpoint({ inputs }) + const output = await embeddingEndpoint({ inputs }); const queryEmbedding: number[] = output[0]; const sentencesEmbeddings: number[][] = output.slice([1, inputs.length - 1]); diff --git a/src/routes/conversation/+server.ts b/src/routes/conversation/+server.ts index 6b79d46e13e..d3fadb396cc 100644 --- a/src/routes/conversation/+server.ts +++ b/src/routes/conversation/+server.ts @@ -6,7 +6,11 @@ import { base } from "$app/paths"; import { z } from "zod"; import type { Message } from "$lib/types/Message"; import { models, validateModel } from "$lib/server/models"; -import { defaultEmbeddingModel, embeddingModels, validateEmbeddingModel } from "$lib/server/embeddingModels"; +import { + defaultEmbeddingModel, + embeddingModels, + validateEmbeddingModel, +} from "$lib/server/embeddingModels"; export const POST: RequestHandler = async ({ locals, request }) => { const body = await request.text(); @@ -43,7 +47,8 @@ export const POST: RequestHandler = async ({ locals, request }) => { const model = models.find((m) => m.name === values.model); - values.embeddingModel = values.embeddingModel ?? model?.embeddingModelName ?? defaultEmbeddingModel.name + values.embeddingModel = + values.embeddingModel ?? model?.embeddingModelName ?? defaultEmbeddingModel.name; if (!model) { throw error(400, "Invalid model"); From aebf653324404c38dbadb94b0a4c19437f9e45cf Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Wed, 20 Dec 2023 21:07:45 +0000 Subject: [PATCH 03/31] Fix bug in sentenceSimilarity --- src/lib/server/sentenceSimilarity.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/server/sentenceSimilarity.ts b/src/lib/server/sentenceSimilarity.ts index b4d9caf950b..d2fbb390364 100644 --- a/src/lib/server/sentenceSimilarity.ts +++ b/src/lib/server/sentenceSimilarity.ts @@ -18,7 +18,7 @@ export async function findSimilarSentences( const output = await embeddingEndpoint({ inputs }); const queryEmbedding: number[] = output[0]; - const sentencesEmbeddings: number[][] = output.slice([1, inputs.length - 1]); + const sentencesEmbeddings: number[][] = output.slice(1, inputs.length - 1); const distancesFromQuery: { distance: number; index: number }[] = [...sentencesEmbeddings].map( (sentenceEmbedding: number[], index: number) => { From c065045196bca33373489bc84f0034623822f5c4 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Wed, 20 Dec 2023 21:09:15 +0000 Subject: [PATCH 04/31] Batches for TEI using /info route --- .../tei/embeddingEndpoints.ts | 56 +++++++++++++++---- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts index 161ee7088e8..2cb5bc5b942 100644 --- a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts @@ -1,5 +1,6 @@ import { z } from "zod"; import type { EmbeddingEndpoint } from "../embeddingEndpoints"; +import { chunk } from "$lib/utils/chunk"; export const embeddingEndpointTeiParametersSchema = z.object({ weight: z.number().int().positive().default(1), @@ -8,23 +9,54 @@ export const embeddingEndpointTeiParametersSchema = z.object({ url: z.string().url(), }); -export function embeddingEndpointTei( + +const getModelInfoByUrl = async (url: string) => { + const { origin } = new URL(url); + + const response = await fetch(`${origin}/info`, { + headers: { + Accept: "application/json", + "Content-Type": "application/json", + } + }); + + const info = await response.json(); + + return info; +} + +export async function embeddingEndpointTei( input: z.input -): EmbeddingEndpoint { - const { url } = embeddingEndpointTeiParametersSchema.parse(input); +): Promise { + const { url, model } = embeddingEndpointTeiParametersSchema.parse(input); + + const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url); + const maxBatchSize = Math.min(max_client_batch_size, Math.floor(max_batch_tokens / model.maxSequenceLength)) + return async ({ inputs }) => { const { origin } = new URL(url); - const response = await fetch(`${origin}/embed`, { - method: "POST", - headers: { - Accept: "application/json", - "Content-Type": "application/json", - }, - body: JSON.stringify({ inputs, normalize: true, truncate: true }), - }); + const batchesInputs = chunk(inputs, maxBatchSize) + + const batchesResults = await Promise.all( + batchesInputs.map(async (batchInputs) => { + const response = await fetch(`${origin}/embed`, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/json", + }, + body: JSON.stringify({ inputs: batchInputs, normalize: true, truncate: true }), + }); + + const embeddings: number[][] = await response.json(); + return embeddings; + }) + ) + + const allEmbeddings = batchesResults.flatMap(embeddings => embeddings) - return response.json(); + return allEmbeddings; }; } From 8df8fd2c5358868f99a4bafb287516210c0e3952 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Wed, 20 Dec 2023 21:36:19 +0000 Subject: [PATCH 05/31] Fix web search disapear when finish searching --- src/lib/components/OpenWebSearchResults.svelte | 4 ++-- src/routes/conversation/[id]/+page.svelte | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/lib/components/OpenWebSearchResults.svelte b/src/lib/components/OpenWebSearchResults.svelte index aac5fa54141..3e8c8190410 100644 --- a/src/lib/components/OpenWebSearchResults.svelte +++ b/src/lib/components/OpenWebSearchResults.svelte @@ -30,8 +30,8 @@ {:else} {/if} - Web search + + Web search
diff --git a/src/routes/conversation/[id]/+page.svelte b/src/routes/conversation/[id]/+page.svelte index 363d14d6176..e7c420358ae 100644 --- a/src/routes/conversation/[id]/+page.svelte +++ b/src/routes/conversation/[id]/+page.svelte @@ -138,6 +138,7 @@ // eslint-disable-next-line no-undef const encoder = new TextDecoderStream(); const reader = response?.body?.pipeThrough(encoder).getReader(); + let importantUpdates: MessageUpdate[] = []; let finalAnswer = ""; // set str queue @@ -173,7 +174,9 @@ inputs.forEach(async (el: string) => { try { const update = JSON.parse(el) as MessageUpdate; + if (update.type === "finalAnswer") { + importantUpdates.push(update); finalAnswer = update.text; reader.cancel(); loading = false; @@ -194,8 +197,10 @@ messages = [...messages]; } } else if (update.type === "webSearch") { + importantUpdates.push(update); webSearchMessages = [...webSearchMessages, update]; } else if (update.type === "status") { + importantUpdates.push(update); if (update.status === "title" && update.message) { const conv = data.conversations.find(({ id }) => id === $page.params.id); if (conv) { @@ -210,6 +215,7 @@ $error = update.message ?? "An error has occurred"; } } else if (update.type === "error") { + importantUpdates.push(update); error.set(update.message); reader.cancel(); } @@ -225,7 +231,11 @@ }); } + const lastMessage = messages[messages.length - 1]; + lastMessage.updates = importantUpdates; + // reset the websearchmessages + importantUpdates = [] webSearchMessages = []; await invalidate(UrlDependency.ConversationList); From cc02b4c2a1959a8a67fc16049cbdf1abcd8ef701 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Wed, 20 Dec 2023 21:38:41 +0000 Subject: [PATCH 06/31] Fix lint and format --- .../embeddingEndpoints/tei/embeddingEndpoints.ts | 16 +++++++++------- src/routes/conversation/[id]/+page.svelte | 4 ++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts index 2cb5bc5b942..e689a37214d 100644 --- a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts @@ -9,7 +9,6 @@ export const embeddingEndpointTeiParametersSchema = z.object({ url: z.string().url(), }); - const getModelInfoByUrl = async (url: string) => { const { origin } = new URL(url); @@ -17,13 +16,13 @@ const getModelInfoByUrl = async (url: string) => { headers: { Accept: "application/json", "Content-Type": "application/json", - } + }, }); const info = await response.json(); return info; -} +}; export async function embeddingEndpointTei( input: z.input @@ -31,12 +30,15 @@ export async function embeddingEndpointTei( const { url, model } = embeddingEndpointTeiParametersSchema.parse(input); const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url); - const maxBatchSize = Math.min(max_client_batch_size, Math.floor(max_batch_tokens / model.maxSequenceLength)) + const maxBatchSize = Math.min( + max_client_batch_size, + Math.floor(max_batch_tokens / model.maxSequenceLength) + ); return async ({ inputs }) => { const { origin } = new URL(url); - const batchesInputs = chunk(inputs, maxBatchSize) + const batchesInputs = chunk(inputs, maxBatchSize); const batchesResults = await Promise.all( batchesInputs.map(async (batchInputs) => { @@ -52,9 +54,9 @@ export async function embeddingEndpointTei( const embeddings: number[][] = await response.json(); return embeddings; }) - ) + ); - const allEmbeddings = batchesResults.flatMap(embeddings => embeddings) + const allEmbeddings = batchesResults.flatMap((embeddings) => embeddings); return allEmbeddings; }; diff --git a/src/routes/conversation/[id]/+page.svelte b/src/routes/conversation/[id]/+page.svelte index e7c420358ae..076b9dcf1bc 100644 --- a/src/routes/conversation/[id]/+page.svelte +++ b/src/routes/conversation/[id]/+page.svelte @@ -174,7 +174,7 @@ inputs.forEach(async (el: string) => { try { const update = JSON.parse(el) as MessageUpdate; - + if (update.type === "finalAnswer") { importantUpdates.push(update); finalAnswer = update.text; @@ -235,7 +235,7 @@ lastMessage.updates = importantUpdates; // reset the websearchmessages - importantUpdates = [] + importantUpdates = []; webSearchMessages = []; await invalidate(UrlDependency.ConversationList); From 53fa58a1a0cc224b0102ea33f71adaa7b9a56f85 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Wed, 20 Dec 2023 22:03:38 +0000 Subject: [PATCH 07/31] Add more options for better embedding model usage --- .env.template | 13 ++++++++++++- src/lib/server/embeddingModels.ts | 2 ++ src/lib/server/sentenceSimilarity.ts | 5 ++++- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/.env.template b/.env.template index 6101175e68f..9f81ee3e1ee 100644 --- a/.env.template +++ b/.env.template @@ -233,7 +233,18 @@ TEXT_EMBEDDING_MODELS = `[ "description": "Hosted embedding model running on the cloud somewhere.", "maxSequenceLength": 512, "endpoints": [ - { "type": "tei", "http://localhost:8080/" } # Make sure the MAX_CLIENT_BATCH_SIZE env in TEI is large enough to accommodate web search chunks. + { "type": "tei", "http://localhost:8080/" } + ] + }, + { + "name": "intfloat/multilingual-e5-large", + "displayName": "intfloat/multilingual-e5-large", + "description": "Hosted embedding model running on the cloud somewhere.", + "maxSequenceLength": 512, + "preQuery": "query: ", # See https://huggingface.co/intfloat/multilingual-e5-large#faq + "prePassage": "passage: ", # See https://huggingface.co/intfloat/multilingual-e5-large#faq + "endpoints": [ + { "type": "tei", "http://localhost:8085/" } ] } ]` diff --git a/src/lib/server/embeddingModels.ts b/src/lib/server/embeddingModels.ts index 78f6a773694..8eaf960be2a 100644 --- a/src/lib/server/embeddingModels.ts +++ b/src/lib/server/embeddingModels.ts @@ -19,6 +19,8 @@ const modelConfig = z.object({ modelUrl: z.string().url().optional(), endpoints: z.array(embeddingEndpointSchema).optional(), maxSequenceLength: z.number().positive(), + preQuery: z.string().default(""), + prePassage: z.string().default(""), }); const embeddingModelsRaw = z.array(modelConfig).parse(JSON.parse(TEXT_EMBEDDING_MODELS)); diff --git a/src/lib/server/sentenceSimilarity.ts b/src/lib/server/sentenceSimilarity.ts index d2fbb390364..b2d0356aecd 100644 --- a/src/lib/server/sentenceSimilarity.ts +++ b/src/lib/server/sentenceSimilarity.ts @@ -12,7 +12,10 @@ export async function findSimilarSentences( sentences: string[], { topK = 5 }: { topK: number } ): Promise { - const inputs = [query, ...sentences]; + const inputs = [ + `${embeddingModel.preQuery}${query}`, + ...sentences.map((sentence) => `${embeddingModel.prePassage}${sentence}`), + ]; const embeddingEndpoint = await embeddingModel.getEndpoint(); const output = await embeddingEndpoint({ inputs }); From 9a867aae478011b03a395e8a5eeee1be5bbcab57 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Fri, 22 Dec 2023 11:45:39 +0000 Subject: [PATCH 08/31] Fixing CR issues --- .env | 2 +- .env.template | 2 +- .vscode/settings.json | 2 +- .../embeddingEndpoints/embeddingEndpoints.ts | 37 ++++++++++--------- .../tei/embeddingEndpoints.ts | 6 +-- .../embeddingEndpoints.ts | 23 ++++++------ src/lib/server/sentenceSimilarity.ts | 2 +- src/lib/server/websearch/runWebSearch.ts | 2 +- src/lib/{server => types}/embeddingModels.ts | 15 ++++---- src/routes/conversation/+server.ts | 2 +- 10 files changed, 46 insertions(+), 47 deletions(-) rename src/lib/server/embeddingEndpoints/{xenova => transformersjs}/embeddingEndpoints.ts (60%) rename src/lib/{server => types}/embeddingModels.ts (84%) diff --git a/.env b/.env index d9c743775d8..bab06e2e968 100644 --- a/.env +++ b/.env @@ -49,7 +49,7 @@ TEXT_EMBEDDING_MODELS = `[ "description": "Local embedding model running on the server.", "maxSequenceLength": 512, "endpoints": [ - { "type": "xenova" } + { "type": "transformersjs" } ] } ]` diff --git a/.env.template b/.env.template index 9f81ee3e1ee..8ae440599b6 100644 --- a/.env.template +++ b/.env.template @@ -224,7 +224,7 @@ TEXT_EMBEDDING_MODELS = `[ "description": "Local embedding model running on the server.", "maxSequenceLength": 512, "endpoints": [ - { "type": "xenova" } + { "type": "transformersjs" } ] }, { diff --git a/.vscode/settings.json b/.vscode/settings.json index c32c1bbc3ef..0d24922796c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,7 +2,7 @@ "editor.formatOnSave": true, "editor.defaultFormatter": "esbenp.prettier-vscode", "editor.codeActionsOnSave": { - "source.fixAll": true + "source.fixAll": "explicit" }, "eslint.validate": ["javascript", "svelte"] } diff --git a/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts index 289e00135e4..ebd9d6c7b3d 100644 --- a/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts @@ -3,36 +3,37 @@ import { embeddingEndpointTeiParametersSchema, } from "./tei/embeddingEndpoints"; import { z } from "zod"; -import embeddingEndpointXenova, { - embeddingEndpointXenovaParametersSchema, -} from "./xenova/embeddingEndpoints"; +import { + embeddingEndpointTransformersJS, + embeddingEndpointTransformersJSParametersSchema, +} from "./transformersjs/embeddingEndpoints"; // parameters passed when generating text interface EmbeddingEndpointParameters { inputs: string[]; } -interface CommonEmbeddingEndpoint { - weight: number; -} - // type signature for the endpoint export type EmbeddingEndpoint = (params: EmbeddingEndpointParameters) => Promise; -// generator function that takes in parameters for defining the endpoint and return the endpoint -export type EmbeddingEndpointGenerator = ( - parameters: T -) => EmbeddingEndpoint; +export const embeddingEndpointSchema = z.discriminatedUnion("type", [ + embeddingEndpointTeiParametersSchema, + embeddingEndpointTransformersJSParametersSchema, +]); + +type EmbeddingEndpointTypeOptions = z.infer["type"]; + +// generator function that takes in type discrimantor value for defining the endpoint and return the endpoint +export type EmbeddingEndpointGenerator = ( + inputs: Extract, { type: T }> +) => EmbeddingEndpoint | Promise; // list of all endpoint generators -export const embeddingEndpoints = { +export const embeddingEndpoints: { + [Key in EmbeddingEndpointTypeOptions]: EmbeddingEndpointGenerator; +} = { tei: embeddingEndpointTei, - xenova: embeddingEndpointXenova, + transformersjs: embeddingEndpointTransformersJS, }; -export const embeddingEndpointSchema = z.discriminatedUnion("type", [ - embeddingEndpointTeiParametersSchema, - embeddingEndpointXenovaParametersSchema, -]); - export default embeddingEndpoints; diff --git a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts index e689a37214d..1ceb6bbd720 100644 --- a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts @@ -56,10 +56,8 @@ export async function embeddingEndpointTei( }) ); - const allEmbeddings = batchesResults.flatMap((embeddings) => embeddings); + const flatAllEmbeddings = batchesResults.flat(); - return allEmbeddings; + return flatAllEmbeddings; }; } - -export default embeddingEndpointTei; diff --git a/src/lib/server/embeddingEndpoints/xenova/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts similarity index 60% rename from src/lib/server/embeddingEndpoints/xenova/embeddingEndpoints.ts rename to src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts index ba580290d39..53424b3f1c0 100644 --- a/src/lib/server/embeddingEndpoints/xenova/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts @@ -3,21 +3,22 @@ import type { EmbeddingEndpoint } from "../embeddingEndpoints"; import type { Tensor, Pipeline } from "@xenova/transformers"; import { pipeline } from "@xenova/transformers"; -export const embeddingEndpointXenovaParametersSchema = z.object({ +export const embeddingEndpointTransformersJSParametersSchema = z.object({ weight: z.number().int().positive().default(1), model: z.any(), - type: z.literal("xenova"), + type: z.literal("transformersjs"), }); // Use the Singleton pattern to enable lazy construction of the pipeline. -class XenovaModelsSingleton { +class TransformersJSModelsSingleton { static instances: Array<[string, Promise]> = []; static async getInstance(modelName: string): Promise { - const modelPipeline = this.instances.find(([name]) => name === modelName); + const modelPipelineInstance = this.instances.find(([name]) => name === modelName); - if (modelPipeline) { - return modelPipeline[1]; + if (modelPipelineInstance) { + const [, modelPipeline] = modelPipelineInstance; + return modelPipeline; } const newModelPipeline = pipeline("feature-extraction", modelName); @@ -28,20 +29,18 @@ class XenovaModelsSingleton { } export async function calculateEmbedding(modelName: string, inputs: string[]) { - const extractor = await XenovaModelsSingleton.getInstance(modelName); + const extractor = await TransformersJSModelsSingleton.getInstance(modelName); const output: Tensor = await extractor(inputs, { pooling: "mean", normalize: true }); return output.tolist(); } -export function embeddingEndpointXenova( - input: z.input +export function embeddingEndpointTransformersJS( + input: z.input ): EmbeddingEndpoint { - const { model } = embeddingEndpointXenovaParametersSchema.parse(input); + const { model } = embeddingEndpointTransformersJSParametersSchema.parse(input); return async ({ inputs }) => { return calculateEmbedding(model.name, inputs); }; } - -export default embeddingEndpointXenova; diff --git a/src/lib/server/sentenceSimilarity.ts b/src/lib/server/sentenceSimilarity.ts index b2d0356aecd..4911be829b5 100644 --- a/src/lib/server/sentenceSimilarity.ts +++ b/src/lib/server/sentenceSimilarity.ts @@ -1,5 +1,5 @@ import { dot } from "@xenova/transformers"; -import type { EmbeddingBackendModel } from "./embeddingModels"; +import type { EmbeddingBackendModel } from "$lib/types/embeddingModels"; // see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34 function innerProduct(embeddingA: number[], embeddingB: number[]) { diff --git a/src/lib/server/websearch/runWebSearch.ts b/src/lib/server/websearch/runWebSearch.ts index 22ead137dc6..c3151b995d0 100644 --- a/src/lib/server/websearch/runWebSearch.ts +++ b/src/lib/server/websearch/runWebSearch.ts @@ -8,7 +8,7 @@ import { findSimilarSentences } from "$lib/server/sentenceSimilarity"; import type { Conversation } from "$lib/types/Conversation"; import type { MessageUpdate } from "$lib/types/MessageUpdate"; import { getWebSearchProvider } from "./searchWeb"; -import { embeddingModels } from "../embeddingModels"; +import { embeddingModels } from "$lib/types/embeddingModels"; const MAX_N_PAGES_SCRAPE = 10 as const; const MAX_N_PAGES_EMBED = 5 as const; diff --git a/src/lib/server/embeddingModels.ts b/src/lib/types/embeddingModels.ts similarity index 84% rename from src/lib/server/embeddingModels.ts rename to src/lib/types/embeddingModels.ts index 8eaf960be2a..c08b1c30947 100644 --- a/src/lib/server/embeddingModels.ts +++ b/src/lib/types/embeddingModels.ts @@ -2,11 +2,12 @@ import { TEXT_EMBEDDING_MODELS } from "$env/static/private"; import { z } from "zod"; import { sum } from "$lib/utils/sum"; -import embeddingEndpoints, { +import { + embeddingEndpoints, embeddingEndpointSchema, type EmbeddingEndpoint, -} from "./embeddingEndpoints/embeddingEndpoints"; -import embeddingEndpointXenova from "./embeddingEndpoints/xenova/embeddingEndpoints"; +} from "$lib/server/embeddingEndpoints/embeddingEndpoints"; +import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints"; const modelConfig = z.object({ /** Used as an identifier in DB */ @@ -34,8 +35,8 @@ const addEndpoint = (m: Awaited>) => ({ ...m, getEndpoint: async (): Promise => { if (!m.endpoints) { - return embeddingEndpointXenova({ - type: "xenova", + return embeddingEndpointTransformersJS({ + type: "transformersjs", weight: 1, model: m, }); @@ -52,8 +53,8 @@ const addEndpoint = (m: Awaited>) => ({ switch (args.type) { case "tei": return embeddingEndpoints.tei(args); - case "xenova": - return embeddingEndpoints.xenova(args); + case "transformersjs": + return embeddingEndpoints.transformersjs(args); } } diff --git a/src/routes/conversation/+server.ts b/src/routes/conversation/+server.ts index d3fadb396cc..583a05cf428 100644 --- a/src/routes/conversation/+server.ts +++ b/src/routes/conversation/+server.ts @@ -10,7 +10,7 @@ import { defaultEmbeddingModel, embeddingModels, validateEmbeddingModel, -} from "$lib/server/embeddingModels"; +} from "$lib/types/embeddingModels"; export const POST: RequestHandler = async ({ locals, request }) => { const body = await request.text(); From 6c6e29041fbb1e44044fd51f56fffb72365021a4 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Fri, 22 Dec 2023 11:48:51 +0000 Subject: [PATCH 09/31] Fix websearch disapear in later PR --- src/routes/conversation/[id]/+page.svelte | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/routes/conversation/[id]/+page.svelte b/src/routes/conversation/[id]/+page.svelte index 076b9dcf1bc..d4e156f0723 100644 --- a/src/routes/conversation/[id]/+page.svelte +++ b/src/routes/conversation/[id]/+page.svelte @@ -138,7 +138,6 @@ // eslint-disable-next-line no-undef const encoder = new TextDecoderStream(); const reader = response?.body?.pipeThrough(encoder).getReader(); - let importantUpdates: MessageUpdate[] = []; let finalAnswer = ""; // set str queue @@ -176,7 +175,6 @@ const update = JSON.parse(el) as MessageUpdate; if (update.type === "finalAnswer") { - importantUpdates.push(update); finalAnswer = update.text; reader.cancel(); loading = false; @@ -197,10 +195,8 @@ messages = [...messages]; } } else if (update.type === "webSearch") { - importantUpdates.push(update); webSearchMessages = [...webSearchMessages, update]; } else if (update.type === "status") { - importantUpdates.push(update); if (update.status === "title" && update.message) { const conv = data.conversations.find(({ id }) => id === $page.params.id); if (conv) { @@ -215,7 +211,6 @@ $error = update.message ?? "An error has occurred"; } } else if (update.type === "error") { - importantUpdates.push(update); error.set(update.message); reader.cancel(); } @@ -232,10 +227,8 @@ } const lastMessage = messages[messages.length - 1]; - lastMessage.updates = importantUpdates; - // reset the websearchmessages - importantUpdates = []; + // reset the websearchMessages webSearchMessages = []; await invalidate(UrlDependency.ConversationList); From dc8d4e932979b5871d74006daaf4facc1fb1322a Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Fri, 22 Dec 2023 11:49:59 +0000 Subject: [PATCH 10/31] Fix lint --- src/routes/conversation/[id]/+page.svelte | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/routes/conversation/[id]/+page.svelte b/src/routes/conversation/[id]/+page.svelte index d4e156f0723..ba00e9757a9 100644 --- a/src/routes/conversation/[id]/+page.svelte +++ b/src/routes/conversation/[id]/+page.svelte @@ -226,8 +226,6 @@ }); } - const lastMessage = messages[messages.length - 1]; - // reset the websearchMessages webSearchMessages = []; From aacbfeb9d1fdab1e8074577e64f581df84735b1b Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Fri, 22 Dec 2023 11:58:26 +0000 Subject: [PATCH 11/31] Fix more minor code CR --- src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts | 2 +- .../embeddingEndpoints/transformersjs/embeddingEndpoints.ts | 2 +- src/lib/{types => server}/embeddingModels.ts | 5 ++--- src/lib/server/sentenceSimilarity.ts | 2 +- src/lib/server/websearch/runWebSearch.ts | 2 +- .../embeddingEndpoints.ts => types/EmbeddingEndpoints.ts} | 6 +++--- src/routes/conversation/+server.ts | 2 +- 7 files changed, 10 insertions(+), 11 deletions(-) rename src/lib/{types => server}/embeddingModels.ts (93%) rename src/lib/{server/embeddingEndpoints/embeddingEndpoints.ts => types/EmbeddingEndpoints.ts} (90%) diff --git a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts index 1ceb6bbd720..2ed14ba3f6c 100644 --- a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts @@ -1,5 +1,5 @@ import { z } from "zod"; -import type { EmbeddingEndpoint } from "../embeddingEndpoints"; +import type { EmbeddingEndpoint } from "$lib/types/EmbeddingEndpoints"; import { chunk } from "$lib/utils/chunk"; export const embeddingEndpointTeiParametersSchema = z.object({ diff --git a/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts index 53424b3f1c0..7cedddcfe15 100644 --- a/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts @@ -1,5 +1,5 @@ import { z } from "zod"; -import type { EmbeddingEndpoint } from "../embeddingEndpoints"; +import type { EmbeddingEndpoint } from "$lib/types/EmbeddingEndpoints"; import type { Tensor, Pipeline } from "@xenova/transformers"; import { pipeline } from "@xenova/transformers"; diff --git a/src/lib/types/embeddingModels.ts b/src/lib/server/embeddingModels.ts similarity index 93% rename from src/lib/types/embeddingModels.ts rename to src/lib/server/embeddingModels.ts index c08b1c30947..8c4f0d6f6f3 100644 --- a/src/lib/types/embeddingModels.ts +++ b/src/lib/server/embeddingModels.ts @@ -6,7 +6,7 @@ import { embeddingEndpoints, embeddingEndpointSchema, type EmbeddingEndpoint, -} from "$lib/server/embeddingEndpoints/embeddingEndpoints"; +} from "$lib/types/EmbeddingEndpoints"; import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints"; const modelConfig = z.object({ @@ -61,7 +61,7 @@ const addEndpoint = (m: Awaited>) => ({ random -= endpoint.weight; } - throw new Error(`Failed to select endpoint`); + throw new Error(`Failed to select embedding endpoint`); }, }); @@ -72,7 +72,6 @@ export const embeddingModels = await Promise.all( export const defaultEmbeddingModel = embeddingModels[0]; export const validateEmbeddingModel = (_models: EmbeddingBackendModel[]) => { - // Zod enum function requires 2 parameters return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]); }; diff --git a/src/lib/server/sentenceSimilarity.ts b/src/lib/server/sentenceSimilarity.ts index 4911be829b5..6058fc43ad9 100644 --- a/src/lib/server/sentenceSimilarity.ts +++ b/src/lib/server/sentenceSimilarity.ts @@ -1,5 +1,5 @@ import { dot } from "@xenova/transformers"; -import type { EmbeddingBackendModel } from "$lib/types/embeddingModels"; +import type { EmbeddingBackendModel } from "$lib/server/embeddingModels"; // see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34 function innerProduct(embeddingA: number[], embeddingB: number[]) { diff --git a/src/lib/server/websearch/runWebSearch.ts b/src/lib/server/websearch/runWebSearch.ts index c3151b995d0..ecbfef88dc3 100644 --- a/src/lib/server/websearch/runWebSearch.ts +++ b/src/lib/server/websearch/runWebSearch.ts @@ -8,7 +8,7 @@ import { findSimilarSentences } from "$lib/server/sentenceSimilarity"; import type { Conversation } from "$lib/types/Conversation"; import type { MessageUpdate } from "$lib/types/MessageUpdate"; import { getWebSearchProvider } from "./searchWeb"; -import { embeddingModels } from "$lib/types/embeddingModels"; +import { embeddingModels } from "$lib/server/embeddingModels"; const MAX_N_PAGES_SCRAPE = 10 as const; const MAX_N_PAGES_EMBED = 5 as const; diff --git a/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts b/src/lib/types/EmbeddingEndpoints.ts similarity index 90% rename from src/lib/server/embeddingEndpoints/embeddingEndpoints.ts rename to src/lib/types/EmbeddingEndpoints.ts index ebd9d6c7b3d..b7805a731f5 100644 --- a/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts +++ b/src/lib/types/EmbeddingEndpoints.ts @@ -1,12 +1,12 @@ +import { z } from "zod"; import { embeddingEndpointTei, embeddingEndpointTeiParametersSchema, -} from "./tei/embeddingEndpoints"; -import { z } from "zod"; +} from "$lib/server/embeddingEndpoints/tei/embeddingEndpoints"; import { embeddingEndpointTransformersJS, embeddingEndpointTransformersJSParametersSchema, -} from "./transformersjs/embeddingEndpoints"; +} from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints"; // parameters passed when generating text interface EmbeddingEndpointParameters { diff --git a/src/routes/conversation/+server.ts b/src/routes/conversation/+server.ts index 583a05cf428..d3fadb396cc 100644 --- a/src/routes/conversation/+server.ts +++ b/src/routes/conversation/+server.ts @@ -10,7 +10,7 @@ import { defaultEmbeddingModel, embeddingModels, validateEmbeddingModel, -} from "$lib/types/embeddingModels"; +} from "$lib/server/embeddingModels"; export const POST: RequestHandler = async ({ locals, request }) => { const body = await request.text(); From 7a9950d9b678169c9971cee271fe0a3229fc4571 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Fri, 22 Dec 2023 12:50:30 +0000 Subject: [PATCH 12/31] Valiadate embeddingModelName field in model config --- src/lib/server/embeddingModels.ts | 12 ++++++++++-- src/lib/server/models.ts | 3 ++- src/routes/conversation/+server.ts | 11 +++++------ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/lib/server/embeddingModels.ts b/src/lib/server/embeddingModels.ts index 8c4f0d6f6f3..29da04790f6 100644 --- a/src/lib/server/embeddingModels.ts +++ b/src/lib/server/embeddingModels.ts @@ -71,8 +71,16 @@ export const embeddingModels = await Promise.all( export const defaultEmbeddingModel = embeddingModels[0]; -export const validateEmbeddingModel = (_models: EmbeddingBackendModel[]) => { - return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]); +const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => { + return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]); +}; + +export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => { + return validateEmbeddingModel(_models, "id"); +}; + +export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => { + return validateEmbeddingModel(_models, "name"); }; export type EmbeddingBackendModel = typeof defaultEmbeddingModel; diff --git a/src/lib/server/models.ts b/src/lib/server/models.ts index fa2952b2827..8f8091def12 100644 --- a/src/lib/server/models.ts +++ b/src/lib/server/models.ts @@ -12,6 +12,7 @@ import { z } from "zod"; import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints"; import endpointTgi from "./endpoints/tgi/endpointTgi"; import { sum } from "$lib/utils/sum"; +import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels"; type Optional = Pick, K> & Omit; @@ -66,7 +67,7 @@ const modelConfig = z.object({ .optional(), multimodal: z.boolean().default(false), unlisted: z.boolean().default(false), - embeddingModelName: z.string().optional(), + embeddingModelName: validateEmbeddingModelByName(embeddingModels).optional(), }); const modelsRaw = z.array(modelConfig).parse(JSON.parse(MODELS)); diff --git a/src/routes/conversation/+server.ts b/src/routes/conversation/+server.ts index d3fadb396cc..739274f01f9 100644 --- a/src/routes/conversation/+server.ts +++ b/src/routes/conversation/+server.ts @@ -9,7 +9,7 @@ import { models, validateModel } from "$lib/server/models"; import { defaultEmbeddingModel, embeddingModels, - validateEmbeddingModel, + validateEmbeddingModelById, } from "$lib/server/embeddingModels"; export const POST: RequestHandler = async ({ locals, request }) => { @@ -22,12 +22,12 @@ export const POST: RequestHandler = async ({ locals, request }) => { .object({ fromShare: z.string().optional(), model: validateModel(models), - embeddingModel: validateEmbeddingModel(embeddingModels).optional(), preprompt: z.string().optional(), }) .parse(JSON.parse(body)); let preprompt = values.preprompt; + let embeddingModelName: string; if (values.fromShare) { const conversation = await collections.sharedConversations.findOne({ @@ -41,19 +41,18 @@ export const POST: RequestHandler = async ({ locals, request }) => { title = conversation.title; messages = conversation.messages; values.model = conversation.model; - values.embeddingModel = conversation.embeddingModel; + embeddingModelName = conversation.embeddingModel; preprompt = conversation.preprompt; } const model = models.find((m) => m.name === values.model); - values.embeddingModel = - values.embeddingModel ?? model?.embeddingModelName ?? defaultEmbeddingModel.name; - if (!model) { throw error(400, "Invalid model"); } + embeddingModelName ??= model.embeddingModelName ?? defaultEmbeddingModel.name; + if (model.unlisted) { throw error(400, "Can't start a conversation with an unlisted model"); } From bce01d46803ffe344b0893e6053bb61af04fd494 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Fri, 22 Dec 2023 12:50:50 +0000 Subject: [PATCH 13/31] Add embeddingModel into shared conversation --- src/routes/conversation/[id]/share/+server.ts | 1 + src/routes/login/callback/updateUser.spec.ts | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/routes/conversation/[id]/share/+server.ts b/src/routes/conversation/[id]/share/+server.ts index e3f81222180..4877de755ad 100644 --- a/src/routes/conversation/[id]/share/+server.ts +++ b/src/routes/conversation/[id]/share/+server.ts @@ -38,6 +38,7 @@ export async function POST({ params, url, locals }) { updatedAt: new Date(), title: conversation.title, model: conversation.model, + embeddingModel: conversation.embeddingModel, preprompt: conversation.preprompt, }; diff --git a/src/routes/login/callback/updateUser.spec.ts b/src/routes/login/callback/updateUser.spec.ts index 54229914571..fefaf8b0f5a 100644 --- a/src/routes/login/callback/updateUser.spec.ts +++ b/src/routes/login/callback/updateUser.spec.ts @@ -6,6 +6,7 @@ import { ObjectId } from "mongodb"; import { DEFAULT_SETTINGS } from "$lib/types/Settings"; import { defaultModel } from "$lib/server/models"; import { findUser } from "$lib/server/auth"; +import { defaultEmbeddingModel } from "$lib/server/embeddingModels"; const userData = { preferred_username: "new-username", @@ -46,6 +47,7 @@ const insertRandomConversations = async (count: number) => { title: "random title", messages: [], model: defaultModel.id, + embeddingModel: defaultEmbeddingModel.id, createdAt: new Date(), updatedAt: new Date(), sessionId: locals.sessionId, From f822ced9edb6cd3f8112ef2cfd2449a377d2e3c6 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Fri, 22 Dec 2023 12:51:47 +0000 Subject: [PATCH 14/31] Fix lint and format --- src/routes/conversation/+server.ts | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/routes/conversation/+server.ts b/src/routes/conversation/+server.ts index 739274f01f9..8f79de4ae69 100644 --- a/src/routes/conversation/+server.ts +++ b/src/routes/conversation/+server.ts @@ -6,11 +6,7 @@ import { base } from "$app/paths"; import { z } from "zod"; import type { Message } from "$lib/types/Message"; import { models, validateModel } from "$lib/server/models"; -import { - defaultEmbeddingModel, - embeddingModels, - validateEmbeddingModelById, -} from "$lib/server/embeddingModels"; +import { defaultEmbeddingModel } from "$lib/server/embeddingModels"; export const POST: RequestHandler = async ({ locals, request }) => { const body = await request.text(); @@ -68,7 +64,7 @@ export const POST: RequestHandler = async ({ locals, request }) => { preprompt: preprompt === model?.preprompt ? model?.preprompt : preprompt, createdAt: new Date(), updatedAt: new Date(), - embeddingModel: values.embeddingModel, + embeddingModel: embeddingModelName, ...(locals.user ? { userId: locals.user._id } : { sessionId: locals.sessionId }), ...(values.fromShare ? { meta: { fromShareId: values.fromShare } } : {}), }); From 421eccae25a03ea224faf78cd2c94210e97a7bd2 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Sat, 23 Dec 2023 16:09:59 +0000 Subject: [PATCH 15/31] Add default embedding model, and more readme explanation --- .env.template | 1 - README.md | 74 ++++++++++++++++++++++++++++++- src/lib/server/embeddingModels.ts | 17 ++++++- 3 files changed, 88 insertions(+), 4 deletions(-) diff --git a/.env.template b/.env.template index 8ae440599b6..301fced8285 100644 --- a/.env.template +++ b/.env.template @@ -216,7 +216,6 @@ OLD_MODELS=`[{"name":"bigcode/starcoder"}, {"name":"OpenAssistant/oasst-sft-6-ll TASK_MODEL='mistralai/Mistral-7B-Instruct-v0.2' -# Default to using the first text embedding model when not specifying 'embeddingModelName' in the model itself. TEXT_EMBEDDING_MODELS = `[ { "name": "Xenova/gte-small", diff --git a/README.md b/README.md index 64293a8247d..87a36a781f7 100644 --- a/README.md +++ b/README.md @@ -78,10 +78,44 @@ Chat UI features a powerful Web Search feature. It works by: 1. Generating an appropriate search query from the user prompt. 2. Performing web search and extracting content from webpages. -3. Creating embeddings from texts using [transformers.js](https://huggingface.co/docs/transformers.js). Specifically, using [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model. +3. Creating embeddings from texts using a text embedding model. 4. From these embeddings, find the ones that are closest to the user query using a vector similarity search. Specifically, we use `inner product` distance. 5. Get the corresponding texts to those closest embeddings and perform [Retrieval-Augmented Generation](https://huggingface.co/papers/2005.11401) (i.e. expand user prompt by adding those texts so that an LLM can use this information). +### Text Embedding Models + +By default (for backward compatibility) when not defining TEXT_EMBEDDING_MODELS environment variable It will use [transformers.js](https://huggingface.co/docs/transformers.js), specifically, [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model. + +You can customize the embedding model by setting TEXT_EMBEDDING_MODELS in your `.env.local`, for example + +```env +TEXT_EMBEDDING_MODELS = `[ + { + "name": "Xenova/gte-small", + "displayName": "Xenova/gte-small", + "description": "locally running embedding", + "maxSequenceLength": 512, + "endpoints": [ + {"type": "xenova"} + ] + }, + { + "name": "intfloat/e5-base-v2", + "displayName": "intfloat/e5-base-v2", + "description": "hosted embedding model", + "maxSequenceLength": 512, + "endpoints": [ + {"type": "tei", "url": "http://127.0.0.1:8080/"} + ] + } +]` +``` + +The required fields are `name`, `maxSequenceLength` and `endpoints`. +It supports [transformers.js](https://huggingface.co/docs/transformers.js) and [TEI](https://github.com/huggingface/text-embeddings-inference), transformers.js model run locally, and TEI models run in a different environment. each `endpoints` provided supports a `weight` parameter which will be used to determine the probability of requesting a particular endpoint. + +When defining more than one embedding model, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModelName` to the name of the model. + ## Extra parameters ### OpenID connect @@ -425,6 +459,44 @@ If you're using a certificate signed by a private CA, you will also need to add If you're using a self-signed certificate, e.g. for testing or development purposes, you can set the `REJECT_UNAUTHORIZED` parameter to `false` in your `.env.local`. This will disable certificate validation, and allow Chat UI to connect to your custom endpoint. +#### Specific Embedding Model + +A model can use any of the embedding models defined in `.env.local`, (currently used when web searching), +by default it will use the first embedding model, but it can be changed with the field `embeddingModelName`: + +```env +TEXT_EMBEDDING_MODELS = `[ + { + "name": "Xenova/gte-small", + "maxSequenceLength": 512, + "endpoints": [ + {"type": "xenova"} + ] + }, + { + "name": "intfloat/e5-base-v2", + "maxSequenceLength": 512, + "endpoints": [ + ... + ] + } +]` + +MODELS=[ + { + "name": "Ollama Mistral", + "chatPromptTemplate": "...", + "embeddingModelName": "intfloat/e5-base-v2" + "parameters": { + ... + }, + "endpoints": [ + ... + ] + } +] +``` + ## Deploying to a HF Space Create a `DOTENV_LOCAL` secret to your HF space with the content of your .env.local, and they will be picked up automatically when you run. diff --git a/src/lib/server/embeddingModels.ts b/src/lib/server/embeddingModels.ts index 29da04790f6..eb9b4fb7b35 100644 --- a/src/lib/server/embeddingModels.ts +++ b/src/lib/server/embeddingModels.ts @@ -18,13 +18,26 @@ const modelConfig = z.object({ description: z.string().min(1).optional(), websiteUrl: z.string().url().optional(), modelUrl: z.string().url().optional(), - endpoints: z.array(embeddingEndpointSchema).optional(), + endpoints: z.array(embeddingEndpointSchema).nonempty(), maxSequenceLength: z.number().positive(), preQuery: z.string().default(""), prePassage: z.string().default(""), }); -const embeddingModelsRaw = z.array(modelConfig).parse(JSON.parse(TEXT_EMBEDDING_MODELS)); +// Default embedding model for backward compatibility +const rawEmbeddingModelJSON = + TEXT_EMBEDDING_MODELS || + `[ + { + "name": "Xenova/gte-small", + "maxSequenceLength": 512, + "endpoints": [ + { "type": "transformersjs" } + ] + } +]`; + +const embeddingModelsRaw = z.array(modelConfig).parse(JSON.parse(rawEmbeddingModelJSON)); const processEmbeddingModel = async (m: z.infer) => ({ ...m, From e85faee0ed23273ca97e91dc8be4a4436f7e78b0 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Sat, 23 Dec 2023 16:20:21 +0000 Subject: [PATCH 16/31] Fix minor embedding model readme detailed --- .env.template | 2 +- README.md | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.env.template b/.env.template index 301fced8285..e5de3722002 100644 --- a/.env.template +++ b/.env.template @@ -239,7 +239,7 @@ TEXT_EMBEDDING_MODELS = `[ "name": "intfloat/multilingual-e5-large", "displayName": "intfloat/multilingual-e5-large", "description": "Hosted embedding model running on the cloud somewhere.", - "maxSequenceLength": 512, + "maxSequenceLength": 1024, "preQuery": "query: ", # See https://huggingface.co/intfloat/multilingual-e5-large#faq "prePassage": "passage: ", # See https://huggingface.co/intfloat/multilingual-e5-large#faq "endpoints": [ diff --git a/README.md b/README.md index 87a36a781f7..ff1d7591fc1 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,9 @@ TEXT_EMBEDDING_MODELS = `[ "name": "intfloat/e5-base-v2", "displayName": "intfloat/e5-base-v2", "description": "hosted embedding model", - "maxSequenceLength": 512, + "maxSequenceLength": 768, + "preQuery": "query: ", # See https://huggingface.co/intfloat/e5-base-v2#faq + "prePassage": "passage: ", # See https://huggingface.co/intfloat/e5-base-v2#faq "endpoints": [ {"type": "tei", "url": "http://127.0.0.1:8080/"} ] @@ -475,7 +477,7 @@ TEXT_EMBEDDING_MODELS = `[ }, { "name": "intfloat/e5-base-v2", - "maxSequenceLength": 512, + "maxSequenceLength": 768, "endpoints": [ ... ] From a36c521b4a0304e0df9f0e6cd3174d9776635e9e Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Sun, 31 Dec 2023 22:27:50 +0200 Subject: [PATCH 17/31] Update settings.json --- .vscode/settings.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 0d24922796c..c32c1bbc3ef 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,7 +2,7 @@ "editor.formatOnSave": true, "editor.defaultFormatter": "esbenp.prettier-vscode", "editor.codeActionsOnSave": { - "source.fixAll": "explicit" + "source.fixAll": true }, "eslint.validate": ["javascript", "svelte"] } From d132375c90401f4c7e7fd75cdb621d6b4f3f2d8a Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Sat, 6 Jan 2024 23:22:09 +0200 Subject: [PATCH 18/31] Update README.md Co-authored-by: Mishig --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index eb64046254d..687452b6aee 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ Chat UI features a powerful Web Search feature. It works by: ### Text Embedding Models -By default (for backward compatibility) when not defining TEXT_EMBEDDING_MODELS environment variable It will use [transformers.js](https://huggingface.co/docs/transformers.js), specifically, [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model. +By default (for backward compatibility), when `TEXT_EMBEDDING_MODELS` environment variable is not defined, [transformers.js](https://huggingface.co/docs/transformers.js) embedding models will be used for embedding tasks, specifically, [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model. You can customize the embedding model by setting TEXT_EMBEDDING_MODELS in your `.env.local`, for example From e1052250b11e9bbfad9446f35087d90b8d1b0521 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Sat, 6 Jan 2024 23:22:27 +0200 Subject: [PATCH 19/31] Update README.md Co-authored-by: Mishig --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 687452b6aee..a2cb7dc0b71 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ Chat UI features a powerful Web Search feature. It works by: By default (for backward compatibility), when `TEXT_EMBEDDING_MODELS` environment variable is not defined, [transformers.js](https://huggingface.co/docs/transformers.js) embedding models will be used for embedding tasks, specifically, [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model. -You can customize the embedding model by setting TEXT_EMBEDDING_MODELS in your `.env.local`, for example +You can customize the embedding model by setting `TEXT_EMBEDDING_MODELS` in your `.env.local` file. For example: ```env TEXT_EMBEDDING_MODELS = `[ From f615fef10efaa9289ca053fd58da44b403b424bd Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Sat, 6 Jan 2024 23:24:39 +0200 Subject: [PATCH 20/31] Apply suggestions from code review Co-authored-by: Mishig --- README.md | 4 ++-- src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index a2cb7dc0b71..c913f3e09f3 100644 --- a/README.md +++ b/README.md @@ -114,9 +114,9 @@ TEXT_EMBEDDING_MODELS = `[ ``` The required fields are `name`, `maxSequenceLength` and `endpoints`. -It supports [transformers.js](https://huggingface.co/docs/transformers.js) and [TEI](https://github.com/huggingface/text-embeddings-inference), transformers.js model run locally, and TEI models run in a different environment. each `endpoints` provided supports a `weight` parameter which will be used to determine the probability of requesting a particular endpoint. +Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js) and [`TEI`](https://github.com/huggingface/text-embeddings-inference). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint. -When defining more than one embedding model, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModelName` to the name of the model. +When more than one embedding models are supplied in `.env.local` file, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModelName` to the name of the model. ## Extra parameters diff --git a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts index 2ed14ba3f6c..52a6df3c42d 100644 --- a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts @@ -19,9 +19,8 @@ const getModelInfoByUrl = async (url: string) => { }, }); - const info = await response.json(); - - return info; + const json = await response.json(); + return json; }; export async function embeddingEndpointTei( From 9ef62b9ecc7d4d21dfd2aa7fd7c8e851e8b68053 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Sat, 6 Jan 2024 21:39:29 +0000 Subject: [PATCH 21/31] Resolved more issues --- README.md | 15 ++++++++------- src/lib/components/MobileNav.svelte | 2 +- src/lib/components/NavMenu.svelte | 4 ++-- src/lib/components/chat/ChatIntroduction.svelte | 2 +- src/lib/components/chat/ChatMessage.svelte | 6 +++--- src/lib/components/chat/ChatWindow.svelte | 6 +++--- .../embeddingEndpoints/tei/embeddingEndpoints.ts | 11 +++++++---- src/lib/server/models.ts | 2 +- src/lib/server/sentenceSimilarity.ts | 11 ++++++----- src/lib/types/EmbeddingEndpoints.ts | 4 +++- src/routes/+layout.svelte | 2 +- src/routes/conversation/+server.ts | 8 ++++---- 12 files changed, 40 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index c913f3e09f3..60a75a377cf 100644 --- a/README.md +++ b/README.md @@ -20,9 +20,10 @@ A chat interface using open source models, eg OpenAssistant or Llama. It is a Sv 1. [Setup](#setup) 2. [Launch](#launch) 3. [Web Search](#web-search) -4. [Extra parameters](#extra-parameters) -5. [Deploying to a HF Space](#deploying-to-a-hf-space) -6. [Building](#building) +4. [Text Embedding Models](#text-embedding-models) +5. [Extra parameters](#extra-parameters) +6. [Deploying to a HF Space](#deploying-to-a-hf-space) +7. [Building](#building) ## No Setup Deploy @@ -82,7 +83,7 @@ Chat UI features a powerful Web Search feature. It works by: 4. From these embeddings, find the ones that are closest to the user query using a vector similarity search. Specifically, we use `inner product` distance. 5. Get the corresponding texts to those closest embeddings and perform [Retrieval-Augmented Generation](https://huggingface.co/papers/2005.11401) (i.e. expand user prompt by adding those texts so that an LLM can use this information). -### Text Embedding Models +## Text Embedding Models By default (for backward compatibility), when `TEXT_EMBEDDING_MODELS` environment variable is not defined, [transformers.js](https://huggingface.co/docs/transformers.js) embedding models will be used for embedding tasks, specifically, [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model. @@ -116,7 +117,7 @@ TEXT_EMBEDDING_MODELS = `[ The required fields are `name`, `maxSequenceLength` and `endpoints`. Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js) and [`TEI`](https://github.com/huggingface/text-embeddings-inference). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint. -When more than one embedding models are supplied in `.env.local` file, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModelName` to the name of the model. +When more than one embedding models are supplied in `.env.local` file, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModel` to the name of the model. ## Extra parameters @@ -464,7 +465,7 @@ If you're using a self-signed certificate, e.g. for testing or development purpo #### Specific Embedding Model A model can use any of the embedding models defined in `.env.local`, (currently used when web searching), -by default it will use the first embedding model, but it can be changed with the field `embeddingModelName`: +by default it will use the first embedding model, but it can be changed with the field `embeddingModel`: ```env TEXT_EMBEDDING_MODELS = `[ @@ -488,7 +489,7 @@ MODELS=[ { "name": "Ollama Mistral", "chatPromptTemplate": "...", - "embeddingModelName": "intfloat/e5-base-v2" + "embeddingModel": "intfloat/e5-base-v2" "parameters": { ... }, diff --git a/src/lib/components/MobileNav.svelte b/src/lib/components/MobileNav.svelte index 2befba31b77..ca13eb4a972 100644 --- a/src/lib/components/MobileNav.svelte +++ b/src/lib/components/MobileNav.svelte @@ -30,7 +30,7 @@
{#each Object.entries(groupedConversations) as [group, convs]} {#if convs.length} @@ -89,7 +89,7 @@ > diff --git a/src/lib/components/chat/ChatIntroduction.svelte b/src/lib/components/chat/ChatIntroduction.svelte index 8a055ef214e..078f44e5d8b 100644 --- a/src/lib/components/chat/ChatIntroduction.svelte +++ b/src/lib/components/chat/ChatIntroduction.svelte @@ -78,7 +78,7 @@ {#each currentModelMetadata.promptExamples as example} diff --git a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts index 52a6df3c42d..8c57b706d32 100644 --- a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts @@ -1,5 +1,5 @@ import { z } from "zod"; -import type { EmbeddingEndpoint } from "$lib/types/EmbeddingEndpoints"; +import type { EmbeddingEndpoint, Embedding } from "$lib/types/EmbeddingEndpoints"; import { chunk } from "$lib/utils/chunk"; export const embeddingEndpointTeiParametersSchema = z.object({ @@ -7,15 +7,17 @@ export const embeddingEndpointTeiParametersSchema = z.object({ model: z.any(), type: z.literal("tei"), url: z.string().url(), + authorization: z.string().optional(), }); -const getModelInfoByUrl = async (url: string) => { +const getModelInfoByUrl = async (url: string, authorization?: string) => { const { origin } = new URL(url); const response = await fetch(`${origin}/info`, { headers: { Accept: "application/json", "Content-Type": "application/json", + ...(authorization ? { Authorization: authorization } : {}), }, }); @@ -26,7 +28,7 @@ const getModelInfoByUrl = async (url: string) => { export async function embeddingEndpointTei( input: z.input ): Promise { - const { url, model } = embeddingEndpointTeiParametersSchema.parse(input); + const { url, model, authorization } = embeddingEndpointTeiParametersSchema.parse(input); const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url); const maxBatchSize = Math.min( @@ -46,11 +48,12 @@ export async function embeddingEndpointTei( headers: { Accept: "application/json", "Content-Type": "application/json", + ...(authorization ? { Authorization: authorization } : {}), }, body: JSON.stringify({ inputs: batchInputs, normalize: true, truncate: true }), }); - const embeddings: number[][] = await response.json(); + const embeddings: Embedding[] = await response.json(); return embeddings; }) ); diff --git a/src/lib/server/models.ts b/src/lib/server/models.ts index 8f8091def12..5a820234aba 100644 --- a/src/lib/server/models.ts +++ b/src/lib/server/models.ts @@ -67,7 +67,7 @@ const modelConfig = z.object({ .optional(), multimodal: z.boolean().default(false), unlisted: z.boolean().default(false), - embeddingModelName: validateEmbeddingModelByName(embeddingModels).optional(), + embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(), }); const modelsRaw = z.array(modelConfig).parse(JSON.parse(MODELS)); diff --git a/src/lib/server/sentenceSimilarity.ts b/src/lib/server/sentenceSimilarity.ts index 6058fc43ad9..455b25d4d06 100644 --- a/src/lib/server/sentenceSimilarity.ts +++ b/src/lib/server/sentenceSimilarity.ts @@ -1,8 +1,9 @@ import { dot } from "@xenova/transformers"; import type { EmbeddingBackendModel } from "$lib/server/embeddingModels"; +import type { Embedding } from "$lib/types/EmbeddingEndpoints"; // see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34 -function innerProduct(embeddingA: number[], embeddingB: number[]) { +function innerProduct(embeddingA: Embedding, embeddingB: Embedding) { return 1.0 - dot(embeddingA, embeddingB); } @@ -11,7 +12,7 @@ export async function findSimilarSentences( query: string, sentences: string[], { topK = 5 }: { topK: number } -): Promise { +): Promise { const inputs = [ `${embeddingModel.preQuery}${query}`, ...sentences.map((sentence) => `${embeddingModel.prePassage}${sentence}`), @@ -20,11 +21,11 @@ export async function findSimilarSentences( const embeddingEndpoint = await embeddingModel.getEndpoint(); const output = await embeddingEndpoint({ inputs }); - const queryEmbedding: number[] = output[0]; - const sentencesEmbeddings: number[][] = output.slice(1, inputs.length - 1); + const queryEmbedding: Embedding = output[0]; + const sentencesEmbeddings: Embedding[] = output.slice(1, inputs.length - 1); const distancesFromQuery: { distance: number; index: number }[] = [...sentencesEmbeddings].map( - (sentenceEmbedding: number[], index: number) => { + (sentenceEmbedding: Embedding, index: number) => { return { distance: innerProduct(queryEmbedding, sentenceEmbedding), index: index, diff --git a/src/lib/types/EmbeddingEndpoints.ts b/src/lib/types/EmbeddingEndpoints.ts index b7805a731f5..57cd425c578 100644 --- a/src/lib/types/EmbeddingEndpoints.ts +++ b/src/lib/types/EmbeddingEndpoints.ts @@ -13,8 +13,10 @@ interface EmbeddingEndpointParameters { inputs: string[]; } +export type Embedding = number[]; + // type signature for the endpoint -export type EmbeddingEndpoint = (params: EmbeddingEndpointParameters) => Promise; +export type EmbeddingEndpoint = (params: EmbeddingEndpointParameters) => Promise; export const embeddingEndpointSchema = z.discriminatedUnion("type", [ embeddingEndpointTeiParametersSchema, diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index 318c21cc3ff..63b1b9d61fc 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -148,7 +148,7 @@
{ .parse(JSON.parse(body)); let preprompt = values.preprompt; - let embeddingModelName: string; + let embeddingModel: string; if (values.fromShare) { const conversation = await collections.sharedConversations.findOne({ @@ -37,7 +37,7 @@ export const POST: RequestHandler = async ({ locals, request }) => { title = conversation.title; messages = conversation.messages; values.model = conversation.model; - embeddingModelName = conversation.embeddingModel; + embeddingModel = conversation.embeddingModel; preprompt = conversation.preprompt; } @@ -47,7 +47,7 @@ export const POST: RequestHandler = async ({ locals, request }) => { throw error(400, "Invalid model"); } - embeddingModelName ??= model.embeddingModelName ?? defaultEmbeddingModel.name; + embeddingModel ??= model.embeddingModel ?? defaultEmbeddingModel.name; if (model.unlisted) { throw error(400, "Can't start a conversation with an unlisted model"); @@ -64,7 +64,7 @@ export const POST: RequestHandler = async ({ locals, request }) => { preprompt: preprompt === model?.preprompt ? model?.preprompt : preprompt, createdAt: new Date(), updatedAt: new Date(), - embeddingModel: embeddingModelName, + embeddingModel: embeddingModel, ...(locals.user ? { userId: locals.user._id } : { sessionId: locals.sessionId }), ...(values.fromShare ? { meta: { fromShareId: values.fromShare } } : {}), }); From 7c795826321674655cb5e1ad589bde35b73f749d Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Mon, 8 Jan 2024 10:56:34 +0100 Subject: [PATCH 22/31] lint --- src/lib/components/MobileNav.svelte | 2 +- src/lib/components/NavMenu.svelte | 4 ++-- src/lib/components/chat/ChatIntroduction.svelte | 2 +- src/lib/components/chat/ChatMessage.svelte | 6 +++--- src/lib/components/chat/ChatWindow.svelte | 6 +++--- src/routes/+layout.svelte | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/lib/components/MobileNav.svelte b/src/lib/components/MobileNav.svelte index ca13eb4a972..2befba31b77 100644 --- a/src/lib/components/MobileNav.svelte +++ b/src/lib/components/MobileNav.svelte @@ -30,7 +30,7 @@
{#each Object.entries(groupedConversations) as [group, convs]} {#if convs.length} @@ -89,7 +89,7 @@ > diff --git a/src/lib/components/chat/ChatIntroduction.svelte b/src/lib/components/chat/ChatIntroduction.svelte index 078f44e5d8b..8a055ef214e 100644 --- a/src/lib/components/chat/ChatIntroduction.svelte +++ b/src/lib/components/chat/ChatIntroduction.svelte @@ -78,7 +78,7 @@ {#each currentModelMetadata.promptExamples as example} diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index 63b1b9d61fc..318c21cc3ff 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -148,7 +148,7 @@
Date: Mon, 8 Jan 2024 22:08:26 +0000 Subject: [PATCH 23/31] Fix more issues --- .env | 1 - .env.template | 32 -------------------------------- README.md | 7 ++++--- 3 files changed, 4 insertions(+), 36 deletions(-) diff --git a/.env b/.env index 433e95570a7..fde7110f20b 100644 --- a/.env +++ b/.env @@ -58,7 +58,6 @@ TEXT_EMBEDDING_MODELS = `[ } ]` - # 'name', 'userMessageToken', 'assistantMessageToken' are required MODELS=`[ { diff --git a/.env.template b/.env.template index b3517a360d4..042a248caad 100644 --- a/.env.template +++ b/.env.template @@ -231,38 +231,6 @@ TASK_MODEL='mistralai/Mistral-7B-Instruct-v0.2' # "stop": [""] # }}` -TEXT_EMBEDDING_MODELS = `[ - { - "name": "Xenova/gte-small", - "displayName": "Xenova/gte-small", - "description": "Local embedding model running on the server.", - "maxSequenceLength": 512, - "endpoints": [ - { "type": "transformersjs" } - ] - }, - { - "name": "thenlper/gte-base", - "displayName": "thenlper/gte-base", - "description": "Hosted embedding model running on the cloud somewhere.", - "maxSequenceLength": 512, - "endpoints": [ - { "type": "tei", "http://localhost:8080/" } - ] - }, - { - "name": "intfloat/multilingual-e5-large", - "displayName": "intfloat/multilingual-e5-large", - "description": "Hosted embedding model running on the cloud somewhere.", - "maxSequenceLength": 1024, - "preQuery": "query: ", # See https://huggingface.co/intfloat/multilingual-e5-large#faq - "prePassage": "passage: ", # See https://huggingface.co/intfloat/multilingual-e5-large#faq - "endpoints": [ - { "type": "tei", "http://localhost:8085/" } - ] - } -]` - APP_BASE="/chat" PUBLIC_ORIGIN=https://huggingface.co PUBLIC_SHARE_PREFIX=https://hf.co/chat diff --git a/README.md b/README.md index 60a75a377cf..d7695b7bc50 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ TEXT_EMBEDDING_MODELS = `[ "description": "locally running embedding", "maxSequenceLength": 512, "endpoints": [ - {"type": "xenova"} + {"type": "transformersjs"} ] }, { @@ -473,14 +473,15 @@ TEXT_EMBEDDING_MODELS = `[ "name": "Xenova/gte-small", "maxSequenceLength": 512, "endpoints": [ - {"type": "xenova"} + {"type": "transformersjs"} ] }, { "name": "intfloat/e5-base-v2", "maxSequenceLength": 768, "endpoints": [ - ... + {"type": "tei", "url": "http://127.0.0.1:8080/", "authorization": "Basic VVNFUjpQQVNT"}, + {"type": "tei", "url": "http://127.0.0.1:8081/"} ] } ]` From 65760bcbae2d081b9940c5e921b73fa78dabf610 Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Mon, 8 Jan 2024 22:09:48 +0000 Subject: [PATCH 24/31] Fix format --- src/lib/components/MobileNav.svelte | 2 +- src/lib/components/NavMenu.svelte | 4 ++-- src/lib/components/chat/ChatIntroduction.svelte | 2 +- src/lib/components/chat/ChatMessage.svelte | 6 +++--- src/lib/components/chat/ChatWindow.svelte | 6 +++--- src/routes/+layout.svelte | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/lib/components/MobileNav.svelte b/src/lib/components/MobileNav.svelte index 2befba31b77..ca13eb4a972 100644 --- a/src/lib/components/MobileNav.svelte +++ b/src/lib/components/MobileNav.svelte @@ -30,7 +30,7 @@
{#each Object.entries(groupedConversations) as [group, convs]} {#if convs.length} @@ -89,7 +89,7 @@ > diff --git a/src/lib/components/chat/ChatIntroduction.svelte b/src/lib/components/chat/ChatIntroduction.svelte index 8a055ef214e..078f44e5d8b 100644 --- a/src/lib/components/chat/ChatIntroduction.svelte +++ b/src/lib/components/chat/ChatIntroduction.svelte @@ -78,7 +78,7 @@ {#each currentModelMetadata.promptExamples as example} diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index 318c21cc3ff..63b1b9d61fc 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -148,7 +148,7 @@
Date: Mon, 8 Jan 2024 22:12:58 +0000 Subject: [PATCH 25/31] fix small typo --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d7695b7bc50..1cf4dee89b3 100644 --- a/README.md +++ b/README.md @@ -486,7 +486,7 @@ TEXT_EMBEDDING_MODELS = `[ } ]` -MODELS=[ +MODELS=`[ { "name": "Ollama Mistral", "chatPromptTemplate": "...", @@ -498,7 +498,7 @@ MODELS=[ ... ] } -] +]` ``` ## Deploying to a HF Space From 591452923d6d3220ed29fc48b8a15052c6a3f558 Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Tue, 9 Jan 2024 11:21:05 +0100 Subject: [PATCH 26/31] lint --- src/lib/components/MobileNav.svelte | 2 +- src/lib/components/NavMenu.svelte | 4 ++-- src/lib/components/chat/ChatIntroduction.svelte | 2 +- src/lib/components/chat/ChatMessage.svelte | 6 +++--- src/lib/components/chat/ChatWindow.svelte | 6 +++--- src/routes/+layout.svelte | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/lib/components/MobileNav.svelte b/src/lib/components/MobileNav.svelte index ca13eb4a972..2befba31b77 100644 --- a/src/lib/components/MobileNav.svelte +++ b/src/lib/components/MobileNav.svelte @@ -30,7 +30,7 @@
{#each Object.entries(groupedConversations) as [group, convs]} {#if convs.length} @@ -89,7 +89,7 @@ > diff --git a/src/lib/components/chat/ChatIntroduction.svelte b/src/lib/components/chat/ChatIntroduction.svelte index 078f44e5d8b..8a055ef214e 100644 --- a/src/lib/components/chat/ChatIntroduction.svelte +++ b/src/lib/components/chat/ChatIntroduction.svelte @@ -78,7 +78,7 @@ {#each currentModelMetadata.promptExamples as example} diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index 63b1b9d61fc..318c21cc3ff 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -148,7 +148,7 @@
Date: Tue, 9 Jan 2024 15:09:08 +0100 Subject: [PATCH 27/31] fix default model --- src/lib/server/websearch/runWebSearch.ts | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lib/server/websearch/runWebSearch.ts b/src/lib/server/websearch/runWebSearch.ts index ecbfef88dc3..deb0ced5b76 100644 --- a/src/lib/server/websearch/runWebSearch.ts +++ b/src/lib/server/websearch/runWebSearch.ts @@ -8,7 +8,7 @@ import { findSimilarSentences } from "$lib/server/sentenceSimilarity"; import type { Conversation } from "$lib/types/Conversation"; import type { MessageUpdate } from "$lib/types/MessageUpdate"; import { getWebSearchProvider } from "./searchWeb"; -import { embeddingModels } from "$lib/server/embeddingModels"; +import { defaultEmbeddingModel, embeddingModels } from "$lib/server/embeddingModels"; const MAX_N_PAGES_SCRAPE = 10 as const; const MAX_N_PAGES_EMBED = 5 as const; @@ -56,7 +56,8 @@ export async function runWebSearch( .slice(0, MAX_N_PAGES_SCRAPE); // limit to first 10 links only // fetch the model - const embeddingModel = embeddingModels.find((m) => m.id === conv.embeddingModel); + const embeddingModel = embeddingModels.find((m) => m.id === conv.embeddingModel) ?? defaultEmbeddingModel; + if (!embeddingModel) { throw new Error(`Embedding model ${conv.embeddingModel} not available anymore`); } From fafecb75f30043184dfbd65a6289106e17343e9c Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Tue, 9 Jan 2024 15:19:40 +0100 Subject: [PATCH 28/31] Rn `maxSequenceLength` -> `chunkCharLength` --- .env | 2 +- README.md | 10 +++++----- .../embeddingEndpoints/tei/embeddingEndpoints.ts | 2 +- src/lib/server/embeddingModels.ts | 4 ++-- src/lib/server/websearch/runWebSearch.ts | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.env b/.env index fde7110f20b..01ee88b4cf2 100644 --- a/.env +++ b/.env @@ -51,7 +51,7 @@ TEXT_EMBEDDING_MODELS = `[ "name": "Xenova/gte-small", "displayName": "Xenova/gte-small", "description": "Local embedding model running on the server.", - "maxSequenceLength": 512, + "chunkCharLength": 512, "endpoints": [ { "type": "transformersjs" } ] diff --git a/README.md b/README.md index 1cf4dee89b3..0d5241deb43 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ TEXT_EMBEDDING_MODELS = `[ "name": "Xenova/gte-small", "displayName": "Xenova/gte-small", "description": "locally running embedding", - "maxSequenceLength": 512, + "chunkCharLength": 512, "endpoints": [ {"type": "transformersjs"} ] @@ -104,7 +104,7 @@ TEXT_EMBEDDING_MODELS = `[ "name": "intfloat/e5-base-v2", "displayName": "intfloat/e5-base-v2", "description": "hosted embedding model", - "maxSequenceLength": 768, + "chunkCharLength": 768, "preQuery": "query: ", # See https://huggingface.co/intfloat/e5-base-v2#faq "prePassage": "passage: ", # See https://huggingface.co/intfloat/e5-base-v2#faq "endpoints": [ @@ -114,7 +114,7 @@ TEXT_EMBEDDING_MODELS = `[ ]` ``` -The required fields are `name`, `maxSequenceLength` and `endpoints`. +The required fields are `name`, `chunkCharLength` and `endpoints`. Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js) and [`TEI`](https://github.com/huggingface/text-embeddings-inference). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint. When more than one embedding models are supplied in `.env.local` file, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModel` to the name of the model. @@ -471,14 +471,14 @@ by default it will use the first embedding model, but it can be changed with the TEXT_EMBEDDING_MODELS = `[ { "name": "Xenova/gte-small", - "maxSequenceLength": 512, + "chunkCharLength": 512, "endpoints": [ {"type": "transformersjs"} ] }, { "name": "intfloat/e5-base-v2", - "maxSequenceLength": 768, + "chunkCharLength": 768, "endpoints": [ {"type": "tei", "url": "http://127.0.0.1:8080/", "authorization": "Basic VVNFUjpQQVNT"}, {"type": "tei", "url": "http://127.0.0.1:8081/"} diff --git a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts index 8c57b706d32..17bdc34ae64 100644 --- a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts @@ -33,7 +33,7 @@ export async function embeddingEndpointTei( const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url); const maxBatchSize = Math.min( max_client_batch_size, - Math.floor(max_batch_tokens / model.maxSequenceLength) + Math.floor(max_batch_tokens / model.chunkCharLength) ); return async ({ inputs }) => { diff --git a/src/lib/server/embeddingModels.ts b/src/lib/server/embeddingModels.ts index eb9b4fb7b35..13305867d95 100644 --- a/src/lib/server/embeddingModels.ts +++ b/src/lib/server/embeddingModels.ts @@ -19,7 +19,7 @@ const modelConfig = z.object({ websiteUrl: z.string().url().optional(), modelUrl: z.string().url().optional(), endpoints: z.array(embeddingEndpointSchema).nonempty(), - maxSequenceLength: z.number().positive(), + chunkCharLength: z.number().positive(), preQuery: z.string().default(""), prePassage: z.string().default(""), }); @@ -30,7 +30,7 @@ const rawEmbeddingModelJSON = `[ { "name": "Xenova/gte-small", - "maxSequenceLength": 512, + "chunkCharLength": 512, "endpoints": [ { "type": "transformersjs" } ] diff --git a/src/lib/server/websearch/runWebSearch.ts b/src/lib/server/websearch/runWebSearch.ts index deb0ced5b76..8a53826a22a 100644 --- a/src/lib/server/websearch/runWebSearch.ts +++ b/src/lib/server/websearch/runWebSearch.ts @@ -77,7 +77,7 @@ export async function runWebSearch( } } const MAX_N_CHUNKS = 100; - const texts = chunk(text, embeddingModel.maxSequenceLength).slice(0, MAX_N_CHUNKS); + const texts = chunk(text, embeddingModel.chunkCharLength).slice(0, MAX_N_CHUNKS); return texts.map((t) => ({ source: result, text: t })); }); const nestedParagraphChunks = (await Promise.all(promises)).slice(0, MAX_N_PAGES_EMBED); From ed3688cc7e85cdfe0a9cef52becc67a3f6dc817e Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Tue, 9 Jan 2024 15:22:17 +0100 Subject: [PATCH 29/31] format --- src/lib/server/websearch/runWebSearch.ts | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lib/server/websearch/runWebSearch.ts b/src/lib/server/websearch/runWebSearch.ts index 8a53826a22a..0390fe165b8 100644 --- a/src/lib/server/websearch/runWebSearch.ts +++ b/src/lib/server/websearch/runWebSearch.ts @@ -56,8 +56,9 @@ export async function runWebSearch( .slice(0, MAX_N_PAGES_SCRAPE); // limit to first 10 links only // fetch the model - const embeddingModel = embeddingModels.find((m) => m.id === conv.embeddingModel) ?? defaultEmbeddingModel; - + const embeddingModel = + embeddingModels.find((m) => m.id === conv.embeddingModel) ?? defaultEmbeddingModel; + if (!embeddingModel) { throw new Error(`Embedding model ${conv.embeddingModel} not available anymore`); } From 4ed0066c58597db91f9ae5e82914768562799ca0 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Tue, 9 Jan 2024 16:16:48 +0100 Subject: [PATCH 30/31] add "authorization" example --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0d5241deb43..34c15398c01 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,11 @@ TEXT_EMBEDDING_MODELS = `[ "preQuery": "query: ", # See https://huggingface.co/intfloat/e5-base-v2#faq "prePassage": "passage: ", # See https://huggingface.co/intfloat/e5-base-v2#faq "endpoints": [ - {"type": "tei", "url": "http://127.0.0.1:8080/"} + { + "type": "tei", + "url": "http://127.0.0.1:8080/", + "authorization": "TOKEN_TYPE TOKEN" // optional authorization field. Example: "Basic VVNFUjpQQVNT" + } ] } ]` From 669e86d997d85de78f9becc593704c7a51cdf2a9 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Tue, 9 Jan 2024 16:17:33 +0100 Subject: [PATCH 31/31] format --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 34c15398c01..057f2470dc3 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ TEXT_EMBEDDING_MODELS = `[ "prePassage": "passage: ", # See https://huggingface.co/intfloat/e5-base-v2#faq "endpoints": [ { - "type": "tei", + "type": "tei", "url": "http://127.0.0.1:8080/", "authorization": "TOKEN_TYPE TOKEN" // optional authorization field. Example: "Basic VVNFUjpQQVNT" }