Skip to content
Draft
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions torchtitan/experiments/compiler_toolkit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.tom
NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --model.flavor=debugmodel_flex_attn
```

**SimpleFSDP + TP + EP + Inductor Lite**
```shell
NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes inductor_lite
```


## llama3

**SimpleFSDP + TP**
Expand All @@ -39,6 +45,11 @@ NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./r
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing
```

**SimpleFSDP + TP + transformer-block-bucketing + inductor lite**
```shell
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes inductor_lite
```

**SimpleFSDP + TP + FlexAttention**
```shell
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn
Expand Down
9 changes: 9 additions & 0 deletions torchtitan/experiments/compiler_toolkit/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from contextlib import contextmanager
from typing import Callable

import torch
from torch.distributed.tensor import DTensor, Replicate
Expand Down Expand Up @@ -53,3 +54,11 @@ def register_blockmask_pytree_node():
flatten_with_keys_fn=BlockMask._flatten_with_keys,
serialized_type_name="torch.nn.attention.flex_attention.BlockMask",
)


def end_with_pass(passes: list[Callable], names: list[str]) -> bool:
return (
len(passes) > 0
and (last_pass_name := getattr(passes[-1], "__name__", None))
and (last_pass_name in names)
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
CompiledModule,
get_compiler_passes_from_config,
get_joint_custom_passes_from_config,
GraphBuilderOptions,
is_using_inductor_lite,
joint_graph_builder,
make_compiler_with_passes,
)
Expand Down Expand Up @@ -87,13 +89,18 @@ def parallelize_deepseekv3(
compiler_passes, dump_folder=job_config.job.dump_folder
)

options = GraphBuilderOptions(
dump_folder=job_config.job.dump_folder,
use_inductor_lite=is_using_inductor_lite(job_config),
)

# Create custom joint_graph_builder with deepseekv3-specific compilers
deepseekv3_joint_graph_builder = functools.partial(
joint_graph_builder,
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
joint_custom_passes=joint_custom_passes,
dump_folder=job_config.job.dump_folder,
options=options,
)

# TODO: CompiledModule should take sample input as well, so that we can
Expand Down
80 changes: 69 additions & 11 deletions torchtitan/experiments/compiler_toolkit/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import contextlib
import dataclasses
import functools
from pathlib import Path
from typing import Any, Callable, List, Optional
Expand All @@ -20,9 +21,16 @@
from torch.distributed.tensor import DTensor
from torchtitan.config import JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.experiments.compiler_toolkit.common_utils import end_with_pass
from torchtitan.tools.logging import logger


@dataclasses.dataclass(frozen=True)
class GraphBuilderOptions:
dump_folder: str | None = None
use_inductor_lite: bool = False


def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None:
# TODO: make the dump rank configurable
if not dump_folder or torch.distributed.get_rank() != 0:
Expand Down Expand Up @@ -88,7 +96,7 @@ def joint_graph_builder(
fw_compiler: Optional[Callable] = None,
bw_compiler: Optional[Callable] = None,
joint_custom_passes: Optional[List[Callable]] = None,
dump_folder: str | None = None,
options: GraphBuilderOptions = None,
):
"""
Build a joint forward-backward graph for the model with optional custom compilers.
Expand All @@ -100,7 +108,7 @@ def joint_graph_builder(
fw_compiler: Optional custom forward compiler function
bw_compiler: Optional custom backward compiler function
joint_custom_passes: list of custom passes to run on the joint graph
dump_folder: Optional folder to dump the graph to
options: Optional configs for graph builder
"""
assert isinstance(model_args, tuple)
for idx, arg in enumerate(model_args):
Expand All @@ -110,7 +118,7 @@ def joint_graph_builder(
(
joint_with_descriptors,
tracing_context,
) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder)
) = export_joint(model, model_args, model_kwargs, dump_folder=options.dump_folder)

# run custom passes on joint-graph before partitioner
if joint_custom_passes is not None:
Expand All @@ -119,7 +127,9 @@ def joint_graph_builder(
joint_with_descriptors.graph_module
)

with tracing(tracing_context):
with tracing(tracing_context), torch._functorch.config.patch(
selective_decompose=options.use_inductor_lite
):
fn = aot_compile_joint_with_descriptors(
joint_with_descriptors, fw_compiler=fw_compiler, bw_compiler=bw_compiler
)
Expand Down Expand Up @@ -217,6 +227,7 @@ def compiler(
example_inputs,
passes: List[Callable] = None,
dump_folder: str | None = None,
is_forward: bool = True,
):
"""
Compile a graph module by applying a sequence of compiler passes.
Expand All @@ -233,6 +244,17 @@ def compiler(
if passes is None:
passes = DEFAULT_COMPILER_PASSES

if end_with_pass(passes, ["inductor_lite_pass"]):
# inductor lite pass is always the last pass if it is applied since it
# behaves differently for forward and backwawrd. so we explicitly pass
# the info. For example, different methods are used to identify static input
# indices.
last_pass = passes[-1]
_last_pass = functools.partial(last_pass, is_forward=is_forward)

# keep the function name for debug log
passes[-1] = functools.wraps(last_pass)(_last_pass)

logger.debug(f"{name} before compiler:")
logger.debug(
gm.print_readable(print_output=False, include_stride=True, include_device=True)
Expand All @@ -248,11 +270,16 @@ def compiler(
logger.info(f"Applying pass: {pass_name}")
gm = pass_fn(gm, example_inputs)

logger.debug(f"{name} after compiler:")
logger.debug(
gm.print_readable(print_output=False, include_stride=True, include_device=True)
)
_dump_gm(dump_folder, gm, f"{name}_after_compiler")
if not end_with_pass(passes, ["inductor_lite_pass"]):
# inductor lite mode returns a CompiledFxGraph which does not support print_readable.
logger.debug(f"{name} after compiler:")
logger.debug(
gm.print_readable(
print_output=False, include_stride=True, include_device=True
)
)
_dump_gm(dump_folder, gm, f"{name}_after_compiler")

return gm


Expand All @@ -271,17 +298,41 @@ def make_compiler_with_passes(

def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler(
"fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
"fwd_gm",
gm,
example_inputs,
passes=passes,
dump_folder=dump_folder,
is_forward=True,
)

def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler(
"bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
"bwd_gm",
gm,
example_inputs,
passes=passes,
dump_folder=dump_folder,
is_forward=False,
)

return fw_compiler, bw_compiler


def validate_pass_names(pass_names: list[str]) -> None:
if "inductor_lite" in pass_names:
# inductor lite supports regional_inductor by default. They share the same
# user-facing frontend API (i.e., the context manager), use different
# backend implementations, and achieve the same compilation result.
assert "regional_inductor" not in pass_names, (
"inductor_lite uses regional_inductor by default. please use one "
"pass at a time."
)
assert (
pass_names[-1] == "inductor_lite"
), "inductor_lite has to be the last pass to apply"


def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfig):
"""
Extract and validate compiler passes from job config.
Expand All @@ -298,6 +349,8 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi
)

pass_names = getattr(job_config.compile, "passes", [])
validate_pass_names(pass_names)

if (
"autobucketing_reordering" in pass_names
and "transformer_block_bucketing" in pass_names
Expand Down Expand Up @@ -371,3 +424,8 @@ def get_joint_custom_passes_from_config(
)

return joint_custom_passes


def is_using_inductor_lite(job_config: JobConfig) -> bool:
pass_names = getattr(job_config.compile, "passes", [])
return "inductor_lite" in pass_names
69 changes: 69 additions & 0 deletions torchtitan/experiments/compiler_toolkit/inductor_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Inductor lite pass for the compiler toolkit.

This module provides inductor lite pass that can be applied to graph modules
during compilation.
"""
from typing import Optional

import torch
from torchtitan.tools.logging import logger


def get_inductor_lite_fw_compiler(extra_config: Optional[dict] = None):
from torch._inductor import lite_mode_options
from torch._inductor.compile_fx import compile_fx_inner

context = torch._guards.TracingContext.try_get()

if not context or not context.fw_metadata:
logger.warn("No context or fw_metadata available")
static_input_idxs = ()
else:
static_input_idxs = context.fw_metadata.static_input_indices

inductor_config = lite_mode_options
if extra_config:
inductor_config.update(extra_config)

def fw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple):
with torch._inductor.config.patch(inductor_config):
compiled_fn = compile_fx_inner(
gm,
example_inputs,
static_input_idxs=static_input_idxs,
is_backward=False,
)
return compiled_fn

return fw_compiler


def get_inductor_lite_bw_compiler(extra_config: Optional[dict] = None):
from torch._inductor import lite_mode_options
from torch._inductor.compile_fx import compile_fx_inner
from torch._inductor.utils import count_tangents

inductor_config = lite_mode_options
if extra_config:
inductor_config.update(extra_config)

def bw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple):
fixed = count_tangents(gm)

with torch._inductor.config.patch(inductor_config):
compiled_fn = compile_fx_inner(
gm,
example_inputs,
static_input_idxs=list(range(fixed)),
is_backward=True,
)
return compiled_fn

return bw_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
CompiledModule,
get_compiler_passes_from_config,
get_joint_custom_passes_from_config,
GraphBuilderOptions,
is_using_inductor_lite,
joint_graph_builder,
make_compiler_with_passes,
)
Expand Down Expand Up @@ -74,13 +76,18 @@ def parallelize_llama(
compiler_passes, dump_folder=job_config.job.dump_folder
)

options = GraphBuilderOptions(
dump_folder=job_config.job.dump_folder,
use_inductor_lite=is_using_inductor_lite(job_config),
)

# Create custom joint_graph_builder with llama-specific compilers
llama_joint_graph_builder = functools.partial(
joint_graph_builder,
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
joint_custom_passes=joint_custom_passes,
dump_folder=job_config.job.dump_folder,
options=options,
)

# TODO: CompiledModule should take sample input as well, so that we can
Expand Down
29 changes: 29 additions & 0 deletions torchtitan/experiments/compiler_toolkit/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
during compilation. Passes can be selected and configured via job config.
"""

from typing import Callable

import torch
from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
from torch.fx.passes.regional_inductor import regional_inductor
from torchtitan.experiments.compiler_toolkit.inductor_lite import (
get_inductor_lite_bw_compiler,
get_inductor_lite_fw_compiler,
)
from torchtitan.experiments.simple_fsdp.reshard_after_forward import (
annotate_fsdp_all_gather,
)
Expand Down Expand Up @@ -83,9 +89,32 @@ def fsdp_reshard_after_fwd_pass(
return gm


def inductor_lite_pass(
gm: torch.fx.GraphModule, example_inputs, is_forward: bool
) -> Callable:
"""
Apply inductor lite mode.

This pass takes a gm and generates a callable (not gm) using inductor. The lite
mode falls back for all ops except explicitly user-annotated ops under
regional compile.
"""
# TODO: fix inductor size assertion for all_reduce
# https://github.com/pytorch/pytorch/issues/167430
extra_inductor_config = {"size_asserts": False}

if is_forward:
_compiler = get_inductor_lite_fw_compiler(extra_inductor_config)
else:
_compiler = get_inductor_lite_bw_compiler(extra_inductor_config)

return _compiler(gm, example_inputs)


# Registry mapping pass names to pass functions
AVAILABLE_COMPILER_PASSES = {
"autobucketing_reordering": autobucketing_reordering_pass,
"transformer_block_bucketing": transformer_block_bucketing_reordering_pass,
"regional_inductor": regional_inductor_pass,
"inductor_lite": inductor_lite_pass,
}
Loading
Loading