From e4e4f86564586bc91b27ac710efeee87e282c2b2 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 18 Jul 2024 16:29:18 +0000 Subject: [PATCH] fix grad norm fsdp --- open_diloco/train_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 666bcea..22f4696 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -402,7 +402,7 @@ def scheduler_fn(opt): else: scaler.unscale_(optimizer=optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # gradient clipping + model.clip_grad_norm_(1.0) # gradient clipping if world_messenger_hv: optimizer.step(scaler=scaler)