Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/integration_test_8gpu_simple_fsdp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ jobs:
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
mkdir artifacts-to-be-uploaded
python -m torchtitan.experiments.simple_fsdp.tests.integration_tests artifacts-to-be-uploaded --ngpu 8
# Run front-end integration tests of SimpleFSDP
python -m torchtitan.experiments.simple_fsdp.tests.frontend_integration_tests artifacts-to-be-uploaded --ngpu 8
# Run backend pass integration tests of SimpleFSDP
python -m torchtitan.experiments.simple_fsdp.tests.compiler_pass_integration_tests artifacts-to-be-uploaded --ngpu 8 --comm_mode local_tensor
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also tried FakeBackend mode, but the memory overhead is significantly higher @fegin 🤔 (~33Gib in Local_tensor -> ~90GiB in FakeBackend). My suspect is FakeBackend initialize the whole model on one rank?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it is the reverse, for FakeBackend mode, it should be lower. On the other hand, local_tensor will allocate all tensors on the same process. So it should be higher. cc., @dzmitry-huba

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, not sure which part is wrong. here is an easy repro. The memory I reported in prev message is on compiler CI test.

  1. NGPU=4 COMM_MODE="fake_backend" CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
[titan] 2025-11-25 18:27:58,478 - root - INFO - step:  1  loss: 12.2713  grad_norm:  0.0000  memory: 47.60GiB(50.11%)  tps: 1,112  tflops: 64.41  mfu: 6.51%
  1. NGPU=4 COMM_MODE="local_tensor" CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh

Since there is no actual training, the peak memory is ~31GiB from nvidia-smi.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh, I see. This is because we current skip the real training for LocalTensor because LocalTensor doesn't support FSDP2. But it should work with SimpleFSDP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, it is also skipped in simplefsdp. Not sure why, but fake_backend gives huge memory overhead in simplefsdp's compiler pass CI (~90GiB).

I think I can either verify with RealBackend (more memory overhead than LocalTensor, but smaller than fake backend); or using a LocalTensor mode (less memory overhead, but doesn't execute actual training).

Curious: which part LocalTensor mode is executing? If it executes compilation but skips the actual training, it looks sufficient for our case, since we are just verifying compiler passes' integration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kind follow up on this @fegin!

# Run the numerics unit tests of SimpleFSDP
torchrun --nproc-per-node=8 -m pytest torchtitan/experiments/simple_fsdp/tests/test_numerics.py -v
Expand Down
17 changes: 14 additions & 3 deletions tests/integration_tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,18 @@ def _run_cmd(cmd):
return subprocess.run([cmd], text=True, shell=True)


def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
def run_single_test(
test_flavor: OverrideDefinitions, full_path: str, output_dir: str, comm_mode: str
):
# run_test supports sequence of tests.
test_name = test_flavor.test_name
dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}"

all_ranks = ",".join(map(str, range(test_flavor.ngpu)))

for idx, override_arg in enumerate(test_flavor.override_args):
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_train.sh"
cmd = f"CONFIG_FILE={full_path} COMM_MODE={comm_mode} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_train.sh"

# dump compile trace for debugging purpose
cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd

