diff --git a/haystack/components/embedders/sentence_transformers_document_embedder.py b/haystack/components/embedders/sentence_transformers_document_embedder.py index a5eaa9ae8b..0125be3f5c 100644 --- a/haystack/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_document_embedder.py @@ -56,6 +56,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments tokenizer_kwargs: Optional[Dict[str, Any]] = None, config_kwargs: Optional[Dict[str, Any]] = None, precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", + encode_kwargs: Optional[Dict[str, Any]] = None, ): """ Creates a SentenceTransformersDocumentEmbedder component. @@ -104,6 +105,10 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments All non-float32 precisions are quantized embeddings. Quantized embeddings are smaller and faster to compute, but may have a lower accuracy. They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks. + :param encode_kwargs: + Additional keyword arguments for `SentenceTransformer.encode` when embedding documents. + This parameter is provided for fine customization. Be careful not to clash with already set parameters and + avoid passing parameters that change the output type. """ self.model = model @@ -121,6 +126,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments self.model_kwargs = model_kwargs self.tokenizer_kwargs = tokenizer_kwargs self.config_kwargs = config_kwargs + self.encode_kwargs = encode_kwargs self.embedding_backend = None self.precision = precision @@ -155,6 +161,7 @@ def to_dict(self) -> Dict[str, Any]: tokenizer_kwargs=self.tokenizer_kwargs, config_kwargs=self.config_kwargs, precision=self.precision, + encode_kwargs=self.encode_kwargs, ) if serialization_dict["init_parameters"].get("model_kwargs") is not None: serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"]) @@ -232,6 +239,7 @@ def run(self, documents: List[Document]): show_progress_bar=self.progress_bar, normalize_embeddings=self.normalize_embeddings, precision=self.precision, + **(self.encode_kwargs if self.encode_kwargs else {}), ) for doc, emb in zip(documents, embeddings): diff --git a/haystack/components/embedders/sentence_transformers_text_embedder.py b/haystack/components/embedders/sentence_transformers_text_embedder.py index 0785caddbd..2cb1622e97 100644 --- a/haystack/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_text_embedder.py @@ -50,6 +50,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments tokenizer_kwargs: Optional[Dict[str, Any]] = None, config_kwargs: Optional[Dict[str, Any]] = None, precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", + encode_kwargs: Optional[Dict[str, Any]] = None, ): """ Create a SentenceTransformersTextEmbedder component. @@ -94,6 +95,10 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments All non-float32 precisions are quantized embeddings. Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy. They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks. + :param encode_kwargs: + Additional keyword arguments for `SentenceTransformer.encode` when embedding texts. + This parameter is provided for fine customization. Be careful not to clash with already set parameters and + avoid passing parameters that change the output type. """ self.model = model @@ -109,6 +114,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments self.model_kwargs = model_kwargs self.tokenizer_kwargs = tokenizer_kwargs self.config_kwargs = config_kwargs + self.encode_kwargs = encode_kwargs self.embedding_backend = None self.precision = precision @@ -141,6 +147,7 @@ def to_dict(self) -> Dict[str, Any]: tokenizer_kwargs=self.tokenizer_kwargs, config_kwargs=self.config_kwargs, precision=self.precision, + encode_kwargs=self.encode_kwargs, ) if serialization_dict["init_parameters"].get("model_kwargs") is not None: serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"]) @@ -209,5 +216,6 @@ def run(self, text: str): show_progress_bar=self.progress_bar, normalize_embeddings=self.normalize_embeddings, precision=self.precision, + **(self.encode_kwargs if self.encode_kwargs else {}), )[0] return {"embedding": embedding} diff --git a/releasenotes/notes/add-encode-kwargs-sentence-transformers-f4d885f6c5b1706f.yaml b/releasenotes/notes/add-encode-kwargs-sentence-transformers-f4d885f6c5b1706f.yaml new file mode 100644 index 0000000000..407a1b7ae7 --- /dev/null +++ b/releasenotes/notes/add-encode-kwargs-sentence-transformers-f4d885f6c5b1706f.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Enhanced `SentenceTransformersDocumentEmbedder` and `SentenceTransformersTextEmbedder` to accept + an additional parameter, which is passed directly to the underlying `SentenceTransformer.encode` method + for greater flexibility in embedding customization. diff --git a/test/components/embedders/test_sentence_transformers_document_embedder.py b/test/components/embedders/test_sentence_transformers_document_embedder.py index d8813f36f4..205feac10f 100644 --- a/test/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/components/embedders/test_sentence_transformers_document_embedder.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +import random from unittest.mock import MagicMock, patch -import random import pytest import torch @@ -79,6 +79,7 @@ def test_to_dict(self): "truncate_dim": None, "model_kwargs": None, "tokenizer_kwargs": None, + "encode_kwargs": None, "config_kwargs": None, "precision": "float32", }, @@ -102,6 +103,7 @@ def test_to_dict_with_custom_init_parameters(self): tokenizer_kwargs={"model_max_length": 512}, config_kwargs={"use_memory_efficient_attention": True}, precision="int8", + encode_kwargs={"task": "clustering"}, ) data = component.to_dict() @@ -124,6 +126,7 @@ def test_to_dict_with_custom_init_parameters(self): "tokenizer_kwargs": {"model_max_length": 512}, "config_kwargs": {"use_memory_efficient_attention": True}, "precision": "int8", + "encode_kwargs": {"task": "clustering"}, }, } @@ -316,6 +319,20 @@ def test_embed_metadata(self): precision="float32", ) + def test_embed_encode_kwargs(self): + embedder = SentenceTransformersDocumentEmbedder(model="model", encode_kwargs={"task": "retrieval.passage"}) + embedder.embedding_backend = MagicMock() + documents = [Document(content=f"document number {i}") for i in range(5)] + embedder.run(documents=documents) + embedder.embedding_backend.embed.assert_called_once_with( + ["document number 0", "document number 1", "document number 2", "document number 3", "document number 4"], + batch_size=32, + show_progress_bar=True, + normalize_embeddings=False, + precision="float32", + task="retrieval.passage", + ) + def test_prefix_suffix(self): embedder = SentenceTransformersDocumentEmbedder( model="model", diff --git a/test/components/embedders/test_sentence_transformers_text_embedder.py b/test/components/embedders/test_sentence_transformers_text_embedder.py index 293b07aada..6d0e239a15 100644 --- a/test/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/components/embedders/test_sentence_transformers_text_embedder.py @@ -1,11 +1,11 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +import random from unittest.mock import MagicMock, patch -import torch -import random import pytest +import torch from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder from haystack.utils import ComponentDevice, Secret @@ -70,6 +70,7 @@ def test_to_dict(self): "truncate_dim": None, "model_kwargs": None, "tokenizer_kwargs": None, + "encode_kwargs": None, "config_kwargs": None, "precision": "float32", }, @@ -91,6 +92,7 @@ def test_to_dict_with_custom_init_parameters(self): tokenizer_kwargs={"model_max_length": 512}, config_kwargs={"use_memory_efficient_attention": False}, precision="int8", + encode_kwargs={"task": "clustering"}, ) data = component.to_dict() assert data == { @@ -110,6 +112,7 @@ def test_to_dict_with_custom_init_parameters(self): "tokenizer_kwargs": {"model_max_length": 512}, "config_kwargs": {"use_memory_efficient_attention": False}, "precision": "int8", + "encode_kwargs": {"task": "clustering"}, }, } @@ -297,3 +300,17 @@ def test_run_quantization(self): assert len(embedding_def) == 768 assert all(isinstance(el, int) for el in embedding_def) + + def test_embed_encode_kwargs(self): + embedder = SentenceTransformersTextEmbedder(model="model", encode_kwargs={"task": "retrieval.query"}) + embedder.embedding_backend = MagicMock() + text = "a nice text to embed" + embedder.run(text=text) + embedder.embedding_backend.embed.assert_called_once_with( + [text], + batch_size=32, + show_progress_bar=True, + normalize_embeddings=False, + precision="float32", + task="retrieval.query", + )