diff --git a/paddleformers/transformers/model_utils.py b/paddleformers/transformers/model_utils.py index cdc2d8d497..9a86355482 100644 --- a/paddleformers/transformers/model_utils.py +++ b/paddleformers/transformers/model_utils.py @@ -2863,15 +2863,17 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): use_safetensors=use_safetensors, variant=variant, ) - - file_list = resolved_sharded_files if is_sharded else [resolved_archive_file] - ckpt_path = get_common_folder(file_list) + if resolved_archive_file is not None: + file_list = resolved_sharded_files if is_sharded else [resolved_archive_file] + ckpt_path = get_common_folder(file_list) + else: + ckpt_path = None # 3. init the model init_args = config["init_args"] or () with ContextManagers(init_contexts): model = cls(config, *init_args, **model_kwargs) - if hasattr(cls, "_gen_aoa_config") and load_checkpoint_format == "flex_checkpoint": + if ckpt_path is not None and hasattr(cls, "_gen_aoa_config") and load_checkpoint_format == "flex_checkpoint": aoa_config = cls._gen_aoa_config(config) sharded_state_dict = model.sharded_state_dict() dist.load_state_dict( diff --git a/tests/transformers/test_modeling_common.py b/tests/transformers/test_modeling_common.py index f0e6f90425..6adab8b146 100644 --- a/tests/transformers/test_modeling_common.py +++ b/tests/transformers/test_modeling_common.py @@ -35,8 +35,9 @@ ) from paddleformers.transformers import AutoModelForCausalLM, AutoTokenizer +from paddleformers.transformers.auto.modeling import MODEL_MAPPING from paddleformers.transformers.configuration_utils import PretrainedConfig -from paddleformers.transformers.model_utils import PretrainedModel +from paddleformers.transformers.model_utils import PretrainedModel, load_state_dict from paddleformers.utils.env import CONFIG_NAME, LEGACY_CONFIG_NAME # MODEL_HOME, from ..testing_utils import slow @@ -50,6 +51,27 @@ def _config_zero_init(config): return configs_no_init +def _mock_init_weights(self, module): + for name, param in module.named_parameters(recurse=False): + # Use the first letter of the name to get a value and go from a <> -13 to z <> 12 + value = ord(name[0].lower()) - 110 + param.data.fill_(value) + + +def _mock_all_init_weights(self): + import paddleformers.transformers.model_utils + + if paddleformers.transformers.model_utils._init_weights: + for module in self.modules(): + module._is_hf_initialized = False + # Initialize weights + self.apply(self._initialize_weights) + + # Tie weights should be skipped when not initializing all weights + # since from_pretrained(...) calls tie weights anyways + self.tie_weights() + + def get_cluster_from_args(selected_gpus): cluster_node_ips = "127.0.0.1" node_ip = "127.0.0.1" @@ -274,6 +296,115 @@ def check_save_load(out1, out2): else: check_save_load(first, second) + def test_from_pretrained_no_checkpoint(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config.dtype = "float32" + for model_class in self.all_model_classes: + model = model_class(copy.deepcopy(config)) + state_dict = model.state_dict() + + new_model = model_class.from_pretrained( + pretrained_model_name_or_path=None, config=config, state_dict=state_dict + ) + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(paddle.allclose(p1, p2)) + + def test_keep_in_fp32_modules(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + if model_class._keep_in_fp32_modules is None: + self.skipTest(reason="Model class has no _keep_in_fp32_modules attribute defined") + + model = self._make_model_instance(config, model_class) + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + if hasattr(config, "moe_group"): + config["moe_group"] = "dummy" + + model = model_class.from_pretrained(tmpdirname, config=config, dtype=paddle.float16) + + for name, param in model.named_parameters(): + if any( + module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in model_class._keep_in_fp32_modules + ): + self.assertTrue(param.dtype == paddle.float32) + else: + self.assertTrue(param.dtype == paddle.float16, name) + + def test_save_load_keys_to_ignore_on_save(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(copy.deepcopy(config)) + _keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None) + if _keys_to_ignore_on_save is None: + continue + + # check the keys are in the original state_dict + for k in _keys_to_ignore_on_save: + self.assertIn(k, model.state_dict().keys(), "\n".join(model.state_dict().keys())) + + # check that certain keys didn't get saved with the model + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + output_model_file = os.path.join(tmpdirname, "model_state.pdparams") + state_dict_saved = paddle.load(output_model_file) + + for k in _keys_to_ignore_on_save: + self.assertNotIn(k, state_dict_saved.keys(), "\n".join(state_dict_saved.keys())) + + # Test we can load the state dict in the model, necessary for the checkpointing API in Trainer. + load_result = model.load_state_dict(state_dict_saved, strict=False) + keys_to_ignore = set(model._keys_to_ignore_on_save) + + self.assertTrue(len(load_result.missing_keys) == 0 or set(load_result.missing_keys) == keys_to_ignore) + self.assertTrue(len(load_result.unexpected_keys) == 0) + + def test_paddle_save_load(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if config.__class__ not in MODEL_MAPPING: + self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING") + + base_class = self.base_model_class + + for model_class in self.all_model_classes: + if base_class is None or model_class == base_class: + continue + + # make a copy of model class to not break future tests + # from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class + class CopyClass(base_class): + pass + + base_class_copy = CopyClass + + # make sure that all keys are expected for test + base_class_copy._keys_to_ignore_on_load_missing = [] + + # make init deterministic, but make sure that + # non-initialized weights throw errors nevertheless + base_class_copy._init_weights = _mock_init_weights + base_class_copy.init_weights = _mock_all_init_weights + + model = model_class(copy.deepcopy(config)) + state_dict = model.state_dict() + + def check_equal(loaded): + for key in state_dict: + max_diff = paddle.max( + state_dict()[key] ^ loaded[key] + if isinstance(state_dict[key], paddle.BoolTensor) + else paddle.abs(state_dict[key] - loaded[key]) + ).item() + self.assertLessEqual(max_diff, 1e-6, msg=f"{key} not identical") + + # check that certain keys didn't get saved with the model + with tempfile.TemporaryDirectory() as tmpdirname: + pt_checkpoint_path = os.path.join(tmpdirname, "paddle_model.bin") + paddle.save(state_dict, pt_checkpoint_path) + check_equal(load_state_dict(pt_checkpoint_path)) + def test_determinism(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()