Skip to content

Commit 41e78c5

Browse files
authored
server : add support for embd_normalize parameter (ggml-org#14964)
This commit adds support for the `embd_normalize` parameter in the server code. The motivation for this is that currently if the server is started with a pooling type that is not `none`, then Euclidean/L2 normalization will be the normalization method used for embeddings. However, this is not always the desired behavior, and users may want to use other normalization (or none) and this commit allows that. Example usage: ```console curl --request POST \ --url http://localhost:8080/embedding \ --header "Content-Type: application/json" \ --data '{"input": "Hello world today", "embd_normalize": -1} ```
1 parent ad4a700 commit 41e78c5

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

tools/server/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,15 @@ The same as [the embedding example](../embedding) does.
644644

645645
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
646646

647+
`embd_normalize`: Normalization for pooled embeddings. Can be one of the following values:
648+
```
649+
-1: No normalization
650+
0: Max absolute
651+
1: Taxicab
652+
2: Euclidean/L2
653+
>2: P-Norm
654+
```
655+
647656
### POST `/reranking`: Rerank documents according to a given query
648657

649658
Similar to https://jina.ai/reranker/ but might change in the future.

tools/server/server.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ struct slot_params {
138138
std::string oaicompat_cmpl_id;
139139
common_chat_syntax oaicompat_chat_syntax;
140140

141+
// Embeddings
142+
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
143+
141144
json to_json() const {
142145
std::vector<std::string> samplers;
143146
samplers.reserve(sampling.samplers.size());
@@ -2601,7 +2604,7 @@ struct server_context {
26012604

26022605
// normalize only when there is pooling
26032606
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
2604-
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
2607+
common_embd_normalize(embd, embd_res.data(), n_embd, slot.params.embd_normalize);
26052608
res->embedding.push_back(embd_res);
26062609
break;
26072610
} else {
@@ -4614,6 +4617,14 @@ int main(int argc, char ** argv) {
46144617
}
46154618
}
46164619

4620+
int embd_normalize = 2; // default to Euclidean/L2 norm
4621+
if (body.count("embd_normalize") != 0) {
4622+
embd_normalize = body.at("embd_normalize");
4623+
if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
4624+
SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx));
4625+
}
4626+
}
4627+
46174628
// create and queue the task
46184629
json responses = json::array();
46194630
bool error = false;
@@ -4629,6 +4640,7 @@ int main(int argc, char ** argv) {
46294640

46304641
// OAI-compat
46314642
task.params.oaicompat = oaicompat;
4643+
task.params.embd_normalize = embd_normalize;
46324644

46334645
tasks.push_back(std::move(task));
46344646
}

0 commit comments

Comments
 (0)