From e897d49da867b6ee14227e56ac4615713f082112 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 11 Mar 2026 02:12:39 -0700 Subject: [PATCH 1/9] Add shared activation hooks infrastructure for minitron and puzzletron - Add base hooks framework in modelopt/torch/nas/plugins/megatron_hooks/ - base_hooks.py: Core hook infrastructure - base_hooks_analysis.py: Analysis utilities for hooks - megatron_hooks.py: Megatron-specific hook implementations - compare_module_outputs.py: Module comparison utilities - Add tests for activation hooks - Update test utilities for distributed testing - Update minitron pruning tests to use new activation hooks Signed-off-by: Daniel Korzekwa --- .../nas/plugins/megatron_hooks/__init__.py | 23 + .../nas/plugins/megatron_hooks/base_hooks.py | 824 ++++++++++++++++++ .../megatron_hooks/base_hooks_analysis.py | 104 +++ .../megatron_hooks/compare_module_outputs.py | 291 +++++++ .../plugins/megatron_hooks/megatron_hooks.py | 36 + tests/_test_utils/torch/distributed/utils.py | 5 + tests/conftest.py | 7 + .../plugins/megatron_hooks/test_base_hooks.py | 100 +++ .../test_base_hooks_analysis.py | 173 ++++ .../test_mcore_gpt_minitron_pruning.py | 48 + 10 files changed, 1611 insertions(+) create mode 100644 modelopt/torch/nas/plugins/megatron_hooks/__init__.py create mode 100644 modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py create mode 100644 modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py create mode 100644 modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py create mode 100644 modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py create mode 100644 tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py create mode 100644 tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py diff --git a/modelopt/torch/nas/plugins/megatron_hooks/__init__.py b/modelopt/torch/nas/plugins/megatron_hooks/__init__.py new file mode 100644 index 0000000000..996d531392 --- /dev/null +++ b/modelopt/torch/nas/plugins/megatron_hooks/__init__.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Forward hooks for estimating importance scores for pruning.""" + +from modelopt.torch.utils import import_plugin + +from .base_hooks import * +from .base_hooks_analysis import * + +with import_plugin("megatron_hooks"): + from .megatron_hooks import * diff --git a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py new file mode 100644 index 0000000000..56436acfdd --- /dev/null +++ b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py @@ -0,0 +1,824 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Forward hooks for activation-based importance estimation.""" + +import gc +import json +from abc import ABC, abstractmethod +from datetime import datetime +from pathlib import Path + +import torch +import torch.nn.functional as F +from omegaconf import DictConfig, OmegaConf +from torch import nn + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.tools.logger import aprint +from modelopt.torch.puzzletron.tools.robust_json import json_dump + +__all__ = [ + "ForwardHook", + "IndependentChannelContributionHook", + "IndependentKvHeadContributionHook", + "IterativeChannelContributionHook", + "L2NormHook", + "LayerNormContributionHook", +] + + +def clear_gpu_memory(clear: bool) -> None: + """Clear GPU memory cache if requested. + + Args: + clear: If True, runs garbage collection and empties CUDA cache. + """ + if clear: + gc.collect() + torch.cuda.empty_cache() + + +class ForwardHook(ABC): + """Base class for PyTorch forward hooks. + + This follows the PyTorch forward hook API where the second + parameter is 'args' (a tuple of positional arguments passed to forward()). + + Usage: + hook = MyHook() + module.register_forward_hook(hook) + """ + + @abstractmethod + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that is called after the module's forward pass. + + Args: + module: The module this hook is registered on + args: Tuple of positional arguments passed to module.forward() + output: The output from module.forward() + + Returns: + None (does not modify the output) + """ + ... + + @abstractmethod + def accumulate(self) -> torch.Tensor: + """Return accumulated importance scores. + + This method should be called after all forward passes to retrieve + the final importance scores for each channel/feature. + + Returns: + Tensor of importance scores, one per channel/feature. + + Raises: + AssertionError: If no activations have been collected yet. + """ + ... + + @abstractmethod + def state_dict(self) -> dict: + """Return the internal state for checkpointing. + + Returns: + dict: State dictionary containing checkpoint data. + Can contain tensors, ints, lists, etc. + """ + ... + + @abstractmethod + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint. + + Args: + state_dict: State dictionary previously returned by state_dict() + """ + ... + + def get_progress_info(self) -> dict: + """Get progress information for this hook. + + Returns: + dict: Progress information (e.g., current iteration, samples processed). + Default implementation returns empty dict. + """ + return {} + + @abstractmethod + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert hook results to dictionary format for saving. + + Returns: + dict: Dictionary containing result tensors (e.g., "score", "channels_importance_ascending"). + """ + ... + + @classmethod + def dump_activations_logs( + cls: type["ForwardHook"], + activation_hooks: dict[str, "ForwardHook"], + activations_log_dir: Path | str, + args: DictConfig, + ) -> None: + """Default implementation for dumping final activation scores logs to disk. + + This is called only at the end of scoring to save final results. + """ + activations_log_dir = Path(activations_log_dir) + activations_log_dir.mkdir(exist_ok=True, parents=True) + rank = dist.rank() + activations_log_path = activations_log_dir / f"rank_{rank}.pth" + activations_log = { + module_name: hook.to_dict() for module_name, hook in activation_hooks.items() + } + torch.save(activations_log, activations_log_path) + + if rank == 0: + args.activation_hooks_kwargs.pop("model") + json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json") + dist.barrier() + + aprint(f"Dumped final activations log to {activations_log_path}") + + @classmethod + def save_hook_states( + cls: type["ForwardHook"], + activation_hooks: dict[str, "ForwardHook"], + activations_log_dir: Path | str, + ) -> None: + """Save hook states for checkpointing (separate from final results). + + This can be called periodically during scoring. + Note: Synchronization should be handled at a higher level to avoid deadlocks. + """ + activations_log_dir = Path(activations_log_dir) + activations_log_dir.mkdir(exist_ok=True, parents=True) + rank = dist.rank() + + hook_states_path = activations_log_dir / f"hook_states_rank_{rank}.pth" + hook_states = { + module_name: hook.state_dict() for module_name, hook in activation_hooks.items() + } + torch.save(hook_states, hook_states_path) + + +class L2NormHook(ForwardHook): + """Hook for accumulating activation statistics for importance estimation. + + Activations are computed as mean over seq_len and then squared and summed over batch_size. + In the accumulate() method we take the square root of the sum to get the L2 norm. + + This is the base version without tensor parallelism support. + For megatron with TP > 1, use MegatronL2NormHook instead. + + Args: + max_size: Optional maximum expected size to validate against (skips if mismatch). + Useful for skipping non-max subnets during profiling. + """ + + def __init__(self, max_size: int | None = None): + """Initialize the L2NormHook.""" + self.max_size = max_size + self._activations: torch.Tensor | None = None + + def _get_input_tensor(self, args: tuple[torch.Tensor, ...]) -> torch.Tensor: + """Get input tensor from args. Override in subclass for TP gathering.""" + return args[0].detach() + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Accumulate activation statistics from the forward pass. + + Args: + module: The module this hook is registered on. + args: Tuple of input tensors. args[0] expected shape: [seq_len, batch_size, hidden_size] + (Megatron sequence-first format). + output: Output tensor from the module's forward pass. + """ + input_tensor = self._get_input_tensor(args) + + if input_tensor.dim() == 2: + # For sparse experts, there is no batch dimension. + input_tensor = input_tensor[:, None, :] + + # Dont aggregate activations from non-max subnets (e.g. from profiling) + if self.max_size is not None and input_tensor.shape[-1] != self.max_size: + return + + input_tensor = input_tensor.to(torch.float32) # use full precision to avoid overflow + activations = input_tensor.abs().mean(dim=0) # [batch_size, hidden_size] + activations = activations.pow(2).sum(dim=0) # [hidden_size] + + if self._activations is None: + self._activations = activations + else: + self._activations += activations + + def accumulate(self) -> torch.Tensor: + """Return the accumulated L2 norm of activations. + + Returns: + Tensor of accumulated scores, one per channel + + Raises: + AssertionError: If no activations have been collected yet + """ + assert self._activations is not None, "No activations collected for importance estimation." + # Convert squared sum to L2 norm + return self._activations.pow(0.5) + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert to dict format for saving.""" + return {"score": self.accumulate().cpu()} + + def state_dict(self) -> dict: + """Return the state dictionary containing activations.""" + return {"activations": self._activations} + + def load_state_dict(self, state_dict: dict) -> None: + """Load activations from checkpoint.""" + self._activations = state_dict["activations"] + + +class IndependentChannelContributionHook(ForwardHook): + """Hook for channel importance estimation using weight norms and activation magnitudes. + + Computes channel importance as the product of: + - L2 norm of each column in the weight matrix (how much each input channel affects output) + - Mean absolute activation for each channel (how strongly each channel is activated) + + Args: + linear_layer: The linear projection layer to analyze. Must have a `weight` attribute + and either `in_features` (nn.Linear) or `input_size` (Megatron RowParallelLinear). + max_size: Optional maximum expected size to validate against (skips if mismatch). + Useful for skipping non-max subnets during profiling. + """ + + def __init__( + self, + linear_layer: nn.Module, + max_size: int | None = None, + ): + """Initialize the independent channel contribution hook.""" + self.max_size = max_size + + weight_matrix = linear_layer.weight.float() + self.weight_norm = torch.linalg.vector_norm(weight_matrix, dim=0) + + # Check if it's a RowParallelLinear (Megatron-Core) or nn.Linear (PyTorch) + if hasattr(linear_layer, "input_size"): + self.num_channels = linear_layer.input_size # Megatron-Core + else: + self.num_channels = linear_layer.in_features # PyTorch + + self.agg_channel_activations = torch.zeros( + size=(self.num_channels,), + dtype=torch.float32, + device=weight_matrix.device, + ) + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor | tuple + ) -> None: + """Accumulate mean absolute activations per channel. + + Args: + module: The module this hook is registered on. + args: Tuple with single input tensor. args[0] expected shape: [batch_size, seq_len, input_channels] + (PyTorch batch-first format). + output: Output tensor of shape [batch_size, seq_len, output_channels], or tuple (output_tensor, bias) + for parallel layers. + """ + activations = args[0] + + # Don't aggregate activations from non-max subnets (e.g. from profiling) + if self.max_size is not None and activations.shape[-1] != self.max_size: + return + + mean_abs_channel_activations = ( + activations.abs().float().mean(dim=list(range(activations.ndim - 1))) + ) + self.agg_channel_activations[:] += mean_abs_channel_activations # shape [input_channels] + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert results to dict with channel importance scores. + + Returns: + Dict with "score" (weight_norm * activations), "weight_norm", and + "agg_channel_activations". + """ + return { + "score": (self.weight_norm * self.agg_channel_activations).cpu(), + "weight_norm": self.weight_norm.cpu(), + "agg_channel_activations": self.agg_channel_activations.cpu(), + } + + def accumulate(self) -> torch.Tensor: + """Return importance scores as a tensor. + + Returns: + Tensor of importance scores (weight_norm * activations), one per channel. + """ + return self.to_dict()["score"] + + def state_dict(self) -> dict: + """Save the internal state for checkpointing.""" + return { + "agg_channel_activations": self.agg_channel_activations.cpu().clone(), + "weight_norm": self.weight_norm.cpu().clone(), + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.agg_channel_activations = state_dict["agg_channel_activations"].to( + self.agg_channel_activations.device + ) + # weight_norm should be the same as it's derived from the model weights + # but we can verify it matches + expected_weight_norm = state_dict["weight_norm"].to(self.weight_norm.device) + if not torch.allclose(self.weight_norm, expected_weight_norm, rtol=1e-5): + raise AssertionError( + "weight_norm mismatch during state loading - model weights may have changed" + ) + + +def get_pruning_schedule(num_channels, pruning_iters): + """Spending decreases monotonically when num_channels >= pruning_iters. + + Intervals between spends increase monotonically when pruning_iters > num_channels. + The budget is fully utilized, and there's spending in the last iteration. + num_channels = 10, pruning_iters = 4 ==> [3, 3, 2, 2] + num_channels = 4, pruning_iters = 10 ==> [0, 1, 0, 1, 0, 0, 1, 0, 0, 1] + """ + if num_channels >= pruning_iters: + # Case when budget is greater than or equal to iterations + q = num_channels // pruning_iters # Base spend per iteration + r = num_channels % pruning_iters # Remainder to distribute + + schedule = [] + for i in range(pruning_iters): + if i < r: + # Assign higher spend to earlier iterations + schedule.append(q + 1) + else: + schedule.append(q) + else: + # Case when iterations are greater than budget + schedule = [0] * pruning_iters + for i in range(1, num_channels + 1): + # Distribute spends at positions where intervals increase monotonically + pos = ((i * pruning_iters) // num_channels) - 1 + schedule[pos] = 1 + return schedule + + +class IterativeChannelContributionHook(ForwardHook): + """Hook for iterative channel pruning based on contribution analysis. + + Progressively identifies and removes the least important input channels of a linear layer + by measuring channel contribution as the L2 norm of output change when removed. + + Args: + linear_layer: The linear projection layer to analyze. Must have a `weight` attribute + and either `in_features` (nn.Linear) or `input_size` (Megatron RowParallelLinear). + activation_hooks_kwargs: Configuration dict with: + - validation_full_iters (int): Number of pruning iterations. + - clear_gpu_memory (bool, optional): Clear GPU memory during computation. + - calibration_method (str, optional): "scale_by_magnitude" or None. + max_size: Optional maximum expected size to validate against (skips if mismatch). + Useful for skipping non-max subnets during profiling. + """ + + def __init__( + self, + linear_layer: nn.Module, + activation_hooks_kwargs: dict, + max_size: int | None = None, + ): + """Initialize the iterative channel contribution hook.""" + self.weight_matrix = linear_layer.weight + + # Check if it's a RowParallelLinear (Megatron-Core) or nn.Linear (PyTorch) + # TODO: Consider better design to handle RowParallelLinear and nn.Linear + if hasattr(linear_layer, "input_size"): + self.num_channels = linear_layer.input_size # Megatron-Core + else: + self.num_channels = linear_layer.in_features # PyTorch + + self.max_size = max_size + self.pruning_iters = activation_hooks_kwargs["validation_full_iters"] + self.clear_gpu_memory = activation_hooks_kwargs.get("clear_gpu_memory", False) + self.curr_iter = 0 + self.pruning_schedule = get_pruning_schedule( + num_channels=self.num_channels, pruning_iters=self.pruning_iters + ) + + self.agg_cont_per_channel = torch.zeros( + size=(self.num_channels,), + dtype=torch.float32, + device=self.weight_matrix.device, + ) + self.pruned_channels = [] + self.calibration_method = activation_hooks_kwargs.get("calibration_method") + self.epsilon = 1e-8 + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor | tuple + ) -> None: + """Compute channel contributions and prune channels according to schedule. + + Args: + module: The module this hook is registered on. + args: Tuple with single input tensor. args[0] expected shape: [batch_size, seq_len, input_channels] + (PyTorch batch-first format). + output: Output tensor of shape [batch_size, seq_len, output_channels], or tuple (output_tensor, bias) + for parallel layers. + """ + # Handle case where output is a tuple (e.g., from ColumnParallelLinear/RowParallelLinear) + # TODO: Consider better design to handle RowParallelLinear and nn.Linear + if isinstance(output, tuple): + output_tensor = output[0] + else: + output_tensor = output + + activations = args[0] + + # Don't aggregate activations from non-max subnets (e.g. from profiling) + if self.max_size is not None and activations.shape[-1] != self.max_size: + return + + n_channels_to_prune = self.pruning_schedule[self.curr_iter] + + curr_activations = activations.clone() # Shape B,T,I + curr_activations[..., self.pruned_channels] = 0 + output_curr = F.linear(input=curr_activations, weight=self.weight_matrix) # Shape B,T,E + + if self.calibration_method is None: + scaling_factor_per_token = torch.ones_like(output_tensor[..., 0]) # Shape B,T + elif self.calibration_method == "scale_by_magnitude": + output_norms = torch.linalg.vector_norm(output_tensor, dim=-1) # Shape B,T + output_curr_norms = torch.linalg.vector_norm(output_curr, dim=-1) # Shape B,T + scaling_factor_per_token = output_curr_norms / (output_norms + self.epsilon) + del output_curr_norms, output_norms + else: + raise NotImplementedError + del curr_activations + clear_gpu_memory(clear=self.clear_gpu_memory) + + s = scaling_factor_per_token.unsqueeze(-1) * output_tensor - output_curr # Shape: (B, T, E) + s_squared_per_token = torch.sum(s**2, dim=-1) # Shape: (B, T) + b = s @ self.weight_matrix # Shape: (B, T, I) + c = torch.sum(self.weight_matrix**2, dim=0) # Shape: (I) + del s, output_curr + clear_gpu_memory(clear=self.clear_gpu_memory) + + contribution_squared = ( + s_squared_per_token.unsqueeze(2) + 2 * activations * b + (activations**2) * c + ) # Shape: (B, T, I) + del s_squared_per_token, b, c, activations + clear_gpu_memory(clear=self.clear_gpu_memory) + + contribution = torch.sqrt(contribution_squared + self.epsilon) # Shape: (B, T, I) + mean_cont_per_channel = torch.mean(contribution, dim=(0, 1)) # Shape: (I) + mean_cont_per_channel[self.pruned_channels] = torch.inf + del contribution, contribution_squared + clear_gpu_memory(clear=self.clear_gpu_memory) + + self.agg_cont_per_channel += mean_cont_per_channel + if n_channels_to_prune > 0: + _, worst_indices = torch.topk( + self.agg_cont_per_channel, n_channels_to_prune, largest=False + ) + worst_indices_list = worst_indices.tolist() + assert not set(self.pruned_channels).intersection(set(worst_indices_list)) + self.pruned_channels.extend(worst_indices_list) + self.agg_cont_per_channel.zero_() + self.curr_iter += 1 + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert pruning results to dict with channel importance rankings. + + Returns: + Dict with "score" (importance rank per channel) and + "channels_importance_ascending" (channel indices in ascending importance). + """ + assert self.num_channels == len(self.pruned_channels) + channels_importance_ascending = torch.tensor(self.pruned_channels, dtype=torch.long) + score = torch.empty(self.num_channels, dtype=torch.long) + score[channels_importance_ascending] = torch.arange(self.num_channels, dtype=torch.long) + + return { + "score": score.cpu(), + "channels_importance_ascending": channels_importance_ascending.cpu(), + } + + def accumulate(self) -> torch.Tensor: + """Return importance scores as a tensor. + + Returns: + Tensor of importance scores, one per channel. Lower scores indicate less important channels. + """ + return self.to_dict()["score"] + + def state_dict(self) -> dict: + """Save the internal state for checkpointing.""" + return { + "curr_iter": self.curr_iter, + "pruned_channels": self.pruned_channels.copy(), + "agg_cont_per_channel": self.agg_cont_per_channel.cpu().clone(), + "num_channels": self.num_channels, + "pruning_iters": self.pruning_iters, + "pruning_schedule": self.pruning_schedule.copy(), + "calibration_method": self.calibration_method, + "epsilon": self.epsilon, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.curr_iter = state_dict["curr_iter"] + self.pruned_channels = state_dict["pruned_channels"].copy() + self.agg_cont_per_channel = state_dict["agg_cont_per_channel"].to(self.weight_matrix.device) + # Verify other parameters match + assert self.num_channels == state_dict["num_channels"], "Channel count mismatch" + assert self.pruning_iters == state_dict["pruning_iters"], "Iteration count mismatch" + assert self.pruning_schedule == state_dict["pruning_schedule"], "Pruning schedule mismatch" + + def get_progress_info(self) -> dict: + """Get progress information for this hook. + + Returns: + dict: Progress information including iteration count and pruned channels. + """ + progress = self.curr_iter / self.pruning_iters if self.pruning_iters > 0 else 0.0 + return { + "curr_iter": self.curr_iter, + "total_iters": self.pruning_iters, + "progress": progress, + "pruned_channels_count": len(self.pruned_channels), + "total_channels": self.num_channels, + } + + +class IndependentKvHeadContributionHook(ForwardHook): + """Hook for estimating KV head importance based on contribution analysis. + + Measures the contribution of each KV head group to the output projection + by computing L2 norms of per-head outputs. + + Args: + linear_layer: The output projection layer (o_proj). + activation_hooks_kwargs: Configuration dict with: + - model: The model instance (to get config). + - block_config: Block configuration with attention settings. + - optimize_for (str, optional): "latency" or "memory". Defaults to "memory". + """ + + def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): + """Initialize the KV head contribution hook.""" + model_config = activation_hooks_kwargs["model"].config + block_config = activation_hooks_kwargs["block_config"] + + self.optimize_for = activation_hooks_kwargs.get("optimize_for", "memory") + assert self.optimize_for in ["latency", "memory"] + + self.hidden_size = model_config.hidden_size + self.n_heads_in_group = block_config.attention.n_heads_in_group + self.num_q_heads = model_config.num_attention_heads + self.num_kv_heads = self.num_q_heads // self.n_heads_in_group + self.head_dim = getattr(model_config, "head_dim", self.hidden_size // self.num_q_heads) + + self.agg_kv_head_contributions = torch.zeros( + size=(self.num_kv_heads,), + dtype=torch.float32, + device=linear_layer.weight.device, + ) + + # Reshape weight matrix to group by KV heads + self.weight_grouped = linear_layer.weight.view( + self.hidden_size, self.num_kv_heads, self.head_dim * self.n_heads_in_group + ).permute((1, 0, 2)) + # weight_grouped.shape: (kv_heads, hidden_dim, head_dim * n_heads_in_group) + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """Compute KV head contributions from the forward pass.""" + attn_out = args[0] # Shape: (B, T, num_q_heads * head_dim) + batch_size, seq_len, _ = attn_out.shape + + # Reshape attention output to group by KV heads + attn_out_grouped = attn_out.view( + batch_size, + seq_len, + self.num_kv_heads, + self.head_dim * self.n_heads_in_group, + ).unsqueeze(-2) + # attn_out_grouped.shape: (B, T, kv_heads, 1, head_dim * n_heads_in_group) + + if self.optimize_for == "latency": + # Compute contribution per KV head group + # First compute the projection for each KV head group + layer_out_grouped = attn_out_grouped @ self.weight_grouped.transpose(-1, -2) + layer_out_grouped = layer_out_grouped.squeeze(-2) + # layer_out_grouped.shape: (B, T, kv_heads, hidden_dim) + + else: + layer_out_grouped = [] + for i in range(self.num_kv_heads): + _layer_out = attn_out_grouped[:, :, i] @ self.weight_grouped[i].transpose(-1, -2) + layer_out_grouped.append(_layer_out) + layer_out_grouped = torch.cat(layer_out_grouped, dim=2) + + # Compute L2 norm of each group's contribution + contrib_per_kv_head = torch.linalg.vector_norm(layer_out_grouped, dim=-1) + # contrib_per_kv_head.shape: (B, T, kv_heads) + + contrib_per_kv_head = contrib_per_kv_head.mean(dim=(0, 1)) + # contrib_per_kv_head.shape: (kv_heads,) + + # Accumulate contributions + self.agg_kv_head_contributions += contrib_per_kv_head + + def accumulate(self) -> torch.Tensor: + """Return accumulated KV head importance scores. + + Returns: + Tensor of importance scores, one per KV head. + """ + return self.agg_kv_head_contributions + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert to dict format for saving. + + Returns: + Dict with "score" tensor containing KV head importance scores. + """ + return { + "score": self.agg_kv_head_contributions.cpu(), + } + + def state_dict(self) -> dict: + """Return the internal state for checkpointing.""" + raise NotImplementedError("Saving state dict is not supported for this hook.") + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + raise NotImplementedError("Loading state dict is not supported for this hook.") + + +class LayerNormContributionHook(ForwardHook): + """Hook for estimating channel importance based on layer normalization activations. + + Aggregates mean absolute activation values per channel for a layer normalization layer. + + Args: + layernorm_layer: The layer normalization layer. + activation_hooks_kwargs: The activation hooks kwargs (not used). + """ + + def __init__(self, layernorm_layer: nn.Module, activation_hooks_kwargs: dict): + """Aggregates mean absolute activation values per channel for a layer normalization layer. + + Args: + layernorm_layer: The layer normalization layer + activation_hooks_kwargs: The activation hooks kwargs (not used) + """ + self.agg_embedding_activations = torch.zeros( + size=(layernorm_layer.weight.shape[0],), + dtype=torch.float32, + device=layernorm_layer.weight.device, + ) + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """Accumulate activation statistics from the forward pass.""" + self.agg_embedding_activations += ( + output.abs().float().mean(dim=list(range(output.ndim - 1))) + ) + + def accumulate(self) -> torch.Tensor: + """Return accumulated channel importance scores.""" + return self.agg_embedding_activations + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert to dict format for saving.""" + return { + "score": self.agg_embedding_activations.cpu(), + "channels_importance_ascending": self.agg_embedding_activations.sort()[1].cpu(), + } + + def state_dict(self) -> dict: + """Return the internal state for checkpointing.""" + raise NotImplementedError("Saving state dict is not supported for this hook.") + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + raise NotImplementedError("Loading state dict is not supported for this hook.") + + @classmethod + def dump_activations_logs( + cls: type["LayerNormContributionHook"], + activation_hooks: dict[str, "ForwardHook"], + activations_log_dir: Path | str, + args: DictConfig, + ) -> None: + """At the end of the default implementation of dumping activation scores to disc. + + Save aggregated channel importance results. + """ + super().dump_activations_logs(activation_hooks, activations_log_dir, args) + + rank = dist.rank() + if rank == 0: + LayerNormContributionHook._save_channel_importance_results( + activation_hooks, activations_log_dir, args + ) + + dist.barrier() + + @staticmethod + def _save_channel_importance_results( + activation_hooks: dict[str, "ForwardHook"], + activations_log_dir: Path | str, + args: DictConfig, + ) -> None: + """Save channel importance results from activation hooks.""" + # Find all activation files (for multi-rank scenarios) + activations_log_dir = Path(activations_log_dir) + activation_files = list(activations_log_dir.glob("rank_*.pth")) + if not activation_files: + aprint(f"Warning: No activation files found in {activations_log_dir}") + return + + # Load and aggregate activation data from all ranks + all_scores = [] + for activation_file in activation_files: + aprint(f"Loading activations from {activation_file}") + activation_data = torch.load(activation_file, map_location="cpu") + + # Extract scores from the activation data + for module_name, hook_data in activation_data.items(): + if "score" in hook_data: + scores = hook_data["score"] + all_scores.append(scores) + aprint(f"Loaded {len(scores)} channel scores from {module_name}") + + if not all_scores: + aprint("Warning: No valid activation data found") + return + + # Average scores across all ranks and modules + avg_scores = torch.stack(all_scores).mean(dim=0) + aprint(f"Averaged {len(all_scores)} score sets into {len(avg_scores)} channels") + + # Create channel importance ranking (descending order) + ranked_channels = torch.argsort(avg_scores, descending=True).tolist() + + # Create output data structure + timestamp = datetime.now().strftime("%Y_%m_%d__%H_%M_%S") + output_data = { + "model_path": getattr(args, "model_name_or_path", "unknown"), + "dataset_path": getattr(args, "dataset_path", "unknown"), + "experiment_id": getattr(args, "experiment_id", f"experiment_{timestamp}"), + "eval_samples": getattr(args, "eval_samples", 0), + "micro_batch_size": getattr(args, "micro_batch_size", 0), + "timestamp": timestamp, + "total_channels": len(ranked_channels), + "channel_importance_ranking": ranked_channels, + "channel_scores": avg_scores.tolist(), + "score_statistics": { + "min": float(avg_scores.min()), + "max": float(avg_scores.max()), + "mean": float(avg_scores.mean()), + "std": float(avg_scores.std()), + }, + } + + # Save the output + output_path = activations_log_dir / "channel_importance_results.json" + aprint(f"Saving channel importance data to {output_path}") + with open(output_path, "w") as f: + json.dump(output_data, f, indent=2) + + # Print summary statistics + aprint("=== Channel Importance Summary ===") + aprint(f"Total channels: {len(ranked_channels)}") + aprint(f"Top 10 most important channels: {ranked_channels[:10]}") + aprint(f"Bottom 10 least important channels: {ranked_channels[-10:]}") + aprint(f"Score range: {avg_scores.min():.4f} to {avg_scores.max():.4f}") + aprint(f"Score mean: {avg_scores.mean():.4f}") + aprint(f"Score std: {avg_scores.std():.4f}") diff --git a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py new file mode 100644 index 0000000000..dc338a7cfa --- /dev/null +++ b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Analysis tools for evaluating importance scores from hooks.""" + +import torch +import torch.nn.functional as F +from torch import nn + +__all__ = ["evaluate_importance_scores"] + + +def evaluate_importance_scores( + linear_layer: nn.Linear, + activations_batches: list[torch.Tensor], + importance_scores: torch.Tensor, + prune_ratio: float = 0.2, +) -> dict[str, float]: + """Compute reconstruction error after pruning input channels of a linear layer. + + This function simulates channel pruning by zeroing out input channels identified as + least important, then measures how much the layer's output changes. + + Args: + linear_layer: The linear layer to analyze with shape (out_features, in_features). + For example: nn.Linear(in_features=1024, out_features=4096) + activations_batches: List of input activation tensors. + Each tensor has shape [seq_len, batch_size, in_features]. + The last dimension must match linear_layer.in_features. + Example: List of [16, 8, 1024] tensors + importance_scores: Importance score for each input channel (feature). + Shape: [in_features]. Lower scores = less important. + Example: [1024] tensor with one score per input feature + prune_ratio: Fraction of input channels to prune (default: 0.2 means prune 20%). + + Returns: + Dictionary containing averaged metrics across all activation batches: + - rmse: Root mean squared error between original and pruned output + - cosine_similarity: Cosine similarity between original and pruned output + - num_pruned: Number of input channels pruned + + Example: + >>> layer = nn.Linear(in_features=1024, out_features=4096) + >>> # Collect multiple batches for robust evaluation + >>> activations_list = [torch.randn(16, 8, 1024) for _ in range(100)] + >>> scores = torch.randn(1024) # one score per input feature + >>> metrics = evaluate_importance_scores(layer, activations_list, scores, 0.2) + >>> print(f"RMSE: {metrics['rmse']:.4f}, Pruned: {metrics['num_pruned']} channels") + + Note: + - This simulates pruning (zeros out inputs) without modifying layer weights + - "Channels" refers to INPUT features, not output features + + """ + num_channels = importance_scores.shape[0] + num_to_prune = int(num_channels * prune_ratio) + + # Identify channels to prune (lowest scoring = least important) + _, channels_to_prune = torch.topk(importance_scores, num_to_prune, largest=False) + + # Compute metrics for each batch and average + rmse_values = [] + cosine_values = [] + + for activations in activations_batches: + # Get original output + original_output = linear_layer(activations) + + # Prune by zeroing out identified channels + pruned_activations = activations.clone() + pruned_activations[..., channels_to_prune] = 0 + + # Get pruned output + pruned_output = linear_layer(pruned_activations) + + # Compute metrics for this batch + rmse = torch.sqrt(F.mse_loss(pruned_output, original_output)).item() + rmse_values.append(rmse) + + # Cosine similarity (flatten to vectors) + original_flat = original_output.reshape(-1) + pruned_flat = pruned_output.reshape(-1) + cosine = F.cosine_similarity( + original_flat.unsqueeze(0), pruned_flat.unsqueeze(0), dim=1 + ).item() + cosine_values.append(cosine) + + # Return averaged metrics + return { + "rmse": sum(rmse_values) / len(rmse_values), + "cosine_similarity": sum(cosine_values) / len(cosine_values), + "num_pruned": num_to_prune, + } diff --git a/modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py b/modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py new file mode 100644 index 0000000000..316aff76ff --- /dev/null +++ b/modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Compare module output tensors from different model variants. + +This module provides: +1. OutputSaveHook - A PyTorch hook to capture module outputs during forward pass +2. Comparison utilities - Compute RMSE and cosine similarity between saved outputs + +Usage Example: +-------------- + +Step 1: Capture outputs from multiple layers: + + from modelopt.torch.nas.plugins.megatron_hooks.compare_module_outputs import ( + OutputSaveHook, + save_multi_layer_outputs, + ) + + # Register hooks on all target layers + hooks = {} + for name, module in model.named_modules(): + if name.endswith('mlp.linear_fc2'): + hook = OutputSaveHook(layer_name=name) + module.register_forward_hook(hook) + hooks[name] = hook + + # Run inference/training + model(input_data) + + # Save all layer outputs + save_multi_layer_outputs(hooks, "output_unpruned.pt") + +Step 2: Compare outputs from different model variants: + + python compare_module_outputs.py \ + --reference output_unpruned.pt \ + --compare output_l2norm.pt \ + --output-json comparison_stats.json + +The saved file format: +{ + 'decoder.layers.0.mlp.linear_fc2': Tensor([steps, seq_len, batch, hidden]), + 'decoder.layers.1.mlp.linear_fc2': Tensor([...]), + ... + 'metadata': {'num_layers': N, 'num_steps': M, 'layer_names': [...]} +} +""" + +import argparse + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class OutputSaveHook: + """Hook to capture and save module outputs during forward pass.""" + + def __init__(self, layer_name: str) -> None: + """Initialize the output save hook. + + Args: + layer_name: Hierarchical name of the layer (e.g., 'decoder.layers.0.mlp.linear_fc2'). + """ + self.layer_name = layer_name + self.saved_outputs: list[torch.Tensor] = [] + + def __call__( + self, + module: nn.Module, + args: tuple[torch.Tensor, ...], + output: torch.Tensor | tuple[torch.Tensor, ...], + ) -> None: + """Capture and save module output during forward pass. + + Args: + module: The PyTorch module being hooked. + args: Input arguments to the module's forward pass. + output: Output tensor(s) from the module's forward pass. + """ + # Handle tuple outputs (e.g., output, bias) + out = output[0] if isinstance(output, tuple) else output + self.saved_outputs.append(out.detach().cpu()) + + def get_outputs_list(self) -> list[torch.Tensor]: + """Return saved outputs as a list.""" + return self.saved_outputs + + +def save_multi_layer_outputs(hooks: dict[str, OutputSaveHook], path: str) -> None: + """Save outputs from multiple layers to a single file. + + Args: + hooks: Dictionary mapping layer names to their hooks. + path: Path to save the outputs. + """ + output_dict = {name: hook.get_outputs_list() for name, hook in hooks.items()} + + # Add metadata + output_dict["metadata"] = { + "num_layers": len(hooks), + # Number of forward passes (generation steps) - all hooks have same count, so use first hook + "num_steps": len(next(iter(hooks.values())).saved_outputs) if hooks else 0, + "layer_names": list(hooks.keys()), + } + + torch.save(output_dict, path) + print(f"\nSaved outputs from {len(hooks)} layers to {path}") + for name, data in output_dict.items(): + if name != "metadata": + print(f" {name}: list of {len(data)} tensors") + + +def compute_rmse(tensor1: torch.Tensor, tensor2: torch.Tensor) -> float: + """Compute Root Mean Square Error between two tensors.""" + mse = torch.mean((tensor1 - tensor2) ** 2) + rmse = torch.sqrt(mse) + return rmse.item() + + +def compute_cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor) -> dict: + """Compute average cosine similarity between two tensors.""" + # Flatten to 2D for cosine similarity computation + t1_flat = tensor1.reshape(-1, tensor1.shape[-1]) + t2_flat = tensor2.reshape(-1, tensor2.shape[-1]) + + # Compute cosine similarity per position + cos_sim = F.cosine_similarity(t1_flat, t2_flat, dim=-1) + + return { + "mean": cos_sim.mean().item(), + "min": cos_sim.min().item(), + "max": cos_sim.max().item(), + "std": cos_sim.std().item(), + } + + +def main(): + """Compare module output tensors from different model variants.""" + parser = argparse.ArgumentParser( + description="Compare module output tensors from different model variants", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--reference", + type=str, + required=True, + help="Path to reference output tensor (e.g., unpruned model)", + ) + parser.add_argument( + "--compare", + type=str, + required=True, + help="Path to output tensor to compare against reference", + ) + parser.add_argument( + "--output-json", + type=str, + default=None, + help="Path to save comparison statistics as JSON", + ) + args = parser.parse_args() + + # Load reference data + print(f"\nLoading reference: {args.reference}") + ref_data = torch.load(args.reference, map_location="cpu") + + # Load comparison data + print(f"Loading compare: {args.compare}") + comp_data = torch.load(args.compare, map_location="cpu") + + # Compare multi-layer outputs + compare_multi_layer(ref_data, comp_data, args.output_json) + + +def compute_layer_metrics(ref_data: list, comp_data: list) -> dict: + """Compute RMSE and cosine similarity for a layer's outputs. + + Args: + ref_data: List of reference tensors. + comp_data: List of comparison tensors. + + Returns: + Dictionary with metrics. + + Raises: + ValueError: If lengths don't match or tensor shapes don't match. + """ + if len(ref_data) != len(comp_data): + raise ValueError( + f"Length mismatch: reference has {len(ref_data)} samples, compare has {len(comp_data)}" + ) + + rmse_values = [] + cos_sim_values = [] + + for ref_tensor, comp_tensor in zip(ref_data, comp_data): + if ref_tensor.shape != comp_tensor.shape: + raise ValueError( + f"Shape mismatch at index {len(rmse_values)}: " + f"reference {ref_tensor.shape} vs compare {comp_tensor.shape}" + ) + rmse_values.append(compute_rmse(ref_tensor, comp_tensor)) + cos_sim = compute_cosine_similarity(ref_tensor, comp_tensor) + cos_sim_values.append(cos_sim["mean"]) + + return { + "rmse": sum(rmse_values) / len(rmse_values), + "cosine_sim": { + "mean": sum(cos_sim_values) / len(cos_sim_values), + "min": min(cos_sim_values), + "max": max(cos_sim_values), + "std": torch.tensor(cos_sim_values).std().item() if len(cos_sim_values) > 1 else 0.0, + }, + "num_samples": len(rmse_values), + } + + +def compare_multi_layer(ref_data: dict, comp_data: dict, output_json: str | None = None): + """Compare multi-layer outputs.""" + import json + + ref_layers = [k for k in ref_data if k != "metadata"] + comp_layers = [k for k in comp_data if k != "metadata"] + + if set(ref_layers) != set(comp_layers): + print("\nERROR: Layer mismatch!") + print(f"Reference layers: {ref_layers}") + print(f"Compare layers: {comp_layers}") + return + + results = {"aggregated": {"rmse": [], "cosine_sim_mean": []}, "per_layer": {}} + + # Per-layer comparison + for layer_name in sorted(ref_layers): + ref_layer_data = ref_data[layer_name] + comp_layer_data = comp_data[layer_name] + + metrics = compute_layer_metrics(ref_layer_data, comp_layer_data) + + results["per_layer"][layer_name] = metrics + results["aggregated"]["rmse"].append(metrics["rmse"]) + results["aggregated"]["cosine_sim_mean"].append(metrics["cosine_sim"]["mean"]) + + # Aggregated statistics + if results["aggregated"]["rmse"]: + rmse_array = torch.tensor(results["aggregated"]["rmse"]) + cos_sim_array = torch.tensor(results["aggregated"]["cosine_sim_mean"]) + + results["aggregated"]["rmse_stats"] = { + "mean": rmse_array.mean().item(), + "std": rmse_array.std().item(), + "min": rmse_array.min().item(), + "max": rmse_array.max().item(), + } + results["aggregated"]["cosine_sim_stats"] = { + "mean": cos_sim_array.mean().item(), + "std": cos_sim_array.std().item(), + "min": cos_sim_array.min().item(), + "max": cos_sim_array.max().item(), + } + results["aggregated"]["num_steps"] = ref_data.get("metadata", {}).get("num_steps", None) + results["aggregated"]["num_layers"] = len(rmse_array) + + # Save to JSON if requested + if output_json: + # Remove raw lists for JSON serialization + results["aggregated"].pop("rmse", None) + results["aggregated"].pop("cosine_sim_mean", None) + + with open(output_json, "w") as f: + json.dump(results, f, indent=2) + print(f"Saved comparison results to {output_json}") + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py new file mode 100644 index 0000000000..d792ff8941 --- /dev/null +++ b/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Megatron-specific hooks with tensor parallelism support.""" + +import torch +from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region + +from .base_hooks import L2NormHook + +__all__ = ["MegatronL2NormHook"] + + +class MegatronL2NormHook(L2NormHook): + """L2NormHook with tensor parallelism support for Megatron models. + + Extends L2NormHook to gather activations across all tensor parallel regions + before computing importance scores. + """ + + def _get_input_tensor(self, args: tuple[torch.Tensor, ...]) -> torch.Tensor: + """Gather input tensor from all TP regions.""" + # Gather input [seq_len, batch_size, hidden_size] over all TP regions + # NOTE: This is not used at the moment since we restrict to TP=1 + return gather_from_tensor_model_parallel_region(args[0]).detach() diff --git a/tests/_test_utils/torch/distributed/utils.py b/tests/_test_utils/torch/distributed/utils.py index f5fccd3a01..dec0413883 100644 --- a/tests/_test_utils/torch/distributed/utils.py +++ b/tests/_test_utils/torch/distributed/utils.py @@ -34,6 +34,11 @@ def init_process(rank, size, job=None, backend="gloo", port=None): """Initialize the distributed environment.""" os.environ["MASTER_ADDR"] = "localhost" + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(size) + os.environ["LOCAL_WORLD_SIZE"] = str(size) + os.environ["WANDB_DISABLED"] = "true" port = str(get_free_port()) if port is None else str(port) diff --git a/tests/conftest.py b/tests/conftest.py index b85924e464..53a2330c22 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,7 @@ # limitations under the License. import platform +from pathlib import Path import pytest import torch @@ -112,3 +113,9 @@ def set_torch_dtype(request): @pytest.fixture(scope="session", autouse=True) def enable_hf_checkpointing(): mto.enable_huggingface_checkpointing() + + +@pytest.fixture +def project_root_path(request: pytest.FixtureRequest) -> Path: + """Fixture providing the project root path for tests.""" + return Path(request.config.rootpath) diff --git a/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py b/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py new file mode 100644 index 0000000000..aa73a3be19 --- /dev/null +++ b/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for base hooks.""" + +import torch +import torch.nn as nn + +from modelopt.torch.nas.plugins.megatron_hooks import IterativeChannelContributionHook, L2NormHook + + +def _test_iterative_channel_contribution_hook_with_shape(dim1: int, dim2: int): + """Helper function to test IterativeChannelContributionHook with given activation shape. + + Args: + dim1: First dimension of activation tensor (before in_features). + dim2: Second dimension of activation tensor (before in_features). + """ + torch.manual_seed(42) + + linear_layer = nn.Linear(in_features=6, out_features=4, bias=False) + activation_hooks_kwargs = { + "validation_full_iters": 3, + "clear_gpu_memory": False, + "calibration_method": None, + } + hook = IterativeChannelContributionHook(linear_layer, activation_hooks_kwargs) + linear_layer.register_forward_hook(hook) + + for _ in range(activation_hooks_kwargs["validation_full_iters"]): + activations = torch.randn(dim1, dim2, linear_layer.in_features) + _ = linear_layer(activations) + + results = hook.to_dict() + + # + # Assertions + # + assert results["score"].shape == (6,) + assert results["channels_importance_ascending"].shape == (6,) + + expected_scores = torch.tensor([5, 1, 3, 2, 4, 0]) + assert torch.equal(results["score"], expected_scores) + + expected_channels_asc = torch.tensor([5, 1, 3, 2, 4, 0]) + assert torch.equal(results["channels_importance_ascending"], expected_channels_asc) + + # Test that accumulate() returns the same scores as to_dict()["score"] + scores_from_accumulate = hook.accumulate() + assert torch.equal(scores_from_accumulate, expected_scores) + + +def test_iterative_channel_contribution_hook_sbi(): + """Test IterativeChannelContributionHook returns correct scores for input [seq_len, batch_size, in_features].""" + _test_iterative_channel_contribution_hook_with_shape(dim1=32, dim2=8) + + +def test_iterative_channel_contribution_hook_bsi(): + """Test IterativeChannelContributionHook returns correct scores for input [batch_size, seq_len, in_features].""" + _test_iterative_channel_contribution_hook_with_shape(dim1=8, dim2=32) + + +def test_l2_norm_hook(): + """Test L2NormHook returns correct scores after accumulating activations.""" + torch.manual_seed(42) + + linear_layer = nn.Linear(in_features=6, out_features=4, bias=False) + hook = L2NormHook(max_size=None) + linear_layer.register_forward_hook(hook) + + num_iterations = 3 + for _ in range(num_iterations): + activations = torch.randn(2, 3, linear_layer.in_features) + _ = linear_layer(activations) + + scores = hook.accumulate() + + # + # Assertions + # + assert scores.shape == (6,) + + expected_scores = torch.tensor( + [3.2030, 2.5018, 2.5272, 1.9222, 2.6204, 2.2623], dtype=torch.float32 + ) + assert torch.allclose(scores, expected_scores, atol=1e-4), ( + f"Expected scores {expected_scores}, got {scores}" + ) diff --git a/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py b/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py new file mode 100644 index 0000000000..954c6e11c7 --- /dev/null +++ b/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for base hooks analysis tools.""" + +import pytest +import torch +import torch.nn as nn + +from modelopt.torch.nas.plugins.megatron_hooks import ( + IndependentChannelContributionHook, + IterativeChannelContributionHook, + L2NormHook, + evaluate_importance_scores, +) + + +def test_evaluate_importance_scores_basic(): + """Test basic functionality of importance score evaluation with synthetic scores.""" + torch.manual_seed(42) + + # Create a simple linear layer (same dimensions as other tests for comparability) + layer = nn.Linear(in_features=50, out_features=30, bias=False) + + # Create synthetic hook that generates sequential importance scores + hook = SyntheticImportanceHook(num_features=50) + + # Use shared helper to run evaluation + metrics = _run_hook_and_evaluate(layer, hook, num_iterations=1000, prune_ratio=0.4) + + print(f"[SyntheticImportanceHook] Metrics: {metrics}") + + # Check values with deterministic seed + assert metrics["num_pruned"] == 20 # 40% of 50 = 20 + assert metrics["rmse"] == pytest.approx(0.3689444, rel=1e-5) + assert metrics["cosine_similarity"] == pytest.approx(0.77117118, rel=1e-5) + + +def test_evaluate_importance_scores_with_l2_norm_hook(): + """Test evaluate_importance_scores with L2NormHook.""" + torch.manual_seed(42) + + # Create layer and hook + layer = nn.Linear(in_features=50, out_features=30, bias=False) + hook = L2NormHook(max_size=None) + + # Run evaluation + metrics = _run_hook_and_evaluate(layer, hook, num_iterations=1000, prune_ratio=0.4) + + print(f"[L2NormHook] Metrics: {metrics}") + + # L2NormHook specific assertions + assert metrics["num_pruned"] == 20 # 40% of 50 = 20 + assert metrics["rmse"] == pytest.approx(0.3616334, rel=1e-5) + assert metrics["cosine_similarity"] == pytest.approx(0.7814186, rel=1e-5) + + +def test_evaluate_importance_scores_with_iterative_channel_contribution_hook(): + """Test evaluate_importance_scores with IterativeChannelContributionHook.""" + torch.manual_seed(42) + + # Create layer and hook + layer = nn.Linear(in_features=50, out_features=30, bias=False) + activation_hooks_kwargs = { + "validation_full_iters": 1000, + "clear_gpu_memory": False, + "calibration_method": None, + } + hook = IterativeChannelContributionHook(layer, activation_hooks_kwargs) + + # Run evaluation + metrics = _run_hook_and_evaluate(layer, hook, num_iterations=1000, prune_ratio=0.4) + + print(f"[IterativeChannelContributionHook] Metrics: {metrics}") + + # Iterative channel contribution hook specific assertions + assert metrics["num_pruned"] == 20 # 40% of 50 = 20 + assert metrics["rmse"] == pytest.approx(0.339014, rel=1e-5) + assert metrics["cosine_similarity"] == pytest.approx(0.8110392, rel=1e-5) + + +def test_evaluate_importance_scores_with_independent_channel_contribution_hook(): + """Test evaluate_importance_scores with IndependentChannelContributionHook.""" + torch.manual_seed(42) + + # Create layer and hook + layer = nn.Linear(in_features=50, out_features=30, bias=False) + hook = IndependentChannelContributionHook(layer) + + # Run evaluation + metrics = _run_hook_and_evaluate(layer, hook, num_iterations=1000, prune_ratio=0.4) + + print(f"[IndependentChannelContributionHook] Metrics: {metrics}") + + # Independent channel contribution hook specific assertions + assert metrics["num_pruned"] == 20 # 40% of 50 = 20 + assert metrics["rmse"] == pytest.approx(0.3385471, rel=1e-5) + assert metrics["cosine_similarity"] == pytest.approx(0.8116209, rel=1e-5) + + +def _run_hook_and_evaluate( + layer: nn.Linear, + hook, + num_iterations: int, + prune_ratio: float, +) -> dict: + """Shared helper to run hook, collect scores, and evaluate. + + Args: + layer: Linear layer to test + hook: Hook instance (already created) + num_iterations: Number of forward passes + prune_ratio: Fraction of channels to prune + + Returns: + Dictionary with evaluation metrics + """ + handle = layer.register_forward_hook(hook) # Store the handle + + # Run forward passes + all_activations = [] + for _ in range(num_iterations): + activations = torch.randn(16, 8, layer.in_features) # seq=16, batch=8, in_features=50 + all_activations.append(activations) + _ = layer(activations) + + # Get importance scores from hook + importance_scores = hook.accumulate() + + # Remove the hook before evaluation to avoid triggering it again + handle.remove() + + # Evaluate the importance scores by simulating pruning on all collected activations + # Pass the list of activations to compute averaged metrics across batches + metrics = evaluate_importance_scores( + layer, + all_activations, # List of activation batches + importance_scores, + prune_ratio=prune_ratio, + ) + + return metrics + + +class SyntheticImportanceHook: + """Synthetic hook that generates sequential importance scores for testing. + + This is a simple mock hook that doesn't compute real importance, + just returns torch.arange(num_features) to test the evaluation pipeline. + """ + + def __init__(self, num_features: int): + """Initialize with the number of features.""" + self.num_features = num_features + + def __call__(self, module, args, output): + """Hook callback - does nothing for synthetic hook.""" + + def accumulate(self) -> torch.Tensor: + """Return synthetic importance scores: [0, 1, 2, ..., num_features-1].""" + return torch.arange(self.num_features, dtype=torch.float32) diff --git a/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index 55583a4300..2c4ce71789 100644 --- a/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -38,6 +38,10 @@ SEED = 1234 +def _assert_approx(actual, expected, abs=1e-3): + assert actual == pytest.approx(expected, abs=abs), f"{actual=} != {expected=}" + + def _test_mcore_gpt_parameter_sorting(activation_func, rank, size): set_seed(SEED) # Use relatively bigger model here for more accurate test for sorting @@ -168,10 +172,12 @@ def _get_model(initialize_megatron=True): normalization=normalization, num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, + use_cpu_initialization=True, # Ensure deterministic weight init across CUDA versions ).cuda() return model model = _get_model() + sd = model.state_dict() def forward_loop(m): @@ -208,6 +214,48 @@ def forward_loop(m): assert pruning_scores["layer_scores"] assert pruning_scores["activations_per_rank"] + # TODO: Simplify it: this unit test is too long, + # hard to read (the same set of assertions across different test cases with if-else). + + assert len(pruning_scores["activations_per_rank"]) == size + activations = pruning_scores["activations_per_rank"][rank] + + # Test case 1: MHA - pruned ffn/4 (num_attention_heads=8, num_query_groups=8, ffn_div=4) + if size == 1 and pruned_ffn_div == 4: + # Layer scores + _assert_approx(pruning_scores["layer_scores"], {1: 0.028923, 2: 0.046508}) + + # Validate decoder.layers.0.mlp activations + mlp_0_acts = activations["decoder.layers.0.mlp"] + _assert_approx(mlp_0_acts.min().item(), 0.000026) + _assert_approx(mlp_0_acts.max().item(), 0.000729) + _assert_approx(mlp_0_acts.mean().item(), 0.000201) + + # Validate decoder.layers.1.mlp activations + mlp_1_acts = activations["decoder.layers.1.mlp"] + _assert_approx(mlp_1_acts.min().item(), 0.000022) + _assert_approx(mlp_1_acts.max().item(), 0.000762) + _assert_approx(mlp_1_acts.mean().item(), 0.000162) + + # Test case 2: GQA - pruned attention/2 (num_attention_heads=8, num_query_groups=4, attention_div=2) + elif size == 1 and pruned_num_attention_heads_div == 2 and pruned_ffn_div == 1: + # Layer scores + _assert_approx(pruning_scores["layer_scores"], {1: 0.028056, 2: 0.038353}) + + # Validate decoder.layers.0.self_attention activations + attn_0_acts = activations["decoder.layers.0.self_attention"] + assert attn_0_acts.shape == torch.Size([hidden_size]) + _assert_approx(attn_0_acts.min().item(), 0.010091) + _assert_approx(attn_0_acts.max().item(), 0.023826) + _assert_approx(attn_0_acts.mean().item(), 0.014548) + + # Validate decoder.layers.1.self_attention activations + attn_1_acts = activations["decoder.layers.1.self_attention"] + assert attn_1_acts.shape == torch.Size([hidden_size]) + _assert_approx(attn_1_acts.min().item(), 0.009982) + _assert_approx(attn_1_acts.max().item(), 0.035644) + _assert_approx(attn_1_acts.mean().item(), 0.020140) + # Assert weights are pruned correctly for layer in model.decoder.layers: assert layer.mlp.linear_fc1.weight.shape == ( From b0e3c9e21e56ddbde6b6576ec59b62c325a9e313 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 11 Mar 2026 04:31:05 -0700 Subject: [PATCH 2/9] Add logger dependency for activation hooks The activation hooks infrastructure depends on aprint from puzzletron.tools.logger. Adding minimal logger module to satisfy this dependency. Note: Some docstring linting warnings are suppressed as this is copied code. Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/__init__.py | 17 +++ modelopt/torch/puzzletron/tools/__init__.py | 17 +++ modelopt/torch/puzzletron/tools/logger.py | 161 ++++++++++++++++++++ 3 files changed, 195 insertions(+) create mode 100644 modelopt/torch/puzzletron/__init__.py create mode 100644 modelopt/torch/puzzletron/tools/__init__.py create mode 100644 modelopt/torch/puzzletron/tools/logger.py diff --git a/modelopt/torch/puzzletron/__init__.py b/modelopt/torch/puzzletron/__init__.py new file mode 100644 index 0000000000..416e67f797 --- /dev/null +++ b/modelopt/torch/puzzletron/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Puzzletron module for model optimization.""" diff --git a/modelopt/torch/puzzletron/tools/__init__.py b/modelopt/torch/puzzletron/tools/__init__.py new file mode 100644 index 0000000000..70b1a3202c --- /dev/null +++ b/modelopt/torch/puzzletron/tools/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Puzzletron tools module.""" diff --git a/modelopt/torch/puzzletron/tools/logger.py b/modelopt/torch/puzzletron/tools/logger.py new file mode 100644 index 0000000000..1d68f1514a --- /dev/null +++ b/modelopt/torch/puzzletron/tools/logger.py @@ -0,0 +1,161 @@ +# noqa: D100 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import inspect +import logging +import os +import sys + +import torch.distributed.launch # noqa: F401 + +logging.getLogger("fsspec.local").setLevel(logging.ERROR) +logging.getLogger("websockets.client").setLevel(logging.WARN) +logging.getLogger("websockets.server").setLevel(logging.WARN) +logging.getLogger("websockets.server:connection").setLevel(logging.WARN) + + +class LogColors: # noqa: D101 + BLUE = "\033[94m" + CYAN = "\033[96m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + RED = "\033[91m" + + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + RESET = "\033[0m" + + +class DistributedLogger(logging.Logger): + verbosity = logging.ERROR + + def __init__ # noqa: D107(self, name, level=logging.DEBUG): + super().__init__(name, level) + self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + self.global_rank = int(os.environ.get("RANK", 0)) + self.world_size = int(os.environ.get("WORLD_SIZE", 1)) + + def dist_log(self, msg: str, ranks: str = "main"): + """Log parameter msg with the given ranks. + + Args: + msg: The message to log. + ranks: The ranks to log the message to. Choices are: + "all": log with all ranks + "main": log with only rank 0 in node 0 + "last": log with only rank -1 in node 0 + "local_main": log with only rank 0 in all nodes + """ + # print(msg, ranks) + if ranks not in ["all", "main", "local_main", "last"]: + raise NotImplementedError( + f"Could not broadcast msg {msg} - " + f"ranks parameters choices are ['all', 'main', 'local_main']. Got {ranks}" + ) + # All ranks to print + if ranks == "all": + pass + + # Only main rank at node 0 to print + elif ( + (ranks == "main" and self.global_rank != 0) + or (ranks == "last" and self.local_rank != self.world_size - 1) + or (ranks == "local_main" and self.local_rank != 0) + ): + return + + message_source = self.get_caller_location() + + self.info( + f"{LogColors.GREEN}[rank-{self.global_rank}]{LogColors.RESET}[{message_source}]\t{msg}" + ) + + # def dist_warning(self, msg): + # if self.verbosity <= logging.WARNING: + # self.warning(f"[rank-{self.global_rank}] " + msg) + + @staticmethod + def get_caller_location # noqa: D102() -> str: + # Get the caller's stack frame + frame = inspect.currentframe() + + # f_back -> class method, 2 x f_back -> utils method, 3 x f_back -> original source + caller_frame = frame.f_back.f_back.f_back + + # Get the filename and line number from the caller's stack frame + filename = os.path.basename(caller_frame.f_code.co_filename) + lineno = caller_frame.f_lineno + return f"{filename}:{lineno}" + + +# Initialize logger +logging.setLoggerClass(DistributedLogger) +logger = logging.getLogger(__name__) +logger.propagate = False + +formatter = logging.Formatter("[%(asctime)s]%(message)s") +handler = logging.StreamHandler(sys.stdout) +handler.setFormatter(formatter) +handler.setLevel(logging.DEBUG) +logger.addHandler(handler) + +# Manually edit torch logger +torch_logger = logging.getLogger("torch") +torch_logger.handlers = logger.handlers +torch_logger.propagate = False + +# Manually edit deepspeed logger + +# Show some love to Mac & Windows users who can't easily install deepspeed ;) +# This is allowing running tests on Mac & Windows and train in non-DDP +try: + from deepspeed.utils import logger as deepspeed_logger + + deepspeed_logger.handlers = logger.handlers + deepspeed_logger.propagate = False +except ImportError: + # If deepspeed is not installed - no op + pass + +# Define a custom function to redirect warnings to logger +# def custom_warning_handler(message, category, filename, lineno, file=None, line=None): +# logger.dist_warning(f'{category.__name__}: {message} (in {filename}, line {lineno})') + + +# Use the custom warning handler +# warnings.showwarning = custom_warning_handler + +logger: DistributedLogger + + +def aprint(msg: str | None): + """All ranks from all nodes print.""" + return logger.dist_log(msg=msg, ranks="all") + + +def lmprint(msg: str | None): + """All local main ranks prints (rank 0 in each node)""" + return logger.dist_log(msg=msg, ranks="local_main") + + +def mprint(msg: str | None): + """Master prints only (rank 0 in node 0)""" + return logger.dist_log(msg=msg, ranks="main") + + +def lprint(msg: str | None): + """Last rank prints only (rank -1 in node 0)""" + return logger.dist_log(msg=msg, ranks="last") From 08a4b66487a7508a049a40f4d9e5ebfd43461005 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 11 Mar 2026 06:06:37 -0700 Subject: [PATCH 3/9] Add a logger and json_dump needed by base_hooks (TODO: both should be moved later outside of puzzletron module) Signed-off-by: Daniel Korzekwa --- .../nas/plugins/megatron_hooks/base_hooks.py | 2 + modelopt/torch/puzzletron/tools/logger.py | 21 +++-- .../torch/puzzletron/tools/robust_json.py | 78 +++++++++++++++++++ 3 files changed, 94 insertions(+), 7 deletions(-) create mode 100644 modelopt/torch/puzzletron/tools/robust_json.py diff --git a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py index 56436acfdd..8ad7f9f98c 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py +++ b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py @@ -26,6 +26,8 @@ from torch import nn import modelopt.torch.utils.distributed as dist + +# TODO: move both outside of puzzletron module from modelopt.torch.puzzletron.tools.logger import aprint from modelopt.torch.puzzletron.tools.robust_json import json_dump diff --git a/modelopt/torch/puzzletron/tools/logger.py b/modelopt/torch/puzzletron/tools/logger.py index 1d68f1514a..0185b9edfe 100644 --- a/modelopt/torch/puzzletron/tools/logger.py +++ b/modelopt/torch/puzzletron/tools/logger.py @@ -1,4 +1,3 @@ -# noqa: D100 # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # @@ -14,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # mypy: ignore-errors +"""Distributed logging utilities for multi-rank environments.""" + import inspect import logging import os @@ -27,7 +28,9 @@ logging.getLogger("websockets.server:connection").setLevel(logging.WARN) -class LogColors: # noqa: D101 +class LogColors: + """ANSI color codes for terminal output.""" + BLUE = "\033[94m" CYAN = "\033[96m" GREEN = "\033[92m" @@ -40,9 +43,12 @@ class LogColors: # noqa: D101 class DistributedLogger(logging.Logger): + """Logger for distributed multi-rank environments.""" + verbosity = logging.ERROR - def __init__ # noqa: D107(self, name, level=logging.DEBUG): + def __init__(self, name, level=logging.DEBUG): + """Initialize the distributed logger.""" super().__init__(name, level) self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) self.global_rank = int(os.environ.get("RANK", 0)) @@ -88,7 +94,8 @@ def dist_log(self, msg: str, ranks: str = "main"): # self.warning(f"[rank-{self.global_rank}] " + msg) @staticmethod - def get_caller_location # noqa: D102() -> str: + def get_caller_location() -> str: + """Get the caller location from the stack frame.""" # Get the caller's stack frame frame = inspect.currentframe() @@ -147,15 +154,15 @@ def aprint(msg: str | None): def lmprint(msg: str | None): - """All local main ranks prints (rank 0 in each node)""" + """All local main ranks print (rank 0 in each node).""" return logger.dist_log(msg=msg, ranks="local_main") def mprint(msg: str | None): - """Master prints only (rank 0 in node 0)""" + """Master prints only (rank 0 in node 0).""" return logger.dist_log(msg=msg, ranks="main") def lprint(msg: str | None): - """Last rank prints only (rank -1 in node 0)""" + """Last rank prints only (rank -1 in node 0).""" return logger.dist_log(msg=msg, ranks="last") diff --git a/modelopt/torch/puzzletron/tools/robust_json.py b/modelopt/torch/puzzletron/tools/robust_json.py new file mode 100644 index 0000000000..030150a9ce --- /dev/null +++ b/modelopt/torch/puzzletron/tools/robust_json.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Provides a robust JSON encoder that can handle various types of objects. + +Including dataclasses, paths, enums, namespaces, and functions. +""" + +import argparse +import dataclasses +import datetime +import inspect +import json +from enum import Enum +from pathlib import Path +from typing import Any + +from omegaconf import DictConfig, ListConfig, OmegaConf + + +class RobustJSONEncoder(json.JSONEncoder): + """JSON encoder that handles dataclasses, paths, enums, and other special types.""" + + def default(self, o): + """Convert special objects to JSON-serializable types.""" + if dataclasses.is_dataclass(o): + return dataclasses.asdict(o) + if isinstance(o, Path): + return str(o) + if isinstance(o, Enum): + return o.name + if isinstance(o, argparse.Namespace): + return vars(o) + if type(o).__name__ == "dtype": + return str(o) + if isinstance(o, (DictConfig, ListConfig)): + return OmegaConf.to_container(o, resolve=True) + if inspect.isfunction(o) or inspect.ismethod(o): + if o.__module__ == "__main__": + # User-defined function in main — fallback to just the name + return o.__name__ + return f"{o.__module__}.{o.__qualname__}" + if isinstance(o, datetime.timedelta): + return str(o) + return super().default(o) + + +def json_dumps(obj: Any) -> str: + """Serialize object to JSON string using RobustJSONEncoder.""" + return json.dumps(obj, cls=RobustJSONEncoder, indent=2) + + +def json_dump(obj: Any, path: Path | str) -> None: + """Serialize object to JSON file using RobustJSONEncoder.""" + path = Path(path) + path.parent.mkdir(exist_ok=True, parents=True) + json_text = json_dumps(obj) + path.write_text(json_text) + + +def json_load(path: Path | str) -> dict: + """Load JSON from file and return as dictionary.""" + path = Path(path) + text = path.read_text() + return json.loads(text) From ee51e7ba58296e5489d8e72bef23febe22c37d03 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 11 Mar 2026 07:10:14 -0700 Subject: [PATCH 4/9] Fix a broken test due to a missing random seed. Signed-off-by: Daniel Korzekwa --- .../torch/prune/plugins/test_mcore_gpt_minitron_pruning.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index 2c4ce71789..5d9be161aa 100644 --- a/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -131,6 +131,8 @@ def _test_mcore_gpt_pruning( rank, size, ): + set_seed(SEED) + channel_divisor = 4 hidden_size = channel_divisor * 4 From fe7aa4b4f02974ba6009962da559c20eb7d305cf Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 11 Mar 2026 07:34:06 -0700 Subject: [PATCH 5/9] Logging and robust_json refactoring. Signed-off-by: Daniel Korzekwa --- .../nas/plugins/megatron_hooks/base_hooks.py | 6 +- modelopt/torch/utils/logging.py | 6 ++ modelopt/torch/utils/robust_json.py | 78 +++++++++++++++++++ 3 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 modelopt/torch/utils/robust_json.py diff --git a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py index 8ad7f9f98c..3d6ff43fb5 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py +++ b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py @@ -26,10 +26,8 @@ from torch import nn import modelopt.torch.utils.distributed as dist - -# TODO: move both outside of puzzletron module -from modelopt.torch.puzzletron.tools.logger import aprint -from modelopt.torch.puzzletron.tools.robust_json import json_dump +from modelopt.torch.utils.logging import aprint +from modelopt.torch.utils.robust_json import json_dump __all__ = [ "ForwardHook", diff --git a/modelopt/torch/utils/logging.py b/modelopt/torch/utils/logging.py index ada1b53612..3fde9eb8df 100644 --- a/modelopt/torch/utils/logging.py +++ b/modelopt/torch/utils/logging.py @@ -33,6 +33,7 @@ __all__ = [ "DeprecatedError", + "aprint", "atomic_print", "capture_io", "no_stdout", @@ -200,3 +201,8 @@ def custom_showwarning(message, category, filename, lineno, file=None, line=None class DeprecatedError(NotImplementedError): """Error for deprecated functions.""" + + +def aprint(*args, **kwargs): + """All ranks from all nodes print.""" + print(*args, **kwargs, flush=True) diff --git a/modelopt/torch/utils/robust_json.py b/modelopt/torch/utils/robust_json.py new file mode 100644 index 0000000000..030150a9ce --- /dev/null +++ b/modelopt/torch/utils/robust_json.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Provides a robust JSON encoder that can handle various types of objects. + +Including dataclasses, paths, enums, namespaces, and functions. +""" + +import argparse +import dataclasses +import datetime +import inspect +import json +from enum import Enum +from pathlib import Path +from typing import Any + +from omegaconf import DictConfig, ListConfig, OmegaConf + + +class RobustJSONEncoder(json.JSONEncoder): + """JSON encoder that handles dataclasses, paths, enums, and other special types.""" + + def default(self, o): + """Convert special objects to JSON-serializable types.""" + if dataclasses.is_dataclass(o): + return dataclasses.asdict(o) + if isinstance(o, Path): + return str(o) + if isinstance(o, Enum): + return o.name + if isinstance(o, argparse.Namespace): + return vars(o) + if type(o).__name__ == "dtype": + return str(o) + if isinstance(o, (DictConfig, ListConfig)): + return OmegaConf.to_container(o, resolve=True) + if inspect.isfunction(o) or inspect.ismethod(o): + if o.__module__ == "__main__": + # User-defined function in main — fallback to just the name + return o.__name__ + return f"{o.__module__}.{o.__qualname__}" + if isinstance(o, datetime.timedelta): + return str(o) + return super().default(o) + + +def json_dumps(obj: Any) -> str: + """Serialize object to JSON string using RobustJSONEncoder.""" + return json.dumps(obj, cls=RobustJSONEncoder, indent=2) + + +def json_dump(obj: Any, path: Path | str) -> None: + """Serialize object to JSON file using RobustJSONEncoder.""" + path = Path(path) + path.parent.mkdir(exist_ok=True, parents=True) + json_text = json_dumps(obj) + path.write_text(json_text) + + +def json_load(path: Path | str) -> dict: + """Load JSON from file and return as dictionary.""" + path = Path(path) + text = path.read_text() + return json.loads(text) From 6188e8d1e9ae251ffec82dd3fcb356fd11dac565 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 11 Mar 2026 07:42:20 -0700 Subject: [PATCH 6/9] Delete not needed logger and robust_json (moved to modelopt) Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/__init__.py | 17 -- modelopt/torch/puzzletron/tools/__init__.py | 17 -- modelopt/torch/puzzletron/tools/logger.py | 168 ------------------ .../torch/puzzletron/tools/robust_json.py | 78 -------- 4 files changed, 280 deletions(-) delete mode 100644 modelopt/torch/puzzletron/__init__.py delete mode 100644 modelopt/torch/puzzletron/tools/__init__.py delete mode 100644 modelopt/torch/puzzletron/tools/logger.py delete mode 100644 modelopt/torch/puzzletron/tools/robust_json.py diff --git a/modelopt/torch/puzzletron/__init__.py b/modelopt/torch/puzzletron/__init__.py deleted file mode 100644 index 416e67f797..0000000000 --- a/modelopt/torch/puzzletron/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""Puzzletron module for model optimization.""" diff --git a/modelopt/torch/puzzletron/tools/__init__.py b/modelopt/torch/puzzletron/tools/__init__.py deleted file mode 100644 index 70b1a3202c..0000000000 --- a/modelopt/torch/puzzletron/tools/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""Puzzletron tools module.""" diff --git a/modelopt/torch/puzzletron/tools/logger.py b/modelopt/torch/puzzletron/tools/logger.py deleted file mode 100644 index 0185b9edfe..0000000000 --- a/modelopt/torch/puzzletron/tools/logger.py +++ /dev/null @@ -1,168 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# mypy: ignore-errors -"""Distributed logging utilities for multi-rank environments.""" - -import inspect -import logging -import os -import sys - -import torch.distributed.launch # noqa: F401 - -logging.getLogger("fsspec.local").setLevel(logging.ERROR) -logging.getLogger("websockets.client").setLevel(logging.WARN) -logging.getLogger("websockets.server").setLevel(logging.WARN) -logging.getLogger("websockets.server:connection").setLevel(logging.WARN) - - -class LogColors: - """ANSI color codes for terminal output.""" - - BLUE = "\033[94m" - CYAN = "\033[96m" - GREEN = "\033[92m" - YELLOW = "\033[93m" - RED = "\033[91m" - - BOLD = "\033[1m" - UNDERLINE = "\033[4m" - RESET = "\033[0m" - - -class DistributedLogger(logging.Logger): - """Logger for distributed multi-rank environments.""" - - verbosity = logging.ERROR - - def __init__(self, name, level=logging.DEBUG): - """Initialize the distributed logger.""" - super().__init__(name, level) - self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) - self.global_rank = int(os.environ.get("RANK", 0)) - self.world_size = int(os.environ.get("WORLD_SIZE", 1)) - - def dist_log(self, msg: str, ranks: str = "main"): - """Log parameter msg with the given ranks. - - Args: - msg: The message to log. - ranks: The ranks to log the message to. Choices are: - "all": log with all ranks - "main": log with only rank 0 in node 0 - "last": log with only rank -1 in node 0 - "local_main": log with only rank 0 in all nodes - """ - # print(msg, ranks) - if ranks not in ["all", "main", "local_main", "last"]: - raise NotImplementedError( - f"Could not broadcast msg {msg} - " - f"ranks parameters choices are ['all', 'main', 'local_main']. Got {ranks}" - ) - # All ranks to print - if ranks == "all": - pass - - # Only main rank at node 0 to print - elif ( - (ranks == "main" and self.global_rank != 0) - or (ranks == "last" and self.local_rank != self.world_size - 1) - or (ranks == "local_main" and self.local_rank != 0) - ): - return - - message_source = self.get_caller_location() - - self.info( - f"{LogColors.GREEN}[rank-{self.global_rank}]{LogColors.RESET}[{message_source}]\t{msg}" - ) - - # def dist_warning(self, msg): - # if self.verbosity <= logging.WARNING: - # self.warning(f"[rank-{self.global_rank}] " + msg) - - @staticmethod - def get_caller_location() -> str: - """Get the caller location from the stack frame.""" - # Get the caller's stack frame - frame = inspect.currentframe() - - # f_back -> class method, 2 x f_back -> utils method, 3 x f_back -> original source - caller_frame = frame.f_back.f_back.f_back - - # Get the filename and line number from the caller's stack frame - filename = os.path.basename(caller_frame.f_code.co_filename) - lineno = caller_frame.f_lineno - return f"{filename}:{lineno}" - - -# Initialize logger -logging.setLoggerClass(DistributedLogger) -logger = logging.getLogger(__name__) -logger.propagate = False - -formatter = logging.Formatter("[%(asctime)s]%(message)s") -handler = logging.StreamHandler(sys.stdout) -handler.setFormatter(formatter) -handler.setLevel(logging.DEBUG) -logger.addHandler(handler) - -# Manually edit torch logger -torch_logger = logging.getLogger("torch") -torch_logger.handlers = logger.handlers -torch_logger.propagate = False - -# Manually edit deepspeed logger - -# Show some love to Mac & Windows users who can't easily install deepspeed ;) -# This is allowing running tests on Mac & Windows and train in non-DDP -try: - from deepspeed.utils import logger as deepspeed_logger - - deepspeed_logger.handlers = logger.handlers - deepspeed_logger.propagate = False -except ImportError: - # If deepspeed is not installed - no op - pass - -# Define a custom function to redirect warnings to logger -# def custom_warning_handler(message, category, filename, lineno, file=None, line=None): -# logger.dist_warning(f'{category.__name__}: {message} (in {filename}, line {lineno})') - - -# Use the custom warning handler -# warnings.showwarning = custom_warning_handler - -logger: DistributedLogger - - -def aprint(msg: str | None): - """All ranks from all nodes print.""" - return logger.dist_log(msg=msg, ranks="all") - - -def lmprint(msg: str | None): - """All local main ranks print (rank 0 in each node).""" - return logger.dist_log(msg=msg, ranks="local_main") - - -def mprint(msg: str | None): - """Master prints only (rank 0 in node 0).""" - return logger.dist_log(msg=msg, ranks="main") - - -def lprint(msg: str | None): - """Last rank prints only (rank -1 in node 0).""" - return logger.dist_log(msg=msg, ranks="last") diff --git a/modelopt/torch/puzzletron/tools/robust_json.py b/modelopt/torch/puzzletron/tools/robust_json.py deleted file mode 100644 index 030150a9ce..0000000000 --- a/modelopt/torch/puzzletron/tools/robust_json.py +++ /dev/null @@ -1,78 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# mypy: ignore-errors - -"""Provides a robust JSON encoder that can handle various types of objects. - -Including dataclasses, paths, enums, namespaces, and functions. -""" - -import argparse -import dataclasses -import datetime -import inspect -import json -from enum import Enum -from pathlib import Path -from typing import Any - -from omegaconf import DictConfig, ListConfig, OmegaConf - - -class RobustJSONEncoder(json.JSONEncoder): - """JSON encoder that handles dataclasses, paths, enums, and other special types.""" - - def default(self, o): - """Convert special objects to JSON-serializable types.""" - if dataclasses.is_dataclass(o): - return dataclasses.asdict(o) - if isinstance(o, Path): - return str(o) - if isinstance(o, Enum): - return o.name - if isinstance(o, argparse.Namespace): - return vars(o) - if type(o).__name__ == "dtype": - return str(o) - if isinstance(o, (DictConfig, ListConfig)): - return OmegaConf.to_container(o, resolve=True) - if inspect.isfunction(o) or inspect.ismethod(o): - if o.__module__ == "__main__": - # User-defined function in main — fallback to just the name - return o.__name__ - return f"{o.__module__}.{o.__qualname__}" - if isinstance(o, datetime.timedelta): - return str(o) - return super().default(o) - - -def json_dumps(obj: Any) -> str: - """Serialize object to JSON string using RobustJSONEncoder.""" - return json.dumps(obj, cls=RobustJSONEncoder, indent=2) - - -def json_dump(obj: Any, path: Path | str) -> None: - """Serialize object to JSON file using RobustJSONEncoder.""" - path = Path(path) - path.parent.mkdir(exist_ok=True, parents=True) - json_text = json_dumps(obj) - path.write_text(json_text) - - -def json_load(path: Path | str) -> dict: - """Load JSON from file and return as dictionary.""" - path = Path(path) - text = path.read_text() - return json.loads(text) From a4b89584ec6014f912046268472759c231620ff2 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 13 Mar 2026 03:00:51 -0700 Subject: [PATCH 7/9] Change the location for activation hooks from nas..... to prune.importance_hooks Signed-off-by: Daniel Korzekwa --- .../importance_hooks}/__init__.py | 2 +- .../importance_hooks}/base_hooks.py | 0 .../importance_hooks}/base_hooks_analysis.py | 0 .../importance_hooks}/compare_module_outputs.py | 2 +- .../prune/importance_hooks/plugins/__init__.py | 15 +++++++++++++++ .../importance_hooks/plugins}/megatron_hooks.py | 2 +- .../importance_hooks}/test_base_hooks.py | 2 +- .../importance_hooks}/test_base_hooks_analysis.py | 2 +- 8 files changed, 20 insertions(+), 5 deletions(-) rename modelopt/torch/{nas/plugins/megatron_hooks => prune/importance_hooks}/__init__.py (95%) rename modelopt/torch/{nas/plugins/megatron_hooks => prune/importance_hooks}/base_hooks.py (100%) rename modelopt/torch/{nas/plugins/megatron_hooks => prune/importance_hooks}/base_hooks_analysis.py (100%) rename modelopt/torch/{nas/plugins/megatron_hooks => prune/importance_hooks}/compare_module_outputs.py (99%) create mode 100644 modelopt/torch/prune/importance_hooks/plugins/__init__.py rename modelopt/torch/{nas/plugins/megatron_hooks => prune/importance_hooks/plugins}/megatron_hooks.py (97%) rename tests/gpu/torch/{nas/plugins/megatron_hooks => prune/importance_hooks}/test_base_hooks.py (97%) rename tests/gpu/torch/{nas/plugins/megatron_hooks => prune/importance_hooks}/test_base_hooks_analysis.py (99%) diff --git a/modelopt/torch/nas/plugins/megatron_hooks/__init__.py b/modelopt/torch/prune/importance_hooks/__init__.py similarity index 95% rename from modelopt/torch/nas/plugins/megatron_hooks/__init__.py rename to modelopt/torch/prune/importance_hooks/__init__.py index 996d531392..3bf30c2a46 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks/__init__.py +++ b/modelopt/torch/prune/importance_hooks/__init__.py @@ -20,4 +20,4 @@ from .base_hooks_analysis import * with import_plugin("megatron_hooks"): - from .megatron_hooks import * + from .plugins.megatron_hooks import * diff --git a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py b/modelopt/torch/prune/importance_hooks/base_hooks.py similarity index 100% rename from modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py rename to modelopt/torch/prune/importance_hooks/base_hooks.py diff --git a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py b/modelopt/torch/prune/importance_hooks/base_hooks_analysis.py similarity index 100% rename from modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py rename to modelopt/torch/prune/importance_hooks/base_hooks_analysis.py diff --git a/modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py b/modelopt/torch/prune/importance_hooks/compare_module_outputs.py similarity index 99% rename from modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py rename to modelopt/torch/prune/importance_hooks/compare_module_outputs.py index 316aff76ff..f728b97e05 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py +++ b/modelopt/torch/prune/importance_hooks/compare_module_outputs.py @@ -24,7 +24,7 @@ Step 1: Capture outputs from multiple layers: - from modelopt.torch.nas.plugins.megatron_hooks.compare_module_outputs import ( + from modelopt.torch.prune.importance_hooks.compare_module_outputs import ( OutputSaveHook, save_multi_layer_outputs, ) diff --git a/modelopt/torch/prune/importance_hooks/plugins/__init__.py b/modelopt/torch/prune/importance_hooks/plugins/__init__.py new file mode 100644 index 0000000000..97fd3fd07d --- /dev/null +++ b/modelopt/torch/prune/importance_hooks/plugins/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Plugin-specific hooks for importance estimation.""" diff --git a/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py b/modelopt/torch/prune/importance_hooks/plugins/megatron_hooks.py similarity index 97% rename from modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py rename to modelopt/torch/prune/importance_hooks/plugins/megatron_hooks.py index d792ff8941..b97c0b9a90 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py +++ b/modelopt/torch/prune/importance_hooks/plugins/megatron_hooks.py @@ -17,7 +17,7 @@ import torch from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region -from .base_hooks import L2NormHook +from ..base_hooks import L2NormHook __all__ = ["MegatronL2NormHook"] diff --git a/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py b/tests/gpu/torch/prune/importance_hooks/test_base_hooks.py similarity index 97% rename from tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py rename to tests/gpu/torch/prune/importance_hooks/test_base_hooks.py index aa73a3be19..911ecf3917 100644 --- a/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py +++ b/tests/gpu/torch/prune/importance_hooks/test_base_hooks.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -from modelopt.torch.nas.plugins.megatron_hooks import IterativeChannelContributionHook, L2NormHook +from modelopt.torch.prune.importance_hooks import IterativeChannelContributionHook, L2NormHook def _test_iterative_channel_contribution_hook_with_shape(dim1: int, dim2: int): diff --git a/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py b/tests/gpu/torch/prune/importance_hooks/test_base_hooks_analysis.py similarity index 99% rename from tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py rename to tests/gpu/torch/prune/importance_hooks/test_base_hooks_analysis.py index 954c6e11c7..ba5dcb3df4 100644 --- a/tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py +++ b/tests/gpu/torch/prune/importance_hooks/test_base_hooks_analysis.py @@ -19,7 +19,7 @@ import torch import torch.nn as nn -from modelopt.torch.nas.plugins.megatron_hooks import ( +from modelopt.torch.prune.importance_hooks import ( IndependentChannelContributionHook, IterativeChannelContributionHook, L2NormHook, From 91fa36d3c4d72b96368c959daa05825183befaa6 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 13 Mar 2026 03:13:04 -0700 Subject: [PATCH 8/9] Add a security comment why using torch.load(weight_only=False) Signed-off-by: Daniel Korzekwa --- modelopt/torch/prune/importance_hooks/base_hooks.py | 6 +++++- .../prune/importance_hooks/compare_module_outputs.py | 10 ++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/prune/importance_hooks/base_hooks.py b/modelopt/torch/prune/importance_hooks/base_hooks.py index 3d6ff43fb5..9ccc5c75a9 100644 --- a/modelopt/torch/prune/importance_hooks/base_hooks.py +++ b/modelopt/torch/prune/importance_hooks/base_hooks.py @@ -768,7 +768,11 @@ def _save_channel_importance_results( all_scores = [] for activation_file in activation_files: aprint(f"Loading activations from {activation_file}") - activation_data = torch.load(activation_file, map_location="cpu") + # SECURITY: weights_only=False is required because files contain dictionaries with tensors. + # These files are generated by dump_activations_logs() in this module and contain + # hook state dictionaries. The activations_log_dir should only contain trusted files + # generated by the same codebase, not from untrusted sources. + activation_data = torch.load(activation_file, map_location="cpu", weights_only=False) # Extract scores from the activation data for module_name, hook_data in activation_data.items(): diff --git a/modelopt/torch/prune/importance_hooks/compare_module_outputs.py b/modelopt/torch/prune/importance_hooks/compare_module_outputs.py index f728b97e05..73c3692439 100644 --- a/modelopt/torch/prune/importance_hooks/compare_module_outputs.py +++ b/modelopt/torch/prune/importance_hooks/compare_module_outputs.py @@ -177,11 +177,17 @@ def main(): # Load reference data print(f"\nLoading reference: {args.reference}") - ref_data = torch.load(args.reference, map_location="cpu") + # SECURITY: weights_only=False is required because files contain dictionaries with tensors. + # These files are expected to be generated by save_multi_layer_outputs() in this module, + # not from untrusted sources. Users should only load files they generated themselves. + ref_data = torch.load(args.reference, map_location="cpu", weights_only=False) # Load comparison data print(f"Loading compare: {args.compare}") - comp_data = torch.load(args.compare, map_location="cpu") + # SECURITY: weights_only=False is required because files contain dictionaries with tensors. + # These files are expected to be generated by save_multi_layer_outputs() in this module, + # not from untrusted sources. Users should only load files they generated themselves. + comp_data = torch.load(args.compare, map_location="cpu", weights_only=False) # Compare multi-layer outputs compare_multi_layer(ref_data, comp_data, args.output_json) From 6345290dd223a8dd4636d0b82f67ee2ec3c63097 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 13 Mar 2026 03:25:14 -0700 Subject: [PATCH 9/9] No need for computing gradients during a loop for activation hooks analysis Signed-off-by: Daniel Korzekwa --- .../importance_hooks/base_hooks_analysis.py | 51 +++++++++++-------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/modelopt/torch/prune/importance_hooks/base_hooks_analysis.py b/modelopt/torch/prune/importance_hooks/base_hooks_analysis.py index dc338a7cfa..016f249bbf 100644 --- a/modelopt/torch/prune/importance_hooks/base_hooks_analysis.py +++ b/modelopt/torch/prune/importance_hooks/base_hooks_analysis.py @@ -63,6 +63,10 @@ def evaluate_importance_scores( - "Channels" refers to INPUT features, not output features """ + # Validate non-empty input + if not activations_batches: + raise ValueError("activations_batches must not be None or empty") + num_channels = importance_scores.shape[0] num_to_prune = int(num_channels * prune_ratio) @@ -73,28 +77,31 @@ def evaluate_importance_scores( rmse_values = [] cosine_values = [] - for activations in activations_batches: - # Get original output - original_output = linear_layer(activations) - - # Prune by zeroing out identified channels - pruned_activations = activations.clone() - pruned_activations[..., channels_to_prune] = 0 - - # Get pruned output - pruned_output = linear_layer(pruned_activations) - - # Compute metrics for this batch - rmse = torch.sqrt(F.mse_loss(pruned_output, original_output)).item() - rmse_values.append(rmse) - - # Cosine similarity (flatten to vectors) - original_flat = original_output.reshape(-1) - pruned_flat = pruned_output.reshape(-1) - cosine = F.cosine_similarity( - original_flat.unsqueeze(0), pruned_flat.unsqueeze(0), dim=1 - ).item() - cosine_values.append(cosine) + # Wrap evaluation loop in no_grad() to avoid building autograd graphs + # This is an analysis-only function and doesn't need gradients + with torch.no_grad(): + for activations in activations_batches: + # Get original output + original_output = linear_layer(activations) + + # Prune by zeroing out identified channels + pruned_activations = activations.clone() + pruned_activations[..., channels_to_prune] = 0 + + # Get pruned output + pruned_output = linear_layer(pruned_activations) + + # Compute metrics for this batch + rmse = torch.sqrt(F.mse_loss(pruned_output, original_output)).item() + rmse_values.append(rmse) + + # Cosine similarity (flatten to vectors) + original_flat = original_output.reshape(-1) + pruned_flat = pruned_output.reshape(-1) + cosine = F.cosine_similarity( + original_flat.unsqueeze(0), pruned_flat.unsqueeze(0), dim=1 + ).item() + cosine_values.append(cosine) # Return averaged metrics return {