Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add embedding models configurable, from both transformers.js and TEI #646

Merged
merged 32 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
491131c
Add embedding models configurable, from both Xenova and TEI
mikelfried Dec 19, 2023
3473fc2
fix lint and format
mikelfried Dec 20, 2023
aebf653
Fix bug in sentenceSimilarity
mikelfried Dec 20, 2023
c065045
Batches for TEI using /info route
mikelfried Dec 20, 2023
8df8fd2
Fix web search disapear when finish searching
mikelfried Dec 20, 2023
cc02b4c
Fix lint and format
mikelfried Dec 20, 2023
53fa58a
Add more options for better embedding model usage
mikelfried Dec 20, 2023
9a867aa
Fixing CR issues
mikelfried Dec 22, 2023
6c6e290
Fix websearch disapear in later PR
mikelfried Dec 22, 2023
dc8d4e9
Fix lint
mikelfried Dec 22, 2023
aacbfeb
Fix more minor code CR
mikelfried Dec 22, 2023
7a9950d
Valiadate embeddingModelName field in model config
mikelfried Dec 22, 2023
bce01d4
Add embeddingModel into shared conversation
mikelfried Dec 22, 2023
f822ced
Fix lint and format
mikelfried Dec 22, 2023
421ecca
Add default embedding model, and more readme explanation
mikelfried Dec 23, 2023
e85faee
Fix minor embedding model readme detailed
mikelfried Dec 23, 2023
00a970e
Merge branch 'main' into embedding_models
mikelfried Dec 31, 2023
a36c521
Update settings.json
mikelfried Dec 31, 2023
d132375
Update README.md
mikelfried Jan 6, 2024
e105225
Update README.md
mikelfried Jan 6, 2024
f615fef
Apply suggestions from code review
mikelfried Jan 6, 2024
9ef62b9
Resolved more issues
mikelfried Jan 6, 2024
7c79582
lint
nsarrazin Jan 8, 2024
60d9b23
Fix more issues
mikelfried Jan 8, 2024
65760bc
Fix format
mikelfried Jan 8, 2024
3eb93ba
fix small typo
mikelfried Jan 8, 2024
5914529
lint
nsarrazin Jan 9, 2024
25d9600
fix default model
mishig25 Jan 9, 2024
fafecb7
Rn `maxSequenceLength` -> `chunkCharLength`
mishig25 Jan 9, 2024
ed3688c
format
mishig25 Jan 9, 2024
4ed0066
add "authorization" example
mishig25 Jan 9, 2024
669e86d
format
mishig25 Jan 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ CA_PATH=#
CLIENT_KEY_PASSWORD=#
REJECT_UNAUTHORIZED=true

TEXT_EMBEDDING_MODELS = `[
{
"name": "Xenova/gte-small",
"displayName": "Xenova/gte-small",
"description": "Local embedding model running on the server.",
"maxSequenceLength": 512,
"endpoints": [
{ "type": "transformersjs" }
]
}
]`

# 'name', 'userMessageToken', 'assistantMessageToken' are required
MODELS=`[
{
Expand Down
1 change: 0 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ TASK_MODEL='mistralai/Mistral-7B-Instruct-v0.2'
# "stop": ["</s>"]
# }}`


