diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 666bcea..22f4696 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -402,7 +402,7 @@ def scheduler_fn(opt): else: scaler.unscale_(optimizer=optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # gradient clipping + model.clip_grad_norm_(1.0) # gradient clipping if world_messenger_hv: optimizer.step(scaler=scaler)