From 26de41904319c7094afc53a3ee809de47112d387 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 28 Jun 2024 23:35:07 -0700 Subject: [PATCH] Fix AC in T5 example (#1273) --- distributed/FSDP/T5_training.py | 1 + distributed/FSDP/configs/fsdp.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/FSDP/T5_training.py b/distributed/FSDP/T5_training.py index 1aae5d0990..4ab136eace 100644 --- a/distributed/FSDP/T5_training.py +++ b/distributed/FSDP/T5_training.py @@ -121,6 +121,7 @@ def fsdp_main(args): device_id=torch.cuda.current_device(), limit_all_gathers=fsdp_config.limit_all_gathers) + # Enabling this causes https://github.com/pytorch/examples/issues/1210 if fsdp_config.fsdp_activation_checkpointing: policies.apply_fsdp_checkpointing(model) diff --git a/distributed/FSDP/configs/fsdp.py b/distributed/FSDP/configs/fsdp.py index 301771cd26..220cc67c55 100644 --- a/distributed/FSDP/configs/fsdp.py +++ b/distributed/FSDP/configs/fsdp.py @@ -8,7 +8,7 @@ class fsdp_config: mixed_precision: bool=True use_fp16: bool=False seed: int=42 - fsdp_activation_checkpointing: bool=True + fsdp_activation_checkpointing: bool=False limit_all_gathers: bool=True sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD #HYBRID_SHARD, SHARD_GRAD_OP checkpoint_type: StateDictType = StateDictType.FULL_STATE_DICT # alternatively can use SHARDED_STATE_DICT to avoid OOMs