APP_BASE="/chat"
PUBLIC_ORIGIN=https://huggingface.co
PUBLIC_SHARE_PREFIX=https://hf.co/chat
Expand Down
84 changes: 80 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ A chat interface using open source models, eg OpenAssistant or Llama. It is a Sv
1. [Setup](#setup)
2. [Launch](#launch)
3. [Web Search](#web-search)
4. [Extra parameters](#extra-parameters)
5. [Deploying to a HF Space](#deploying-to-a-hf-space)
6. [Building](#building)
4. [Text Embedding Models](#text-embedding-models)
5. [Extra parameters](#extra-parameters)
6. [Deploying to a HF Space](#deploying-to-a-hf-space)
7. [Building](#building)

## No Setup Deploy

Expand Down Expand Up @@ -78,10 +79,46 @@ Chat UI features a powerful Web Search feature. It works by:

1. Generating an appropriate search query from the user prompt.
2. Performing web search and extracting content from webpages.
3. Creating embeddings from texts using [transformers.js](https://huggingface.co/docs/transformers.js). Specifically, using [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model.
3. Creating embeddings from texts using a text embedding model.
4. From these embeddings, find the ones that are closest to the user query using a vector similarity search. Specifically, we use `inner product` distance.
5. Get the corresponding texts to those closest embeddings and perform [Retrieval-Augmented Generation](https://huggingface.co/papers/2005.11401) (i.e. expand user prompt by adding those texts so that an LLM can use this information).

## Text Embedding Models

By default (for backward compatibility), when `TEXT_EMBEDDING_MODELS` environment variable is not defined, [transformers.js](https://huggingface.co/docs/transformers.js) embedding models will be used for embedding tasks, specifically, [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model.

You can customize the embedding model by setting `TEXT_EMBEDDING_MODELS` in your `.env.local` file. For example:

```env
TEXT_EMBEDDING_MODELS = `[
{
"name": "Xenova/gte-small",
"displayName": "Xenova/gte-small",
"description": "locally running embedding",
"maxSequenceLength": 512,
"endpoints": [
{"type": "transformersjs"}
]
},
{
"name": "intfloat/e5-base-v2",
"displayName": "intfloat/e5-base-v2",
"description": "hosted embedding model",
"maxSequenceLength": 768,
"preQuery": "query: ", # See https://huggingface.co/intfloat/e5-base-v2#faq
"prePassage": "passage: ", # See https://huggingface.co/intfloat/e5-base-v2#faq
"endpoints": [
{"type": "tei", "url": "http://127.0.0.1:8080/"}
]
}
]`
```

The required fields are `name`, `maxSequenceLength` and `endpoints`.
Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js) and [`TEI`](https://github.com/huggingface/text-embeddings-inference). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint.

When more than one embedding models are supplied in `.env.local` file, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModel` to the name of the model.

## Extra parameters

### OpenID connect
Expand Down Expand Up @@ -425,6 +462,45 @@ If you're using a certificate signed by a private CA, you will also need to add

If you're using a self-signed certificate, e.g. for testing or development purposes, you can set the `REJECT_UNAUTHORIZED` parameter to `false` in your `.env.local`. This will disable certificate validation, and allow Chat UI to connect to your custom endpoint.

#### Specific Embedding Model

A model can use any of the embedding models defined in `.env.local`, (currently used when web searching),
by default it will use the first embedding model, but it can be changed with the field `embeddingModel`:

```env
TEXT_EMBEDDING_MODELS = `[
{
"name": "Xenova/gte-small",
"maxSequenceLength": 512,
"endpoints": [
{"type": "transformersjs"}
]
},
{
mikelfried marked this conversation as resolved.
Show resolved Hide resolved
"name": "intfloat/e5-base-v2",
"maxSequenceLength": 768,
"endpoints": [
{"type": "tei", "url": "http://127.0.0.1:8080/", "authorization": "Basic VVNFUjpQQVNT"},
{"type": "tei", "url": "http://127.0.0.1:8081/"}
]
}
]`

MODELS=`[
{
"name": "Ollama Mistral",
"chatPromptTemplate": "...",
"embeddingModel": "intfloat/e5-base-v2"
"parameters": {
...
},
"endpoints": [
...
]
}
]`
```

## Deploying to a HF Space

Create a `DOTENV_LOCAL` secret to your HF space with the content of your .env.local, and they will be picked up automatically when you run.
Expand Down
4 changes: 2 additions & 2 deletions src/lib/components/OpenWebSearchResults.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
{:else}
<CarbonCheckmark class="my-auto text-gray-500" />
{/if}
<span class="px-2 font-medium" class:text-red-700={error} class:dark:text-red-500={error}
>Web search
<span class="px-2 font-medium" class:text-red-700={error} class:dark:text-red-500={error}>
Web search
</span>
<div class="my-auto transition-all" class:rotate-90={detailsOpen}>
<CarbonCaretRight />
Expand Down
65 changes: 65 additions & 0 deletions src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { z } from "zod";
import type { EmbeddingEndpoint, Embedding } from "$lib/types/EmbeddingEndpoints";
import { chunk } from "$lib/utils/chunk";

export const embeddingEndpointTeiParametersSchema = z.object({
weight: z.number().int().positive().default(1),
mikelfried marked this conversation as resolved.
Show resolved Hide resolved
model: z.any(),
type: z.literal("tei"),
url: z.string().url(),
authorization: z.string().optional(),
});

const getModelInfoByUrl = async (url: string, authorization?: string) => {
const { origin } = new URL(url);

const response = await fetch(`${origin}/info`, {
headers: {
Accept: "application/json",
"Content-Type": "application/json",
...(authorization ? { Authorization: authorization } : {}),
},
});

const json = await response.json();
return json;
};

export async function embeddingEndpointTei(
input: z.input<typeof embeddingEndpointTeiParametersSchema>
): Promise<EmbeddingEndpoint> {
const { url, model, authorization } = embeddingEndpointTeiParametersSchema.parse(input);

const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url);
const maxBatchSize = Math.min(
max_client_batch_size,
Math.floor(max_batch_tokens / model.maxSequenceLength)
);

return async ({ inputs }) => {
const { origin } = new URL(url);

const batchesInputs = chunk(inputs, maxBatchSize);

const batchesResults = await Promise.all(
batchesInputs.map(async (batchInputs) => {
const response = await fetch(`${origin}/embed`, {
method: "POST",
headers: {
mikelfried marked this conversation as resolved.
Show resolved Hide resolved
Accept: "application/json",
"Content-Type": "application/json",
...(authorization ? { Authorization: authorization } : {}),
},
body: JSON.stringify({ inputs: batchInputs, normalize: true, truncate: true }),
});

const embeddings: Embedding[] = await response.json();
return embeddings;
})
);

const flatAllEmbeddings = batchesResults.flat();

return flatAllEmbeddings;
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { z } from "zod";
import type { EmbeddingEndpoint } from "$lib/types/EmbeddingEndpoints";
import type { Tensor, Pipeline } from "@xenova/transformers";
import { pipeline } from "@xenova/transformers";

export const embeddingEndpointTransformersJSParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("transformersjs"),
});

// Use the Singleton pattern to enable lazy construction of the pipeline.
class TransformersJSModelsSingleton {
static instances: Array<[string, Promise<Pipeline>]> = [];

static async getInstance(modelName: string): Promise<Pipeline> {
const modelPipelineInstance = this.instances.find(([name]) => name === modelName);

if (modelPipelineInstance) {
const [, modelPipeline] = modelPipelineInstance;
return modelPipeline;
}

const newModelPipeline = pipeline("feature-extraction", modelName);
this.instances.push([modelName, newModelPipeline]);

return newModelPipeline;
}
}

export async function calculateEmbedding(modelName: string, inputs: string[]) {
const extractor = await TransformersJSModelsSingleton.getInstance(modelName);
const output: Tensor = await extractor(inputs, { pooling: "mean", normalize: true });

return output.tolist();
}

export function embeddingEndpointTransformersJS(
input: z.input<typeof embeddingEndpointTransformersJSParametersSchema>
): EmbeddingEndpoint {
const { model } = embeddingEndpointTransformersJSParametersSchema.parse(input);

return async ({ inputs }) => {
return calculateEmbedding(model.name, inputs);
};
}
99 changes: 99 additions & 0 deletions src/lib/server/embeddingModels.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import { TEXT_EMBEDDING_MODELS } from "$env/static/private";

import { z } from "zod";
import { sum } from "$lib/utils/sum";
import {
embeddingEndpoints,
embeddingEndpointSchema,
type EmbeddingEndpoint,
} from "$lib/types/EmbeddingEndpoints";
import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints";

const modelConfig = z.object({
/** Used as an identifier in DB */
id: z.string().optional(),
/** Used to link to the model page, and for inference */
name: z.string().min(1),
displayName: z.string().min(1).optional(),
description: z.string().min(1).optional(),
websiteUrl: z.string().url().optional(),
modelUrl: z.string().url().optional(),
endpoints: z.array(embeddingEndpointSchema).nonempty(),
maxSequenceLength: z.number().positive(),
preQuery: z.string().default(""),
prePassage: z.string().default(""),
});

// Default embedding model for backward compatibility
const rawEmbeddingModelJSON =
TEXT_EMBEDDING_MODELS ||
`[
{
"name": "Xenova/gte-small",
"maxSequenceLength": 512,
"endpoints": [
{ "type": "transformersjs" }
]
}
]`;

const embeddingModelsRaw = z.array(modelConfig).parse(JSON.parse(rawEmbeddingModelJSON));

const processEmbeddingModel = async (m: z.infer<typeof modelConfig>) => ({
...m,
id: m.id || m.name,
});

const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({
...m,
getEndpoint: async (): Promise<EmbeddingEndpoint> => {
if (!m.endpoints) {
return embeddingEndpointTransformersJS({
type: "transformersjs",
weight: 1,
model: m,
});
}

const totalWeight = sum(m.endpoints.map((e) => e.weight));

let random = Math.random() * totalWeight;

for (const endpoint of m.endpoints) {
if (random < endpoint.weight) {
const args = { ...endpoint, model: m };

switch (args.type) {
case "tei":
return embeddingEndpoints.tei(args);
case "transformersjs":
return embeddingEndpoints.transformersjs(args);
}
}

random -= endpoint.weight;
}

throw new Error(`Failed to select embedding endpoint`);
},
});

export const embeddingModels = await Promise.all(
embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint))
);

export const defaultEmbeddingModel = embeddingModels[0];

const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => {
return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]);
};

export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => {
return validateEmbeddingModel(_models, "id");
};

export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => {
return validateEmbeddingModel(_models, "name");
};

export type EmbeddingBackendModel = typeof defaultEmbeddingModel;
2 changes: 2 additions & 0 deletions src/lib/server/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { z } from "zod";
import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints";
import endpointTgi from "./endpoints/tgi/endpointTgi";
import { sum } from "$lib/utils/sum";
import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels";

type Optional<T, K extends keyof T> = Pick<Partial<T>, K> & Omit<T, K>;

Expand Down Expand Up @@ -66,6 +67,7 @@ const modelConfig = z.object({
.optional(),
multimodal: z.boolean().default(false),
unlisted: z.boolean().default(false),
embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(),
});

const modelsRaw = z.array(modelConfig).parse(JSON.parse(MODELS));
Expand Down
Loading