Skip to content

Commit

Permalink
Gemma2: eager attention by default (huggingface#32865)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored and zucchini-nlp committed Aug 30, 2024
1 parent 0a617d2 commit 79ce5c0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
14 changes: 14 additions & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,20 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
"""
Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models.
SDPA reduces the model performance on Gemma2 because of the logits softcapping.
"""
config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)

# if using the default path -> swap sdpa by eager
if not hard_check_only and config._attn_implementation == "sdpa":
config._attn_implementation = "eager"

return config


_CONFIG_FOR_DOC = "Gemma2Config"

Expand Down
20 changes: 18 additions & 2 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def setUp(self):
self.model_tester = Gemma2ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Gemma2Config, hidden_size=37)

@unittest.skip("Eager and SDPA do not produce the same outputs, thus this test fails")
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass

@unittest.skip("Gemma2's outputs are expected to be different")
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_inference(self):
pass

Expand Down Expand Up @@ -182,6 +182,22 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l
static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean]
self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape)
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_sdpa_equivalence(self):
pass

def test_eager_attention_loaded_by_default(self):
"""Gemma 2 + SDPA = inferior results, because of the logit softcapping. Eager is the default."""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

# Usually we enable SDPA by default, but not for Gemma2
model = Gemma2Model(config)
self.assertTrue(model.config._attn_implementation == "eager")

# We can still force SDPA
config._attn_implementation = "sdpa"
model = Gemma2Model(config)
self.assertTrue(model.config._attn_implementation == "sdpa")


@slow
Expand Down

0 comments on commit 79ce5c0

Please sign in to comment.