diff --git a/docs/source/en/model_doc/m2m_100.md b/docs/source/en/model_doc/m2m_100.md index d64545fafb0612..5cfad1145d2d9a 100644 --- a/docs/source/en/model_doc/m2m_100.md +++ b/docs/source/en/model_doc/m2m_100.md @@ -180,4 +180,83 @@ model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", t ... ``` -For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). \ No newline at end of file +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + + +# SONAR encoders and decoders + +The SONAR embedding space was introduced in [SONAR: Sentence-Level Multimodal and Language-Agnostic Representations](https://arxiv.org/abs/2308.11466) by Paul-Ambroise Duquenne, Holger Schwenk, and Benoît Sagot. +In the text modality, SONAR allows encoding texts in 202 languages into a shared space, and decoding these embeddings +back to any of the 202 languages. The architecture of the text encoders and decoders is based on [NLLB](../nllb). + +The original implementation resides in the [SONAR](https://github.com/facebookresearch/sonar) repository. + +The Huggingface implementation of SONAR text encoder and decoder, following the NLLB implementation, is based on +the M2M100 architecture. Therefore, SONAR encoder is represented by `M2M100EncoderModel`, and a SONAR decoder +is represented by `M2M100DecoderModel`. + +## Usage tips and examples + +An encoder can be applied to convert text sentences into 1024-dimensional embeddings in the following way: + +```python +from transformers import M2M100EncoderModel, NllbTokenizer + +encoder = M2M100EncoderModel.from_pretrained("cointegrated/SONAR_200_text_encoder_hf") +tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="eng_Latn") + +inputs = tokenizer(sentences, padding=True, return_tensors="pt") +with torch.inference_mode(): + encoder_out = encoder(**inputs, pool_last_hidden_state=True) + embeddings = encoder_out.last_hidden_state.squeeze(1) + +print(embeddings.shape) # [2, 1024] +``` +The argument `pool_last_hidden_state=True` tells the model to aggregate the contextual token embeddings by mean pooling +after they are produced by the encoder. This is the only difference between the architectures of SONAR and NLLB. + +Given these embeddings, a decoder can reconstruct them back to texts. +```python +from transformers import M2M100DecoderModel +from transformers.modeling_outputs import BaseModelOutput + +decoder = M2M100DecoderModel.from_pretrained("cointegrated/SONAR_200_text_decoder_hf") + +# Decoding into the original (English) language +generator_out = decoder.generate( + # passing encoder_outputs is not recommended, because beam search decoding modifies them in place, which is ugly + # encoder_outputs=enc_out, + encoder_outputs=BaseModelOutput(last_hidden_state=embeddings.unsqueeze(1)), + num_beams=5, + forced_bos_token_id=tokenizer.convert_tokens_to_ids("eng_Latn"), +) +text_out = tokenizer.batch_decode(generator_out, skip_special_tokens=True) +print(text_out) # ["My name is SONAR.", "I can embed the sentences into vector space."] + +# Decoding into some other (French) language +generator_out = decoder.generate( + # passing encoder_outputs is not recommended, because beam search decoding modifies them in place, which is ugly + # encoder_outputs=enc_out, + encoder_outputs=BaseModelOutput(last_hidden_state=embeddings.unsqueeze(1)), + num_beams=5, + forced_bos_token_id=tokenizer.convert_tokens_to_ids("fra_Latn"), +) +text_out = tokenizer.batch_decode(generator_out, skip_special_tokens=True) +print(text_out) # ['Mon nom est SONAR.', "Je peux intégrer les phrases dans l'espace vectoriel."] +``` + +The list of BCP-47 codes for the 202 language varieties currently supported by the SONAR text models can be found in +[a model card](https://github.com/facebookresearch/SONAR/blob/main/sonar/cards/text_sonar_basic_decoder.yaml) in the +official SONAR repo. It is the same list of languages that [NLLB](../nllb) supports, and it mostly coincides with the +[FLORES-200 language list](https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200). + + +## M2M100EncoderModel + +[[autodoc]] M2M100EncoderModel + - forward + +## M2M100DecoderModel + +[[autodoc]] M2M100DecoderModel + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 47b43e0b90896f..76ed91a5b96f7c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2644,6 +2644,8 @@ ) _import_structure["models.m2m_100"].extend( [ + "M2M100DecoderModel", + "M2M100EncoderModel", "M2M100ForConditionalGeneration", "M2M100Model", "M2M100PreTrainedModel", @@ -7309,6 +7311,8 @@ LxmertVisualFeatureEncoder, ) from .models.m2m_100 import ( + M2M100DecoderModel, + M2M100EncoderModel, M2M100ForConditionalGeneration, M2M100Model, M2M100PreTrainedModel, diff --git a/src/transformers/models/m2m_100/__init__.py b/src/transformers/models/m2m_100/__init__.py index 45232f1390a53b..c223816e0f3a26 100644 --- a/src/transformers/models/m2m_100/__init__.py +++ b/src/transformers/models/m2m_100/__init__.py @@ -29,6 +29,8 @@ pass else: _import_structure["modeling_m2m_100"] = [ + "M2M100DecoderModel", + "M2M100EncoderModel", "M2M100ForConditionalGeneration", "M2M100Model", "M2M100PreTrainedModel", @@ -46,6 +48,8 @@ pass else: from .modeling_m2m_100 import ( + M2M100DecoderModel, + M2M100EncoderModel, M2M100ForConditionalGeneration, M2M100Model, M2M100PreTrainedModel, diff --git a/src/transformers/models/m2m_100/convert_m2m100_original_checkpoint_to_pytorch.py b/src/transformers/models/m2m_100/convert_m2m100_original_checkpoint_to_pytorch.py index 97265fbdcf9346..030886b54e2fa8 100644 --- a/src/transformers/models/m2m_100/convert_m2m100_original_checkpoint_to_pytorch.py +++ b/src/transformers/models/m2m_100/convert_m2m100_original_checkpoint_to_pytorch.py @@ -81,5 +81,5 @@ def convert_fairseq_m2m100_checkpoint_from_disk(checkpoint_path): parser.add_argument("fairseq_path", type=str, help="path to a model.pt on local filesystem.") parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") args = parser.parse_args() - model = convert_fairseq_m2m100_checkpoint_from_disk(args.fairseq_pathß) + model = convert_fairseq_m2m100_checkpoint_from_disk(args.fairseq_path) model.save_pretrained(args.pytorch_dump_folder_path) diff --git a/src/transformers/models/m2m_100/convert_sonar_original_checkpoint_to_transformers.py b/src/transformers/models/m2m_100/convert_sonar_original_checkpoint_to_transformers.py new file mode 100644 index 00000000000000..a373c5b961a86c --- /dev/null +++ b/src/transformers/models/m2m_100/convert_sonar_original_checkpoint_to_transformers.py @@ -0,0 +1,221 @@ +# Copyright 2024 The Sonar Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script converts fairseq/fairseq2 SONAR text encoder and decoder checkpoints to transformers. +The reference architectures are given in https://github.com/facebookresearch/SONAR/blob/main/sonar/models/sonar_text/builder.py. +The checkpoints for conversion can be found in: +- encoder: https://github.com/facebookresearch/SONAR/blob/main/sonar/cards/text_sonar_basic_encoder.yaml +- decoder: https://github.com/facebookresearch/SONAR/blob/main/sonar/cards/text_sonar_basic_decoder.yaml +""" + +import argparse + +import torch +from torch import nn + +from transformers import M2M100Config, M2M100DecoderModel, M2M100EncoderModel + + +def make_linear_from_emb(emb): + vocab_size, emb_size = emb.weight.shape + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) + lin_layer.weight.data = emb.weight.data + return lin_layer + + +def get_parameter_renames(model, state_dict, is_decoder=True): + parameter_renames = {} + trf_names = {name for name, value in model.named_parameters()} + fs2_names = set(state_dict.keys()) + + for trf_name in trf_names: + fs2_name = trf_name + if trf_name == "shared.weight": + fs2_name = "decoder_frontend.embed.weight" if is_decoder else "encoder_frontend.embed.weight" + if trf_name == "lm_head.weight": + fs2_name = "final_proj.weight" + + if trf_name.startswith("layers."): + fs2_name = "decoder." + trf_name if is_decoder else "encoder." + trf_name + if trf_name.startswith("layer_norm.") and is_decoder: + fs2_name = "decoder." + trf_name + if trf_name.startswith("encoder.layer_norm.") and not is_decoder: + fs2_name = trf_name.split(".", 1)[1] + + if ".encoder_attn." in fs2_name: + fs2_name = fs2_name.replace(".encoder_attn.", ".encoder_decoder_attn.") + if ".encoder_attn_layer_norm." in fs2_name: + fs2_name = fs2_name.replace(".encoder_attn_layer_norm.", ".encoder_decoder_attn_layer_norm.") + if ".out_proj." in fs2_name: + fs2_name = fs2_name.replace(".out_proj.", ".output_proj.") + if ".fc1." in fs2_name: + fs2_name = fs2_name.replace( + ".fc1.", + ".ffn.inner_proj.", + ) + if ".fc2." in fs2_name: + fs2_name = fs2_name.replace( + ".fc2.", + ".ffn.output_proj.", + ) + if ".final_layer_norm." in fs2_name: + fs2_name = fs2_name.replace( + ".final_layer_norm.", + ".ffn_layer_norm.", + ) + + if fs2_name in fs2_names: + parameter_renames[trf_name] = fs2_name + else: + raise ValueError(f"Parameter {trf_name} could not be mapped from transformers to fairseq2 state dict.") + return parameter_renames + + +def reorder_special_tokens(new_state_dict): + """ + In fairseq2, special tokens are ['', '', '', '']. + In transformers (NLLB) they are ['', '', '', '']. + We want to reuse the NLLB tokenizer, so we reorder the embeddings. + """ + special_token_embs = new_state_dict["shared.weight"].data[[2, 0, 3, 1]].clone() + for param_name in [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + "lm_head.weight", + "shared.weight", + ]: + if param_name in new_state_dict: + new_state_dict[param_name].data[[0, 1, 2, 3]] = special_token_embs + + +def convert_sonar_checkpoint_from_disk(checkpoint_path): + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + state_dict = checkpoint_dict["model"] + + # In Fairseq2 SONAR checkpoints, there are no configs (they are supposed to be in a yaml model card elsewhere). + # Thus, we just assume the "basic" hyperparameters + # see the arhc registry at https://github.com/facebookresearch/SONAR/blob/main/sonar/models/sonar_text/builder.py + + config = M2M100Config( + vocab_size=256206, + max_position_embeddings=1024, + encoder_layers=24, + decoder_layers=24, + encoder_attention_heads=16, + decoder_attention_heads=16, + encoder_ffn_dim=1024 * 8, + decoder_ffn_dim=1024 * 8, + d_model=1024, + encoder_layerdrop=0, + decoder_layerdrop=0, + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.1, + activation_function="relu", + ) + + if any(parameter_name.startswith("encoder.") for parameter_name in state_dict): + is_decoder = False + model = M2M100EncoderModel(config) + elif any(parameter_name.startswith("decoder.") for parameter_name in state_dict): + is_decoder = True + model = M2M100DecoderModel(config) + else: + raise ValueError("The state dict does not seem to contain SONAR encoder or decoder.") + + parameter_renames = get_parameter_renames(model, state_dict, is_decoder) + new_state_dict = {trf_name: state_dict[fs2_name] for trf_name, fs2_name in parameter_renames.items()} + reorder_special_tokens(new_state_dict) + + if is_decoder: + new_state_dict["decoder.embed_tokens.weight"] = new_state_dict["shared.weight"] + else: + new_state_dict["encoder.embed_tokens.weight"] = new_state_dict["shared.weight"] + + model.load_state_dict(new_state_dict, strict=True) + model.tie_weights() + + return model + + +def test_conversion_accuracy(fairseq2_encoder_path, fairseq2_decoder_path): + """ + This test is not directly invoked, because the encoder and decoder paths should be provided explicitly, + and these checkpoints are too heavy to download them by default, just for a test. + Please run the test from your code like below: + ``` + from transformers.models.m2m_100.convert_sonar_original_checkpoint_to_transformers import test_conversion_accuracy + test_conversion_accuracy(PATH_TO_ENCODER, PATH_TO_DECODER) + ``` + The fairseq2 checkpoints can be downloaded from the urls indicated in the following cards: + - https://github.com/facebookresearch/SONAR/blob/main/sonar/cards/text_sonar_basic_encoder.yaml + - https://github.com/facebookresearch/SONAR/blob/main/sonar/cards/text_sonar_basic_decoder.yaml + + The reference embeddings were obtained with: + ``` + from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline + t2vec_model = TextToEmbeddingModelPipeline(encoder="text_sonar_basic_encoder", tokenizer="text_sonar_basic_encoder") + ref_embeddings = t2vec_model.predict(sentences, source_lang="eng_Latn")[:, :5] + ``` + """ + from transformers import NllbTokenizer + from transformers.modeling_outputs import BaseModelOutput + + tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", clean_up_tokenization_spaces=True) + sentences = ["My name is SONAR.", "I can embed the sentences into vectorial space."] + tokenizer.src_lang = "eng_Latn" + batch = tokenizer(sentences, padding=True, return_tensors="pt") + + print("Converting the encoder...") + enc = convert_sonar_checkpoint_from_disk(fairseq2_encoder_path).eval() + assert isinstance(enc, M2M100EncoderModel) + + print("Conversion completed, testing the embedding accuracy...") + with torch.inference_mode(): + enc_out = enc(**batch, pool_last_hidden_state=True) + assert enc_out.last_hidden_state.shape == (2, 1, 1024) + embeddings = enc_out.last_hidden_state.squeeze(1) + + ref_embeddings = torch.tensor( + [[-0.005286, 0.002008, -0.000562, 0.006344, 0.006329], [-0.000330, -0.007055, 0.007644, 0.001841, 0.003727]] + ) + assert torch.allclose(embeddings[:, :5], ref_embeddings, rtol=1e-3) + print("The embedding accuracy test has passed!") + + print("Converting the decoder...") + dec = convert_sonar_checkpoint_from_disk(fairseq2_decoder_path).eval() + assert isinstance(dec, M2M100DecoderModel) + + print("Conversion completed, testing the decoding accuracy...") + gen_out = dec.generate( + # passing encoder_outputs is not recommended, because beam search decoding modifies them in place, which is ugly + # encoder_outputs=enc_out, + encoder_outputs=BaseModelOutput(last_hidden_state=enc_out.last_hidden_state.clone()), + num_beams=5, + forced_bos_token_id=tokenizer.convert_tokens_to_ids("eng_Latn"), + ) + text_out = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + assert text_out == ["My name is SONAR.", "I can embed the sentences into vector space."] + print("The decoding accuracy test has passed!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("fairseq_path", type=str, help="path to a model.pt on local filesystem.") + parser.add_argument("dump_folder_path", default=None, type=str, help="Path to the output transformers model.") + args = parser.parse_args() + model = convert_sonar_checkpoint_from_disk(args.fairseq_path) + model.save_pretrained(args.dump_folder_path) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index cc35a3504255bf..d1f7156b474df0 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -14,6 +14,7 @@ # limitations under the License. """PyTorch M2M100 model.""" +import copy import math from typing import List, Optional, Tuple, Union @@ -1629,3 +1630,264 @@ def _reorder_cache(past_key_values, beam_idx): tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past + + +@add_start_docstrings( + "The M2M100 (SONAR) transformer decoder model for generating texts from embeddings.", + M2M_100_START_DOCSTRING, +) +class M2M100DecoderModel(M2M100PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"] + _keys_to_ignore_on_load_unexpected = [r"encoder"] + + def __init__(self, config: M2M100Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + decoder_config = copy.deepcopy(config) + decoder_config.use_cache = False + decoder_config.is_encoder_decoder = False + self.decoder = M2M100Decoder(decoder_config, self.shared) + self.lm_head = nn.Linear(config.d_model, self.shared.num_embeddings, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.decoder.embed_tokens = self.shared + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + # in SONAR models, input and output projections are tied (ideally, this should be configurable) + self._tie_or_clone_weights(self.lm_head, self.shared) + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(M2M_100_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(M2M_100_GENERATION_EXAMPLE) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + if encoder_outputs is None: + raise ValueError("M2M100DecoderModel expects the `encoder_outputs` to be always present.") + + if return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + # move labels to the correct device to enable PP + labels = labels.to(lm_logits.device) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + +@add_start_docstrings( + "The M2M100 (SONAR) transformer model outputting encoder's raw hidden states. Can be used for embedding sentences.", + M2M_100_START_DOCSTRING, +) +class M2M100EncoderModel(M2M100PreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight"] + _keys_to_ignore_on_load_unexpected = [r"decoder"] + + def __init__(self, config: M2M100Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = M2M100Encoder(encoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + @add_start_docstrings_to_model_forward(M2M_100_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pool_last_hidden_state: bool = False, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, M2M100EncoderModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = T5EncoderModel.from_pretrained("google-t5/t5-small") + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + if pool_last_hidden_state: + last_hidden_state = encoder_outputs.last_hidden_state + if attention_mask is None: + mean_state = last_hidden_state.mean(1, keepdims=True) + else: + input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() + sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1, keepdims=True) + sum_mask = torch.clamp(input_mask_expanded.sum(1, keepdims=True), min=1e-9) + mean_state = sum_embeddings / sum_mask + encoder_outputs.last_hidden_state = mean_state + + return encoder_outputs diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 36e1ff2cfe65c4..40975a2fe487c2 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -5661,6 +5661,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class M2M100DecoderModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class M2M100EncoderModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class M2M100ForConditionalGeneration(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py index 4fe0902c615b14..8a931e080e04fb 100644 --- a/tests/models/m2m_100/test_modeling_m2m_100.py +++ b/tests/models/m2m_100/test_modeling_m2m_100.py @@ -42,7 +42,14 @@ if is_torch_available(): import torch - from transformers import M2M100ForConditionalGeneration, M2M100Model, M2M100Tokenizer + from transformers import ( + M2M100DecoderModel, + M2M100EncoderModel, + M2M100ForConditionalGeneration, + M2M100Model, + M2M100Tokenizer, + ) + from transformers.modeling_outputs import BaseModelOutput from transformers.models.m2m_100.modeling_m2m_100 import M2M100Decoder, M2M100Encoder @@ -472,3 +479,384 @@ def test_flash_attn_2_seq_to_seq_generation(self): hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True ) assert generated == expected_en + + +class M2M100EncoderModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=False, + use_labels=False, + vocab_size=99, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="relu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + max_position_embeddings=20, + eos_token_id=2, + pad_token_id=1, + bos_token_id=0, + use_attention_mask: bool = True, + is_encoder_decoder=False, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.max_position_embeddings = max_position_embeddings + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.use_attention_mask = use_attention_mask + self.is_encoder_decoder = is_encoder_decoder + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + config = M2M100Config( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + decoder_ffn_dim=self.intermediate_size, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + encoder_layerdrop=self.encoder_layerdrop, + decoder_layerdrop=self.decoder_layerdrop, + max_position_embeddings=self.max_position_embeddings, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + pad_token_id=self.pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return ( + config, + input_ids, + attention_mask, + ) + + def create_and_check_model( + self, + config, + input_ids, + attention_mask, + ): + model = M2M100EncoderModel(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids=input_ids, + attention_mask=attention_mask, + ) + result = model(input_ids=input_ids) + encoder_output = result.last_hidden_state + + self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_fp16_forward( + self, + config, + input_ids, + attention_mask, + ): + model = M2M100EncoderModel(config=config).to(torch_device).half().eval() + output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"] + self.parent.assertFalse(torch.isnan(output).any().item()) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +class M2M100EncoderModelModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (M2M100EncoderModel,) if is_torch_available() else () + test_pruning = False + + def setUp(self): + self.model_tester = M2M100EncoderModelTester(self) + self.config_tester = ConfigTester(self, config_class=M2M100Config, d_model=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + + +class M2M100DecoderModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=False, + use_labels=False, + vocab_size=99, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="relu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + max_position_embeddings=20, + eos_token_id=2, + pad_token_id=1, + bos_token_id=0, + use_attention_mask: bool = True, + is_encoder_decoder=True, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.max_position_embeddings = max_position_embeddings + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.use_attention_mask = use_attention_mask + self.is_encoder_decoder = is_encoder_decoder + + # some tests ask for encoder-related parameters (e.g. test_attention_outputs) + self.encoder_seq_length = 1 + + def prepare_config_and_inputs(self): + # encoder_input_ids should usually be ignored, but adding them for compatibility with some tests like test_resize_tokens_embeddings + encoder_input_ids = ids_tensor([self.batch_size, 1], self.vocab_size) + decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + embeddings = torch.randn([self.batch_size, 1, self.hidden_size], device=torch_device) + encoder_outputs = BaseModelOutput(last_hidden_state=embeddings) + + inputs_dict = { + "input_ids": encoder_input_ids, # the input ids are always ignored anyway + "attention_mask": None, # the inputs are pooled embeddins, so they are never masked + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": attention_mask, + "encoder_outputs": encoder_outputs, + } + + config = M2M100Config( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + decoder_ffn_dim=self.intermediate_size, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + encoder_layerdrop=self.encoder_layerdrop, + decoder_layerdrop=self.decoder_layerdrop, + max_position_embeddings=self.max_position_embeddings, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + pad_token_id=self.pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return (config, inputs_dict) + + def create_and_check_model( + self, + config, + **inputs_dict, + ): + model = M2M100DecoderModel(config=config) + model.to(torch_device) + model.eval() + result = model( + **inputs_dict, + ) + logits = result.logits + + self.parent.assertEqual(logits.size(), (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_model_fp16_forward( + self, + config, + encoder_outputs, + **inputs_dict, + ): + model = M2M100DecoderModel(config=config).to(torch_device).half().eval() + encoder_outputs_half = BaseModelOutput(last_hidden_state=encoder_outputs.last_hidden_state.half()) + output = model(encoder_outputs=encoder_outputs_half, **inputs_dict) + self.parent.assertFalse(torch.isnan(output["logits"]).any().item()) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, inputs_dict = config_and_inputs + return config, inputs_dict + + +class M2M100DecoderModelTest(ModelTesterMixin, unittest.TestCase): # TODO: add GenerationTesterMixin + all_model_classes = (M2M100DecoderModel,) if is_torch_available() else () + test_pruning = False + test_head_masking = ( + False # this would require also masking the attention heads of the fake encoder, which is cumbersome + ) + + def setUp(self): + self.model_tester = M2M100DecoderModelTester(self) + self.config_tester = ConfigTester(self, config_class=M2M100Config, d_model=37) + self.is_encoder_decoder = True + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(config, **inputs_dict) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_model_fp16_forward(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_fp16_forward(config, **inputs_dict) + + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + + @unittest.skip( + reason="The encoder hidden states are an input to M2M100DecoderModelTest, so they do not have gradients, as the test expects." + ) + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip( + reason="The batching test does not split the precomputed encoder outputs into batches, so it cannot really batch the inputs." + ) + def test_batching_equivalence(self): + pass + + @unittest.skip( + reason="For M2M100DecoderModel, the decoder inputs are usually fake, so they do not contain attentions." + ) + def test_attention_outputs(self): + pass + + @unittest.skip( + reason="For M2M100DecoderModel, the decoder inputs are usually fake, so they do not contain hidden states." + ) + def test_hidden_states_output(self): + pass + + +@require_torch +@require_sentencepiece +@require_tokenizers +@slow +class SonarIntegrationTests(unittest.TestCase): + @cached_property + def default_tokenizer(self): + from transformers import NllbTokenizer + + return NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") + + def default_encoder(self): + return M2M100EncoderModel.from_pretrained("cointegrated/SONAR_200_text_encoder_hf").to(torch_device) + + def default_decoder(self): + return M2M100DecoderModel.from_pretrained("cointegrated/SONAR_200_text_decoder_hf").to(torch_device) + + def test_encoding_and_decoding(self): + tokenizer = self.default_tokenizer + encoder = self.default_encoder() + tokenizer.src_lang = "eng_Latn" + + sentences = ["My name is SONAR.", "I can embed the sentences into vectorial space."] + batch = tokenizer(sentences, padding=True, return_tensors="pt").to(torch_device) + + with torch.inference_mode(): + enc_out = encoder(**batch, pool_last_hidden_state=True) + assert enc_out.last_hidden_state.shape == (2, 1, 1024) + embeddings = enc_out.last_hidden_state.squeeze(1) + + ref_embeddings = torch.tensor( + [ + [-0.005286, 0.002008, -0.000562, 0.006344, 0.006329], + [-0.000330, -0.007055, 0.007644, 0.001841, 0.003727], + ], + device=torch_device, + ) + assert torch.allclose(embeddings[:, :5], ref_embeddings, rtol=1e-3) + + decoder = self.default_decoder() + + gen_out = decoder.generate( + # passing encoder_outputs is not recommended, because beam search decoding modifies them in place, which is ugly + # encoder_outputs=enc_out, + encoder_outputs=BaseModelOutput(last_hidden_state=enc_out.last_hidden_state.clone()), + num_beams=5, + forced_bos_token_id=tokenizer.convert_tokens_to_ids("eng_Latn"), + ) + text_out = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + assert text_out == ["My name is SONAR.", "I can embed the sentences into vector space."]