Skip to content

Commit

Permalink
test: added tests for encode_kwargs in SentenceTransformersTextEmbedd…
Browse files Browse the repository at this point in the history
…er and SentenceTransformersDocumentEmbedder
  • Loading branch information
oroszgy committed Feb 5, 2025
1 parent c07ecbb commit 088e43b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import random
from unittest.mock import MagicMock, patch

import random
import pytest
import torch

Expand Down Expand Up @@ -319,6 +319,20 @@ def test_embed_metadata(self):
precision="float32",
)

def test_embed_encode_kwargs(self):
embedder = SentenceTransformersDocumentEmbedder(model="model", encode_kwargs={"task": "retrieval.passage"})
embedder.embedding_backend = MagicMock()
documents = [Document(content=f"document number {i}") for i in range(5)]
embedder.run(documents=documents)
embedder.embedding_backend.embed.assert_called_once_with(
["document number 0", "document number 1", "document number 2", "document number 3", "document number 4"],
batch_size=32,
show_progress_bar=True,
normalize_embeddings=False,
precision="float32",
task="retrieval.passage",
)

def test_prefix_suffix(self):
embedder = SentenceTransformersDocumentEmbedder(
model="model",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import random
from unittest.mock import MagicMock, patch

import torch
import random
import pytest
import torch

from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
from haystack.utils import ComponentDevice, Secret
Expand Down Expand Up @@ -300,3 +300,17 @@ def test_run_quantization(self):

assert len(embedding_def) == 768
assert all(isinstance(el, int) for el in embedding_def)

def test_embed_encode_kwargs(self):
embedder = SentenceTransformersTextEmbedder(model="model", encode_kwargs={"task": "retrieval.query"})
embedder.embedding_backend = MagicMock()
text = "a nice text to embed"
embedder.run(text=text)
embedder.embedding_backend.embed.assert_called_once_with(
[text],
batch_size=32,
show_progress_bar=True,
normalize_embeddings=False,
precision="float32",
task="retrieval.query",
)

0 comments on commit 088e43b

Please sign in to comment.