From 088e43b2d86fae4367749a45063cdb97ef69b80b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gy=C3=B6rgy=20Orosz?= Date: Wed, 5 Feb 2025 16:35:54 +0100 Subject: [PATCH] test: added tests for encode_kwargs in SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder --- ..._sentence_transformers_document_embedder.py | 16 +++++++++++++++- ...test_sentence_transformers_text_embedder.py | 18 ++++++++++++++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/test/components/embedders/test_sentence_transformers_document_embedder.py b/test/components/embedders/test_sentence_transformers_document_embedder.py index b15b74c519..205feac10f 100644 --- a/test/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/components/embedders/test_sentence_transformers_document_embedder.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +import random from unittest.mock import MagicMock, patch -import random import pytest import torch @@ -319,6 +319,20 @@ def test_embed_metadata(self): precision="float32", ) + def test_embed_encode_kwargs(self): + embedder = SentenceTransformersDocumentEmbedder(model="model", encode_kwargs={"task": "retrieval.passage"}) + embedder.embedding_backend = MagicMock() + documents = [Document(content=f"document number {i}") for i in range(5)] + embedder.run(documents=documents) + embedder.embedding_backend.embed.assert_called_once_with( + ["document number 0", "document number 1", "document number 2", "document number 3", "document number 4"], + batch_size=32, + show_progress_bar=True, + normalize_embeddings=False, + precision="float32", + task="retrieval.passage", + ) + def test_prefix_suffix(self): embedder = SentenceTransformersDocumentEmbedder( model="model", diff --git a/test/components/embedders/test_sentence_transformers_text_embedder.py b/test/components/embedders/test_sentence_transformers_text_embedder.py index 1ae37244db..6d0e239a15 100644 --- a/test/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/components/embedders/test_sentence_transformers_text_embedder.py @@ -1,11 +1,11 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +import random from unittest.mock import MagicMock, patch -import torch -import random import pytest +import torch from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder from haystack.utils import ComponentDevice, Secret @@ -300,3 +300,17 @@ def test_run_quantization(self): assert len(embedding_def) == 768 assert all(isinstance(el, int) for el in embedding_def) + + def test_embed_encode_kwargs(self): + embedder = SentenceTransformersTextEmbedder(model="model", encode_kwargs={"task": "retrieval.query"}) + embedder.embedding_backend = MagicMock() + text = "a nice text to embed" + embedder.run(text=text) + embedder.embedding_backend.embed.assert_called_once_with( + [text], + batch_size=32, + show_progress_bar=True, + normalize_embeddings=False, + precision="float32", + task="retrieval.query", + )