diff --git a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts
new file mode 100644
index 00000000000..17bdc34ae64
--- /dev/null
+++ b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts
@@ -0,0 +1,65 @@
+import { z } from "zod";
+import type { EmbeddingEndpoint, Embedding } from "$lib/types/EmbeddingEndpoints";
+import { chunk } from "$lib/utils/chunk";
+
+export const embeddingEndpointTeiParametersSchema = z.object({
+ weight: z.number().int().positive().default(1),
+ model: z.any(),
+ type: z.literal("tei"),
+ url: z.string().url(),
+ authorization: z.string().optional(),
+});
+
+const getModelInfoByUrl = async (url: string, authorization?: string) => {
+ const { origin } = new URL(url);
+
+ const response = await fetch(`${origin}/info`, {
+ headers: {
+ Accept: "application/json",
+ "Content-Type": "application/json",
+ ...(authorization ? { Authorization: authorization } : {}),
+ },
+ });
+
+ const json = await response.json();
+ return json;
+};
+
+export async function embeddingEndpointTei(
+ input: z.input
+): Promise {
+ const { url, model, authorization } = embeddingEndpointTeiParametersSchema.parse(input);
+
+ const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url);
+ const maxBatchSize = Math.min(
+ max_client_batch_size,
+ Math.floor(max_batch_tokens / model.chunkCharLength)
+ );
+
+ return async ({ inputs }) => {
+ const { origin } = new URL(url);
+
+ const batchesInputs = chunk(inputs, maxBatchSize);
+
+ const batchesResults = await Promise.all(
+ batchesInputs.map(async (batchInputs) => {
+ const response = await fetch(`${origin}/embed`, {
+ method: "POST",
+ headers: {
+ Accept: "application/json",
+ "Content-Type": "application/json",
+ ...(authorization ? { Authorization: authorization } : {}),
+ },
+ body: JSON.stringify({ inputs: batchInputs, normalize: true, truncate: true }),
+ });
+
+ const embeddings: Embedding[] = await response.json();
+ return embeddings;
+ })
+ );
+
+ const flatAllEmbeddings = batchesResults.flat();
+
+ return flatAllEmbeddings;
+ };
+}
diff --git a/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts
new file mode 100644
index 00000000000..7cedddcfe15
--- /dev/null
+++ b/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts
@@ -0,0 +1,46 @@
+import { z } from "zod";
+import type { EmbeddingEndpoint } from "$lib/types/EmbeddingEndpoints";
+import type { Tensor, Pipeline } from "@xenova/transformers";
+import { pipeline } from "@xenova/transformers";
+
+export const embeddingEndpointTransformersJSParametersSchema = z.object({
+ weight: z.number().int().positive().default(1),
+ model: z.any(),
+ type: z.literal("transformersjs"),
+});
+
+// Use the Singleton pattern to enable lazy construction of the pipeline.
+class TransformersJSModelsSingleton {
+ static instances: Array<[string, Promise]> = [];
+
+ static async getInstance(modelName: string): Promise {
+ const modelPipelineInstance = this.instances.find(([name]) => name === modelName);
+
+ if (modelPipelineInstance) {
+ const [, modelPipeline] = modelPipelineInstance;
+ return modelPipeline;
+ }
+
+ const newModelPipeline = pipeline("feature-extraction", modelName);
+ this.instances.push([modelName, newModelPipeline]);
+
+ return newModelPipeline;
+ }
+}
+
+export async function calculateEmbedding(modelName: string, inputs: string[]) {
+ const extractor = await TransformersJSModelsSingleton.getInstance(modelName);
+ const output: Tensor = await extractor(inputs, { pooling: "mean", normalize: true });
+
+ return output.tolist();
+}
+
+export function embeddingEndpointTransformersJS(
+ input: z.input
+): EmbeddingEndpoint {
+ const { model } = embeddingEndpointTransformersJSParametersSchema.parse(input);
+
+ return async ({ inputs }) => {
+ return calculateEmbedding(model.name, inputs);
+ };
+}
diff --git a/src/lib/server/embeddingModels.ts b/src/lib/server/embeddingModels.ts
new file mode 100644
index 00000000000..13305867d95
--- /dev/null
+++ b/src/lib/server/embeddingModels.ts
@@ -0,0 +1,99 @@
+import { TEXT_EMBEDDING_MODELS } from "$env/static/private";
+
+import { z } from "zod";
+import { sum } from "$lib/utils/sum";
+import {
+ embeddingEndpoints,
+ embeddingEndpointSchema,
+ type EmbeddingEndpoint,
+} from "$lib/types/EmbeddingEndpoints";
+import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints";
+
+const modelConfig = z.object({
+ /** Used as an identifier in DB */
+ id: z.string().optional(),
+ /** Used to link to the model page, and for inference */
+ name: z.string().min(1),
+ displayName: z.string().min(1).optional(),
+ description: z.string().min(1).optional(),
+ websiteUrl: z.string().url().optional(),
+ modelUrl: z.string().url().optional(),
+ endpoints: z.array(embeddingEndpointSchema).nonempty(),
+ chunkCharLength: z.number().positive(),
+ preQuery: z.string().default(""),
+ prePassage: z.string().default(""),
+});
+
+// Default embedding model for backward compatibility
+const rawEmbeddingModelJSON =
+ TEXT_EMBEDDING_MODELS ||
+ `[
+ {
+ "name": "Xenova/gte-small",
+ "chunkCharLength": 512,
+ "endpoints": [
+ { "type": "transformersjs" }
+ ]
+ }
+]`;
+
+const embeddingModelsRaw = z.array(modelConfig).parse(JSON.parse(rawEmbeddingModelJSON));
+
+const processEmbeddingModel = async (m: z.infer) => ({
+ ...m,
+ id: m.id || m.name,
+});
+
+const addEndpoint = (m: Awaited>) => ({
+ ...m,
+ getEndpoint: async (): Promise => {
+ if (!m.endpoints) {
+ return embeddingEndpointTransformersJS({
+ type: "transformersjs",
+ weight: 1,
+ model: m,
+ });
+ }
+
+ const totalWeight = sum(m.endpoints.map((e) => e.weight));
+
+ let random = Math.random() * totalWeight;
+
+ for (const endpoint of m.endpoints) {
+ if (random < endpoint.weight) {
+ const args = { ...endpoint, model: m };
+
+ switch (args.type) {
+ case "tei":
+ return embeddingEndpoints.tei(args);
+ case "transformersjs":
+ return embeddingEndpoints.transformersjs(args);
+ }
+ }
+
+ random -= endpoint.weight;
+ }
+
+ throw new Error(`Failed to select embedding endpoint`);
+ },
+});
+
+export const embeddingModels = await Promise.all(
+ embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint))
+);
+
+export const defaultEmbeddingModel = embeddingModels[0];
+
+const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => {
+ return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]);
+};
+
+export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => {
+ return validateEmbeddingModel(_models, "id");
+};
+
+export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => {
+ return validateEmbeddingModel(_models, "name");
+};
+
+export type EmbeddingBackendModel = typeof defaultEmbeddingModel;
diff --git a/src/lib/server/models.ts b/src/lib/server/models.ts
index 58d05bd7a9b..5a820234aba 100644
--- a/src/lib/server/models.ts
+++ b/src/lib/server/models.ts
@@ -12,6 +12,7 @@ import { z } from "zod";
import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints";
import endpointTgi from "./endpoints/tgi/endpointTgi";
import { sum } from "$lib/utils/sum";
+import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels";
type Optional = Pick, K> & Omit;
@@ -66,6 +67,7 @@ const modelConfig = z.object({
.optional(),
multimodal: z.boolean().default(false),
unlisted: z.boolean().default(false),
+ embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(),
});
const modelsRaw = z.array(modelConfig).parse(JSON.parse(MODELS));
diff --git a/src/lib/server/sentenceSimilarity.ts b/src/lib/server/sentenceSimilarity.ts
new file mode 100644
index 00000000000..455b25d4d06
--- /dev/null
+++ b/src/lib/server/sentenceSimilarity.ts
@@ -0,0 +1,42 @@
+import { dot } from "@xenova/transformers";
+import type { EmbeddingBackendModel } from "$lib/server/embeddingModels";
+import type { Embedding } from "$lib/types/EmbeddingEndpoints";
+
+// see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34
+function innerProduct(embeddingA: Embedding, embeddingB: Embedding) {
+ return 1.0 - dot(embeddingA, embeddingB);
+}
+
+export async function findSimilarSentences(
+ embeddingModel: EmbeddingBackendModel,
+ query: string,
+ sentences: string[],
+ { topK = 5 }: { topK: number }
+): Promise {
+ const inputs = [
+ `${embeddingModel.preQuery}${query}`,
+ ...sentences.map((sentence) => `${embeddingModel.prePassage}${sentence}`),
+ ];
+
+ const embeddingEndpoint = await embeddingModel.getEndpoint();
+ const output = await embeddingEndpoint({ inputs });
+
+ const queryEmbedding: Embedding = output[0];
+ const sentencesEmbeddings: Embedding[] = output.slice(1, inputs.length - 1);
+
+ const distancesFromQuery: { distance: number; index: number }[] = [...sentencesEmbeddings].map(
+ (sentenceEmbedding: Embedding, index: number) => {
+ return {
+ distance: innerProduct(queryEmbedding, sentenceEmbedding),
+ index: index,
+ };
+ }
+ );
+
+ distancesFromQuery.sort((a, b) => {
+ return a.distance - b.distance;
+ });
+
+ // Return the indexes of the closest topK sentences
+ return distancesFromQuery.slice(0, topK).map((item) => item.index);
+}
diff --git a/src/lib/server/websearch/runWebSearch.ts b/src/lib/server/websearch/runWebSearch.ts
index 0869ea8b494..0390fe165b8 100644
--- a/src/lib/server/websearch/runWebSearch.ts
+++ b/src/lib/server/websearch/runWebSearch.ts
@@ -4,13 +4,11 @@ import type { WebSearch, WebSearchSource } from "$lib/types/WebSearch";
import { generateQuery } from "$lib/server/websearch/generateQuery";
import { parseWeb } from "$lib/server/websearch/parseWeb";
import { chunk } from "$lib/utils/chunk";
-import {
- MAX_SEQ_LEN as CHUNK_CAR_LEN,
- findSimilarSentences,
-} from "$lib/server/websearch/sentenceSimilarity";
+import { findSimilarSentences } from "$lib/server/sentenceSimilarity";
import type { Conversation } from "$lib/types/Conversation";
import type { MessageUpdate } from "$lib/types/MessageUpdate";
import { getWebSearchProvider } from "./searchWeb";
+import { defaultEmbeddingModel, embeddingModels } from "$lib/server/embeddingModels";
const MAX_N_PAGES_SCRAPE = 10 as const;
const MAX_N_PAGES_EMBED = 5 as const;
@@ -57,6 +55,14 @@ export async function runWebSearch(
.filter(({ link }) => !DOMAIN_BLOCKLIST.some((el) => link.includes(el))) // filter out blocklist links
.slice(0, MAX_N_PAGES_SCRAPE); // limit to first 10 links only
+ // fetch the model
+ const embeddingModel =
+ embeddingModels.find((m) => m.id === conv.embeddingModel) ?? defaultEmbeddingModel;
+
+ if (!embeddingModel) {
+ throw new Error(`Embedding model ${conv.embeddingModel} not available anymore`);
+ }
+
let paragraphChunks: { source: WebSearchSource; text: string }[] = [];
if (webSearch.results.length > 0) {
appendUpdate("Browsing results");
@@ -72,7 +78,7 @@ export async function runWebSearch(
}
}
const MAX_N_CHUNKS = 100;
- const texts = chunk(text, CHUNK_CAR_LEN).slice(0, MAX_N_CHUNKS);
+ const texts = chunk(text, embeddingModel.chunkCharLength).slice(0, MAX_N_CHUNKS);
return texts.map((t) => ({ source: result, text: t }));
});
const nestedParagraphChunks = (await Promise.all(promises)).slice(0, MAX_N_PAGES_EMBED);
@@ -87,7 +93,7 @@ export async function runWebSearch(
appendUpdate("Extracting relevant information");
const topKClosestParagraphs = 8;
const texts = paragraphChunks.map(({ text }) => text);
- const indices = await findSimilarSentences(prompt, texts, {
+ const indices = await findSimilarSentences(embeddingModel, prompt, texts, {
topK: topKClosestParagraphs,
});
webSearch.context = indices.map((idx) => texts[idx]).join("");
diff --git a/src/lib/server/websearch/sentenceSimilarity.ts b/src/lib/server/websearch/sentenceSimilarity.ts
deleted file mode 100644
index a877f8e0cd6..00000000000
--- a/src/lib/server/websearch/sentenceSimilarity.ts
+++ /dev/null
@@ -1,52 +0,0 @@
-import type { Tensor, Pipeline } from "@xenova/transformers";
-import { pipeline, dot } from "@xenova/transformers";
-
-// see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34
-function innerProduct(tensor1: Tensor, tensor2: Tensor) {
- return 1.0 - dot(tensor1.data, tensor2.data);
-}
-
-// Use the Singleton pattern to enable lazy construction of the pipeline.
-class PipelineSingleton {
- static modelId = "Xenova/gte-small";
- static instance: Promise | null = null;
- static async getInstance() {
- if (this.instance === null) {
- this.instance = pipeline("feature-extraction", this.modelId);
- }
- return this.instance;
- }
-}
-
-// see https://huggingface.co/thenlper/gte-small/blob/d8e2604cadbeeda029847d19759d219e0ce2e6d8/README.md?code=true#L2625
-export const MAX_SEQ_LEN = 512 as const;
-
-export async function findSimilarSentences(
- query: string,
- sentences: string[],
- { topK = 5 }: { topK: number }
-) {
- const input = [query, ...sentences];
-
- const extractor = await PipelineSingleton.getInstance();
- const output: Tensor = await extractor(input, { pooling: "mean", normalize: true });
-
- const queryTensor: Tensor = output[0];
- const sentencesTensor: Tensor = output.slice([1, input.length - 1]);
-
- const distancesFromQuery: { distance: number; index: number }[] = [...sentencesTensor].map(
- (sentenceTensor: Tensor, index: number) => {
- return {
- distance: innerProduct(queryTensor, sentenceTensor),
- index: index,
- };
- }
- );
-
- distancesFromQuery.sort((a, b) => {
- return a.distance - b.distance;
- });
-
- // Return the indexes of the closest topK sentences
- return distancesFromQuery.slice(0, topK).map((item) => item.index);
-}
diff --git a/src/lib/types/Conversation.ts b/src/lib/types/Conversation.ts
index 5788ce63fd8..665a688f6b4 100644
--- a/src/lib/types/Conversation.ts
+++ b/src/lib/types/Conversation.ts
@@ -10,6 +10,7 @@ export interface Conversation extends Timestamps {
userId?: User["_id"];
model: string;
+ embeddingModel: string;
title: string;
messages: Message[];
diff --git a/src/lib/types/EmbeddingEndpoints.ts b/src/lib/types/EmbeddingEndpoints.ts
new file mode 100644
index 00000000000..57cd425c578
--- /dev/null
+++ b/src/lib/types/EmbeddingEndpoints.ts
@@ -0,0 +1,41 @@
+import { z } from "zod";
+import {
+ embeddingEndpointTei,
+ embeddingEndpointTeiParametersSchema,
+} from "$lib/server/embeddingEndpoints/tei/embeddingEndpoints";
+import {
+ embeddingEndpointTransformersJS,
+ embeddingEndpointTransformersJSParametersSchema,
+} from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints";
+
+// parameters passed when generating text
+interface EmbeddingEndpointParameters {
+ inputs: string[];
+}
+
+export type Embedding = number[];
+
+// type signature for the endpoint
+export type EmbeddingEndpoint = (params: EmbeddingEndpointParameters) => Promise;
+
+export const embeddingEndpointSchema = z.discriminatedUnion("type", [
+ embeddingEndpointTeiParametersSchema,
+ embeddingEndpointTransformersJSParametersSchema,
+]);
+
+type EmbeddingEndpointTypeOptions = z.infer["type"];
+
+// generator function that takes in type discrimantor value for defining the endpoint and return the endpoint
+export type EmbeddingEndpointGenerator = (
+ inputs: Extract, { type: T }>
+) => EmbeddingEndpoint | Promise;
+
+// list of all endpoint generators
+export const embeddingEndpoints: {
+ [Key in EmbeddingEndpointTypeOptions]: EmbeddingEndpointGenerator;
+} = {
+ tei: embeddingEndpointTei,
+ transformersjs: embeddingEndpointTransformersJS,
+};
+
+export default embeddingEndpoints;
diff --git a/src/lib/types/SharedConversation.ts b/src/lib/types/SharedConversation.ts
index 8571f2c3f3a..1996bcc6ff9 100644
--- a/src/lib/types/SharedConversation.ts
+++ b/src/lib/types/SharedConversation.ts
@@ -7,6 +7,8 @@ export interface SharedConversation extends Timestamps {
hash: string;
model: string;
+ embeddingModel: string;
+
title: string;
messages: Message[];
preprompt?: string;
diff --git a/src/routes/conversation/+server.ts b/src/routes/conversation/+server.ts
index 6452e985d67..2870eddd1bc 100644
--- a/src/routes/conversation/+server.ts
+++ b/src/routes/conversation/+server.ts
@@ -6,6 +6,7 @@ import { base } from "$app/paths";
import { z } from "zod";
import type { Message } from "$lib/types/Message";
import { models, validateModel } from "$lib/server/models";
+import { defaultEmbeddingModel } from "$lib/server/embeddingModels";
export const POST: RequestHandler = async ({ locals, request }) => {
const body = await request.text();
@@ -22,6 +23,7 @@ export const POST: RequestHandler = async ({ locals, request }) => {
.parse(JSON.parse(body));
let preprompt = values.preprompt;
+ let embeddingModel: string;
if (values.fromShare) {
const conversation = await collections.sharedConversations.findOne({
@@ -35,6 +37,7 @@ export const POST: RequestHandler = async ({ locals, request }) => {
title = conversation.title;
messages = conversation.messages;
values.model = conversation.model;
+ embeddingModel = conversation.embeddingModel;
preprompt = conversation.preprompt;
}
@@ -44,6 +47,8 @@ export const POST: RequestHandler = async ({ locals, request }) => {
throw error(400, "Invalid model");
}
+ embeddingModel ??= model.embeddingModel ?? defaultEmbeddingModel.name;
+
if (model.unlisted) {
throw error(400, "Can't start a conversation with an unlisted model");
}
@@ -59,6 +64,7 @@ export const POST: RequestHandler = async ({ locals, request }) => {
preprompt: preprompt === model?.preprompt ? model?.preprompt : preprompt,
createdAt: new Date(),
updatedAt: new Date(),
+ embeddingModel: embeddingModel,
...(locals.user ? { userId: locals.user._id } : { sessionId: locals.sessionId }),
...(values.fromShare ? { meta: { fromShareId: values.fromShare } } : {}),
});
diff --git a/src/routes/conversation/[id]/+page.svelte b/src/routes/conversation/[id]/+page.svelte
index 363d14d6176..ba00e9757a9 100644
--- a/src/routes/conversation/[id]/+page.svelte
+++ b/src/routes/conversation/[id]/+page.svelte
@@ -173,6 +173,7 @@
inputs.forEach(async (el: string) => {
try {
const update = JSON.parse(el) as MessageUpdate;
+
if (update.type === "finalAnswer") {
finalAnswer = update.text;
reader.cancel();
@@ -225,7 +226,7 @@
});
}
- // reset the websearchmessages
+ // reset the websearchMessages
webSearchMessages = [];
await invalidate(UrlDependency.ConversationList);
diff --git a/src/routes/conversation/[id]/share/+server.ts b/src/routes/conversation/[id]/share/+server.ts
index e3f81222180..4877de755ad 100644
--- a/src/routes/conversation/[id]/share/+server.ts
+++ b/src/routes/conversation/[id]/share/+server.ts
@@ -38,6 +38,7 @@ export async function POST({ params, url, locals }) {
updatedAt: new Date(),
title: conversation.title,
model: conversation.model,
+ embeddingModel: conversation.embeddingModel,
preprompt: conversation.preprompt,
};
diff --git a/src/routes/login/callback/updateUser.spec.ts b/src/routes/login/callback/updateUser.spec.ts
index 54229914571..fefaf8b0f5a 100644
--- a/src/routes/login/callback/updateUser.spec.ts
+++ b/src/routes/login/callback/updateUser.spec.ts
@@ -6,6 +6,7 @@ import { ObjectId } from "mongodb";
import { DEFAULT_SETTINGS } from "$lib/types/Settings";
import { defaultModel } from "$lib/server/models";
import { findUser } from "$lib/server/auth";
+import { defaultEmbeddingModel } from "$lib/server/embeddingModels";
const userData = {
preferred_username: "new-username",
@@ -46,6 +47,7 @@ const insertRandomConversations = async (count: number) => {
title: "random title",
messages: [],
model: defaultModel.id,
+ embeddingModel: defaultEmbeddingModel.id,
createdAt: new Date(),
updatedAt: new Date(),
sessionId: locals.sessionId,