Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ONNX & OpenVINO backend support, and torch dtype kwargs in Sentence Transformers Components #8813

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional

from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
Expand All @@ -28,8 +28,9 @@ def get_embedding_backend( # pylint: disable=too-many-positional-arguments
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
):
embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}"
embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}{backend}"

if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
Expand All @@ -42,6 +43,7 @@ def get_embedding_backend( # pylint: disable=too-many-positional-arguments
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
config_kwargs=config_kwargs,
backend=backend,
)
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend
Expand All @@ -62,8 +64,10 @@ def __init__( # pylint: disable=too-many-positional-arguments
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
):
sentence_transformers_import.check()

self.model = SentenceTransformer(
model_name_or_path=model,
device=device,
Expand All @@ -73,6 +77,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
config_kwargs=config_kwargs,
backend=backend,
)

def embed(self, data: List[str], **kwargs) -> List[List[float]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
config_kwargs: Optional[Dict[str, Any]] = None,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
encode_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
):
"""
Creates a SentenceTransformersDocumentEmbedder component.
Expand Down Expand Up @@ -109,6 +110,10 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
Additional keyword arguments for `SentenceTransformer.encode` when embedding documents.
This parameter is provided for fine customization. Be careful not to clash with already set parameters and
avoid passing parameters that change the output type.
:param backend:
The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
for more information on acceleration and quantization options.
"""

self.model = model
Expand All @@ -129,6 +134,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
self.encode_kwargs = encode_kwargs
self.embedding_backend = None
self.precision = precision
self.backend = backend

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -162,6 +168,7 @@ def to_dict(self) -> Dict[str, Any]:
config_kwargs=self.config_kwargs,
precision=self.precision,
encode_kwargs=self.encode_kwargs,
backend=self.backend,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
Expand Down Expand Up @@ -199,6 +206,7 @@ def warm_up(self):
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
backend=self.backend,
)
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
config_kwargs: Optional[Dict[str, Any]] = None,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
encode_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
):
"""
Create a SentenceTransformersTextEmbedder component.
Expand Down Expand Up @@ -99,6 +100,10 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
Additional keyword arguments for `SentenceTransformer.encode` when embedding texts.
This parameter is provided for fine customization. Be careful not to clash with already set parameters and
avoid passing parameters that change the output type.
:param backend:
The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
for more information on acceleration and quantization options.
"""

self.model = model
Expand All @@ -117,6 +122,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
self.encode_kwargs = encode_kwargs
self.embedding_backend = None
self.precision = precision
self.backend = backend

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -148,6 +154,7 @@ def to_dict(self) -> Dict[str, Any]:
config_kwargs=self.config_kwargs,
precision=self.precision,
encode_kwargs=self.encode_kwargs,
backend=self.backend,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
Expand Down Expand Up @@ -185,6 +192,7 @@ def warm_up(self):
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
backend=self.backend,
)
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
Expand Down
42 changes: 38 additions & 4 deletions haystack/components/rankers/sentence_transformers_diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
# SPDX-License-Identifier: Apache-2.0

