Skip to content

Commit

Permalink
Add embedding models configurable, from both transformers.js and TEI (#…
Browse files Browse the repository at this point in the history
…646)

* Add embedding models configurable, from both Xenova and TEI

* fix lint and format

* Fix bug in sentenceSimilarity

* Batches for TEI using /info route

* Fix web search disapear when finish searching

* Fix lint and format

* Add more options for better embedding model usage

* Fixing CR issues

* Fix websearch disapear in later PR

* Fix lint

* Fix more minor code CR

* Valiadate embeddingModelName field in model config

* Add embeddingModel into shared conversation

* Fix lint and format

* Add default embedding model, and more readme explanation

* Fix minor embedding model readme detailed

* Update settings.json

* Update README.md

Co-authored-by: Mishig <[email protected]>

* Update README.md

Co-authored-by: Mishig <[email protected]>

* Apply suggestions from code review

Co-authored-by: Mishig <[email protected]>

* Resolved more issues

* lint

* Fix more issues

* Fix format

* fix small typo

* lint

* fix default model

* Rn `maxSequenceLength` -> `chunkCharLength`

* format

* add "authorization" example

* format

---------

Co-authored-by: Mishig <[email protected]>
Co-authored-by: Nathan Sarrazin <[email protected]>
Co-authored-by: Mishig Davaadorj <[email protected]>
  • Loading branch information
4 people committed Jan 9, 2024
1 parent 69c0464 commit 3a01622
Show file tree
Hide file tree
Showing 18 changed files with 419 additions and 66 deletions.
12 changes: 12 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ CA_PATH=#
CLIENT_KEY_PASSWORD=#
REJECT_UNAUTHORIZED=true

TEXT_EMBEDDING_MODELS = `[
{
"name": "Xenova/gte-small",
"displayName": "Xenova/gte-small",
"description": "Local embedding model running on the server.",
"chunkCharLength": 512,
"endpoints": [
{ "type": "transformersjs" }
]
}
]`

# 'name', 'userMessageToken', 'assistantMessageToken' are required
MODELS=`[
{
Expand Down
1 change: 0 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ TASK_MODEL='mistralai/Mistral-7B-Instruct-v0.2'
# "stop": ["</s>"]
# }}`


