Skip to content

Commit

Permalink
fix opt step on cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 11, 2024
1 parent 04dd585 commit e926e6c
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion open_diloco/train_pure_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,14 @@ def get_model(config: Config) -> LlamaForCausalLM:


def get_offloaded_param(model: LlamaForCausalLM) -> list[torch.Tensor]:
return [param.data.detach().clone().to("cpu") for param in model.parameters()]
offloaded_params = []
for param in model.parameters():
if param.requires_grad:
offloaded_param = param.data.detach().clone().to("cpu")
offloaded_param.requires_grad = True
offloaded_params.append(offloaded_param)

return offloaded_params


def train(config: Config):
Expand Down Expand Up @@ -241,6 +248,9 @@ def train(config: Config):
else:
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg)

for param in outer_optimizer.param_groups[0]["params"]:
print(param.requires_grad)

outer_optimizer.step()
outer_optimizer.zero_grad()

Expand Down

0 comments on commit e926e6c

Please sign in to comment.