Skip to content

Commit

Permalink
Merge pull request #87 from sbintuitions/fix/handle_prompt_length_for…
Browse files Browse the repository at this point in the history
…_multi_gpus

[Fix] handle prompt length for multi-GPU
  • Loading branch information
hppRC authored Dec 9, 2024
2 parents 0ded6ad + 1d3d846 commit 06dbef6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
21 changes: 20 additions & 1 deletion src/jmteb/embedders/data_parallel_sbert_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from accelerate.utils import find_executable_batch_size
from loguru import logger
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Pooling
from sentence_transformers.quantization import quantize_embeddings
from sentence_transformers.util import truncate_embeddings
from torch import Tensor
Expand Down Expand Up @@ -85,7 +86,10 @@ def encode(
# Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
# Tracking the prompt length allow us to remove the prompt during pooling
tokenized_prompt = self.sbert.tokenize([prompt])
if "input_ids" in tokenized_prompt:

# When `include_prompt` is True in Pooling, prompt_length is unnecessary and should be removed.
# This prevents problems arising from DataParallel
if "input_ids" in tokenized_prompt and not self.include_prompt_for_pooling():
extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1

all_embeddings = []
Expand All @@ -97,6 +101,13 @@ def encode(
features = self.sbert.tokenize(sentences_batch)
features.update(extra_features)

# `.gather()` in `.forward()` does not support int type, so make it a type that can gather
# we cast it from an int type to a torch.Tensor type
if "prompt_length" in features and isinstance(features["prompt_length"], int):
batch_size = len(sentences_batch)
prompt_length = torch.Tensor([features["prompt_length"] for _ in range(batch_size)])
features["prompt_length"] = prompt_length

with torch.no_grad():
out_features = self.forward(features)

Expand Down Expand Up @@ -156,6 +167,14 @@ def encode(

return all_embeddings

# Sentence Transformersの`include_prompt`判定メソッドを参考に実装
# ref: https://github.com/UKPLab/sentence-transformers/blob/679ab5d38e4cf9cd73d4dcf1cda25ba2ef1ad837/sentence_transformers/trainer.py#L931 # noqa: E501
def include_prompt_for_pooling(self) -> bool:
for module in self:
if isinstance(module, Pooling):
return module.include_prompt
return True


class DataParallelSentenceBertEmbedder(TextEmbedder):
"""SentenceBERT embedder with pytorch data parallel"""
Expand Down
5 changes: 5 additions & 0 deletions tests/embedders/test_dp_sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ def test_encode(self):
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (OUTPUT_DIM,)

def test_encode_with_prompt(self):
embeddings = self.model.encode("任意のテキスト", prefix="プロンプト")
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (OUTPUT_DIM,)

def test_get_output_dim(self):
assert self.model.get_output_dim() == OUTPUT_DIM

Expand Down

0 comments on commit 06dbef6

Please sign in to comment.