Skip to content

Commit

Permalink
refactor: HF API Embedders - use InferenceClient.feature_extraction
Browse files Browse the repository at this point in the history
… instead of `InferenceClient.post` (#8794)

* HF API Embedders: refactoring

* rename variables

* rm leftovers

* rm pin

* rm unused import

* relnote

* warning with truncate/normalize and serverless inference API

* test that warnings are raised
  • Loading branch information
anakin87 authored Feb 3, 2025
1 parent f165212 commit 877f826
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 47 deletions.
36 changes: 26 additions & 10 deletions haystack/components/embedders/hugging_face_api_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import json
import warnings
from typing import Any, Dict, List, Optional, Union

from tqdm import tqdm
Expand Down Expand Up @@ -96,8 +96,8 @@ def __init__(
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
prefix: str = "",
suffix: str = "",
truncate: bool = True,
normalize: bool = False,
truncate: Optional[bool] = True,
normalize: Optional[bool] = False,
batch_size: int = 32,
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
Expand All @@ -124,13 +124,11 @@ def __init__(
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
if the backend uses Text Embeddings Inference.
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
It is always set to `True` and cannot be changed.
:param normalize:
Normalizes the embeddings to unit length.
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
if the backend uses Text Embeddings Inference.
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
It is always set to `False` and cannot be changed.
:param batch_size:
Number of documents to process at once.
:param progress_bar:
Expand Down Expand Up @@ -239,18 +237,36 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[
"""
Embed a list of texts in batches.
"""
truncate = self.truncate
normalize = self.normalize

if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
if truncate is not None:
msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
warnings.warn(msg)
truncate = None
if normalize is not None:
msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
warnings.warn(msg)
normalize = None

all_embeddings = []
for i in tqdm(
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i : i + batch_size]
response = self._client.post(
json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize},
task="feature-extraction",

np_embeddings = self._client.feature_extraction(
# this method does not officially support list of strings, but works as expected
text=batch, # type: ignore[arg-type]
truncate=truncate,
normalize=normalize,
)
embeddings = json.loads(response.decode())
all_embeddings.extend(embeddings)

if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch):
raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}")

all_embeddings.extend(np_embeddings.tolist())

return all_embeddings

Expand Down
35 changes: 25 additions & 10 deletions haystack/components/embedders/hugging_face_api_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import json
import warnings
from typing import Any, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
Expand Down Expand Up @@ -80,8 +80,8 @@ def __init__(
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
prefix: str = "",
suffix: str = "",
truncate: bool = True,
normalize: bool = False,
truncate: Optional[bool] = True,
normalize: Optional[bool] = False,
): # pylint: disable=too-many-positional-arguments
"""
Creates a HuggingFaceAPITextEmbedder component.
Expand All @@ -104,13 +104,11 @@ def __init__(
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
if the backend uses Text Embeddings Inference.
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
It is always set to `True` and cannot be changed.
:param normalize:
Normalizes the embeddings to unit length.
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
if the backend uses Text Embeddings Inference.
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
It is always set to `False` and cannot be changed.
"""
huggingface_hub_import.check()

Expand Down Expand Up @@ -198,12 +196,29 @@ def run(self, text: str):
"In case you want to embed a list of Documents, please use the HuggingFaceAPIDocumentEmbedder."
)

truncate = self.truncate
normalize = self.normalize

if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
if truncate is not None:
msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
warnings.warn(msg)
truncate = None
if normalize is not None:
msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
warnings.warn(msg)
normalize = None

text_to_embed = self.prefix + text + self.suffix

response = self._client.post(
json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize},
task="feature-extraction",
)
embedding = json.loads(response.decode())[0]
np_embedding = self._client.feature_extraction(text=text_to_embed, truncate=truncate, normalize=normalize)

error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}"
if np_embedding.ndim > 2:
raise ValueError(error_msg)
if np_embedding.ndim == 2 and np_embedding.shape[0] != 1:
raise ValueError(error_msg)

embedding = np_embedding.flatten().tolist()

return {"embedding": embedding}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ extra-dependencies = [
"numba>=0.54.0", # This pin helps uv resolve the dependency tree. See https://github.com/astral-sh/uv/issues/7881

"transformers[torch,sentencepiece]==4.47.1", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators...
"huggingface_hub>=0.27.0, <0.28.0", # Hugging Face API Generators and Embedders
"huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders
"sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder
"langdetect", # TextLanguageRouter and DocumentLanguageClassifier
"openai-whisper>=20231106", # LocalWhisperTranscriber
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
In the Hugging Face API embedders, the `InferenceClient.feature_extraction` method is now used instead of
`InferenceClient.post` to compute embeddings. This ensures a more robust and future-proof implementation.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import pytest
from huggingface_hub.utils import RepositoryNotFoundError

from numpy import array

from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
from haystack.dataclasses import Document
from haystack.utils.auth import Secret
Expand All @@ -23,8 +25,8 @@ def mock_check_valid_model():
yield mock


def mock_embedding_generation(json, **kwargs):
response = str([[random.random() for _ in range(384)] for _ in range(len(json["inputs"]))]).encode()
def mock_embedding_generation(text, **kwargs):
response = array([[random.random() for _ in range(384)] for _ in range(len(text))])
return response


Expand Down Expand Up @@ -201,10 +203,10 @@ def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model):
"my_prefix document number 4 my_suffix",
]

