Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder can accept and pass any arguments to SentenceTransformer.encode #8806

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
encode_kwargs: Optional[Dict[str, Any]] = None,
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
):
"""
Expand Down Expand Up @@ -99,6 +100,8 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
Refer to specific model documentation for available kwargs.
:param config_kwargs:
Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
:param encode_kwargs:
Additional keyword arguments for `SentenceTransformer.encode` when embedding documents.
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
:param precision:
The precision to use for the embeddings.
All non-float32 precisions are quantized embeddings.
Expand All @@ -121,6 +124,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.config_kwargs = config_kwargs
self.encode_kwargs = encode_kwargs
self.embedding_backend = None
self.precision = precision

Expand Down Expand Up @@ -154,6 +158,7 @@ def to_dict(self) -> Dict[str, Any]:
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
encode_kwargs=self.encode_kwargs,
precision=self.precision,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
Expand Down Expand Up @@ -232,6 +237,7 @@ def run(self, documents: List[Document]):
show_progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
precision=self.precision,
**(self.encode_kwargs if self.encode_kwargs else {}),
)

for doc, emb in zip(documents, embeddings):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
encode_kwargs: Optional[Dict[str, Any]] = None,
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
):
"""
Expand Down Expand Up @@ -89,6 +90,8 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
Refer to specific model documentation for available kwargs.
:param config_kwargs:
Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
:param encode_kwargs:
Additional keyword arguments for `SentenceTransformer.encode` when embedding texts.
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
:param precision:
The precision to use for the embeddings.
All non-float32 precisions are quantized embeddings.
Expand All @@ -109,6 +112,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.config_kwargs = config_kwargs
self.encode_kwargs = encode_kwargs
self.embedding_backend = None
self.precision = precision

Expand Down Expand Up @@ -140,6 +144,7 @@ def to_dict(self) -> Dict[str, Any]:
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
encode_kwargs=self.encode_kwargs,
precision=self.precision,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
Expand Down Expand Up @@ -209,5 +214,6 @@ def run(self, text: str):
show_progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
precision=self.precision,
**(self.encode_kwargs if self.encode_kwargs else {}),
)[0]
return {"embedding": embedding}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
enhancements:
- |
Enhanced `SentenceTransformersDocumentEmbedder` and `SentenceTransformersTextEmbedder` to accept
an additional parameter, which is passed directly to the underlying `SentenceTransformer.encode` method
for greater flexibility in embedding customization.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_to_dict(self):
"truncate_dim": None,
"model_kwargs": None,
"tokenizer_kwargs": None,
"encode_kwargs": None,
"config_kwargs": None,
"precision": "float32",
},
Expand All @@ -101,6 +102,7 @@ def test_to_dict_with_custom_init_parameters(self):
model_kwargs={"torch_dtype": torch.float32},
tokenizer_kwargs={"model_max_length": 512},
config_kwargs={"use_memory_efficient_attention": True},
encode_kwargs={"task": "clustering"},
precision="int8",
)
data = component.to_dict()
Expand All @@ -123,6 +125,7 @@ def test_to_dict_with_custom_init_parameters(self):
"model_kwargs": {"torch_dtype": "torch.float32"},
"tokenizer_kwargs": {"model_max_length": 512},
"config_kwargs": {"use_memory_efficient_attention": True},
"encode_kwargs": {"task": "clustering"},
"precision": "int8",
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_to_dict(self):
"truncate_dim": None,
"model_kwargs": None,
"tokenizer_kwargs": None,
"encode_kwargs": None,
"config_kwargs": None,
"precision": "float32",
},
Expand All @@ -90,6 +91,7 @@ def test_to_dict_with_custom_init_parameters(self):
model_kwargs={"torch_dtype": torch.float32},
tokenizer_kwargs={"model_max_length": 512},
config_kwargs={"use_memory_efficient_attention": False},
encode_kwargs={"task": "clustering"},
precision="int8",
)
data = component.to_dict()
Expand All @@ -109,6 +111,7 @@ def test_to_dict_with_custom_init_parameters(self):
"model_kwargs": {"torch_dtype": "torch.float32"},
"tokenizer_kwargs": {"model_max_length": 512},
"config_kwargs": {"use_memory_efficient_attention": False},
"encode_kwargs": {"task": "clustering"},
"precision": "int8",
},
}
Expand Down
Loading