Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add embedding models configurable, from both transformers.js and TEI #646

Merged
merged 32 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
491131c
Add embedding models configurable, from both Xenova and TEI
mikelfried Dec 19, 2023
3473fc2
fix lint and format
mikelfried Dec 20, 2023
aebf653
Fix bug in sentenceSimilarity
mikelfried Dec 20, 2023
c065045
Batches for TEI using /info route
mikelfried Dec 20, 2023
8df8fd2
Fix web search disapear when finish searching
mikelfried Dec 20, 2023
cc02b4c
Fix lint and format
mikelfried Dec 20, 2023
53fa58a
Add more options for better embedding model usage
mikelfried Dec 20, 2023
9a867aa
Fixing CR issues
mikelfried Dec 22, 2023
6c6e290
Fix websearch disapear in later PR
mikelfried Dec 22, 2023
dc8d4e9
Fix lint
mikelfried Dec 22, 2023
aacbfeb
Fix more minor code CR
mikelfried Dec 22, 2023
7a9950d
Valiadate embeddingModelName field in model config
mikelfried Dec 22, 2023
bce01d4
Add embeddingModel into shared conversation
mikelfried Dec 22, 2023
f822ced
Fix lint and format
mikelfried Dec 22, 2023
421ecca
Add default embedding model, and more readme explanation
mikelfried Dec 23, 2023
e85faee
Fix minor embedding model readme detailed
mikelfried Dec 23, 2023
00a970e
Merge branch 'main' into embedding_models
mikelfried Dec 31, 2023
a36c521
Update settings.json
mikelfried Dec 31, 2023
d132375
Update README.md
mikelfried Jan 6, 2024
e105225
Update README.md
mikelfried Jan 6, 2024
f615fef
Apply suggestions from code review
mikelfried Jan 6, 2024
9ef62b9
Resolved more issues
mikelfried Jan 6, 2024
7c79582
lint
nsarrazin Jan 8, 2024
60d9b23
Fix more issues
mikelfried Jan 8, 2024
65760bc
Fix format
mikelfried Jan 8, 2024
3eb93ba
fix small typo
mikelfried Jan 8, 2024
5914529
lint
nsarrazin Jan 9, 2024
25d9600
fix default model
mishig25 Jan 9, 2024
fafecb7
Rn `maxSequenceLength` -> `chunkCharLength`
mishig25 Jan 9, 2024
ed3688c
format
mishig25 Jan 9, 2024
4ed0066
add "authorization" example
mishig25 Jan 9, 2024
669e86d
format
mishig25 Jan 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ CA_PATH=#
CLIENT_KEY_PASSWORD=#
REJECT_UNAUTHORIZED=true

TEXT_EMBEDDING_MODELS = `[
{
"name": "Xenova/gte-small",
"displayName": "Xenova/gte-small",
"description": "Local embedding model running on the server.",
"maxSequenceLength": 512,
"endpoints": [
{ "type": "xenova" }
]
}
]`


