Skip to content

Commit

Permalink
idefics2 enable_input_require_grads not aligned with disable_input_re… (
Browse files Browse the repository at this point in the history
huggingface#33194)

* idefics2 enable_input_require_grads not aligned with disable_input_require_grads
make peft+idefics2 checkpoints disable fail

Signed-off-by: Wang, Yi <[email protected]>

* split test case

Signed-off-by: Wang, Yi <[email protected]>

* fix ci failure

Signed-off-by: Wang, Yi <[email protected]>

* refine test

Signed-off-by: Wang, Yi <[email protected]>

---------

Signed-off-by: Wang, Yi <[email protected]>
  • Loading branch information
sywangyi authored Sep 17, 2024
1 parent 642256d commit 74026b4
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,10 @@ def make_inputs_require_grads(module, input, output):
make_inputs_require_grads
)

def disable_input_require_grads(self):
self._text_require_grads_hook.remove()
self._vision_require_grads_hook.remove()

def get_input_embeddings(self):
return self.text_model.get_input_embeddings()

Expand Down Expand Up @@ -1466,6 +1470,10 @@ def make_inputs_require_grads(module, input, output):
make_inputs_require_grads
)

def disable_input_require_grads(self):
self._text_require_grads_hook.remove()
self._vision_require_grads_hook.remove()

def get_input_embeddings(self):
return self.model.text_model.get_input_embeddings()

Expand Down
12 changes: 12 additions & 0 deletions tests/models/speecht5/test_modeling_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ def test_torchscript_output_hidden_state(self):
def test_torchscript_simple(self):
pass

@unittest.skip(
reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527"
)
def test_peft_gradient_checkpointing_enable_disable(self):
pass


@require_torch
class SpeechT5ForSpeechToTextTester:
Expand Down Expand Up @@ -1743,6 +1749,12 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@unittest.skip(
reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527"
)
def test_peft_gradient_checkpointing_enable_disable(self):
pass

# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
Expand Down
38 changes: 38 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,44 @@ def test_gradient_checkpointing_enable_disable(self):
m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False"
)

def test_peft_gradient_checkpointing_enable_disable(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
if not model_class.supports_gradient_checkpointing:
continue

# at init model should have gradient checkpointing disabled
model = model_class(config)
self.assertFalse(model.is_gradient_checkpointing)

# check enable works
model._hf_peft_config_loaded = True
try:
model.gradient_checkpointing_enable()
except NotImplementedError:
continue

self.assertTrue(model.is_gradient_checkpointing)

# Loop over all modules and check that relevant modules have gradient_checkpointing set to True
for n, m in model.named_modules():
if hasattr(m, "gradient_checkpointing"):
self.assertTrue(
m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True"
)

# check disable works
model.gradient_checkpointing_disable()
self.assertFalse(model.is_gradient_checkpointing)

# Loop over all modules and check that relevant modules have gradient_checkpointing set to False
for n, m in model.named_modules():
if hasattr(m, "gradient_checkpointing"):
self.assertFalse(
m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False"
)

@is_flaky(description="low likelihood of failure, reason not yet discovered")
def test_save_load_fast_init_from_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down

0 comments on commit 74026b4

Please sign in to comment.