From 508743bbef18e2c4da6a8d7356994fb374776b6e Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Tue, 8 Aug 2023 12:49:21 -0700 Subject: [PATCH] fix a typo in the FSDP example (#1159) --- distributed/FSDP/utils/train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/FSDP/utils/train_utils.py b/distributed/FSDP/utils/train_utils.py index aaf0127dd8..24cf239e7c 100644 --- a/distributed/FSDP/utils/train_utils.py +++ b/distributed/FSDP/utils/train_utils.py @@ -72,7 +72,7 @@ def validation(model, rank, world_size, val_loader): model.eval() correct = 0 local_rank = int(os.environ['LOCAL_RANK']) - fsdp_loss = torch.zeros(3).to(local_rank) + fsdp_loss = torch.zeros(2).to(local_rank) if rank == 0: inner_pbar = tqdm.tqdm( range(len(val_loader)), colour="green", desc="Validation Epoch"