From 9383e3cc1e1b150d0897ec9977755b97178287a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gy=C3=B6rgy=20Orosz?= Date: Tue, 4 Feb 2025 13:15:58 +0100 Subject: [PATCH 1/6] feat: SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder can accept and pass any arguments to SentenceTransformer.encode --- .../embedders/sentence_transformers_document_embedder.py | 6 ++++++ .../embedders/sentence_transformers_text_embedder.py | 6 ++++++ ...ncode-kwargs-sentence-transformers-f4d885f6c5b1706f.yaml | 6 ++++++ .../test_sentence_transformers_document_embedder.py | 3 +++ .../embedders/test_sentence_transformers_text_embedder.py | 3 +++ 5 files changed, 24 insertions(+) create mode 100644 releasenotes/notes/add-encode-kwargs-sentence-transformers-f4d885f6c5b1706f.yaml diff --git a/haystack/components/embedders/sentence_transformers_document_embedder.py b/haystack/components/embedders/sentence_transformers_document_embedder.py index a5eaa9ae8b..aa063c5f65 100644 --- a/haystack/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_document_embedder.py @@ -55,6 +55,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, config_kwargs: Optional[Dict[str, Any]] = None, + encode_kwargs: Optional[Dict[str, Any]] = None, precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", ): """ @@ -99,6 +100,8 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments Refer to specific model documentation for available kwargs. :param config_kwargs: Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration. + :param encode_kwargs: + Additional keyword arguments for `SentenceTransformer.encode` when embedding documents. :param precision: The precision to use for the embeddings. All non-float32 precisions are quantized embeddings. @@ -121,6 +124,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments self.model_kwargs = model_kwargs self.tokenizer_kwargs = tokenizer_kwargs self.config_kwargs = config_kwargs + self.encode_kwargs = encode_kwargs self.embedding_backend = None self.precision = precision @@ -154,6 +158,7 @@ def to_dict(self) -> Dict[str, Any]: model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, config_kwargs=self.config_kwargs, + encode_kwargs=self.encode_kwargs, precision=self.precision, ) if serialization_dict["init_parameters"].get("model_kwargs") is not None: @@ -232,6 +237,7 @@ def run(self, documents: List[Document]): show_progress_bar=self.progress_bar, normalize_embeddings=self.normalize_embeddings, precision=self.precision, + **(self.encode_kwargs if self.encode_kwargs else {}), ) for doc, emb in zip(documents, embeddings): diff --git a/haystack/components/embedders/sentence_transformers_text_embedder.py b/haystack/components/embedders/sentence_transformers_text_embedder.py index 0785caddbd..932e7bfea6 100644 --- a/haystack/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_text_embedder.py @@ -49,6 +49,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, config_kwargs: Optional[Dict[str, Any]] = None, + encode_kwargs: Optional[Dict[str, Any]] = None, precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", ): """ @@ -89,6 +90,8 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments Refer to specific model documentation for available kwargs. :param config_kwargs: Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration. + :param encode_kwargs: + Additional keyword arguments for `SentenceTransformer.encode` when embedding texts. :param precision: The precision to use for the embeddings. All non-float32 precisions are quantized embeddings. @@ -109,6 +112,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments self.model_kwargs = model_kwargs self.tokenizer_kwargs = tokenizer_kwargs self.config_kwargs = config_kwargs + self.encode_kwargs = encode_kwargs self.embedding_backend = None self.precision = precision @@ -140,6 +144,7 @@ def to_dict(self) -> Dict[str, Any]: model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, config_kwargs=self.config_kwargs, + encode_kwargs=self.encode_kwargs, precision=self.precision, ) if serialization_dict["init_parameters"].get("model_kwargs") is not None: @@ -209,5 +214,6 @@ def run(self, text: str): show_progress_bar=self.progress_bar, normalize_embeddings=self.normalize_embeddings, precision=self.precision, + **(self.encode_kwargs if self.encode_kwargs else {}), )[0] return {"embedding": embedding} diff --git a/releasenotes/notes/add-encode-kwargs-sentence-transformers-f4d885f6c5b1706f.yaml b/releasenotes/notes/add-encode-kwargs-sentence-transformers-f4d885f6c5b1706f.yaml new file mode 100644 index 0000000000..407a1b7ae7 --- /dev/null +++ b/releasenotes/notes/add-encode-kwargs-sentence-transformers-f4d885f6c5b1706f.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Enhanced `SentenceTransformersDocumentEmbedder` and `SentenceTransformersTextEmbedder` to accept + an additional parameter, which is passed directly to the underlying `SentenceTransformer.encode` method + for greater flexibility in embedding customization. diff --git a/test/components/embedders/test_sentence_transformers_document_embedder.py b/test/components/embedders/test_sentence_transformers_document_embedder.py index d8813f36f4..b7c84dd500 100644 --- a/test/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/components/embedders/test_sentence_transformers_document_embedder.py @@ -79,6 +79,7 @@ def test_to_dict(self): "truncate_dim": None, "model_kwargs": None, "tokenizer_kwargs": None, + "encode_kwargs": None, "config_kwargs": None, "precision": "float32", }, @@ -101,6 +102,7 @@ def test_to_dict_with_custom_init_parameters(self): model_kwargs={"torch_dtype": torch.float32}, tokenizer_kwargs={"model_max_length": 512}, config_kwargs={"use_memory_efficient_attention": True}, + encode_kwargs={"task": "clustering"}, precision="int8", ) data = component.to_dict() @@ -123,6 +125,7 @@ def test_to_dict_with_custom_init_parameters(self): "model_kwargs": {"torch_dtype": "torch.float32"}, "tokenizer_kwargs": {"model_max_length": 512}, "config_kwargs": {"use_memory_efficient_attention": True}, + "encode_kwargs": {"task": "clustering"}, "precision": "int8", }, } diff --git a/test/components/embedders/test_sentence_transformers_text_embedder.py b/test/components/embedders/test_sentence_transformers_text_embedder.py index 195ee8efd1..572834c278 100644 --- a/test/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/components/embedders/test_sentence_transformers_text_embedder.py @@ -70,6 +70,7 @@ def test_to_dict(self): "truncate_dim": None, "model_kwargs": None, "tokenizer_kwargs": None, + "encode_kwargs": None, "config_kwargs": None, "precision": "float32", }, @@ -90,6 +91,7 @@ def test_to_dict_with_custom_init_parameters(self): model_kwargs={"torch_dtype": torch.float32}, tokenizer_kwargs={"model_max_length": 512}, config_kwargs={"use_memory_efficient_attention": False}, + encode_kwargs={"task": "clustering"}, precision="int8", ) data = component.to_dict() @@ -109,6 +111,7 @@ def test_to_dict_with_custom_init_parameters(self): "model_kwargs": {"torch_dtype": "torch.float32"}, "tokenizer_kwargs": {"model_max_length": 512}, "config_kwargs": {"use_memory_efficient_attention": False}, + "encode_kwargs": {"task": "clustering"}, "precision": "int8", }, } From 837d5c2782b82a9496daca3c7d03ab2d09a80240 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gy=C3=B6rgy=20Orosz?= Date: Wed, 5 Feb 2025 16:05:38 +0100 Subject: [PATCH 2/6] refactor: encode_kwargs parameter of SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder mae to be the last positional parameter for backward compatibility reasons --- .../embedders/sentence_transformers_document_embedder.py | 2 +- .../embedders/sentence_transformers_text_embedder.py | 2 +- .../embedders/test_sentence_transformers_document_embedder.py | 4 ++-- .../embedders/test_sentence_transformers_text_embedder.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/haystack/components/embedders/sentence_transformers_document_embedder.py b/haystack/components/embedders/sentence_transformers_document_embedder.py index aa063c5f65..d79e899da3 100644 --- a/haystack/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_document_embedder.py @@ -55,8 +55,8 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, config_kwargs: Optional[Dict[str, Any]] = None, - encode_kwargs: Optional[Dict[str, Any]] = None, precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", + encode_kwargs: Optional[Dict[str, Any]] = None, ): """ Creates a SentenceTransformersDocumentEmbedder component. diff --git a/haystack/components/embedders/sentence_transformers_text_embedder.py b/haystack/components/embedders/sentence_transformers_text_embedder.py index 932e7bfea6..6a670acf0f 100644 --- a/haystack/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_text_embedder.py @@ -144,8 +144,8 @@ def to_dict(self) -> Dict[str, Any]: model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, config_kwargs=self.config_kwargs, - encode_kwargs=self.encode_kwargs, precision=self.precision, + encode_kwargs=self.encode_kwargs, ) if serialization_dict["init_parameters"].get("model_kwargs") is not None: serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"]) diff --git a/test/components/embedders/test_sentence_transformers_document_embedder.py b/test/components/embedders/test_sentence_transformers_document_embedder.py index b7c84dd500..b15b74c519 100644 --- a/test/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/components/embedders/test_sentence_transformers_document_embedder.py @@ -102,8 +102,8 @@ def test_to_dict_with_custom_init_parameters(self): model_kwargs={"torch_dtype": torch.float32}, tokenizer_kwargs={"model_max_length": 512}, config_kwargs={"use_memory_efficient_attention": True}, - encode_kwargs={"task": "clustering"}, precision="int8", + encode_kwargs={"task": "clustering"}, ) data = component.to_dict() @@ -125,8 +125,8 @@ def test_to_dict_with_custom_init_parameters(self): "model_kwargs": {"torch_dtype": "torch.float32"}, "tokenizer_kwargs": {"model_max_length": 512}, "config_kwargs": {"use_memory_efficient_attention": True}, - "encode_kwargs": {"task": "clustering"}, "precision": "int8", + "encode_kwargs": {"task": "clustering"}, }, } diff --git a/test/components/embedders/test_sentence_transformers_text_embedder.py b/test/components/embedders/test_sentence_transformers_text_embedder.py index 5652e2338c..1ae37244db 100644 --- a/test/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/components/embedders/test_sentence_transformers_text_embedder.py @@ -91,8 +91,8 @@ def test_to_dict_with_custom_init_parameters(self): model_kwargs={"torch_dtype": torch.float32}, tokenizer_kwargs={"model_max_length": 512}, config_kwargs={"use_memory_efficient_attention": False}, - encode_kwargs={"task": "clustering"}, precision="int8", + encode_kwargs={"task": "clustering"}, ) data = component.to_dict() assert data == { @@ -111,8 +111,8 @@ def test_to_dict_with_custom_init_parameters(self): "model_kwargs": {"torch_dtype": "torch.float32"}, "tokenizer_kwargs": {"model_max_length": 512}, "config_kwargs": {"use_memory_efficient_attention": False}, - "encode_kwargs": {"task": "clustering"}, "precision": "int8", + "encode_kwargs": {"task": "clustering"}, }, } From c07ecbbfe1cde0f109f51218bc65c5560b48f3fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gy=C3=B6rgy=20Orosz?= Date: Wed, 5 Feb 2025 16:08:48 +0100 Subject: [PATCH 3/6] docs: added explanation for encode_kwargs in SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder --- .../embedders/sentence_transformers_document_embedder.py | 7 +++++-- .../embedders/sentence_transformers_text_embedder.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/haystack/components/embedders/sentence_transformers_document_embedder.py b/haystack/components/embedders/sentence_transformers_document_embedder.py index d79e899da3..3cfbd6c737 100644 --- a/haystack/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_document_embedder.py @@ -100,13 +100,16 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments Refer to specific model documentation for available kwargs. :param config_kwargs: Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration. - :param encode_kwargs: - Additional keyword arguments for `SentenceTransformer.encode` when embedding documents. :param precision: The precision to use for the embeddings. All non-float32 precisions are quantized embeddings. Quantized embeddings are smaller and faster to compute, but may have a lower accuracy. They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks. + :param encode_kwargs: + Additional keyword arguments for `SentenceTransformer.encode` when embedding documents. + + This parameter is provided for fine customization. Be careful not to clash with already set parameters and + avoid passing parameters that change the output type. """ self.model = model diff --git a/haystack/components/embedders/sentence_transformers_text_embedder.py b/haystack/components/embedders/sentence_transformers_text_embedder.py index 6a670acf0f..78b016ad34 100644 --- a/haystack/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_text_embedder.py @@ -90,13 +90,16 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments Refer to specific model documentation for available kwargs. :param config_kwargs: Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration. - :param encode_kwargs: - Additional keyword arguments for `SentenceTransformer.encode` when embedding texts. :param precision: The precision to use for the embeddings. All non-float32 precisions are quantized embeddings. Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy. They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks. + :param encode_kwargs: + Additional keyword arguments for `SentenceTransformer.encode` when embedding texts. + + This parameter is provided for fine customization. Be careful not to clash with already set parameters and + avoid passing parameters that change the output type. """ self.model = model From 088e43b2d86fae4367749a45063cdb97ef69b80b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gy=C3=B6rgy=20Orosz?= Date: Wed, 5 Feb 2025 16:35:54 +0100 Subject: [PATCH 4/6] test: added tests for encode_kwargs in SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder --- ..._sentence_transformers_document_embedder.py | 16 +++++++++++++++- ...test_sentence_transformers_text_embedder.py | 18 ++++++++++++++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/test/components/embedders/test_sentence_transformers_document_embedder.py b/test/components/embedders/test_sentence_transformers_document_embedder.py index b15b74c519..205feac10f 100644 --- a/test/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/components/embedders/test_sentence_transformers_document_embedder.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +import random from unittest.mock import MagicMock, patch -import random import pytest import torch @@ -319,6 +319,20 @@ def test_embed_metadata(self): precision="float32", ) + def test_embed_encode_kwargs(self): + embedder = SentenceTransformersDocumentEmbedder(model="model", encode_kwargs={"task": "retrieval.passage"}) + embedder.embedding_backend = MagicMock() + documents = [Document(content=f"document number {i}") for i in range(5)] + embedder.run(documents=documents) + embedder.embedding_backend.embed.assert_called_once_with( + ["document number 0", "document number 1", "document number 2", "document number 3", "document number 4"], + batch_size=32, + show_progress_bar=True, + normalize_embeddings=False, + precision="float32", + task="retrieval.passage", + ) + def test_prefix_suffix(self): embedder = SentenceTransformersDocumentEmbedder( model="model", diff --git a/test/components/embedders/test_sentence_transformers_text_embedder.py b/test/components/embedders/test_sentence_transformers_text_embedder.py index 1ae37244db..6d0e239a15 100644 --- a/test/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/components/embedders/test_sentence_transformers_text_embedder.py @@ -1,11 +1,11 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +import random from unittest.mock import MagicMock, patch -import torch -import random import pytest +import torch from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder from haystack.utils import ComponentDevice, Secret @@ -300,3 +300,17 @@ def test_run_quantization(self): assert len(embedding_def) == 768 assert all(isinstance(el, int) for el in embedding_def) + + def test_embed_encode_kwargs(self): + embedder = SentenceTransformersTextEmbedder(model="model", encode_kwargs={"task": "retrieval.query"}) + embedder.embedding_backend = MagicMock() + text = "a nice text to embed" + embedder.run(text=text) + embedder.embedding_backend.embed.assert_called_once_with( + [text], + batch_size=32, + show_progress_bar=True, + normalize_embeddings=False, + precision="float32", + task="retrieval.query", + ) From 1f9ae33bd8f09f389e40fc11a04f0aff0055d131 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gy=C3=B6rgy=20Orosz?= Date: Wed, 5 Feb 2025 16:53:00 +0100 Subject: [PATCH 5/6] doc: removed empty lines from docstrings of SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder --- .../embedders/sentence_transformers_document_embedder.py | 1 - .../components/embedders/sentence_transformers_text_embedder.py | 1 - 2 files changed, 2 deletions(-) diff --git a/haystack/components/embedders/sentence_transformers_document_embedder.py b/haystack/components/embedders/sentence_transformers_document_embedder.py index 3cfbd6c737..da804edde0 100644 --- a/haystack/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_document_embedder.py @@ -107,7 +107,6 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks. :param encode_kwargs: Additional keyword arguments for `SentenceTransformer.encode` when embedding documents. - This parameter is provided for fine customization. Be careful not to clash with already set parameters and avoid passing parameters that change the output type. """ diff --git a/haystack/components/embedders/sentence_transformers_text_embedder.py b/haystack/components/embedders/sentence_transformers_text_embedder.py index 78b016ad34..cf01495668 100644 --- a/haystack/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_text_embedder.py @@ -97,7 +97,6 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks. :param encode_kwargs: Additional keyword arguments for `SentenceTransformer.encode` when embedding texts. - This parameter is provided for fine customization. Be careful not to clash with already set parameters and avoid passing parameters that change the output type. """ From ea14a19b627a35333a77ce356deb99d34aa6a2ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gy=C3=B6rgy=20Orosz?= Date: Wed, 5 Feb 2025 16:53:29 +0100 Subject: [PATCH 6/6] refactor: encode_kwargs parameter of SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder mae to be the last positional parameter for backward compatibility (part II.) --- .../embedders/sentence_transformers_document_embedder.py | 2 +- .../components/embedders/sentence_transformers_text_embedder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack/components/embedders/sentence_transformers_document_embedder.py b/haystack/components/embedders/sentence_transformers_document_embedder.py index da804edde0..0125be3f5c 100644 --- a/haystack/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_document_embedder.py @@ -160,8 +160,8 @@ def to_dict(self) -> Dict[str, Any]: model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, config_kwargs=self.config_kwargs, - encode_kwargs=self.encode_kwargs, precision=self.precision, + encode_kwargs=self.encode_kwargs, ) if serialization_dict["init_parameters"].get("model_kwargs") is not None: serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"]) diff --git a/haystack/components/embedders/sentence_transformers_text_embedder.py b/haystack/components/embedders/sentence_transformers_text_embedder.py index cf01495668..2cb1622e97 100644 --- a/haystack/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_text_embedder.py @@ -49,8 +49,8 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, config_kwargs: Optional[Dict[str, Any]] = None, - encode_kwargs: Optional[Dict[str, Any]] = None, precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", + encode_kwargs: Optional[Dict[str, Any]] = None, ): """ Create a SentenceTransformersTextEmbedder component.