Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] handle prompt length for multi-GPU #87

Merged
merged 10 commits into from
Dec 9, 2024
20 changes: 18 additions & 2 deletions src/jmteb/embedders/data_parallel_sbert_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from accelerate.utils import find_executable_batch_size
from loguru import logger
from sentence_transformers import SentenceTransformer
from sentence_transformers import SentenceTransformer, models
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: modelsの中のclassは一個しか使われなかったので,from sentence_transformers.models import Poolingと書くのがどうでしょう(self.modelと混同するのを防ぐため)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

確かにそちらの方が良さそうですね、ありがとうございます、変更します!(modelsはちょっと一般的すぎるモジュール名で嫌ですよね)

from sentence_transformers.quantization import quantize_embeddings
from sentence_transformers.util import truncate_embeddings
from torch import Tensor
Expand Down Expand Up @@ -85,7 +85,10 @@ def encode(
# Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
# Tracking the prompt length allow us to remove the prompt during pooling
tokenized_prompt = self.sbert.tokenize([prompt])
if "input_ids" in tokenized_prompt:

# When `include_prompt` is True in Pooling, prompt_length is unnecessary and should be removed.
# This prevents problems arising from DataParallel
if "input_ids" in tokenized_prompt and not self.include_prompt_for_pooling():
extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1

all_embeddings = []
Expand All @@ -97,6 +100,13 @@ def encode(
features = self.sbert.tokenize(sentences_batch)
features.update(extra_features)

# `.gather()` in `.forward()` does not support int type, so make it a type that can gather
# we cast it from an int type to a torch.Tensor type
if "prompt_length" in features and isinstance(features["prompt_length"], int):
batch_size = len(sentences_batch)
prompt_length = torch.Tensor([features["prompt_length"] for _ in range(batch_size)])
features["prompt_length"] = prompt_length

with torch.no_grad():
out_features = self.forward(features)

Expand Down Expand Up @@ -156,6 +166,12 @@ def encode(

return all_embeddings

def include_prompt_for_pooling(self) -> bool:
for module in self:
if isinstance(module, models.Pooling):
return module.include_prompt
return True


class DataParallelSentenceBertEmbedder(TextEmbedder):
"""SentenceBERT embedder with pytorch data parallel"""
Expand Down
5 changes: 5 additions & 0 deletions tests/embedders/test_dp_sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ def test_encode(self):
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (OUTPUT_DIM,)

def test_encode_with_prompt(self):
embeddings = self.model.encode("任意のテキスト", prefix="プロンプト")
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (OUTPUT_DIM,)

def test_get_output_dim(self):
assert self.model.get_output_dim() == OUTPUT_DIM

Expand Down
Loading