Skip to content

Commit

Permalink
fix: refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
hppRC committed Dec 9, 2024
1 parent 97f5502 commit 672e8ba
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/jmteb/embedders/data_parallel_sbert_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,11 @@ def encode(
# Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
# Tracking the prompt length allow us to remove the prompt during pooling
tokenized_prompt = self.sbert.tokenize([prompt])
if "input_ids" in tokenized_prompt:
extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1

# When `include_prompt` is False in Pooling, prompt_length is unnecessary and should be removed.
# This prevents problems arising from DataParallel
if self.include_prompt_for_pooling():
_ = extra_features.pop("prompt_length")
# When `include_prompt` is True in Pooling, prompt_length is unnecessary and should be removed.
# This prevents problems arising from DataParallel
if "input_ids" in tokenized_prompt and not self.include_prompt_for_pooling():
extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1

all_embeddings = []
length_sorted_idx = np.argsort([-self.sbert._text_length(sen) for sen in sentences])
Expand Down

0 comments on commit 672e8ba

Please sign in to comment.