from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,7 +112,7 @@ class SentenceTransformersDiversityRanker:
```
""" # noqa: E501

def __init__(
def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
self,
model: str = "sentence-transformers/all-MiniLM-L6-v2",
top_k: int = 10,
Expand All @@ -126,7 +127,11 @@ def __init__(
embedding_separator: str = "\n",
strategy: Union[str, DiversityRankingStrategy] = "greedy_diversity_order",
lambda_threshold: float = 0.5,
): # pylint: disable=too-many-positional-arguments
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
):
"""
Initialize a SentenceTransformersDiversityRanker.
Expand All @@ -152,6 +157,18 @@ def __init__(
"maximum_margin_relevance".
:param lambda_threshold: The trade-off parameter between relevance and diversity. Only used when strategy is
"maximum_margin_relevance".
:param model_kwargs:
Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
when loading the model. Refer to specific model documentation for available kwargs.
:param tokenizer_kwargs:
Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
Refer to specific model documentation for available kwargs.
:param config_kwargs:
Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
:param backend:
The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
for more information on acceleration and quantization options.
"""
torch_and_sentence_transformers_import.check()

Expand All @@ -172,6 +189,10 @@ def __init__(
self.strategy = DiversityRankingStrategy.from_str(strategy) if isinstance(strategy, str) else strategy
self.lambda_threshold = lambda_threshold or 0.5
self._check_lambda_threshold(self.lambda_threshold, self.strategy)
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.config_kwargs = config_kwargs
self.backend = backend

def warm_up(self):
"""
Expand All @@ -182,6 +203,10 @@ def warm_up(self):
model_name_or_path=self.model_name_or_path,
device=self.device.to_torch_str(),
use_auth_token=self.token.resolve_value() if self.token else None,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
backend=self.backend,
)

def to_dict(self) -> Dict[str, Any]:
Expand All @@ -191,7 +216,7 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
serialization_dict = default_to_dict(
self,
model=self.model_name_or_path,
top_k=self.top_k,
Expand All @@ -206,7 +231,14 @@ def to_dict(self) -> Dict[str, Any]:
embedding_separator=self.embedding_separator,
strategy=str(self.strategy),
lambda_threshold=self.lambda_threshold,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
backend=self.backend,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
return serialization_dict

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker":
Expand All @@ -222,6 +254,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker
if init_params.get("device") is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
deserialize_secrets_inplace(init_params, keys=["token"])
if init_params.get("model_kwargs") is not None:
deserialize_hf_model_kwargs(init_params["model_kwargs"])
return default_from_dict(cls, data)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ extra-dependencies = [
"transformers[torch,sentencepiece]==4.47.1", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators...
"huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders
"sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder
"optimum[onnxruntime]>=1.24.0", # ONNX CPU runtime for Sentence Transformers components
"optimum-intel[openvino]>=1.22.0", # OpenVINO runtime for Sentence Transformers components (older versions will not work with transformers>=4.7.0)
"langdetect", # TextLanguageRouter and DocumentLanguageClassifier
"openai-whisper>=20231106", # LocalWhisperTranscriber
"arrow>=1.3.0", # Jinja2TimeExtension
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Sentence Transformers components now support ONNX and OpenVINO backends through the "backend" parameter.
Supported backends are torch (default), onnx, and openvino. Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html) for more information.
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_to_dict(self):
"encode_kwargs": None,
"config_kwargs": None,
"precision": "float32",
"backend": "torch",
},
}

Expand Down Expand Up @@ -127,6 +128,7 @@ def test_to_dict_with_custom_init_parameters(self):
"config_kwargs": {"use_memory_efficient_attention": True},
"precision": "int8",
"encode_kwargs": {"task": "clustering"},
"backend": "torch",
},
}

Expand Down Expand Up @@ -252,6 +254,7 @@ def test_warmup(self, mocked_factory):
model_kwargs=None,
tokenizer_kwargs={"model_max_length": 512},
config_kwargs={"use_memory_efficient_attention": True},
backend="torch",
)

@patch(
Expand Down Expand Up @@ -357,3 +360,63 @@ def test_prefix_suffix(self):
normalize_embeddings=False,
precision="float32",
)

@pytest.mark.integration
def test_model_onnx_quantization(self, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
documents = [Document(content="document number 0"), Document(content="document number 1")]
onnx_embedder = SentenceTransformersDocumentEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
backend="onnx",
model_kwargs={
"file_name": "onnx/model.onnx"
}, # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent a HF warning
)
onnx_embedder.warm_up()

result = onnx_embedder.run(documents=documents)

assert len(result["documents"]) == 2
assert len(result["documents"][0].embedding) == 384
assert len(result["documents"][1].embedding) == 384
assert result["documents"][0].embedding[0] == pytest.approx(0.0, abs=0.1)

@pytest.mark.integration
def test_model_openvino_quantization(self, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
documents = [Document(content="document number 0"), Document(content="document number 1")]
openvino_embedder = SentenceTransformersDocumentEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
backend="openvino",
model_kwargs={
"file_name": "openvino/openvino_model.xml"
}, # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is to prevent a HF warning
)
openvino_embedder.warm_up()

result = openvino_embedder.run(documents=documents)

assert len(result["documents"]) == 2
assert len(result["documents"][0].embedding) == 384
assert len(result["documents"][1].embedding) == 384
assert result["documents"][0].embedding[0] == pytest.approx(0.0, abs=0.1)

@pytest.mark.skip(reason="Test env doesn't compile Torch with CUDA support")
@pytest.mark.integration
@pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "bfloat16"}, {"torch_dtype": "float16"}])
Comment on lines +404 to +406
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for writing this, but we don't plan on adding GPU support to our CI any time soon so it'd be best if we could remove this.

def test_dtype_on_gpu(self, model_kwargs, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
documents = [Document(content="document number 0"), Document(content="document number 1")]
torch_dtype_embedder = SentenceTransformersDocumentEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
device=ComponentDevice.from_str("cuda:0"),
model_kwargs=model_kwargs,
)
torch_dtype_embedder.warm_up()

result = torch_dtype_embedder.run(documents=documents)

assert len(result["documents"]) == 2
assert len(result["documents"][0].embedding) == 384
assert len(result["documents"][1].embedding) == 384
assert result["documents"][0].embedding[0] == pytest.approx(0.0, abs=0.1)
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_model_initialization(mock_sentence_transformer):
auth_token=Secret.from_token("fake-api-token"),
trust_remote_code=True,
truncate_dim=256,
backend="torch",
)
mock_sentence_transformer.assert_called_once_with(
model_name_or_path="model",
Expand All @@ -43,6 +44,7 @@ def test_model_initialization(mock_sentence_transformer):
model_kwargs=None,
tokenizer_kwargs=None,
config_kwargs=None,
backend="torch",
)


Expand Down
Loading
Loading