From ee34da49f33d4c050add90f022303fe3b60f4cf6 Mon Sep 17 00:00:00 2001 From: hppRC Date: Fri, 6 Dec 2024 22:14:23 +0900 Subject: [PATCH 1/6] turn prompt_length from int into tensor --- src/jmteb/embedders/data_parallel_sbert_embedder.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index 6fb7e87..afa9b7a 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -97,6 +97,12 @@ def encode( features = self.sbert.tokenize(sentences_batch) features.update(extra_features) + # `.gather()` in `.forward()` does not support int type, so make it a type that can gather + if "prompt_length" in features and isinstance(features["prompt_length"], int): + batch_size = len(sentences_batch) + prompt_length = torch.Tensor([features["prompt_length"] for _ in range(batch_size)]) + features["prompt_length"] = prompt_length + with torch.no_grad(): out_features = self.forward(features) From e4149e936f0c09ec6e97572f9741037b539172ab Mon Sep 17 00:00:00 2001 From: hppRC Date: Mon, 9 Dec 2024 17:10:06 +0900 Subject: [PATCH 2/6] fix: made DPSbertEmbedder work with prompts. --- .../embedders/data_parallel_sbert_embedder.py | 14 +++++++++++++- tests/embedders/test_dp_sbert.py | 5 +++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index afa9b7a..4cde996 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -7,7 +7,7 @@ import torch from accelerate.utils import find_executable_batch_size from loguru import logger -from sentence_transformers import SentenceTransformer +from sentence_transformers import SentenceTransformer, models from sentence_transformers.quantization import quantize_embeddings from sentence_transformers.util import truncate_embeddings from torch import Tensor @@ -88,6 +88,11 @@ def encode( if "input_ids" in tokenized_prompt: extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1 + # When `include_prompt` is False in Pooling, prompt_length is unnecessary and should be removed. + # This prevents problems arising from DataParallel + if self.include_prompt_for_pooling(): + _ = extra_features.pop("prompt_length") + all_embeddings = [] length_sorted_idx = np.argsort([-self.sbert._text_length(sen) for sen in sentences]) sentences_sorted = [sentences[idx] for idx in length_sorted_idx] @@ -98,6 +103,7 @@ def encode( features.update(extra_features) # `.gather()` in `.forward()` does not support int type, so make it a type that can gather + # we cast it from an int type to a torch.Tensor type if "prompt_length" in features and isinstance(features["prompt_length"], int): batch_size = len(sentences_batch) prompt_length = torch.Tensor([features["prompt_length"] for _ in range(batch_size)]) @@ -162,6 +168,12 @@ def encode( return all_embeddings + def include_prompt_for_pooling(self) -> bool: + for module in self: + if isinstance(module, models.Pooling): + return module.include_prompt + return True + class DataParallelSentenceBertEmbedder(TextEmbedder): """SentenceBERT embedder with pytorch data parallel""" diff --git a/tests/embedders/test_dp_sbert.py b/tests/embedders/test_dp_sbert.py index 028e240..1d3e67b 100644 --- a/tests/embedders/test_dp_sbert.py +++ b/tests/embedders/test_dp_sbert.py @@ -18,6 +18,11 @@ def test_encode(self): assert isinstance(embeddings, np.ndarray) assert embeddings.shape == (OUTPUT_DIM,) + def test_encode_with_prompt(self): + embeddings = self.model.encode("任意のテキスト", prompt="プロンプト") + assert isinstance(embeddings, np.ndarray) + assert embeddings.shape == (OUTPUT_DIM,) + def test_get_output_dim(self): assert self.model.get_output_dim() == OUTPUT_DIM From 672e8ba40e1687643e79b52f0cd625afa116e62c Mon Sep 17 00:00:00 2001 From: hppRC Date: Mon, 9 Dec 2024 17:48:51 +0900 Subject: [PATCH 3/6] fix: refactor --- src/jmteb/embedders/data_parallel_sbert_embedder.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index 4cde996..d1b0b11 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -85,13 +85,11 @@ def encode( # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling # Tracking the prompt length allow us to remove the prompt during pooling tokenized_prompt = self.sbert.tokenize([prompt]) - if "input_ids" in tokenized_prompt: - extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1 - # When `include_prompt` is False in Pooling, prompt_length is unnecessary and should be removed. - # This prevents problems arising from DataParallel - if self.include_prompt_for_pooling(): - _ = extra_features.pop("prompt_length") + # When `include_prompt` is True in Pooling, prompt_length is unnecessary and should be removed. + # This prevents problems arising from DataParallel + if "input_ids" in tokenized_prompt and not self.include_prompt_for_pooling(): + extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1 all_embeddings = [] length_sorted_idx = np.argsort([-self.sbert._text_length(sen) for sen in sentences]) From 6428c196669b7161066557b6b4232014fdf2a848 Mon Sep 17 00:00:00 2001 From: hppRC Date: Mon, 9 Dec 2024 17:53:33 +0900 Subject: [PATCH 4/6] fix: argument name --- tests/embedders/test_dp_sbert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/embedders/test_dp_sbert.py b/tests/embedders/test_dp_sbert.py index 1d3e67b..e1cd4d2 100644 --- a/tests/embedders/test_dp_sbert.py +++ b/tests/embedders/test_dp_sbert.py @@ -19,7 +19,7 @@ def test_encode(self): assert embeddings.shape == (OUTPUT_DIM,) def test_encode_with_prompt(self): - embeddings = self.model.encode("任意のテキスト", prompt="プロンプト") + embeddings = self.model.encode("任意のテキスト", prefix="プロンプト") assert isinstance(embeddings, np.ndarray) assert embeddings.shape == (OUTPUT_DIM,) From f177e60436fc7edb8739086a3102d837462d7d7b Mon Sep 17 00:00:00 2001 From: hppRC Date: Mon, 9 Dec 2024 18:48:35 +0900 Subject: [PATCH 5/6] :+1: changed the import source of `models` and added comments. --- src/jmteb/embedders/data_parallel_sbert_embedder.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index d1b0b11..ba5864f 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -7,7 +7,8 @@ import torch from accelerate.utils import find_executable_batch_size from loguru import logger -from sentence_transformers import SentenceTransformer, models +from sentence_transformers import SentenceTransformer +from sentence_transformers.models import Pooling from sentence_transformers.quantization import quantize_embeddings from sentence_transformers.util import truncate_embeddings from torch import Tensor @@ -166,9 +167,11 @@ def encode( return all_embeddings + # Sentence Transformersの`include_prompt`判定メソッドを参考に実装 + # ref: https://github.com/UKPLab/sentence-transformers/blob/679ab5d38e4cf9cd73d4dcf1cda25ba2ef1ad837/sentence_transformers/trainer.py#L931 def include_prompt_for_pooling(self) -> bool: for module in self: - if isinstance(module, models.Pooling): + if isinstance(module, Pooling): return module.include_prompt return True From 1d3d846cb34b57761bf6642443b4eb9c9e6e7197 Mon Sep 17 00:00:00 2001 From: hppRC Date: Mon, 9 Dec 2024 18:51:48 +0900 Subject: [PATCH 6/6] :+1: supress a lint error --- src/jmteb/embedders/data_parallel_sbert_embedder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index ba5864f..7416fe4 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -168,7 +168,7 @@ def encode( return all_embeddings # Sentence Transformersの`include_prompt`判定メソッドを参考に実装 - # ref: https://github.com/UKPLab/sentence-transformers/blob/679ab5d38e4cf9cd73d4dcf1cda25ba2ef1ad837/sentence_transformers/trainer.py#L931 + # ref: https://github.com/UKPLab/sentence-transformers/blob/679ab5d38e4cf9cd73d4dcf1cda25ba2ef1ad837/sentence_transformers/trainer.py#L931 # noqa: E501 def include_prompt_for_pooling(self) -> bool: for module in self: if isinstance(module, Pooling):