Skip to content

Commit 53e949c

Browse files
authored
modify save list for varlen attn (#2082)
adding varlen attention ops to ac save list **testing** used DebugMode() to print out op list. verified that forward is not being recomputed in the backward step. ``` [rank0]:forward ops [rank0]:varlen_attn in forward: True ... [rank0]:varlen_attn recomputed in backward: False [rank0]:saved correctly ```
1 parent b39377f commit 53e949c

File tree

5 files changed

+8
-4
lines changed

5 files changed

+8
-4
lines changed

tests/integration_tests/features.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,12 +350,13 @@ def build_features_test_list() -> list[OverrideDefinitions]:
350350
[
351351
[
352352
"--parallelism.data_parallel_shard_degree=4",
353-
"--activation_checkpoint.mode='full'",
353+
"--activation_checkpoint.mode=selective",
354+
"--activation_checkpoint.selective_ac_option=op",
354355
"--model.flavor=debugmodel_varlen_attn",
355356
]
356357
],
357-
"FSDP+VARLEN_ATTN",
358-
"fsdp+varlen_attn",
358+
"FSDP+VARLEN_ATTN + per op SAC",
359+
"fsdp+varlen_attn+per_op_sac",
359360
ngpu=4,
360361
skip_rocm_test=True,
361362
),

tests/unit_tests/test_activation_checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
# used to compute the scaling factor for quantization.
2929
torch.ops.aten.max.default,
3030
torch._higher_order_ops.flex_attention,
31+
torch.ops.torch_attn._varlen_attn,
3132
}
3233

3334

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
# used to compute the scaling factor for quantization.
3434
torch.ops.aten.max.default,
3535
torch._higher_order_ops.flex_attention,
36+
torch.ops.torch_attn._varlen_attn,
3637
}
3738

3839

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
# used to compute the scaling factor for quantization.
4545
torch.ops.aten.max.default,
4646
torch._higher_order_ops.flex_attention,
47+
torch.ops.torch_attn._varlen_attn.default,
4748
}
4849

4950

torchtitan/models/qwen3/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
# used to compute the scaling factor for quantization.
4747
torch.ops.aten.max.default,
4848
torch._higher_order_ops.flex_attention,
49-
torch.ops.torch_attn._varlen_attn,
49+
torch.ops.torch_attn._varlen_attn.default,
5050
}
5151

5252

0 commit comments

Comments
 (0)