Expand Down Expand Up @@ -102,14 +105,22 @@ def run_tests(args, test_list: list[OverrideDefinitions]):
f" because --ngpu arg is {args.ngpu}"
)
else:
run_single_test(test_flavor, args.config_path, args.output_dir)
run_single_test(
test_flavor, args.config_path, args.output_dir, args.comm_mode
)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"output_dir", help="Directory to dump results generated by tests"
)
parser.add_argument(
"comm_mode",
default="default",
choices=["default", "fake_backend", "local_tensor"],
help="Communication mode to validate tests",
)
parser.add_argument(
"--gpu_arch_type",
default="cuda",
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if splitting into two tests will incur overhead.
@wwwjn does this incur any overhead to CI?

Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os

from tests.integration_tests import OverrideDefinitions
from tests.integration_tests.run_tests import run_tests


def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"""
key is the config file name and value is a list of OverrideDefinitions
that is used to generate variations of integration tests based on the
same root config file.
"""
integration_tests_flavors = [
OverrideDefinitions(
[
[
"--model.name simple_fsdp.llama3",
"--model.flavor 8B",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.backend aot_eager",
"--compile.graph_passes auto_bucketing",
],
],
"1D+autobucketing",
"1d_autobucketing",
ngpu=8,
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.llama3",
"--model.flavor 8B",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.backend aot_eager",
"--compile.graph_passes transformer_block_bucketing",
],
],
"1D+transformer_block_bucketing",
"1d_transformer_block_bucketing",
ngpu=8,
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.llama3",
"--model.flavor 8B",
"--parallelism.tensor_parallel_degree 2",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.backend aot_eager",
"--compile.graph_passes auto_bucketing",
],
],
"2D+autobucketing",
"2d_autobucketing",
ngpu=8,
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.llama3",
"--model.flavor 8B",
"--parallelism.tensor_parallel_degree 2",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.backend aot_eager",
"--compile.graph_passes transformer_block_bucketing",
],
],
"2D+transformer_block_bucketing",
"2d_transformer_block_bucketing",
ngpu=8,
),
# TODO(ruisizhang123): add back after passes + PP is supported
# OverrideDefinitions(
# [
# [
# "--model.name simple_fsdp.llama3",
# "--model.flavor 8B",
# "--parallelism.tensor_parallel_degree 2",
# "--parallelism.pipeline_parallel_degree 2",
# "--compile.enable",
# "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
# "--compile.backend aot_eager",
# "--compile.graph_passes auto_bucketing",
# ],
# ],
# "3D+autobucketing",
# "3d_autobucketing",
# ngpu=8,
# ),
# OverrideDefinitions(
# [
# [
# "--model.name simple_fsdp.llama3",
# "--model.flavor 8B",
# "--parallelism.tensor_parallel_degree 2",
# "--parallelism.pipeline_parallel_degree 2",
# "--compile.enable",
# "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
# "--compile.backend aot_eager",
# "--compile.graph_passes transformer_block_bucketing",
# ],
# ],
# "3D+transformer_block_bucketing",
# "3d_transformer_block_bucketing",
# ngpu=8,
# ),
# OverrideDefinitions(
# [
# [
# "--model.name simple_fsdp.llama3",
# "--model.flavor 8B",
# "--parallelism.tensor_parallel_degree 2",
# "--parallelism.context_parallel_degree 2",
# "--compile.enable",
# "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
# "--compile.backend aot_eager",
# "--compile.graph_passes auto_bucketing",
# ],
# ],
# "FSDP+TP+CP+autobucketing",
# "fsdp+tp+cp_autobucketing",
# ngpu=8,
# ),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.llama3",
"--model.flavor 8B",
"--parallelism.tensor_parallel_degree 2",
"--parallelism.context_parallel_degree 2",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.backend aot_eager",
"--compile.graph_passes transformer_block_bucketing",
],
],
"FSDP+TP+CP+transformer_block_bucketing",
"fsdp+tp+cp_transformer_block_bucketing",
ngpu=8,
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.deepseek_v3",
"--model.flavor 16B",
"--parallelism.data_parallel_shard_degree 4",
"--parallelism.expert_parallel_degree 2",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.backend aot_eager",
"--compile.graph_passes auto_bucketing",
],
],
"FSDP+EP+autobucketing",
"fsdp+ep_autobucketing",
ngpu=4,
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.deepseek_v3",
"--model.flavor 16B",
"--parallelism.data_parallel_shard_degree 4",
"--parallelism.expert_parallel_degree 2",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.backend aot_eager",
"--compile.graph_passes transformer_block_bucketing",
],
],
"FSDP+EP+transformer_block_bucketing",
"fsdp+ep_transformer_block_bucketing",
ngpu=4,
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.deepseek_v3",
"--model.flavor 16B",
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.tensor_parallel_degree 2",
"--parallelism.expert_parallel_degree 4",
"--parallelism.expert_tensor_parallel_degree 1",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.backend aot_eager",
"--compile.graph_passes auto_bucketing",
],
],
"FSDP+TP+EP+autobucketing",
"fsdp+tp+ep_autobucketing",
ngpu=4,
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.deepseek_v3",
"--model.flavor 16B",
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.tensor_parallel_degree 2",
"--parallelism.expert_parallel_degree 4",
"--parallelism.expert_tensor_parallel_degree 1",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.backend aot_eager",
"--compile.graph_passes transformer_block_bucketing",
],
],
"FSDP+TP+EP+transformer_block_bucketing",
"fsdp+tp+ep_transformer_block_bucketing",
ngpu=4,
),
]
return integration_tests_flavors


_TEST_SUITES_FUNCTION = {
"simple_fsdp": build_simple_fsdp_test_list,
}


def main():
parser = argparse.ArgumentParser()
parser.add_argument("output_dir")
parser.add_argument(
"--comm_mode",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dislike this idea of creating another layer of comm_mode config. We are doing

  1. pass comm_mode from test config to run_train.sh's COMM_MODE
  2. pass COMM_MODE from run_train.sh to actual training's comm.mode

The only reason we are doing this is to let COMM_MODE select torchrun / python job starter.

If we have to differentiate between torchrun / python, what we can do is to let run_train.sh select the starter by looking at the field of --comm.mode passed in by user / tests.

cc @fegin

default="default",
choices=["default", "fake_backend", "local_tensor"],
help="Communication mode to validate tests",
)
parser.add_argument(
"--config_path",
default="./tests/integration_tests/base_config.toml",
help="Base config path for integration tests. This is the config that will be used as a base for all tests.",
)
parser.add_argument(
"--test_name",
default="all",
help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)",
)
parser.add_argument("--ngpu", default=8, type=int)
args = parser.parse_args()

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if os.listdir(args.output_dir):
raise RuntimeError("Please provide an empty output directory.")

test_list = _TEST_SUITES_FUNCTION["simple_fsdp"]()
run_tests(args, test_list)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,6 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"1D",
"1d",
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.backend aot_eager",
"--compile.graph_passes auto_bucketing",
],
],
"1D+autobucketing",
"1d_autobucketing",
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.backend aot_eager",
"--compile.graph_passes transformer_block_bucketing",
],
],
"1D+transformer_block_bucketing",
"1d_transformer_block_bucketing",
),
OverrideDefinitions(
[
[
Expand Down