From 9951d78217b363ea99f312e8396060f7c953bffa Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Wed, 9 Oct 2024 00:48:03 -0700 Subject: [PATCH 01/14] Align RNG tracker with megatron Signed-off-by: Robin Zhang Co-authored-by: Yifei Song --- transformer_engine/pytorch/distributed.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 490ac3b160..d6bba0e624 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -720,16 +720,22 @@ class CudaRNGStatesTracker: """ def __init__(self): - # Map from a string name to the cuda rng state. - self.states_ = {} - # Seeds are just for book keeping and ensure no seed is set twice. - self.seeds_ = set() + self.reset() + + def is_initialized(self): + """Checks if the internal RNG state has been set wirth set_states().""" + return self._is_initialized def reset(self): - """ - Set to the initial state (no tracker). - """ + """Set to the initial state (no tracker).""" + + # Track if initialized. + self._is_initialized = False + + # Map from a string name to the cuda rng state. self.states_ = {} + + # Seeds are just for book keeping and ensure no seed is set twice. self.seeds_ = set() def get_states(self) -> Dict[str, torch.Tensor]: @@ -750,6 +756,7 @@ def set_states(self, states: Dict[str, torch.Tensor]) -> None: states: Dict[str, torch.Tensor] A mapping from string names to RNG states. """ + self._is_initialized = True self.states_ = states def add(self, name: str, seed: int) -> None: @@ -761,6 +768,7 @@ def add(self, name: str, seed: int) -> None: seed: int PyTorch seed for the RNG state. """ + self._is_initialized = True # Check seed is not already used. if seed in self.seeds_: raise RuntimeError(f"seed {seed} already exists") From b5f7cdfe80b273d39aa8b9814cd7903852f8d124 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Wed, 9 Oct 2024 16:26:05 -0700 Subject: [PATCH 02/14] Fix module_params order and warmup bug in cudagraph Signed-off-by: Robin Zhang Co-authored-by: Yifei Song --- transformer_engine/pytorch/graph.py | 44 +++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index c47b792a95..ca8bc7f407 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -173,11 +173,14 @@ def _make_graphed_callables( ] else: per_callable_module_params = [] - for c in callables: - for i in range(num_microbatches): - per_callable_module_params.append( - tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () - ) + for m_chunk in range(num_model_chunks): + for _ in range(num_microbatches): + for l_no in range(num_layers): + per_callable_module_params.append( + tuple(callables[m_chunk * num_layers + l_no].parameters()) + if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module) + else () + ) assert len(per_callable_module_params) == len(flatten_sample_args) per_callable_static_input_surfaces = [ flatten_sample_args[i] + per_callable_module_params[i] @@ -201,8 +204,37 @@ def _make_graphed_callables( # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work # from ending up in any captures. torch.cuda.synchronize() - with torch.cuda.stream(torch.cuda.Stream()): + + # Get warmup func and func_idx. + warmup_func_idx = [] + warmup_func = [] + if _order is None: for func_idx, func in enumerate(callables): + warmup_func_idx.append(func_idx) + warmup_func.append(func) + else: + fwd_idx = [0] * num_model_chunks + for idx, c_id in enumerate(_order): + if c_id > 0: + m_chunk = c_id - 1 + for l_no in range(num_layers): + func = callables[m_chunk * num_layers + l_no] + func_idx = (m_chunk * num_microbatches * num_layers) + ( + fwd_idx[m_chunk] * num_layers + l_no + ) + warmup_func_idx.append(func_idx) + warmup_func.append(func) + fwd_idx[m_chunk] += 1 + assert len(warmup_func) == len( + sample_args + ), f"Warmup runs {len(warmup_func)} don't match args {len(sample_args)}." + assert len(warmup_func_idx) == len( + set(warmup_func_idx) + ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." + + # Run warmup. + with torch.cuda.stream(torch.cuda.Stream()): + for func_idx, func in zip(warmup_func_idx, warmup_func): args = sample_args[func_idx] kwargs = sample_kwargs[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx] From 0ede5cb4eea3c05cb3edbfd4356ef20e60b42e60 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Wed, 9 Oct 2024 16:28:40 -0700 Subject: [PATCH 03/14] Add fp8_group argument and fix fp8 accuracy issue for cudagraph Signed-off-by: Robin Zhang Co-authored-by: Yifei Song --- transformer_engine/pytorch/fp8.py | 12 ++++++------ transformer_engine/pytorch/graph.py | 14 +++++++++++++- .../pytorch/module/layernorm_linear.py | 7 +++++-- transformer_engine/pytorch/module/layernorm_mlp.py | 5 ++++- transformer_engine/pytorch/module/linear.py | 6 ++++-- 5 files changed, 32 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 2a909dabc6..15f20c81e5 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -442,16 +442,16 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() # Replace amaxes and scales with stashed values for phase 2 forward - fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0] - fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1] - fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2] + fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) + fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1]) + fp8_meta["scaling_fwd"].scale_inv.copy_(stashed_fp8_meta[2]) @staticmethod def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" - fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"] - fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"] - fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"] + fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) + fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) + fp8_meta["scaling_fwd"].scale_inv.copy_(fp8_meta["updated_scale_inv_fwd"]) @contextmanager diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index ca8bc7f407..4bcc935986 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -12,6 +12,7 @@ from torch._C import _graph_pool_handle from transformer_engine.common.recipe import DelayedScaling +from transformer_engine.pytorch.constants import dist_group_type from .fp8 import ( fp8_autocast, FP8GlobalStateManager, @@ -248,6 +249,9 @@ def _make_graphed_callables( allow_unused=allow_unused_input, ) del outputs, grad_inputs + for module in func.modules(): + if hasattr(module, "is_first_microbatch"): + module.is_first_microbatch = True torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, @@ -570,6 +574,7 @@ def make_graphed_callables( fp8_enabled: bool = False, fp8_calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, + fp8_group: Optional[dist_group_type] = None, fp8_weight_caching: bool = False, _order: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, @@ -611,6 +616,9 @@ def make_graphed_callables( using a higher precision. fp8_recipe: recipe.DelayedScaling, default = `None` recipe used for FP8 training. + fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` + distributed group over which amaxes for the fp8 tensors + are reduced at the end of each training step. fp8_weight_caching: bool, default = `False` Whether or not to cache FP8 weights across microbatches. if set to `True`, the `is_first_microbatch` boolean argument must be passed into the forward @@ -639,7 +647,11 @@ def wrap_autocast(block): def forward_func(*args, **kwargs): with fp8_autocast( - enabled=fp8_enabled, calibrating=fp8_calibrating, fp8_recipe=fp8_recipe, _graph=True + enabled=fp8_enabled, + calibrating=fp8_calibrating, + fp8_recipe=fp8_recipe, + fp8_group=fp8_group, + _graph=True, ): outputs = old_forward(*args, **kwargs) return outputs diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index fbf1b97704..b179a6c619 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1152,7 +1152,10 @@ def forward( produced) """ - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False @@ -1214,7 +1217,7 @@ def forward( self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, - self.fp8, + False if self.fp8 is None else self.fp8, self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 64e8c9ce36..1a651474bf 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1484,7 +1484,10 @@ def forward( produced) """ - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1fed467210..9492725f56 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -938,8 +938,10 @@ def forward( first microbatch (since it is the first gradient being produced) """ - - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False From 5596437dc1310ef417bddfe874f8fb9a947e47fe Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Wed, 9 Oct 2024 16:29:47 -0700 Subject: [PATCH 04/14] Add TE modules and weights filters to support MoE models Signed-off-by: Robin Zhang Co-authored-by: Yifei Song --- transformer_engine/pytorch/graph.py | 56 ++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 4bcc935986..6898d9d230 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -233,14 +233,43 @@ def _make_graphed_callables( set(warmup_func_idx) ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." - # Run warmup. + # Filter the TE modules that cudagraph can access. + visited_te_modules = set() + + def hook_fn(module, input, output): + if ( + isinstance(module, TransformerEngineBaseModule) + and FP8GlobalStateManager.is_fp8_enabled() + ): + visited_te_modules.add(module) + + # Filter the weights without gradients in backward. These weights are not in the computation graph so can be removed from cudagraph inputs. + no_grad_weights = None + + def get_no_grads(static_input_surface, grad_inputs, func_idx): + grad_index = 0 + none_grads = [] + for i in range(len(static_input_surface)): + if static_input_surface[i].requires_grad: + if grad_inputs[grad_index] is None and i >= len(flatten_sample_args[func_idx]): + none_grads.append(i) + grad_index += 1 + return set(none_grads) + + # Run warmup and do the above two filtering. with torch.cuda.stream(torch.cuda.Stream()): for func_idx, func in zip(warmup_func_idx, warmup_func): args = sample_args[func_idx] kwargs = sample_kwargs[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx] for _ in range(num_warmup_iters): + hooks = [] + for module in func.modules(): + hook = module.register_forward_hook(hook_fn) + hooks.append(hook) outputs, _ = _tree_flatten(func(*args, **kwargs)) + for hook in hooks: + hook.remove() grad_inputs = torch.autograd.grad( outputs=tuple(o for o in outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), @@ -248,10 +277,24 @@ def _make_graphed_callables( only_inputs=True, allow_unused=allow_unused_input, ) + if no_grad_weights is None: + no_grad_weights = get_no_grads(static_input_surface, grad_inputs, func_idx) del outputs, grad_inputs for module in func.modules(): if hasattr(module, "is_first_microbatch"): module.is_first_microbatch = True + if len(no_grad_weights) > 0: + per_callable_static_input_surfaces[func_idx] = tuple( + inp + for i, inp in enumerate(per_callable_static_input_surfaces[func_idx]) + if i not in no_grad_weights + ) + per_callable_module_params[func_idx] = tuple( + param + for i, param in enumerate(per_callable_module_params[func_idx]) + if i + len(flatten_sample_args[func_idx]) not in no_grad_weights + ) + no_grad_weights = None torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, @@ -498,6 +541,17 @@ def new_fwd(*user_args, **user_kwargs): isinstance(m, TransformerEngineBaseModule) and FP8GlobalStateManager.is_fp8_enabled() ): + if m not in visited_te_modules: + # Only Set the FP8 meta for the modules included by forward + continue + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if ( + not fp8_recipe.fp8_mha + and not fp8_recipe.fp8_dpa + and hasattr(m, "attention_dropout") + and m.deterministic + ): + continue m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( From 1d0759efc597ee6572a4850f39c1a6bf2203e65c Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Thu, 10 Oct 2024 20:07:18 -0700 Subject: [PATCH 05/14] Revert self.fp8 Signed-off-by: Robin Zhang --- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index b179a6c619..92b37fcb07 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1217,7 +1217,7 @@ def forward( self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, - False if self.fp8 is None else self.fp8, + self.fp8, self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, From 73a22e2e8bcf43ec84c23bc844b8d16d06626e26 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Thu, 10 Oct 2024 20:10:33 -0700 Subject: [PATCH 06/14] Use hooks to filter module params Signed-off-by: Robin Zhang --- transformer_engine/pytorch/graph.py | 43 ++++++++++------------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 6898d9d230..d4470130e3 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -233,8 +233,9 @@ def _make_graphed_callables( set(warmup_func_idx) ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." - # Filter the TE modules that cudagraph can access. + # Filter the TE modules and parameters that cudagraph can access. visited_te_modules = set() + visited_params = set() def hook_fn(module, input, output): if ( @@ -242,21 +243,9 @@ def hook_fn(module, input, output): and FP8GlobalStateManager.is_fp8_enabled() ): visited_te_modules.add(module) + visited_params.update(module.parameters(recurse=False)) - # Filter the weights without gradients in backward. These weights are not in the computation graph so can be removed from cudagraph inputs. - no_grad_weights = None - - def get_no_grads(static_input_surface, grad_inputs, func_idx): - grad_index = 0 - none_grads = [] - for i in range(len(static_input_surface)): - if static_input_surface[i].requires_grad: - if grad_inputs[grad_index] is None and i >= len(flatten_sample_args[func_idx]): - none_grads.append(i) - grad_index += 1 - return set(none_grads) - - # Run warmup and do the above two filtering. + # Run warmup and do the above filtering. with torch.cuda.stream(torch.cuda.Stream()): for func_idx, func in zip(warmup_func_idx, warmup_func): args = sample_args[func_idx] @@ -277,24 +266,20 @@ def get_no_grads(static_input_surface, grad_inputs, func_idx): only_inputs=True, allow_unused=allow_unused_input, ) - if no_grad_weights is None: - no_grad_weights = get_no_grads(static_input_surface, grad_inputs, func_idx) del outputs, grad_inputs for module in func.modules(): if hasattr(module, "is_first_microbatch"): module.is_first_microbatch = True - if len(no_grad_weights) > 0: - per_callable_static_input_surfaces[func_idx] = tuple( - inp - for i, inp in enumerate(per_callable_static_input_surfaces[func_idx]) - if i not in no_grad_weights - ) - per_callable_module_params[func_idx] = tuple( - param - for i, param in enumerate(per_callable_module_params[func_idx]) - if i + len(flatten_sample_args[func_idx]) not in no_grad_weights - ) - no_grad_weights = None + + # Remove cudagraph input params that are not accessed. + for idx in range(len(flatten_sample_args)): + per_callable_module_params[idx] = tuple( + param for param in per_callable_module_params[idx] if param in visited_params + ) + per_callable_static_input_surfaces[idx] = ( + flatten_sample_args[idx] + per_callable_module_params[idx] + ) + torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, From cd56618005c84007254683b706bfb41fd5822de1 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Fri, 11 Oct 2024 00:07:53 -0700 Subject: [PATCH 07/14] Filter all TE modules in hooks Signed-off-by: Robin Zhang Co-authored-by: Yifei Song --- transformer_engine/pytorch/graph.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index d4470130e3..04b68e5e35 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -238,10 +238,7 @@ def _make_graphed_callables( visited_params = set() def hook_fn(module, input, output): - if ( - isinstance(module, TransformerEngineBaseModule) - and FP8GlobalStateManager.is_fp8_enabled() - ): + if isinstance(module, TransformerEngineBaseModule): visited_te_modules.add(module) visited_params.update(module.parameters(recurse=False)) From 2a7f54b488ff16c950458998f3fd006ace32c088 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Fri, 11 Oct 2024 01:07:00 -0700 Subject: [PATCH 08/14] Format code Signed-off-by: Robin Zhang --- transformer_engine/pytorch/graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 04b68e5e35..dcee2bfd0d 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -215,7 +215,7 @@ def _make_graphed_callables( warmup_func.append(func) else: fwd_idx = [0] * num_model_chunks - for idx, c_id in enumerate(_order): + for c_id in _order: if c_id > 0: m_chunk = c_id - 1 for l_no in range(num_layers): @@ -237,7 +237,7 @@ def _make_graphed_callables( visited_te_modules = set() visited_params = set() - def hook_fn(module, input, output): + def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument if isinstance(module, TransformerEngineBaseModule): visited_te_modules.add(module) visited_params.update(module.parameters(recurse=False)) From c6dddaf7e3ae618c1caf50602fe059c8b15b3934 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Oct 2024 08:07:30 +0000 Subject: [PATCH 09/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index dcee2bfd0d..24f26a6306 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -237,7 +237,7 @@ def _make_graphed_callables( visited_te_modules = set() visited_params = set() - def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument + def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument if isinstance(module, TransformerEngineBaseModule): visited_te_modules.add(module) visited_params.update(module.parameters(recurse=False)) From a01602e9bd7918bbf4ea2c45dbde7d4be183bf65 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Fri, 11 Oct 2024 22:10:07 +0800 Subject: [PATCH 10/14] Update graph.py Signed-off-by: Xin Yao --- transformer_engine/pytorch/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 24f26a6306..ae1336233c 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -269,7 +269,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument module.is_first_microbatch = True # Remove cudagraph input params that are not accessed. - for idx in range(len(flatten_sample_args)): + for idx, _ in enumerate(flatten_sample_args): per_callable_module_params[idx] = tuple( param for param in per_callable_module_params[idx] if param in visited_params ) From 8017b6d9ce96c7d90d2fe46647cec1cca18bd4a3 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Thu, 31 Oct 2024 18:50:58 -0700 Subject: [PATCH 11/14] Revert CudaRNGStatesTracker Signed-off-by: Robin Zhang --- transformer_engine/pytorch/distributed.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index d6bba0e624..490ac3b160 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -720,24 +720,18 @@ class CudaRNGStatesTracker: """ def __init__(self): - self.reset() - - def is_initialized(self): - """Checks if the internal RNG state has been set wirth set_states().""" - return self._is_initialized - - def reset(self): - """Set to the initial state (no tracker).""" - - # Track if initialized. - self._is_initialized = False - # Map from a string name to the cuda rng state. self.states_ = {} - # Seeds are just for book keeping and ensure no seed is set twice. self.seeds_ = set() + def reset(self): + """ + Set to the initial state (no tracker). + """ + self.states_ = {} + self.seeds_ = set() + def get_states(self) -> Dict[str, torch.Tensor]: """ Get rng states. Copy the dictionary so we have direct pointers @@ -756,7 +750,6 @@ def set_states(self, states: Dict[str, torch.Tensor]) -> None: states: Dict[str, torch.Tensor] A mapping from string names to RNG states. """ - self._is_initialized = True self.states_ = states def add(self, name: str, seed: int) -> None: @@ -768,7 +761,6 @@ def add(self, name: str, seed: int) -> None: seed: int PyTorch seed for the RNG state. """ - self._is_initialized = True # Check seed is not already used. if seed in self.seeds_: raise RuntimeError(f"seed {seed} already exists") From 41d6100571134de9b5f0c883c97ff6cff6cb6784 Mon Sep 17 00:00:00 2001 From: Yifei Song Date: Tue, 19 Nov 2024 22:50:11 -0800 Subject: [PATCH 12/14] Format Update Signed-off-by: Yifei Song --- transformer_engine/pytorch/graph.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index ae1336233c..7ebfe7e379 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -264,6 +264,8 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument allow_unused=allow_unused_input, ) del outputs, grad_inputs + # The following code is added specifically for MCore's special requirements, + # aimed at preventing warmup from altering the control flow. for module in func.modules(): if hasattr(module, "is_first_microbatch"): module.is_first_microbatch = True @@ -527,12 +529,14 @@ def new_fwd(*user_args, **user_kwargs): # Only Set the FP8 meta for the modules included by forward continue fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + from transformer_engine.pytorch.attention import DotProductAttention + if ( - not fp8_recipe.fp8_mha + isinstance(m, DotProductAttention) + and not fp8_recipe.fp8_mha and not fp8_recipe.fp8_dpa - and hasattr(m, "attention_dropout") - and m.deterministic ): + # Don't need to update FP8 meta for non-FP8 DPA continue m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() From c82faa2268ea326a1d6ad09bd51aef9a96ae401e Mon Sep 17 00:00:00 2001 From: Yifei Song Date: Wed, 20 Nov 2024 09:54:55 +0000 Subject: [PATCH 13/14] Revert "Use hooks to filter module params" This reverts commit 73a22e2e8bcf43ec84c23bc844b8d16d06626e26. Signed-off-by: Yifei Song --- transformer_engine/pytorch/graph.py | 43 +++++++++++++++++++---------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 7ebfe7e379..d53ede87bf 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -233,16 +233,27 @@ def _make_graphed_callables( set(warmup_func_idx) ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." - # Filter the TE modules and parameters that cudagraph can access. + # Filter the TE modules that cudagraph can access. visited_te_modules = set() - visited_params = set() def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument if isinstance(module, TransformerEngineBaseModule): visited_te_modules.add(module) - visited_params.update(module.parameters(recurse=False)) - # Run warmup and do the above filtering. + # Filter the weights without gradients in backward. These weights are not in the computation graph so can be removed from cudagraph inputs. + no_grad_weights = None + + def get_no_grads(static_input_surface, grad_inputs, func_idx): + grad_index = 0 + none_grads = [] + for i in range(len(static_input_surface)): + if static_input_surface[i].requires_grad: + if grad_inputs[grad_index] is None and i >= len(flatten_sample_args[func_idx]): + none_grads.append(i) + grad_index += 1 + return set(none_grads) + + # Run warmup and do the above two filtering. with torch.cuda.stream(torch.cuda.Stream()): for func_idx, func in zip(warmup_func_idx, warmup_func): args = sample_args[func_idx] @@ -263,22 +274,26 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument only_inputs=True, allow_unused=allow_unused_input, ) + if no_grad_weights is None: + no_grad_weights = get_no_grads(static_input_surface, grad_inputs, func_idx) del outputs, grad_inputs # The following code is added specifically for MCore's special requirements, # aimed at preventing warmup from altering the control flow. for module in func.modules(): if hasattr(module, "is_first_microbatch"): module.is_first_microbatch = True - - # Remove cudagraph input params that are not accessed. - for idx, _ in enumerate(flatten_sample_args): - per_callable_module_params[idx] = tuple( - param for param in per_callable_module_params[idx] if param in visited_params - ) - per_callable_static_input_surfaces[idx] = ( - flatten_sample_args[idx] + per_callable_module_params[idx] - ) - + if len(no_grad_weights) > 0: + per_callable_static_input_surfaces[func_idx] = tuple( + inp + for i, inp in enumerate(per_callable_static_input_surfaces[func_idx]) + if i not in no_grad_weights + ) + per_callable_module_params[func_idx] = tuple( + param + for i, param in enumerate(per_callable_module_params[func_idx]) + if i + len(flatten_sample_args[func_idx]) not in no_grad_weights + ) + no_grad_weights = None torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, From 938f32549d55be50509d6d37c62023cde3dff8d5 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Wed, 20 Nov 2024 04:56:47 -0800 Subject: [PATCH 14/14] Remove filtering module params Signed-off-by: Robin Zhang --- transformer_engine/pytorch/graph.py | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index d53ede87bf..6c33cc72b9 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -240,20 +240,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument if isinstance(module, TransformerEngineBaseModule): visited_te_modules.add(module) - # Filter the weights without gradients in backward. These weights are not in the computation graph so can be removed from cudagraph inputs. - no_grad_weights = None - - def get_no_grads(static_input_surface, grad_inputs, func_idx): - grad_index = 0 - none_grads = [] - for i in range(len(static_input_surface)): - if static_input_surface[i].requires_grad: - if grad_inputs[grad_index] is None and i >= len(flatten_sample_args[func_idx]): - none_grads.append(i) - grad_index += 1 - return set(none_grads) - - # Run warmup and do the above two filtering. + # Run warmup and do the above filtering. with torch.cuda.stream(torch.cuda.Stream()): for func_idx, func in zip(warmup_func_idx, warmup_func): args = sample_args[func_idx] @@ -274,26 +261,12 @@ def get_no_grads(static_input_surface, grad_inputs, func_idx): only_inputs=True, allow_unused=allow_unused_input, ) - if no_grad_weights is None: - no_grad_weights = get_no_grads(static_input_surface, grad_inputs, func_idx) del outputs, grad_inputs # The following code is added specifically for MCore's special requirements, # aimed at preventing warmup from altering the control flow. for module in func.modules(): if hasattr(module, "is_first_microbatch"): module.is_first_microbatch = True - if len(no_grad_weights) > 0: - per_callable_static_input_surfaces[func_idx] = tuple( - inp - for i, inp in enumerate(per_callable_static_input_surfaces[func_idx]) - if i not in no_grad_weights - ) - per_callable_module_params[func_idx] = tuple( - param - for i, param in enumerate(per_callable_module_params[func_idx]) - if i + len(flatten_sample_args[func_idx]) not in no_grad_weights - ) - no_grad_weights = None torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory,