From a90a2497aced5e07ee48b0f22d8b0cb89c9d259c Mon Sep 17 00:00:00 2001 From: Ulises M Date: Wed, 5 Feb 2025 14:48:32 -0800 Subject: [PATCH] expose backend for ST diversity backend --- .../backends/sentence_transformers_backend.py | 9 ++------- .../rankers/sentence_transformers_diversity.py | 5 ++++- .../test_sentence_transformers_document_embedder.py | 6 ++++-- .../test_sentence_transformers_text_embedder.py | 6 ++++-- .../rankers/test_sentence_transformers_diversity.py | 13 ++++++++++--- 5 files changed, 24 insertions(+), 15 deletions(-) diff --git a/haystack/components/embedders/backends/sentence_transformers_backend.py b/haystack/components/embedders/backends/sentence_transformers_backend.py index 252196e2f7..78a3c4806c 100644 --- a/haystack/components/embedders/backends/sentence_transformers_backend.py +++ b/haystack/components/embedders/backends/sentence_transformers_backend.py @@ -28,7 +28,7 @@ def get_embedding_backend( # pylint: disable=too-many-positional-arguments model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, config_kwargs: Optional[Dict[str, Any]] = None, - backend: Optional[Literal["torch", "onnx", "openvino"]] = "torch", + backend: Literal["torch", "onnx", "openvino"] = "torch", ): embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}{backend}" @@ -64,15 +64,10 @@ def __init__( # pylint: disable=too-many-positional-arguments model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, config_kwargs: Optional[Dict[str, Any]] = None, - backend: Optional[Literal["torch", "onnx", "openvino"]] = "torch", + backend: Literal["torch", "onnx", "openvino"] = "torch", ): sentence_transformers_import.check() - # this line is necessary to avoid an arg-type error for the type checker - # or we can make backend not optional - if backend is None: - backend = "torch" - self.model = SentenceTransformer( model_name_or_path=model, device=device, diff --git a/haystack/components/rankers/sentence_transformers_diversity.py b/haystack/components/rankers/sentence_transformers_diversity.py index 88fdc6aaaf..3d13e74d6a 100644 --- a/haystack/components/rankers/sentence_transformers_diversity.py +++ b/haystack/components/rankers/sentence_transformers_diversity.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.lazy_imports import LazyImport @@ -126,6 +126,7 @@ def __init__( embedding_separator: str = "\n", strategy: Union[str, DiversityRankingStrategy] = "greedy_diversity_order", lambda_threshold: float = 0.5, + backend: Literal["torch", "onnx", "openvino"] = "torch", ): # pylint: disable=too-many-positional-arguments """ Initialize a SentenceTransformersDiversityRanker. @@ -172,6 +173,7 @@ def __init__( self.strategy = DiversityRankingStrategy.from_str(strategy) if isinstance(strategy, str) else strategy self.lambda_threshold = lambda_threshold or 0.5 self._check_lambda_threshold(self.lambda_threshold, self.strategy) + self.backend = backend def warm_up(self): """ @@ -182,6 +184,7 @@ def warm_up(self): model_name_or_path=self.model_name_or_path, device=self.device.to_torch_str(), use_auth_token=self.token.resolve_value() if self.token else None, + backend=self.backend, ) def to_dict(self) -> Dict[str, Any]: diff --git a/test/components/embedders/test_sentence_transformers_document_embedder.py b/test/components/embedders/test_sentence_transformers_document_embedder.py index ccdd62df5c..9e1d77a9fc 100644 --- a/test/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/components/embedders/test_sentence_transformers_document_embedder.py @@ -362,7 +362,8 @@ def test_prefix_suffix(self): ) @pytest.mark.integration - def test_model_onnx_quantization(self): + def test_model_onnx_quantization(self, monkeypatch): + monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811 documents = [Document(content="document number 0"), Document(content="document number 1")] onnx_embedder = SentenceTransformersDocumentEmbedder( model="sentence-transformers/all-MiniLM-L6-v2", @@ -383,7 +384,8 @@ def test_model_onnx_quantization(self): @pytest.mark.skip( reason="OpenVINO backend does not support our current transformers + sentence transformers dependencies versions" ) - def test_model_openvino_quantization(self): + def test_model_openvino_quantization(self, monkeypatch): + monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811 documents = [Document(content="document number 0"), Document(content="document number 1")] openvino_embedder = SentenceTransformersDocumentEmbedder( model="sentence-transformers/all-MiniLM-L6-v2", diff --git a/test/components/embedders/test_sentence_transformers_text_embedder.py b/test/components/embedders/test_sentence_transformers_text_embedder.py index 790a271c3f..2df029c0f3 100644 --- a/test/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/components/embedders/test_sentence_transformers_text_embedder.py @@ -319,7 +319,8 @@ def test_embed_encode_kwargs(self): ) @pytest.mark.integration - def test_model_onnx_quantization(self): + def test_model_onnx_quantization(self, monkeypatch): + monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811 text = "a nice text to embed" onnx_embedder_def = SentenceTransformersTextEmbedder( @@ -340,7 +341,8 @@ def test_model_onnx_quantization(self): @pytest.mark.skip( reason="OpenVINO backend does not support our current transformers + sentence transformers dependencies versions" ) - def test_model_openvino_quantization(self): + def test_model_openvino_quantization(self, monkeypatch): + monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811 text = "a nice text to embed" openvino_embedder_def = SentenceTransformersTextEmbedder( diff --git a/test/components/rankers/test_sentence_transformers_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py index a4dcd7e89f..1bc8a6f83d 100644 --- a/test/components/rankers/test_sentence_transformers_diversity.py +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -291,6 +291,7 @@ def test_warm_up(self, similarity, monkeypatch): model_name_or_path="mock_model_name", device=ComponentDevice.resolve_device(None).to_torch_str(), use_auth_token=None, + backend="torch", ) assert ranker.model == mock_model_instance @@ -674,8 +675,11 @@ def test_run_real_world_use_case(self, similarity, monkeypatch): assert result_content == expected_content @pytest.mark.integration - @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) - def test_run_with_maximum_margin_relevance_strategy(self, similarity, monkeypatch): + @pytest.mark.parametrize( + "similarity,backend", + [("dot_product", "torch"), ("dot_product", "onnx"), ("cosine", "torch"), ("cosine", "onnx")], + ) # we don't use "openvino" due to dependency issues + def test_run_with_maximum_margin_relevance_strategy(self, similarity, backend, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811 query = "renewable energy sources" docs = [ @@ -690,7 +694,10 @@ def test_run_with_maximum_margin_relevance_strategy(self, similarity, monkeypatc ] ranker = SentenceTransformersDiversityRanker( - model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, strategy="maximum_margin_relevance" + model="sentence-transformers/all-MiniLM-L6-v2", + similarity=similarity, + strategy="maximum_margin_relevance", + backend=backend, ) ranker.warm_up()