Skip to content

Commit 6d56edd

Browse files
ganteBernardZach
authored andcommitted
Config: lower save_pretrained exception to warning (huggingface#33906)
* lower to warning * msg * make fixup * rm extra comma
1 parent b45cef2 commit 6d56edd

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

src/transformers/configuration_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,11 +380,14 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
380380

381381
non_default_generation_parameters = self._get_non_default_generation_parameters()
382382
if len(non_default_generation_parameters) > 0:
383-
raise ValueError(
383+
# TODO (joao): this should be an exception if the user has modified the loaded config. See #33886
384+
warnings.warn(
384385
"Some non-default generation parameters are set in the model config. These should go into either a) "
385386
"`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file "
386-
"(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) "
387-
f"\nNon-default generation parameters: {str(non_default_generation_parameters)}"
387+
"(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model)."
388+
"This warning will become an exception in the future."
389+
f"\nNon-default generation parameters: {str(non_default_generation_parameters)}",
390+
UserWarning,
388391
)
389392

390393
os.makedirs(save_directory, exist_ok=True)

tests/utils/test_configuration_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,11 +313,12 @@ def test_repo_versioning_before(self):
313313
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
314314
self.assertEqual(old_configuration.hidden_size, 768)
315315

316-
def test_saving_config_with_custom_generation_kwargs_raises_exception(self):
316+
def test_saving_config_with_custom_generation_kwargs_raises_warning(self):
317317
config = BertConfig(min_length=3) # `min_length = 3` is a non-default generation kwarg
318318
with tempfile.TemporaryDirectory() as tmp_dir:
319-
with self.assertRaises(ValueError):
319+
with self.assertWarns(UserWarning) as cm:
320320
config.save_pretrained(tmp_dir)
321+
self.assertIn("min_length", str(cm.warning))
321322

322323
def test_get_non_default_generation_parameters(self):
323324
config = BertConfig()

0 commit comments

Comments
 (0)