Skip to content

Commit

Permalink
improve robustness
Browse files Browse the repository at this point in the history
  • Loading branch information
lsz05 committed Sep 9, 2024
1 parent b30acc5 commit 5db9bed
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/jmteb/embedders/transformers_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
pooling_mode: str | None = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
word_embedding_dimension: int | None = None,
encode_method_name: str | None = None,
encode_method_text_argument: str = "text",
encode_method_prefix_argument: str = "prefix",
Expand Down Expand Up @@ -76,6 +77,10 @@ def __init__(

self.add_eos = add_eos
self.truncate_dim = truncate_dim
if word_embedding_dimension:
self.word_embedding_dimension = word_embedding_dimension
else:
self.word_embedding_dimension = getattr(self.model.config, "hidden_size")

if pooling_mode:
pooling_config: dict = {
Expand All @@ -86,7 +91,7 @@ def __init__(
pooling_config: dict = self._load_pooling_config(os.path.join(model_name_or_path, pooling_config))

self.pooling = Pooling(
word_embedding_dimension=pooling_config.get("word_embedding_dimension"),
word_embedding_dimension=self.word_embedding_dimension,
pooling_mode=pooling_config.get("pooling_mode", None),
pooling_mode_cls_token=pooling_config.get("pooling_mode_cls_token", False),
pooling_mode_max_tokens=pooling_config.get("pooling_mode_max_tokens", False),
Expand Down Expand Up @@ -184,6 +189,7 @@ def _encode(self, text: str | list[str], prefix: str | None = None) -> torch.Ten

if self.encode_method_name and hasattr(self.model, self.encode_method_name):
# ensure the built-in encoding method accepts positional arguments for text and prefix
logger.info("Used built-in encoding method")
sentence_embeddings = getattr(self.model, self.encode_method_name)(
**{self.encode_method_text_argument: text, self.encode_method_prefix_argument: prefix}
)
Expand Down Expand Up @@ -275,7 +281,7 @@ def tokenize(
to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize]

# Lowercase
if self.tokenizer.do_lower_case:
if getattr(self.tokenizer, "do_lower_case", False):
to_tokenize = [[s.lower() for s in col] for col in to_tokenize]

output.update(
Expand Down

0 comments on commit 5db9bed

Please sign in to comment.