Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Natively support SONAR text models as M2M100 encoder and decoder models #29646

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ab00388
add m2m100 encoder model code
avidale Mar 13, 2024
5f20047
Merge branch 'main' into sonar-text-models
avidale Mar 13, 2024
97e2356
add the decoder-only model
avidale Apr 15, 2024
776bbb1
add SONAR checkpoint conversion script
avidale Apr 16, 2024
9dc73a5
auto style fixers
avidale Apr 16, 2024
f797f86
Merge remote-tracking branch 'origin/main' into sonar-text-models
avidale Apr 16, 2024
05bcfe3
add an optional mean pooling
avidale Aug 5, 2024
158e70c
Merge branch 'main' into sonar-text-models
avidale Aug 5, 2024
3b20d77
fix a typo
avidale Aug 5, 2024
a3ff402
adding a test of model conversion
avidale Aug 5, 2024
40f93a7
fixup, add doc stub, add integration tests
avidale Aug 5, 2024
6ef942e
update the docs and fix the integration tests
avidale Aug 5, 2024
d240680
add special tests for the SONAR encoder model
avidale Aug 5, 2024
e3cc795
create tests for the sonar decoder
avidale Aug 5, 2024
d23973e
fix decoder unit tests
avidale Aug 5, 2024
7dbf383
fix the rest of decoder unit tests
avidale Aug 5, 2024
28eefe3
change the copyright header
avidale Aug 28, 2024
1c51a62
Merge branch 'main' into sonar-text-models
avidale Oct 24, 2024
25931cb
remove _reorder_cache and dummy encoder attentions and states
avidale Oct 24, 2024
9f18897
return M2M100 _reorder_cache and skip extra outputs test for SONAR de…
avidale Oct 24, 2024
f60ebea
resurrect one more _reorder_cache
avidale Oct 24, 2024
bc5479d
Merge branch 'main' into sonar-text-models
avidale Oct 25, 2024
6eb0dbd
Merge branch 'main' into sonar-text-models
avidale Oct 25, 2024
f9d10d9
Merge branch 'main' into sonar-text-models
avidale Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion docs/source/en/model_doc/m2m_100.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,83 @@ model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", t
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).


# SONAR encoders and decoders