def test_embed_batch(self, mock_check_valid_model):
def test_embed_batch(self, mock_check_valid_model, recwarn):
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]

with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation

embedder = HuggingFaceAPIDocumentEmbedder(
Expand All @@ -223,6 +225,40 @@ def test_embed_batch(self, mock_check_valid_model):
assert len(embedding) == 384
assert all(isinstance(x, float) for x in embedding)

# Check that warnings about ignoring truncate and normalize are raised
assert len(recwarn) == 2
assert "truncate" in str(recwarn[0].message)
assert "normalize" in str(recwarn[1].message)

def test_embed_batch_wrong_embedding_shape(self, mock_check_valid_model):
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]

# embedding ndim != 2
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.return_value = array([0.1, 0.2, 0.3])

embedder = HuggingFaceAPIDocumentEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "BAAI/bge-small-en-v1.5"},
token=Secret.from_token("fake-api-token"),
)

with pytest.raises(ValueError):
embedder._embed_batch(texts_to_embed=texts, batch_size=2)

# embedding ndim == 2 but shape[0] != len(batch)
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.return_value = array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])

embedder = HuggingFaceAPIDocumentEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "BAAI/bge-small-en-v1.5"},
token=Secret.from_token("fake-api-token"),
)

with pytest.raises(ValueError):
embedder._embed_batch(texts_to_embed=texts, batch_size=2)

def test_run_wrong_input_format(self, mock_check_valid_model):
embedder = HuggingFaceAPIDocumentEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"}
Expand Down Expand Up @@ -252,7 +288,7 @@ def test_run(self, mock_check_valid_model):
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]

with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation

embedder = HuggingFaceAPIDocumentEmbedder(
Expand All @@ -268,16 +304,14 @@ def test_run(self, mock_check_valid_model):
result = embedder.run(documents=docs)

mock_embedding_patch.assert_called_once_with(
json={
"inputs": [
"prefix Cuisine | I love cheese suffix",
"prefix ML | A transformer is a deep learning architecture suffix",
],
"truncate": True,
"normalize": False,
},
task="feature-extraction",
text=[
"prefix Cuisine | I love cheese suffix",
"prefix ML | A transformer is a deep learning architecture suffix",
],
truncate=None,
normalize=None,
)

documents_with_embeddings = result["documents"]

assert isinstance(documents_with_embeddings, list)
Expand All @@ -294,7 +328,7 @@ def test_run_custom_batch_size(self, mock_check_valid_model):
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]

with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation

embedder = HuggingFaceAPIDocumentEmbedder(
Expand Down
44 changes: 33 additions & 11 deletions test/components/embedders/test_hugging_face_api_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import random
import pytest
from huggingface_hub.utils import RepositoryNotFoundError

from numpy import array
from haystack.components.embedders import HuggingFaceAPITextEmbedder
from haystack.utils.auth import Secret
from haystack.utils.hf import HFEmbeddingAPIType
Expand All @@ -21,11 +21,6 @@ def mock_check_valid_model():
yield mock


def mock_embedding_generation(json, **kwargs):
response = str([[random.random() for _ in range(384)] for _ in range(len(json["inputs"]))]).encode()
return response


class TestHuggingFaceAPITextEmbedder:
def test_init_invalid_api_type(self):
with pytest.raises(ValueError):
Expand Down Expand Up @@ -141,9 +136,9 @@ def test_run_wrong_input_format(self, mock_check_valid_model):
with pytest.raises(TypeError):
embedder.run(text=list_integers_input)

def test_run(self, mock_check_valid_model):
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation
def test_run(self, mock_check_valid_model, recwarn):
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.return_value = array([[random.random() for _ in range(384)]])

embedder = HuggingFaceAPITextEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
Expand All @@ -156,13 +151,40 @@ def test_run(self, mock_check_valid_model):
result = embedder.run(text="The food was delicious")

mock_embedding_patch.assert_called_once_with(
json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False},
task="feature-extraction",
text="prefix The food was delicious suffix", truncate=None, normalize=None
)

assert len(result["embedding"]) == 384
assert all(isinstance(x, float) for x in result["embedding"])

# Check that warnings about ignoring truncate and normalize are raised
assert len(recwarn) == 2
assert "truncate" in str(recwarn[0].message)
assert "normalize" in str(recwarn[1].message)

def test_run_wrong_embedding_shape(self, mock_check_valid_model):
# embedding ndim > 2
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.return_value = array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]])

embedder = HuggingFaceAPITextEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"}
)

with pytest.raises(ValueError):
embedder.run(text="The food was delicious")

# embedding ndim == 2 but shape[0] != 1
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.return_value = array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])

embedder = HuggingFaceAPITextEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"}
)

with pytest.raises(ValueError):
embedder.run(text="The food was delicious")

@pytest.mark.flaky(reruns=5, reruns_delay=5)
@pytest.mark.integration
@pytest.mark.skipif(
Expand Down

0 comments on commit 877f826

Please sign in to comment.