Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions paddleformers/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2863,15 +2863,17 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
use_safetensors=use_safetensors,
variant=variant,
)

file_list = resolved_sharded_files if is_sharded else [resolved_archive_file]
ckpt_path = get_common_folder(file_list)
if resolved_archive_file is not None:
file_list = resolved_sharded_files if is_sharded else [resolved_archive_file]
ckpt_path = get_common_folder(file_list)
else:
ckpt_path = None
# 3. init the model
init_args = config["init_args"] or ()
with ContextManagers(init_contexts):
model = cls(config, *init_args, **model_kwargs)

if hasattr(cls, "_gen_aoa_config") and load_checkpoint_format == "flex_checkpoint":
if ckpt_path is not None and hasattr(cls, "_gen_aoa_config") and load_checkpoint_format == "flex_checkpoint":
aoa_config = cls._gen_aoa_config(config)
sharded_state_dict = model.sharded_state_dict()
dist.load_state_dict(
Expand Down
133 changes: 132 additions & 1 deletion tests/transformers/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
)

from paddleformers.transformers import AutoModelForCausalLM, AutoTokenizer
from paddleformers.transformers.auto.modeling import MODEL_MAPPING
from paddleformers.transformers.configuration_utils import PretrainedConfig
from paddleformers.transformers.model_utils import PretrainedModel
from paddleformers.transformers.model_utils import PretrainedModel, load_state_dict
from paddleformers.utils.env import CONFIG_NAME, LEGACY_CONFIG_NAME # MODEL_HOME,

from ..testing_utils import slow
Expand All @@ -50,6 +51,27 @@ def _config_zero_init(config):
return configs_no_init


def _mock_init_weights(self, module):
for name, param in module.named_parameters(recurse=False):
# Use the first letter of the name to get a value and go from a <> -13 to z <> 12
value = ord(name[0].lower()) - 110
param.data.fill_(value)


def _mock_all_init_weights(self):
import paddleformers.transformers.model_utils

if paddleformers.transformers.model_utils._init_weights:
for module in self.modules():
module._is_hf_initialized = False
# Initialize weights
self.apply(self._initialize_weights)

# Tie weights should be skipped when not initializing all weights
# since from_pretrained(...) calls tie weights anyways
self.tie_weights()


def get_cluster_from_args(selected_gpus):
cluster_node_ips = "127.0.0.1"
node_ip = "127.0.0.1"
Expand Down Expand Up @@ -274,6 +296,115 @@ def check_save_load(out1, out2):
else:
check_save_load(first, second)

def test_from_pretrained_no_checkpoint(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.dtype = "float32"
for model_class in self.all_model_classes:
model = model_class(copy.deepcopy(config))
state_dict = model.state_dict()

new_model = model_class.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(paddle.allclose(p1, p2))

def test_keep_in_fp32_modules(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if model_class._keep_in_fp32_modules is None:
self.skipTest(reason="Model class has no _keep_in_fp32_modules attribute defined")

model = self._make_model_instance(config, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

if hasattr(config, "moe_group"):
config["moe_group"] = "dummy"

model = model_class.from_pretrained(tmpdirname, config=config, dtype=paddle.float16)

for name, param in model.named_parameters():
if any(
module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in model_class._keep_in_fp32_modules
):
self.assertTrue(param.dtype == paddle.float32)
else:
self.assertTrue(param.dtype == paddle.float16, name)

def test_save_load_keys_to_ignore_on_save(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
model = model_class(copy.deepcopy(config))
_keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None)
if _keys_to_ignore_on_save is None:
continue

# check the keys are in the original state_dict
for k in _keys_to_ignore_on_save:
self.assertIn(k, model.state_dict().keys(), "\n".join(model.state_dict().keys()))

# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
output_model_file = os.path.join(tmpdirname, "model_state.pdparams")
state_dict_saved = paddle.load(output_model_file)

for k in _keys_to_ignore_on_save:
self.assertNotIn(k, state_dict_saved.keys(), "\n".join(state_dict_saved.keys()))

# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
load_result = model.load_state_dict(state_dict_saved, strict=False)
keys_to_ignore = set(model._keys_to_ignore_on_save)

self.assertTrue(len(load_result.missing_keys) == 0 or set(load_result.missing_keys) == keys_to_ignore)
self.assertTrue(len(load_result.unexpected_keys) == 0)

def test_paddle_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING")

base_class = self.base_model_class

for model_class in self.all_model_classes:
if base_class is None or model_class == base_class:
continue

# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(base_class):
pass

base_class_copy = CopyClass

# make sure that all keys are expected for test
base_class_copy._keys_to_ignore_on_load_missing = []

# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
base_class_copy._init_weights = _mock_init_weights
base_class_copy.init_weights = _mock_all_init_weights

model = model_class(copy.deepcopy(config))
state_dict = model.state_dict()

def check_equal(loaded):
for key in state_dict:
max_diff = paddle.max(
state_dict()[key] ^ loaded[key]
if isinstance(state_dict[key], paddle.BoolTensor)
else paddle.abs(state_dict[key] - loaded[key])
).item()
self.assertLessEqual(max_diff, 1e-6, msg=f"{key} not identical")

# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "paddle_model.bin")
paddle.save(state_dict, pt_checkpoint_path)
check_equal(load_state_dict(pt_checkpoint_path))

def test_determinism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down