diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index 6fb7e87..7416fe4 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -8,6 +8,7 @@ from accelerate.utils import find_executable_batch_size from loguru import logger from sentence_transformers import SentenceTransformer +from sentence_transformers.models import Pooling from sentence_transformers.quantization import quantize_embeddings from sentence_transformers.util import truncate_embeddings from torch import Tensor @@ -85,7 +86,10 @@ def encode( # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling # Tracking the prompt length allow us to remove the prompt during pooling tokenized_prompt = self.sbert.tokenize([prompt]) - if "input_ids" in tokenized_prompt: + + # When `include_prompt` is True in Pooling, prompt_length is unnecessary and should be removed. + # This prevents problems arising from DataParallel + if "input_ids" in tokenized_prompt and not self.include_prompt_for_pooling(): extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1 all_embeddings = [] @@ -97,6 +101,13 @@ def encode( features = self.sbert.tokenize(sentences_batch) features.update(extra_features) + # `.gather()` in `.forward()` does not support int type, so make it a type that can gather + # we cast it from an int type to a torch.Tensor type + if "prompt_length" in features and isinstance(features["prompt_length"], int): + batch_size = len(sentences_batch) + prompt_length = torch.Tensor([features["prompt_length"] for _ in range(batch_size)]) + features["prompt_length"] = prompt_length + with torch.no_grad(): out_features = self.forward(features) @@ -156,6 +167,14 @@ def encode( return all_embeddings + # Sentence Transformersの`include_prompt`判定メソッドを参考に実装 + # ref: https://github.com/UKPLab/sentence-transformers/blob/679ab5d38e4cf9cd73d4dcf1cda25ba2ef1ad837/sentence_transformers/trainer.py#L931 # noqa: E501 + def include_prompt_for_pooling(self) -> bool: + for module in self: + if isinstance(module, Pooling): + return module.include_prompt + return True + class DataParallelSentenceBertEmbedder(TextEmbedder): """SentenceBERT embedder with pytorch data parallel""" diff --git a/tests/embedders/test_dp_sbert.py b/tests/embedders/test_dp_sbert.py index 028e240..e1cd4d2 100644 --- a/tests/embedders/test_dp_sbert.py +++ b/tests/embedders/test_dp_sbert.py @@ -18,6 +18,11 @@ def test_encode(self): assert isinstance(embeddings, np.ndarray) assert embeddings.shape == (OUTPUT_DIM,) + def test_encode_with_prompt(self): + embeddings = self.model.encode("任意のテキスト", prefix="プロンプト") + assert isinstance(embeddings, np.ndarray) + assert embeddings.shape == (OUTPUT_DIM,) + def test_get_output_dim(self): assert self.model.get_output_dim() == OUTPUT_DIM