Skip to content

Commit

Permalink
Fix excessive CPU memory usage with FSDP and cpu_ram_efficient_loading (
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Sep 4, 2024
1 parent a1faf22 commit b390998
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,9 @@ def _load_state_dict_into_meta_model(
)
)
):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"

# For backward compatibility with older versions of `accelerate` and for non-quantized params
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else:
Expand All @@ -968,7 +971,10 @@ def _load_state_dict_into_meta_model(
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
module, tensor_name = get_module_from_name(model, param_name)
value = getattr(module, tensor_name)
value = type(value)(value.data.to("cpu"), **value.__dict__)
param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
value = type(value)(value.data.to(param_to), **value.__dict__)
setattr(module, tensor_name, value)
# TODO: consider removing used param_parts from state_dict before return

Expand Down

0 comments on commit b390998

Please sign in to comment.