# 'name', 'userMessageToken', 'assistantMessageToken' are required
MODELS=`[
{
Expand Down
33 changes: 33 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ MODELS=`[
"max_new_tokens" : 8192,
"stop" : ["</s>"]
},
"embeddingModelName": "thenlper/gte-base",
"promptExamples" : [
{
"title": "Write an email from bullet list",
Expand All @@ -215,6 +216,38 @@ OLD_MODELS=`[{"name":"bigcode/starcoder"}, {"name":"OpenAssistant/oasst-sft-6-ll

TASK_MODEL='mistralai/Mistral-7B-Instruct-v0.2'

# Default to using the first text embedding model when not specifying 'embeddingModelName' in the model itself.
TEXT_EMBEDDING_MODELS = `[
mikelfried marked this conversation as resolved.
Show resolved Hide resolved
{
"name": "Xenova/gte-small",
"displayName": "Xenova/gte-small",
"description": "Local embedding model running on the server.",
"maxSequenceLength": 512,
"endpoints": [
{ "type": "xenova" }
]
},
{
"name": "thenlper/gte-base",
"displayName": "thenlper/gte-base",
"description": "Hosted embedding model running on the cloud somewhere.",
"maxSequenceLength": 512,
"endpoints": [
{ "type": "tei", "http://localhost:8080/" }
]
},
{
"name": "intfloat/multilingual-e5-large",
"displayName": "intfloat/multilingual-e5-large",
"description": "Hosted embedding model running on the cloud somewhere.",
"maxSequenceLength": 512,
"preQuery": "query: ", # See https://huggingface.co/intfloat/multilingual-e5-large#faq
"prePassage": "passage: ", # See https://huggingface.co/intfloat/multilingual-e5-large#faq
"endpoints": [
{ "type": "tei", "http://localhost:8085/" }
]
}
]`

APP_BASE="/chat"
PUBLIC_ORIGIN=https://huggingface.co
Expand Down
4 changes: 2 additions & 2 deletions src/lib/components/OpenWebSearchResults.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
{:else}
<CarbonCheckmark class="my-auto text-gray-500" />
{/if}
<span class="px-2 font-medium" class:text-red-700={error} class:dark:text-red-500={error}
>Web search
<span class="px-2 font-medium" class:text-red-700={error} class:dark:text-red-500={error}>
Web search
</span>
<div class="my-auto transition-all" class:rotate-90={detailsOpen}>
<CarbonCaretRight />
Expand Down
38 changes: 38 additions & 0 deletions src/lib/server/embeddingEndpoints/embeddingEndpoints.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import {
mikelfried marked this conversation as resolved.
Show resolved Hide resolved
embeddingEndpointTei,
embeddingEndpointTeiParametersSchema,
} from "./tei/embeddingEndpoints";
import { z } from "zod";
import embeddingEndpointXenova, {
embeddingEndpointXenovaParametersSchema,
} from "./xenova/embeddingEndpoints";
mikelfried marked this conversation as resolved.
Show resolved Hide resolved

// parameters passed when generating text
interface EmbeddingEndpointParameters {
inputs: string[];
}

interface CommonEmbeddingEndpoint {
weight: number;
}

// type signature for the endpoint
export type EmbeddingEndpoint = (params: EmbeddingEndpointParameters) => Promise<number[][]>;

// generator function that takes in parameters for defining the endpoint and return the endpoint
export type EmbeddingEndpointGenerator<T extends CommonEmbeddingEndpoint> = (
parameters: T
) => EmbeddingEndpoint;

// list of all endpoint generators
export const embeddingEndpoints = {
tei: embeddingEndpointTei,
xenova: embeddingEndpointXenova,
mikelfried marked this conversation as resolved.
Show resolved Hide resolved
};

export const embeddingEndpointSchema = z.discriminatedUnion("type", [
embeddingEndpointTeiParametersSchema,
embeddingEndpointXenovaParametersSchema,
]);

export default embeddingEndpoints;
65 changes: 65 additions & 0 deletions src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { z } from "zod";
import type { EmbeddingEndpoint } from "../embeddingEndpoints";
import { chunk } from "$lib/utils/chunk";

export const embeddingEndpointTeiParametersSchema = z.object({
weight: z.number().int().positive().default(1),
mikelfried marked this conversation as resolved.
Show resolved Hide resolved
model: z.any(),
type: z.literal("tei"),
url: z.string().url(),
});

const getModelInfoByUrl = async (url: string) => {
const { origin } = new URL(url);

const response = await fetch(`${origin}/info`, {
headers: {
Accept: "application/json",
"Content-Type": "application/json",
},
});

const info = await response.json();

return info;
mikelfried marked this conversation as resolved.
Show resolved Hide resolved
};

export async function embeddingEndpointTei(
input: z.input<typeof embeddingEndpointTeiParametersSchema>
): Promise<EmbeddingEndpoint> {
const { url, model } = embeddingEndpointTeiParametersSchema.parse(input);

const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url);
const maxBatchSize = Math.min(
max_client_batch_size,
Math.floor(max_batch_tokens / model.maxSequenceLength)
);

return async ({ inputs }) => {
const { origin } = new URL(url);

const batchesInputs = chunk(inputs, maxBatchSize);

const batchesResults = await Promise.all(
batchesInputs.map(async (batchInputs) => {
const response = await fetch(`${origin}/embed`, {
method: "POST",
headers: {
mikelfried marked this conversation as resolved.
Show resolved Hide resolved
Accept: "application/json",
"Content-Type": "application/json",
},
body: JSON.stringify({ inputs: batchInputs, normalize: true, truncate: true }),
});

const embeddings: number[][] = await response.json();
return embeddings;
})
);

const allEmbeddings = batchesResults.flatMap((embeddings) => embeddings);
mikelfried marked this conversation as resolved.
Show resolved Hide resolved

return allEmbeddings;
};
}

export default embeddingEndpointTei;
47 changes: 47 additions & 0 deletions src/lib/server/embeddingEndpoints/xenova/embeddingEndpoints.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import { z } from "zod";
import type { EmbeddingEndpoint } from "../embeddingEndpoints";
import type { Tensor, Pipeline } from "@xenova/transformers";
import { pipeline } from "@xenova/transformers";

export const embeddingEndpointXenovaParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("xenova"),
});

// Use the Singleton pattern to enable lazy construction of the pipeline.
class XenovaModelsSingleton {
static instances: Array<[string, Promise<Pipeline>]> = [];

static async getInstance(modelName: string): Promise<Pipeline> {
const modelPipeline = this.instances.find(([name]) => name === modelName);

if (modelPipeline) {
return modelPipeline[1];
}

const newModelPipeline = pipeline("feature-extraction", modelName);
this.instances.push([modelName, newModelPipeline]);

return newModelPipeline;
}
}

export async function calculateEmbedding(modelName: string, inputs: string[]) {
const extractor = await XenovaModelsSingleton.getInstance(modelName);
const output: Tensor = await extractor(inputs, { pooling: "mean", normalize: true });

return output.tolist();
}

export function embeddingEndpointXenova(
input: z.input<typeof embeddingEndpointXenovaParametersSchema>
): EmbeddingEndpoint {
const { model } = embeddingEndpointXenovaParametersSchema.parse(input);

return async ({ inputs }) => {
return calculateEmbedding(model.name, inputs);
};
}

export default embeddingEndpointXenova;
78 changes: 78 additions & 0 deletions src/lib/server/embeddingModels.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import { TEXT_EMBEDDING_MODELS } from "$env/static/private";

import { z } from "zod";
import { sum } from "$lib/utils/sum";
import embeddingEndpoints, {
embeddingEndpointSchema,
type EmbeddingEndpoint,
} from "./embeddingEndpoints/embeddingEndpoints";
import embeddingEndpointXenova from "./embeddingEndpoints/xenova/embeddingEndpoints";

const modelConfig = z.object({
/** Used as an identifier in DB */
id: z.string().optional(),
/** Used to link to the model page, and for inference */
name: z.string().min(1),
displayName: z.string().min(1).optional(),
description: z.string().min(1).optional(),
websiteUrl: z.string().url().optional(),
modelUrl: z.string().url().optional(),
endpoints: z.array(embeddingEndpointSchema).optional(),
maxSequenceLength: z.number().positive(),
preQuery: z.string().default(""),
prePassage: z.string().default(""),
});

const embeddingModelsRaw = z.array(modelConfig).parse(JSON.parse(TEXT_EMBEDDING_MODELS));

const processEmbeddingModel = async (m: z.infer<typeof modelConfig>) => ({
...m,
id: m.id || m.name,
});

const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({
...m,
getEndpoint: async (): Promise<EmbeddingEndpoint> => {
if (!m.endpoints) {
return embeddingEndpointXenova({
type: "xenova",
weight: 1,
model: m,
});
}

const totalWeight = sum(m.endpoints.map((e) => e.weight));

let random = Math.random() * totalWeight;

for (const endpoint of m.endpoints) {
if (random < endpoint.weight) {
const args = { ...endpoint, model: m };

switch (args.type) {
case "tei":
return embeddingEndpoints.tei(args);
case "xenova":
return embeddingEndpoints.xenova(args);
}
}

random -= endpoint.weight;
}

throw new Error(`Failed to select endpoint`);
},
});

export const embeddingModels = await Promise.all(
embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint))
);

export const defaultEmbeddingModel = embeddingModels[0];

export const validateEmbeddingModel = (_models: EmbeddingBackendModel[]) => {
// Zod enum function requires 2 parameters
return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]);
};

export type EmbeddingBackendModel = typeof defaultEmbeddingModel;
1 change: 1 addition & 0 deletions src/lib/server/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ const modelConfig = z.object({
.optional(),
multimodal: z.boolean().default(false),
unlisted: z.boolean().default(false),
embeddingModelName: z.string().optional(),
});

const modelsRaw = z.array(modelConfig).parse(JSON.parse(MODELS));
Expand Down
41 changes: 41 additions & 0 deletions src/lib/server/sentenceSimilarity.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import { dot } from "@xenova/transformers";
import type { EmbeddingBackendModel } from "./embeddingModels";

// see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34
function innerProduct(embeddingA: number[], embeddingB: number[]) {
return 1.0 - dot(embeddingA, embeddingB);
}

export async function findSimilarSentences(
embeddingModel: EmbeddingBackendModel,
query: string,
sentences: string[],
{ topK = 5 }: { topK: number }
): Promise<number[]> {
const inputs = [
`${embeddingModel.preQuery}${query}`,
...sentences.map((sentence) => `${embeddingModel.prePassage}${sentence}`),
];

const embeddingEndpoint = await embeddingModel.getEndpoint();
const output = await embeddingEndpoint({ inputs });

const queryEmbedding: number[] = output[0];
const sentencesEmbeddings: number[][] = output.slice(1, inputs.length - 1);

const distancesFromQuery: { distance: number; index: number }[] = [...sentencesEmbeddings].map(
(sentenceEmbedding: number[], index: number) => {
return {
distance: innerProduct(queryEmbedding, sentenceEmbedding),
index: index,
};
}
);

distancesFromQuery.sort((a, b) => {
return a.distance - b.distance;
});

// Return the indexes of the closest topK sentences
return distancesFromQuery.slice(0, topK).map((item) => item.index);
}
Loading
Loading