Skip to content

Commit

Permalink
feat: SentenceTransformersDocumentEmbedder and SentenceTransformersTe…
Browse files Browse the repository at this point in the history
…xtEmbedder can accept and pass any arguments to SentenceTransformer.encode (#8806)

* feat: SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder can accept and pass any arguments to SentenceTransformer.encode

* refactor: encode_kwargs parameter of SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder mae to be the last positional parameter for backward compatibility reasons

* docs: added explanation for encode_kwargs in SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder

* test: added tests for encode_kwargs in SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder

* doc: removed empty lines from docstrings of SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder

* refactor: encode_kwargs parameter of SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder mae to be the last positional parameter for backward compatibility (part II.)
  • Loading branch information
oroszgy authored Feb 5, 2025
1 parent 2828d9e commit d2348ad
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
encode_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates a SentenceTransformersDocumentEmbedder component.
Expand Down Expand Up @@ -104,6 +105,10 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
All non-float32 precisions are quantized embeddings.
Quantized embeddings are smaller and faster to compute, but may have a lower accuracy.
They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
:param encode_kwargs:
Additional keyword arguments for `SentenceTransformer.encode` when embedding documents.
This parameter is provided for fine customization. Be careful not to clash with already set parameters and
avoid passing parameters that change the output type.
"""

self.model = model
Expand All @@ -121,6 +126,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.config_kwargs = config_kwargs
self.encode_kwargs = encode_kwargs
self.embedding_backend = None
self.precision = precision

Expand Down Expand Up @@ -155,6 +161,7 @@ def to_dict(self) -> Dict[str, Any]:
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
precision=self.precision,
encode_kwargs=self.encode_kwargs,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
Expand Down Expand Up @@ -232,6 +239,7 @@ def run(self, documents: List[Document]):
show_progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
precision=self.precision,
**(self.encode_kwargs if self.encode_kwargs else {}),
)

for doc, emb in zip(documents, embeddings):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
encode_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Create a SentenceTransformersTextEmbedder component.
Expand Down Expand Up @@ -94,6 +95,10 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
All non-float32 precisions are quantized embeddings.
Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy.
They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
:param encode_kwargs:
Additional keyword arguments for `SentenceTransformer.encode` when embedding texts.
This parameter is provided for fine customization. Be careful not to clash with already set parameters and
avoid passing parameters that change the output type.
"""

self.model = model
Expand All @@ -109,6 +114,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.config_kwargs = config_kwargs
self.encode_kwargs = encode_kwargs
self.embedding_backend = None
self.precision = precision

Expand Down Expand Up @@ -141,6 +147,7 @@ def to_dict(self) -> Dict[str, Any]:
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
precision=self.precision,
encode_kwargs=self.encode_kwargs,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
Expand Down Expand Up @@ -209,5 +216,6 @@ def run(self, text: str):
show_progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
precision=self.precision,
**(self.encode_kwargs if self.encode_kwargs else {}),
)[0]
return {"embedding": embedding}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
enhancements:
- |
Enhanced `SentenceTransformersDocumentEmbedder` and `SentenceTransformersTextEmbedder` to accept
an additional parameter, which is passed directly to the underlying `SentenceTransformer.encode` method
for greater flexibility in embedding customization.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import random
from unittest.mock import MagicMock, patch

import random
import pytest
import torch

Expand Down Expand Up @@ -79,6 +79,7 @@ def test_to_dict(self):
"truncate_dim": None,
"model_kwargs": None,
"tokenizer_kwargs": None,
"encode_kwargs": None,
"config_kwargs": None,
"precision": "float32",
},
Expand All @@ -102,6 +103,7 @@ def test_to_dict_with_custom_init_parameters(self):
tokenizer_kwargs={"model_max_length": 512},
config_kwargs={"use_memory_efficient_attention": True},
precision="int8",
encode_kwargs={"task": "clustering"},
)
data = component.to_dict()

Expand All @@ -124,6 +126,7 @@ def test_to_dict_with_custom_init_parameters(self):
"tokenizer_kwargs": {"model_max_length": 512},
"config_kwargs": {"use_memory_efficient_attention": True},
"precision": "int8",
"encode_kwargs": {"task": "clustering"},
},
}

Expand Down Expand Up @@ -316,6 +319,20 @@ def test_embed_metadata(self):
precision="float32",
)

def test_embed_encode_kwargs(self):
embedder = SentenceTransformersDocumentEmbedder(model="model", encode_kwargs={"task": "retrieval.passage"})
embedder.embedding_backend = MagicMock()
documents = [Document(content=f"document number {i}") for i in range(5)]
embedder.run(documents=documents)
embedder.embedding_backend.embed.assert_called_once_with(
["document number 0", "document number 1", "document number 2", "document number 3", "document number 4"],
batch_size=32,
show_progress_bar=True,
normalize_embeddings=False,
precision="float32",
task="retrieval.passage",
)

def test_prefix_suffix(self):
embedder = SentenceTransformersDocumentEmbedder(
model="model",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import random
from unittest.mock import MagicMock, patch

import torch
import random
import pytest
import torch

from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
from haystack.utils import ComponentDevice, Secret
Expand Down Expand Up @@ -70,6 +70,7 @@ def test_to_dict(self):
"truncate_dim": None,
"model_kwargs": None,
"tokenizer_kwargs": None,
"encode_kwargs": None,
"config_kwargs": None,
"precision": "float32",
},
Expand All @@ -91,6 +92,7 @@ def test_to_dict_with_custom_init_parameters(self):
tokenizer_kwargs={"model_max_length": 512},
config_kwargs={"use_memory_efficient_attention": False},
precision="int8",
encode_kwargs={"task": "clustering"},
)
data = component.to_dict()
assert data == {
Expand All @@ -110,6 +112,7 @@ def test_to_dict_with_custom_init_parameters(self):
"tokenizer_kwargs": {"model_max_length": 512},
"config_kwargs": {"use_memory_efficient_attention": False},
"precision": "int8",
"encode_kwargs": {"task": "clustering"},
},
}

Expand Down Expand Up @@ -297,3 +300,17 @@ def test_run_quantization(self):

assert len(embedding_def) == 768
assert all(isinstance(el, int) for el in embedding_def)

def test_embed_encode_kwargs(self):
embedder = SentenceTransformersTextEmbedder(model="model", encode_kwargs={"task": "retrieval.query"})
embedder.embedding_backend = MagicMock()
text = "a nice text to embed"
embedder.run(text=text)
embedder.embedding_backend.embed.assert_called_once_with(
[text],
batch_size=32,
show_progress_bar=True,
normalize_embeddings=False,
precision="float32",
task="retrieval.query",
)

0 comments on commit d2348ad

Please sign in to comment.