diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 1e595e155..6c5cbadea 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -1,5 +1,5 @@ from torch import nn -from transformers import AutoModel, AutoTokenizer, AutoConfig, T5Config, MT5Config +from transformers import AutoModel, AutoTokenizer, AutoConfig, T5Config, MT5Config, M2M100Config import json from typing import List, Dict, Optional, Union, Tuple import os @@ -61,6 +61,8 @@ def _load_model(self, model_name_or_path, config, cache_dir, **model_args): self._load_t5_model(model_name_or_path, config, cache_dir, **model_args) elif isinstance(config, MT5Config): self._load_mt5_model(model_name_or_path, config, cache_dir, **model_args) + elif isinstance(config, M2M100Config): + self._load_m2m100_model(model_name_or_path, config, cache_dir, **model_args) else: self.auto_model = AutoModel.from_pretrained( model_name_or_path, config=config, cache_dir=cache_dir, **model_args @@ -75,6 +77,15 @@ def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args): model_name_or_path, config=config, cache_dir=cache_dir, **model_args ) + def _load_m2m100_model(self, model_name_or_path, config, cache_dir, **model_args): + """Loads the encoder model from M2M100 (aka NLLB, aka SONAR text encoder)""" + from transformers import M2M100EncoderModel + + M2M100EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"] + self.auto_model = M2M100EncoderModel.from_pretrained( + model_name_or_path, config=config, cache_dir=cache_dir, **model_args + ) + def _load_mt5_model(self, model_name_or_path, config, cache_dir, **model_args): """Loads the encoder model from T5""" from transformers import MT5EncoderModel