From b3909989d38cc2666110e7bc6d114d29d942639c Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 4 Sep 2024 12:37:54 -0400 Subject: [PATCH] Fix excessive CPU memory usage with FSDP and cpu_ram_efficient_loading (#33154) --- src/transformers/modeling_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index eb0e61e26fac52..2faa60210ed4d7 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -958,6 +958,9 @@ def _load_state_dict_into_meta_model( ) ) ): + if is_fsdp_enabled(): + param_device = "cpu" if is_local_dist_rank_0() else "meta" + # For backward compatibility with older versions of `accelerate` and for non-quantized params set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) else: @@ -968,7 +971,10 @@ def _load_state_dict_into_meta_model( if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): module, tensor_name = get_module_from_name(model, param_name) value = getattr(module, tensor_name) - value = type(value)(value.data.to("cpu"), **value.__dict__) + param_to = "cpu" + if is_fsdp_enabled() and not is_local_dist_rank_0(): + param_to = "meta" + value = type(value)(value.data.to(param_to), **value.__dict__) setattr(module, tensor_name, value) # TODO: consider removing used param_parts from state_dict before return