Skip to content

Commit 4b9440c

Browse files
committed
modify save list
1 parent 607c70d commit 4b9440c

File tree

4 files changed

+7
-3
lines changed

4 files changed

+7
-3
lines changed

tests/integration_tests/features.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,12 +350,13 @@ def build_features_test_list() -> list[OverrideDefinitions]:
350350
[
351351
[
352352
"--parallelism.data_parallel_shard_degree=4",
353-
"--activation_checkpoint.mode='full'",
353+
"--activation_checkpoint.mode=selective",
354+
"--activation_checkpoint.selective_ac_option=op",
354355
"--model.flavor=debugmodel_varlen_attn",
355356
]
356357
],
357-
"FSDP+VARLEN_ATTN",
358-
"fsdp+varlen_attn",
358+
"FSDP+VARLEN_ATTN + per op SAC",
359+
"fsdp+varlen_attn+per_op_sac",
359360
ngpu=4,
360361
skip_rocm_test=True,
361362
),

tests/unit_tests/test_activation_checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# used to compute the scaling factor for quantization.
2626
torch.ops.aten.max.default,
2727
torch._higher_order_ops.flex_attention,
28+
torch.ops.torch_attn._varlen_attn,
2829
}
2930

3031

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
# used to compute the scaling factor for quantization.
3131
torch.ops.aten.max.default,
3232
torch._higher_order_ops.flex_attention,
33+
torch.ops.torch_attn._varlen_attn,
3334
}
3435

3536

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
# used to compute the scaling factor for quantization.
4242
torch.ops.aten.max.default,
4343
torch._higher_order_ops.flex_attention,
44+
torch.ops.torch_attn._varlen_attn,
4445
}
4546

4647

0 commit comments

Comments
 (0)