Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder can accept and pass any arguments to SentenceTransformer.encode #8806

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
encode_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates a SentenceTransformersDocumentEmbedder component.
Expand Down Expand Up @@ -104,6 +105,10 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
All non-float32 precisions are quantized embeddings.
Quantized embeddings are smaller and faster to compute, but may have a lower accuracy.
They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
:param encode_kwargs:
Additional keyword arguments for `SentenceTransformer.encode` when embedding documents.
This parameter is provided for fine customization. Be careful not to clash with already set parameters and
avoid passing parameters that change the output type.
"""

self.model = model
Expand All @@ -121,6 +126,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.config_kwargs = config_kwargs
self.encode_kwargs = encode_kwargs
self.embedding_backend = None
self.precision = precision

Expand Down Expand Up @@ -155,6 +161,7 @@ def to_dict(self) -> Dict[str, Any]:
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
precision=self.precision,
encode_kwargs=self.encode_kwargs,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
Expand Down Expand Up @@ -232,6 +239,7 @@ def run(self, documents: List[Document]):
show_progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
precision=self.precision,
**(self.encode_kwargs if self.encode_kwargs else {}),
)

for doc, emb in zip(documents, embeddings):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
encode_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Create a SentenceTransformersTextEmbedder component.
Expand Down Expand Up @@ -94,6 +95,10 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
All non-float32 precisions are quantized embeddings.
Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy.
They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
:param encode_kwargs:
Additional keyword arguments for `SentenceTransformer.encode` when embedding texts.
This parameter is provided for fine customization. Be careful not to clash with already set parameters and
avoid passing parameters that change the output type.
"""

self.model = model
Expand All @@ -109,6 +114,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.config_kwargs = config_kwargs
self.encode_kwargs = encode_kwargs
self.embedding_backend = None
self.precision = precision

Expand Down Expand Up @@ -141,6 +147,7 @@ def to_dict(self) -> Dict[str, Any]:
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
precision=self.precision,
encode_kwargs=self.encode_kwargs,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
Expand Down Expand Up @@ -209,5 +216,6 @@ def run(self, text: str):
show_progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
precision=self.precision,
**(self.encode_kwargs if self.encode_kwargs else {}),
)[0]
return {"embedding": embedding}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
enhancements:
- |
Enhanced `SentenceTransformersDocumentEmbedder` and `SentenceTransformersTextEmbedder` to accept
an additional parameter, which is passed directly to the underlying `SentenceTransformer.encode` method
for greater flexibility in embedding customization.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import random
from unittest.mock import MagicMock, patch

import random
import pytest
import torch

Expand Down Expand Up @@ -79,6 +79,7 @@ def test_to_dict(self):
"truncate_dim": None,
"model_kwargs": None,
"tokenizer_kwargs": None,
"encode_kwargs": None,
"config_kwargs": None,
"precision": "float32",
},
Expand All @@ -102,6 +103,7 @@ def test_to_dict_with_custom_init_parameters(self):
tokenizer_kwargs={"model_max_length": 512},
config_kwargs={"use_memory_efficient_attention": True},
precision="int8",
encode_kwargs={"task": "clustering"},
)
data = component.to_dict()

Expand All @@ -124,6 +126,7 @@ def test_to_dict_with_custom_init_parameters(self):
"tokenizer_kwargs": {"model_max_length": 512},
"config_kwargs": {"use_memory_efficient_attention": True},
"precision": "int8",
"encode_kwargs": {"task": "clustering"},
},
}

Expand Down Expand Up @@ -316,6 +319,20 @@ def test_embed_metadata(self):
precision="float32",
)

def test_embed_encode_kwargs(self):
embedder = SentenceTransformersDocumentEmbedder(model="model", encode_kwargs={"task": "retrieval.passage"})
embedder.embedding_backend = MagicMock()
documents = [Document(content=f"document number {i}") for i in range(5)]
embedder.run(documents=documents)
embedder.embedding_backend.embed.assert_called_once_with(
["document number 0", "document number 1", "document number 2", "document number 3", "document number 4"],
batch_size=32,
show_progress_bar=True,
normalize_embeddings=False,
precision="float32",
task="retrieval.passage",
)

def test_prefix_suffix(self):
embedder = SentenceTransformersDocumentEmbedder(
model="model",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import random
from unittest.mock import MagicMock, patch

import torch
import random
import pytest
import torch

from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
from haystack.utils import ComponentDevice, Secret
Expand Down Expand Up @@ -70,6 +70,7 @@ def test_to_dict(self):
"truncate_dim": None,
"model_kwargs": None,
"tokenizer_kwargs": None,
"encode_kwargs": None,
"config_kwargs": None,
"precision": "float32",
},
Expand All @@ -91,6 +92,7 @@ def test_to_dict_with_custom_init_parameters(self):
tokenizer_kwargs={"model_max_length": 512},
config_kwargs={"use_memory_efficient_attention": False},
precision="int8",
encode_kwargs={"task": "clustering"},
)
data = component.to_dict()
assert data == {
Expand All @@ -110,6 +112,7 @@ def test_to_dict_with_custom_init_parameters(self):
"tokenizer_kwargs": {"model_max_length": 512},
"config_kwargs": {"use_memory_efficient_attention": False},
"precision": "int8",
"encode_kwargs": {"task": "clustering"},
},
}

Expand Down Expand Up @@ -297,3 +300,17 @@ def test_run_quantization(self):

assert len(embedding_def) == 768
assert all(isinstance(el, int) for el in embedding_def)

def test_embed_encode_kwargs(self):
embedder = SentenceTransformersTextEmbedder(model="model", encode_kwargs={"task": "retrieval.query"})
embedder.embedding_backend = MagicMock()
text = "a nice text to embed"
embedder.run(text=text)
embedder.embedding_backend.embed.assert_called_once_with(
[text],
batch_size=32,
show_progress_bar=True,
normalize_embeddings=False,
precision="float32",
task="retrieval.query",
)