The SONAR embedding space was introduced in [SONAR: Sentence-Level Multimodal and Language-Agnostic Representations](https://arxiv.org/abs/2308.11466) by Paul-Ambroise Duquenne, Holger Schwenk, and Benoît Sagot.
In the text modality, SONAR allows encoding texts in 202 languages into a shared space, and decoding these embeddings
back to any of the 202 languages. The architecture of the text encoders and decoders is based on [NLLB](../nllb).

The original implementation resides in the [SONAR](https://github.com/facebookresearch/sonar) repository.

The Huggingface implementation of SONAR text encoder and decoder, following the NLLB implementation, is based on
the M2M100 architecture. Therefore, SONAR encoder is represented by `M2M100EncoderModel`, and a SONAR decoder
is represented by `M2M100DecoderModel`.

## Usage tips and examples

An encoder can be applied to convert text sentences into 1024-dimensional embeddings in the following way:

```python
from transformers import M2M100EncoderModel, NllbTokenizer

encoder = M2M100EncoderModel.from_pretrained("cointegrated/SONAR_200_text_encoder_hf")
tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="eng_Latn")

inputs = tokenizer(sentences, padding=True, return_tensors="pt")
with torch.inference_mode():
encoder_out = encoder(**inputs, pool_last_hidden_state=True)
embeddings = encoder_out.last_hidden_state.squeeze(1)

print(embeddings.shape) # [2, 1024]
```
The argument `pool_last_hidden_state=True` tells the model to aggregate the contextual token embeddings by mean pooling
after they are produced by the encoder. This is the only difference between the architectures of SONAR and NLLB.

Given these embeddings, a decoder can reconstruct them back to texts.
```python
from transformers import M2M100DecoderModel
from transformers.modeling_outputs import BaseModelOutput

decoder = M2M100DecoderModel.from_pretrained("cointegrated/SONAR_200_text_decoder_hf")

# Decoding into the original (English) language
generator_out = decoder.generate(
# passing encoder_outputs is not recommended, because beam search decoding modifies them in place, which is ugly
# encoder_outputs=enc_out,
encoder_outputs=BaseModelOutput(last_hidden_state=embeddings.unsqueeze(1)),
num_beams=5,
forced_bos_token_id=tokenizer.convert_tokens_to_ids("eng_Latn"),
)
text_out = tokenizer.batch_decode(generator_out, skip_special_tokens=True)
print(text_out) # ["My name is SONAR.", "I can embed the sentences into vector space."]

# Decoding into some other (French) language
generator_out = decoder.generate(
# passing encoder_outputs is not recommended, because beam search decoding modifies them in place, which is ugly
# encoder_outputs=enc_out,
encoder_outputs=BaseModelOutput(last_hidden_state=embeddings.unsqueeze(1)),
num_beams=5,
forced_bos_token_id=tokenizer.convert_tokens_to_ids("fra_Latn"),
)
text_out = tokenizer.batch_decode(generator_out, skip_special_tokens=True)
print(text_out) # ['Mon nom est SONAR.', "Je peux intégrer les phrases dans l'espace vectoriel."]
```

The list of BCP-47 codes for the 202 language varieties currently supported by the SONAR text models can be found in
[a model card](https://github.com/facebookresearch/SONAR/blob/main/sonar/cards/text_sonar_basic_decoder.yaml) in the
official SONAR repo. It is the same list of languages that [NLLB](../nllb) supports, and it mostly coincides with the
[FLORES-200 language list](https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200).


## M2M100EncoderModel

[[autodoc]] M2M100EncoderModel
- forward

## M2M100DecoderModel

[[autodoc]] M2M100DecoderModel
- forward
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2644,6 +2644,8 @@
)
_import_structure["models.m2m_100"].extend(
[
"M2M100DecoderModel",
"M2M100EncoderModel",
"M2M100ForConditionalGeneration",
"M2M100Model",
"M2M100PreTrainedModel",
Expand Down Expand Up @@ -7309,6 +7311,8 @@
LxmertVisualFeatureEncoder,
)
from .models.m2m_100 import (
M2M100DecoderModel,
M2M100EncoderModel,
M2M100ForConditionalGeneration,
M2M100Model,
M2M100PreTrainedModel,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/m2m_100/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
pass
else:
_import_structure["modeling_m2m_100"] = [
"M2M100DecoderModel",
"M2M100EncoderModel",
"M2M100ForConditionalGeneration",
"M2M100Model",
"M2M100PreTrainedModel",
Expand All @@ -46,6 +48,8 @@
pass
else:
from .modeling_m2m_100 import (
M2M100DecoderModel,
M2M100EncoderModel,
M2M100ForConditionalGeneration,
M2M100Model,
M2M100PreTrainedModel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,5 @@ def convert_fairseq_m2m100_checkpoint_from_disk(checkpoint_path):
parser.add_argument("fairseq_path", type=str, help="path to a model.pt on local filesystem.")
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
model = convert_fairseq_m2m100_checkpoint_from_disk(args.fairseq_pathß)
model = convert_fairseq_m2m100_checkpoint_from_disk(args.fairseq_path)
model.save_pretrained(args.pytorch_dump_folder_path)
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright 2024 The Sonar Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script converts fairseq/fairseq2 SONAR text encoder and decoder checkpoints to transformers.
The reference architectures are given in https://github.com/facebookresearch/SONAR/blob/main/sonar/models/sonar_text/builder.py.
The checkpoints for conversion can be found in:
- encoder: https://github.com/facebookresearch/SONAR/blob/main/sonar/cards/text_sonar_basic_encoder.yaml
- decoder: https://github.com/facebookresearch/SONAR/blob/main/sonar/cards/text_sonar_basic_decoder.yaml
"""

import argparse

import torch
from torch import nn

from transformers import M2M100Config, M2M100DecoderModel, M2M100EncoderModel


def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer


def get_parameter_renames(model, state_dict, is_decoder=True):
parameter_renames = {}
trf_names = {name for name, value in model.named_parameters()}
fs2_names = set(state_dict.keys())

for trf_name in trf_names:
fs2_name = trf_name
if trf_name == "shared.weight":
fs2_name = "decoder_frontend.embed.weight" if is_decoder else "encoder_frontend.embed.weight"
if trf_name == "lm_head.weight":
fs2_name = "final_proj.weight"

if trf_name.startswith("layers."):
fs2_name = "decoder." + trf_name if is_decoder else "encoder." + trf_name
if trf_name.startswith("layer_norm.") and is_decoder:
fs2_name = "decoder." + trf_name
if trf_name.startswith("encoder.layer_norm.") and not is_decoder:
fs2_name = trf_name.split(".", 1)[1]

if ".encoder_attn." in fs2_name:
fs2_name = fs2_name.replace(".encoder_attn.", ".encoder_decoder_attn.")
if ".encoder_attn_layer_norm." in fs2_name:
fs2_name = fs2_name.replace(".encoder_attn_layer_norm.", ".encoder_decoder_attn_layer_norm.")
if ".out_proj." in fs2_name:
fs2_name = fs2_name.replace(".out_proj.", ".output_proj.")
if ".fc1." in fs2_name:
fs2_name = fs2_name.replace(
".fc1.",
".ffn.inner_proj.",
)
if ".fc2." in fs2_name:
fs2_name = fs2_name.replace(
".fc2.",
".ffn.output_proj.",
)
if ".final_layer_norm." in fs2_name:
fs2_name = fs2_name.replace(
".final_layer_norm.",
".ffn_layer_norm.",
)

if fs2_name in fs2_names:
parameter_renames[trf_name] = fs2_name
else:
raise ValueError(f"Parameter {trf_name} could not be mapped from transformers to fairseq2 state dict.")
return parameter_renames


def reorder_special_tokens(new_state_dict):
"""
In fairseq2, special tokens are ['<pad>', '<unk>', '<s>', '</s>'].
In transformers (NLLB) they are ['<s>', '<pad>', '</s>', '<unk>'].
We want to reuse the NLLB tokenizer, so we reorder the embeddings.
avidale marked this conversation as resolved.
Show resolved Hide resolved
"""
special_token_embs = new_state_dict["shared.weight"].data[[2, 0, 3, 1]].clone()
for param_name in [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
"lm_head.weight",
"shared.weight",
]:
if param_name in new_state_dict:
new_state_dict[param_name].data[[0, 1, 2, 3]] = special_token_embs


def convert_sonar_checkpoint_from_disk(checkpoint_path):
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
state_dict = checkpoint_dict["model"]

# In Fairseq2 SONAR checkpoints, there are no configs (they are supposed to be in a yaml model card elsewhere).
# Thus, we just assume the "basic" hyperparameters
# see the arhc registry at https://github.com/facebookresearch/SONAR/blob/main/sonar/models/sonar_text/builder.py

config = M2M100Config(
vocab_size=256206,
max_position_embeddings=1024,
encoder_layers=24,
decoder_layers=24,
encoder_attention_heads=16,
decoder_attention_heads=16,
encoder_ffn_dim=1024 * 8,
decoder_ffn_dim=1024 * 8,
d_model=1024,
encoder_layerdrop=0,
decoder_layerdrop=0,
dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.1,
activation_function="relu",
)

if any(parameter_name.startswith("encoder.") for parameter_name in state_dict):
is_decoder = False
model = M2M100EncoderModel(config)
elif any(parameter_name.startswith("decoder.") for parameter_name in state_dict):
is_decoder = True
model = M2M100DecoderModel(config)
else:
raise ValueError("The state dict does not seem to contain SONAR encoder or decoder.")

parameter_renames = get_parameter_renames(model, state_dict, is_decoder)
new_state_dict = {trf_name: state_dict[fs2_name] for trf_name, fs2_name in parameter_renames.items()}
reorder_special_tokens(new_state_dict)

if is_decoder:
new_state_dict["decoder.embed_tokens.weight"] = new_state_dict["shared.weight"]
else:
new_state_dict["encoder.embed_tokens.weight"] = new_state_dict["shared.weight"]

model.load_state_dict(new_state_dict, strict=True)
model.tie_weights()

return model


def test_conversion_accuracy(fairseq2_encoder_path, fairseq2_decoder_path):
"""
This test is not directly invoked, because the encoder and decoder paths should be provided explicitly,
and these checkpoints are too heavy to download them by default, just for a test.
Please run the test from your code like below:
```
from transformers.models.m2m_100.convert_sonar_original_checkpoint_to_transformers import test_conversion_accuracy
test_conversion_accuracy(PATH_TO_ENCODER, PATH_TO_DECODER)
```
The fairseq2 checkpoints can be downloaded from the urls indicated in the following cards:
- https://github.com/facebookresearch/SONAR/blob/main/sonar/cards/text_sonar_basic_encoder.yaml
- https://github.com/facebookresearch/SONAR/blob/main/sonar/cards/text_sonar_basic_decoder.yaml

The reference embeddings were obtained with:
```
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
t2vec_model = TextToEmbeddingModelPipeline(encoder="text_sonar_basic_encoder", tokenizer="text_sonar_basic_encoder")
ref_embeddings = t2vec_model.predict(sentences, source_lang="eng_Latn")[:, :5]
```
"""
from transformers import NllbTokenizer
from transformers.modeling_outputs import BaseModelOutput

tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", clean_up_tokenization_spaces=True)
sentences = ["My name is SONAR.", "I can embed the sentences into vectorial space."]
tokenizer.src_lang = "eng_Latn"
batch = tokenizer(sentences, padding=True, return_tensors="pt")

print("Converting the encoder...")
enc = convert_sonar_checkpoint_from_disk(fairseq2_encoder_path).eval()
assert isinstance(enc, M2M100EncoderModel)

print("Conversion completed, testing the embedding accuracy...")
with torch.inference_mode():
enc_out = enc(**batch, pool_last_hidden_state=True)
assert enc_out.last_hidden_state.shape == (2, 1, 1024)
embeddings = enc_out.last_hidden_state.squeeze(1)

ref_embeddings = torch.tensor(
[[-0.005286, 0.002008, -0.000562, 0.006344, 0.006329], [-0.000330, -0.007055, 0.007644, 0.001841, 0.003727]]
)
assert torch.allclose(embeddings[:, :5], ref_embeddings, rtol=1e-3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not assert on the embeddings as they are bound to change depending on the model!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the official SONAR model https://github.com/facebookresearch/SONAR/blob/main/sonar/cards/text_sonar_basic_encoder.yaml, which is the only SONAR text encoder that has ever been released so far.

I intended this test only to reproduce how this particular model is converted (and as a template if anyone ever applies my conversion script to some models).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(let's remove this still, specifically because the conversion script should allow anyone that has trained a model with your framework to also convert it without hassle! We have integration tests that make sure embeddings or outputs orvalide, conversion scripts are not the place for this!)

Copy link
Author

@avidale avidale Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function test_conversion_accuracy is not a part of the conversion script, it is an integration test for the conversion script. And it is optional, because it requires downloading the huge original checkpoints.
If you insist, though, I can remove it or move it to another file.

print("The embedding accuracy test has passed!")

print("Converting the decoder...")
dec = convert_sonar_checkpoint_from_disk(fairseq2_decoder_path).eval()
assert isinstance(dec, M2M100DecoderModel)

print("Conversion completed, testing the decoding accuracy...")
gen_out = dec.generate(
# passing encoder_outputs is not recommended, because beam search decoding modifies them in place, which is ugly
# encoder_outputs=enc_out,
encoder_outputs=BaseModelOutput(last_hidden_state=enc_out.last_hidden_state.clone()),
num_beams=5,
forced_bos_token_id=tokenizer.convert_tokens_to_ids("eng_Latn"),
)
text_out = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
assert text_out == ["My name is SONAR.", "I can embed the sentences into vector space."]
print("The decoding accuracy test has passed!")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("fairseq_path", type=str, help="path to a model.pt on local filesystem.")
parser.add_argument("dump_folder_path", default=None, type=str, help="Path to the output transformers model.")
args = parser.parse_args()
model = convert_sonar_checkpoint_from_disk(args.fairseq_path)
model.save_pretrained(args.dump_folder_path)
Loading