Skip to content

Commit

Permalink
fix_mbart_tied_weights (huggingface#26422)
Browse files Browse the repository at this point in the history
* fix_mbart_tied_weights

* add test
  • Loading branch information
SunMarc authored Sep 28, 2023
1 parent 216dff7 commit 5e11d72
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,11 @@ def get_encoder(self):
def get_decoder(self):
return self.decoder

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.get_input_embeddings())
self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings())

@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
Expand Down
37 changes: 37 additions & 0 deletions tests/models/mbart/test_modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,43 @@ def test_generate_fp16(self):
model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)

def test_ensure_weights_are_shared(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()

config.tie_word_embeddings = True
model = MBartForConditionalGeneration(config)

# MBart shares four weights.
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
self.assertEqual(
len(
{
model.get_output_embeddings().weight.data_ptr(),
model.get_input_embeddings().weight.data_ptr(),
model.base_model.decoder.embed_tokens.weight.data_ptr(),
model.base_model.encoder.embed_tokens.weight.data_ptr(),
}
),
1,
)

config.tie_word_embeddings = False
model = MBartForConditionalGeneration(config)

# MBart shares four weights.
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
self.assertEqual(
len(
{
model.get_output_embeddings().weight.data_ptr(),
model.get_input_embeddings().weight.data_ptr(),
model.base_model.decoder.embed_tokens.weight.data_ptr(),
model.base_model.encoder.embed_tokens.weight.data_ptr(),
}
),
2,
)


def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
Expand Down

0 comments on commit 5e11d72

Please sign in to comment.