Skip to content

Commit

Permalink
expose backend for ST diversity backend
Browse files Browse the repository at this point in the history
  • Loading branch information
lbux committed Feb 5, 2025
1 parent a0c9531 commit a90a249
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_embedding_backend( # pylint: disable=too-many-positional-arguments
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
backend: Optional[Literal["torch", "onnx", "openvino"]] = "torch",
backend: Literal["torch", "onnx", "openvino"] = "torch",
):
embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}{backend}"

Expand Down Expand Up @@ -64,15 +64,10 @@ def __init__( # pylint: disable=too-many-positional-arguments
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
backend: Optional[Literal["torch", "onnx", "openvino"]] = "torch",
backend: Literal["torch", "onnx", "openvino"] = "torch",
):
sentence_transformers_import.check()

# this line is necessary to avoid an arg-type error for the type checker
# or we can make backend not optional
if backend is None:
backend = "torch"

self.model = SentenceTransformer(
model_name_or_path=model,
device=device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport
Expand Down Expand Up @@ -126,6 +126,7 @@ def __init__(
embedding_separator: str = "\n",
strategy: Union[str, DiversityRankingStrategy] = "greedy_diversity_order",
lambda_threshold: float = 0.5,
backend: Literal["torch", "onnx", "openvino"] = "torch",
): # pylint: disable=too-many-positional-arguments
"""
Initialize a SentenceTransformersDiversityRanker.
Expand Down Expand Up @@ -172,6 +173,7 @@ def __init__(
self.strategy = DiversityRankingStrategy.from_str(strategy) if isinstance(strategy, str) else strategy
self.lambda_threshold = lambda_threshold or 0.5
self._check_lambda_threshold(self.lambda_threshold, self.strategy)
self.backend = backend

def warm_up(self):
"""
Expand All @@ -182,6 +184,7 @@ def warm_up(self):
model_name_or_path=self.model_name_or_path,
device=self.device.to_torch_str(),
use_auth_token=self.token.resolve_value() if self.token else None,
backend=self.backend,
)

def to_dict(self) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@ def test_prefix_suffix(self):
)

@pytest.mark.integration
def test_model_onnx_quantization(self):
def test_model_onnx_quantization(self, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
documents = [Document(content="document number 0"), Document(content="document number 1")]
onnx_embedder = SentenceTransformersDocumentEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
Expand All @@ -383,7 +384,8 @@ def test_model_onnx_quantization(self):
@pytest.mark.skip(
reason="OpenVINO backend does not support our current transformers + sentence transformers dependencies versions"
)
def test_model_openvino_quantization(self):
def test_model_openvino_quantization(self, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
documents = [Document(content="document number 0"), Document(content="document number 1")]
openvino_embedder = SentenceTransformersDocumentEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def test_embed_encode_kwargs(self):
)

@pytest.mark.integration
def test_model_onnx_quantization(self):
def test_model_onnx_quantization(self, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
text = "a nice text to embed"

onnx_embedder_def = SentenceTransformersTextEmbedder(
Expand All @@ -340,7 +341,8 @@ def test_model_onnx_quantization(self):
@pytest.mark.skip(
reason="OpenVINO backend does not support our current transformers + sentence transformers dependencies versions"
)
def test_model_openvino_quantization(self):
def test_model_openvino_quantization(self, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
text = "a nice text to embed"

openvino_embedder_def = SentenceTransformersTextEmbedder(
Expand Down
13 changes: 10 additions & 3 deletions test/components/rankers/test_sentence_transformers_diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def test_warm_up(self, similarity, monkeypatch):
model_name_or_path="mock_model_name",
device=ComponentDevice.resolve_device(None).to_torch_str(),
use_auth_token=None,
backend="torch",
)
assert ranker.model == mock_model_instance

Expand Down Expand Up @@ -674,8 +675,11 @@ def test_run_real_world_use_case(self, similarity, monkeypatch):
assert result_content == expected_content

@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_with_maximum_margin_relevance_strategy(self, similarity, monkeypatch):
@pytest.mark.parametrize(
"similarity,backend",
[("dot_product", "torch"), ("dot_product", "onnx"), ("cosine", "torch"), ("cosine", "onnx")],
) # we don't use "openvino" due to dependency issues
def test_run_with_maximum_margin_relevance_strategy(self, similarity, backend, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
query = "renewable energy sources"
docs = [
Expand All @@ -690,7 +694,10 @@ def test_run_with_maximum_margin_relevance_strategy(self, similarity, monkeypatc
]

ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, strategy="maximum_margin_relevance"
model="sentence-transformers/all-MiniLM-L6-v2",
similarity=similarity,
strategy="maximum_margin_relevance",
backend=backend,
)
ranker.warm_up()

Expand Down

0 comments on commit a90a249

Please sign in to comment.