APP_BASE="/chat"
PUBLIC_ORIGIN=https://huggingface.co
PUBLIC_SHARE_PREFIX=https://hf.co/chat
Expand Down
88 changes: 84 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ A chat interface using open source models, eg OpenAssistant or Llama. It is a Sv
1. [Setup](#setup)
2. [Launch](#launch)
3. [Web Search](#web-search)
4. [Extra parameters](#extra-parameters)
5. [Deploying to a HF Space](#deploying-to-a-hf-space)
6. [Building](#building)
4. [Text Embedding Models](#text-embedding-models)
5. [Extra parameters](#extra-parameters)
6. [Deploying to a HF Space](#deploying-to-a-hf-space)
7. [Building](#building)

## No Setup Deploy

Expand Down Expand Up @@ -78,10 +79,50 @@ Chat UI features a powerful Web Search feature. It works by:

1. Generating an appropriate search query from the user prompt.
2. Performing web search and extracting content from webpages.
3. Creating embeddings from texts using [transformers.js](https://huggingface.co/docs/transformers.js). Specifically, using [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model.
3. Creating embeddings from texts using a text embedding model.
4. From these embeddings, find the ones that are closest to the user query using a vector similarity search. Specifically, we use `inner product` distance.
5. Get the corresponding texts to those closest embeddings and perform [Retrieval-Augmented Generation](https://huggingface.co/papers/2005.11401) (i.e. expand user prompt by adding those texts so that an LLM can use this information).

## Text Embedding Models

By default (for backward compatibility), when `TEXT_EMBEDDING_MODELS` environment variable is not defined, [transformers.js](https://huggingface.co/docs/transformers.js) embedding models will be used for embedding tasks, specifically, [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model.

You can customize the embedding model by setting `TEXT_EMBEDDING_MODELS` in your `.env.local` file. For example:

```env
TEXT_EMBEDDING_MODELS = `[
{
"name": "Xenova/gte-small",
"displayName": "Xenova/gte-small",
"description": "locally running embedding",
"chunkCharLength": 512,
"endpoints": [
{"type": "transformersjs"}
]
},
{
"name": "intfloat/e5-base-v2",
"displayName": "intfloat/e5-base-v2",
"description": "hosted embedding model",
"chunkCharLength": 768,
"preQuery": "query: ", # See https://huggingface.co/intfloat/e5-base-v2#faq
"prePassage": "passage: ", # See https://huggingface.co/intfloat/e5-base-v2#faq
"endpoints": [
{
"type": "tei",
"url": "http://127.0.0.1:8080/",
"authorization": "TOKEN_TYPE TOKEN" // optional authorization field. Example: "Basic VVNFUjpQQVNT"
}
]
}
]`
```

The required fields are `name`, `chunkCharLength` and `endpoints`.
Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js) and [`TEI`](https://github.com/huggingface/text-embeddings-inference). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint.

When more than one embedding models are supplied in `.env.local` file, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModel` to the name of the model.

## Extra parameters

### OpenID connect
Expand Down Expand Up @@ -425,6 +466,45 @@ If you're using a certificate signed by a private CA, you will also need to add

If you're using a self-signed certificate, e.g. for testing or development purposes, you can set the `REJECT_UNAUTHORIZED` parameter to `false` in your `.env.local`. This will disable certificate validation, and allow Chat UI to connect to your custom endpoint.

#### Specific Embedding Model

A model can use any of the embedding models defined in `.env.local`, (currently used when web searching),
by default it will use the first embedding model, but it can be changed with the field `embeddingModel`:

```env
TEXT_EMBEDDING_MODELS = `[
{
"name": "Xenova/gte-small",
"chunkCharLength": 512,
"endpoints": [
{"type": "transformersjs"}
]
},
{
"name": "intfloat/e5-base-v2",
"chunkCharLength": 768,
"endpoints": [
{"type": "tei", "url": "http://127.0.0.1:8080/", "authorization": "Basic VVNFUjpQQVNT"},
{"type": "tei", "url": "http://127.0.0.1:8081/"}
]
}
]`
MODELS=`[
{
"name": "Ollama Mistral",
"chatPromptTemplate": "...",
"embeddingModel": "intfloat/e5-base-v2"
"parameters": {
...
},
"endpoints": [
...
]
}
]`
```

## Deploying to a HF Space

Create a `DOTENV_LOCAL` secret to your HF space with the content of your .env.local, and they will be picked up automatically when you run.
Expand Down
4 changes: 2 additions & 2 deletions src/lib/components/OpenWebSearchResults.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
{:else}
<CarbonCheckmark class="my-auto text-gray-500" />
{/if}
<span class="px-2 font-medium" class:text-red-700={error} class:dark:text-red-500={error}
>Web search
<span class="px-2 font-medium" class:text-red-700={error} class:dark:text-red-500={error}>
Web search
</span>
<div class="my-auto transition-all" class:rotate-90={detailsOpen}>
<CarbonCaretRight />
Expand Down
65 changes: 65 additions & 0 deletions src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { z } from "zod";
import type { EmbeddingEndpoint, Embedding } from "$lib/types/EmbeddingEndpoints";
import { chunk } from "$lib/utils/chunk";

export const embeddingEndpointTeiParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("tei"),
url: z.string().url(),
authorization: z.string().optional(),
});

const getModelInfoByUrl = async (url: string, authorization?: string) => {
const { origin } = new URL(url);

const response = await fetch(`${origin}/info`, {
headers: {
Accept: "application/json",
"Content-Type": "application/json",
...(authorization ? { Authorization: authorization } : {}),
},
});

const json = await response.json();
return json;
};

export async function embeddingEndpointTei(
input: z.input<typeof embeddingEndpointTeiParametersSchema>
): Promise<EmbeddingEndpoint> {
const { url, model, authorization } = embeddingEndpointTeiParametersSchema.parse(input);

const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url);
const maxBatchSize = Math.min(
max_client_batch_size,
Math.floor(max_batch_tokens / model.chunkCharLength)
);

return async ({ inputs }) => {
const { origin } = new URL(url);

const batchesInputs = chunk(inputs, maxBatchSize);

const batchesResults = await Promise.all(
batchesInputs.map(async (batchInputs) => {
const response = await fetch(`${origin}/embed`, {
method: "POST",
headers: {
Accept: "application/json",
"Content-Type": "application/json",
...(authorization ? { Authorization: authorization } : {}),
},
body: JSON.stringify({ inputs: batchInputs, normalize: true, truncate: true }),
});

const embeddings: Embedding[] = await response.json();
return embeddings;
})
);

const flatAllEmbeddings = batchesResults.flat();

return flatAllEmbeddings;
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { z } from "zod";
import type { EmbeddingEndpoint } from "$lib/types/EmbeddingEndpoints";
import type { Tensor, Pipeline } from "@xenova/transformers";
import { pipeline } from "@xenova/transformers";

export const embeddingEndpointTransformersJSParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("transformersjs"),
});

// Use the Singleton pattern to enable lazy construction of the pipeline.
class TransformersJSModelsSingleton {
static instances: Array<[string, Promise<Pipeline>]> = [];

static async getInstance(modelName: string): Promise<Pipeline> {
const modelPipelineInstance = this.instances.find(([name]) => name === modelName);

if (modelPipelineInstance) {
const [, modelPipeline] = modelPipelineInstance;
return modelPipeline;
}

const newModelPipeline = pipeline("feature-extraction", modelName);
this.instances.push([modelName, newModelPipeline]);

return newModelPipeline;
}
}

export async function calculateEmbedding(modelName: string, inputs: string[]) {
const extractor = await TransformersJSModelsSingleton.getInstance(modelName);
const output: Tensor = await extractor(inputs, { pooling: "mean", normalize: true });

return output.tolist();
}

export function embeddingEndpointTransformersJS(
input: z.input<typeof embeddingEndpointTransformersJSParametersSchema>
): EmbeddingEndpoint {
const { model } = embeddingEndpointTransformersJSParametersSchema.parse(input);

return async ({ inputs }) => {
return calculateEmbedding(model.name, inputs);
};
}
99 changes: 99 additions & 0 deletions src/lib/server/embeddingModels.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import { TEXT_EMBEDDING_MODELS } from "$env/static/private";

import { z } from "zod";
import { sum } from "$lib/utils/sum";
import {
embeddingEndpoints,
embeddingEndpointSchema,
type EmbeddingEndpoint,
} from "$lib/types/EmbeddingEndpoints";
import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints";

const modelConfig = z.object({
/** Used as an identifier in DB */
id: z.string().optional(),
/** Used to link to the model page, and for inference */
name: z.string().min(1),
displayName: z.string().min(1).optional(),
description: z.string().min(1).optional(),
websiteUrl: z.string().url().optional(),
modelUrl: z.string().url().optional(),
endpoints: z.array(embeddingEndpointSchema).nonempty(),
chunkCharLength: z.number().positive(),
preQuery: z.string().default(""),
prePassage: z.string().default(""),
});

// Default embedding model for backward compatibility
const rawEmbeddingModelJSON =
TEXT_EMBEDDING_MODELS ||
`[
{
"name": "Xenova/gte-small",
"chunkCharLength": 512,
"endpoints": [
{ "type": "transformersjs" }
]
}
]`;

const embeddingModelsRaw = z.array(modelConfig).parse(JSON.parse(rawEmbeddingModelJSON));

const processEmbeddingModel = async (m: z.infer<typeof modelConfig>) => ({
...m,
id: m.id || m.name,
});

const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({
...m,
getEndpoint: async (): Promise<EmbeddingEndpoint> => {
if (!m.endpoints) {
return embeddingEndpointTransformersJS({
type: "transformersjs",
weight: 1,
model: m,
});
}

const totalWeight = sum(m.endpoints.map((e) => e.weight));

let random = Math.random() * totalWeight;

for (const endpoint of m.endpoints) {
if (random < endpoint.weight) {
const args = { ...endpoint, model: m };

switch (args.type) {
case "tei":
return embeddingEndpoints.tei(args);
case "transformersjs":
return embeddingEndpoints.transformersjs(args);
}
}

random -= endpoint.weight;
}

throw new Error(`Failed to select embedding endpoint`);
},
});

export const embeddingModels = await Promise.all(
embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint))
);

export const defaultEmbeddingModel = embeddingModels[0];

const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => {
return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]);
};

export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => {
return validateEmbeddingModel(_models, "id");
};

export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => {
return validateEmbeddingModel(_models, "name");
};

export type EmbeddingBackendModel = typeof defaultEmbeddingModel;
2 changes: 2 additions & 0 deletions src/lib/server/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { z } from "zod";
import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints";
import endpointTgi from "./endpoints/tgi/endpointTgi";
import { sum } from "$lib/utils/sum";
import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels";

import JSON5 from "json5";

Expand Down Expand Up @@ -68,6 +69,7 @@ const modelConfig = z.object({
.optional(),
multimodal: z.boolean().default(false),
unlisted: z.boolean().default(false),
embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(),
});

const modelsRaw = z.array(modelConfig).parse(JSON5.parse(MODELS));
Expand Down
Loading

0 comments on commit 3a01622

Please sign in to comment.