From 979fdbd44018ad289bb36b8b22cf2670bd5f9898 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 12 Nov 2025 13:12:05 -0800 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- .../deepseek_v3/parallelize.py | 5 +- .../compiler_toolkit/graph_utils.py | 50 ++++++++++++++----- .../compiler_toolkit/llama3/parallelize.py | 5 +- 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index bc6859af61..20ad17f301 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -80,7 +80,9 @@ def parallelize_deepseekv3( compiler_passes = get_compiler_passes_from_config(job_config) # Create compilers with specified passes (defaults to no passes) - fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes) + fw_compiler, bw_compiler = make_compiler_with_passes( + compiler_passes, dump_folder=job_config.job.dump_folder + ) # Create custom joint_graph_builder with deepseekv3-specific compilers deepseekv3_joint_graph_builder = functools.partial( @@ -88,6 +90,7 @@ def parallelize_deepseekv3( fw_compiler=fw_compiler, bw_compiler=bw_compiler, joint_custom_pass=validate_flex_attention_annotation, + dump_folder=job_config.job.dump_folder, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index aee089cad9..db998aa170 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +from pathlib import Path from typing import Callable, List, Optional import torch @@ -21,8 +22,18 @@ from torchtitan.tools.logging import logger +def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None: + # TODO: make the dump rank configurable + if not dump_folder or torch.distributed.get_rank() != 0: + return + + output_path = Path(dump_folder) / "compiler" / f"{name}.txt" + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(gm.print_readable(print_output=False)) + + def export_joint( - model, args, kwargs=None + model, args, kwargs=None, dump_folder: str | None = None ) -> tuple[JointWithDescriptors, TracingContext]: if kwargs is None: kwargs = {} @@ -35,8 +46,10 @@ def export_joint( torch.fx.traceback.preserve_node_meta(), ): gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) - logger.info("Dynamo gm:") - logger.info(gm.print_readable(print_output=False)) + logger.debug("Dynamo gm:") + logger.debug(gm.print_readable(print_output=False)) + _dump_gm(dump_folder, gm, "dynamo_gm") + tracing_context = gm.meta["tracing_context"] with tracing(tracing_context): @@ -68,6 +81,7 @@ def joint_graph_builder( fw_compiler: Optional[Callable] = None, bw_compiler: Optional[Callable] = None, joint_custom_pass: Optional[Callable] = None, + dump_folder: str | None = None, ): """ Build a joint forward-backward graph for the model with optional custom compilers. @@ -79,16 +93,17 @@ def joint_graph_builder( fw_compiler: Optional custom forward compiler function bw_compiler: Optional custom backward compiler function joint_custom_pass: Optional custom pass to run on the joint graph + dump_folder: Optional folder to dump the graph to """ assert isinstance(model_args, tuple) - for arg in model_args: - assert isinstance(arg, DTensor) + for idx, arg in enumerate(model_args): + assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}" # get joint graph ( joint_with_descriptors, tracing_context, - ) = export_joint(model, model_args, model_kwargs) + ) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder) # Optional validation if joint_custom_pass is not None: @@ -179,6 +194,7 @@ def compiler( gm: torch.fx.GraphModule, example_inputs, passes: List[Callable] = None, + dump_folder: str | None = None, ): """ Compile a graph module by applying a sequence of compiler passes. @@ -194,19 +210,23 @@ def compiler( if passes is None: passes = DEFAULT_COMPILER_PASSES - logger.info(f"{name} before compiler:") - logger.info(gm.print_readable(print_output=False)) + logger.debug(f"{name} before compiler:") + logger.debug(gm.print_readable(print_output=False)) + _dump_gm(dump_folder, gm, f"{name}_before_compiler") for pass_fn in passes: logger.info(f"Applying pass: {pass_fn.__name__}") gm = pass_fn(gm, example_inputs) - logger.info(f"{name} after compiler:") - logger.info(gm.print_readable(print_output=False)) + logger.debug(f"{name} after compiler:") + logger.debug(gm.print_readable(print_output=False)) + _dump_gm(dump_folder, gm, f"{name}_after_compiler") return gm -def make_compiler_with_passes(passes: List[Callable] = None): +def make_compiler_with_passes( + passes: List[Callable] = None, dump_folder: str | None = None +): """ Create forward and backward compilers with specified passes. @@ -218,10 +238,14 @@ def make_compiler_with_passes(passes: List[Callable] = None): """ def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("fwd_gm", gm, example_inputs, passes=passes) + return compiler( + "fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + ) def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("bwd_gm", gm, example_inputs, passes=passes) + return compiler( + "bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + ) return fw_compiler, bw_compiler diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index e3dca203e9..0ffbe61b89 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -64,7 +64,9 @@ def parallelize_llama( model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config) # Get compiler passes from config - compiler_passes = get_compiler_passes_from_config(job_config) + compiler_passes = get_compiler_passes_from_config( + job_config, dump_folder=job_config.job.dump_folder + ) # Create compilers with specified passes (defaults to no passes) fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes) @@ -75,6 +77,7 @@ def parallelize_llama( fw_compiler=fw_compiler, bw_compiler=bw_compiler, joint_custom_pass=validate_flex_attention_annotation, + dump_folder=job_config.job.dump_folder, ) # TODO: CompiledModule should take sample input as well, so that we can From be7e4861c7bbb5d0a7fa7fb82ce89dd10aaddf43 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 12 Nov 2025 13:30:19 -0800 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- .../experiments/compiler_toolkit/llama3/parallelize.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index 0ffbe61b89..62def3ef00 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -64,12 +64,12 @@ def parallelize_llama( model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config) # Get compiler passes from config - compiler_passes = get_compiler_passes_from_config( - job_config, dump_folder=job_config.job.dump_folder - ) + compiler_passes = get_compiler_passes_from_config(job_config) # Create compilers with specified passes (defaults to no passes) - fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes) + fw_compiler, bw_compiler = make_compiler_with_passes( + compiler_passes, dump_folder=job_config.job.dump_folder + ) # Create custom joint_graph_builder with llama-specific compilers and validation llama_joint_graph_builder = functools.partial( From 097d5473acacb4c64f02f749b407b111c8d0b960 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 12 Nov 2025 23:56:59 -0800 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- torchtitan/experiments/compiler_toolkit/graph_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index db998aa170..413ea066fb 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -206,6 +206,7 @@ def compiler( passes: List of compiler pass functions to apply. Each function should take (gm, example_inputs) and return a transformed gm. If None, uses DEFAULT_COMPILER_PASSES. + dump_folder: Optional folder to dump the graph to """ if passes is None: passes = DEFAULT_COMPILER_PASSES