From e82164fbb074266e18d60158cc06d1b55fb861ca Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 4 Mar 2026 11:27:43 -0800 Subject: [PATCH 01/30] Add anymodel directories to feature/puzzletron - Add converter, model_descriptor, puzzformer, and llama model support - Selective merge of anymodel functionality Signed-off-by: Daniel Korzekwa --- .../puzzletron/anymodel/converter/__init__.py | 19 ++ .../anymodel/converter/convert_any_model.py | 68 +++++ .../anymodel/converter/converter.py | 235 ++++++++++++++++++ .../anymodel/converter/converter_factory.py | 75 ++++++ .../anymodel/model_descriptor/__init__.py | 18 ++ .../model_descriptor/model_descriptor.py | 210 ++++++++++++++++ .../model_descriptor_factory.py | 111 +++++++++ .../anymodel/models/llama/__init__.py | 19 ++ .../anymodel/models/llama/llama_converter.py | 50 ++++ .../models/llama/llama_model_descriptor.py | 131 ++++++++++ .../anymodel/puzzformer/__init__.py | 24 ++ .../puzzletron/anymodel/puzzformer/no_op.py | 79 ++++++ .../puzzletron/anymodel/puzzformer/utils.py | 122 +++++++++ 13 files changed, 1161 insertions(+) create mode 100644 modelopt/torch/puzzletron/anymodel/converter/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py create mode 100644 modelopt/torch/puzzletron/anymodel/converter/converter.py create mode 100644 modelopt/torch/puzzletron/anymodel/converter/converter_factory.py create mode 100644 modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py create mode 100644 modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/llama/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py create mode 100644 modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py create mode 100644 modelopt/torch/puzzletron/anymodel/puzzformer/utils.py diff --git a/modelopt/torch/puzzletron/anymodel/converter/__init__.py b/modelopt/torch/puzzletron/anymodel/converter/__init__.py new file mode 100644 index 000000000..02903b817 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Converters for transforming HuggingFace models to AnyModel format.""" + +from .convert_any_model import * +from .converter import * +from .converter_factory import * diff --git a/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py b/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py new file mode 100644 index 000000000..889685c00 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Convert a HuggingFace model to AnyModel format.""" + +from pathlib import Path + +from modelopt.torch.puzzletron.anymodel.converter.converter import Converter +from modelopt.torch.puzzletron.anymodel.converter.converter_factory import ConverterFactory +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory + +__all__ = ["convert_model"] + + +def convert_model( + input_dir: str, + output_dir: str, + converter: Converter | str, +): + """Convert a HuggingFace model to AnyModel format. + + This function converts a HuggingFace checkpoint to the AnyModel format used + for compression. The conversion process: + + 1. Copies non-weight files (config, tokenizer, etc.) + 2. Creates block_configs for each layer + 3. Reorganizes weights into subblock checkpoints + + Args: + input_dir: Path to the input HuggingFace checkpoint directory. + output_dir: Path to the output AnyModel checkpoint directory. + converter: Either a converter name (e.g., "llama") or a Converter class. + + Example: + >>> convert_model( + ... input_dir="/path/to/Llama-3.1-8B-Instruct", + ... output_dir="/path/to/output/ckpts/teacher", + ... converter="llama", + ... ) + """ + input_dir = Path(input_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Get descriptor and converter from factories (they use the same name) + descriptor = ModelDescriptorFactory.get(converter) + converter = ConverterFactory.get(converter) + + converter.convert(descriptor=descriptor, input_dir=input_dir, output_dir=output_dir) + + +if __name__ == "__main__": + from fire import Fire + + Fire(convert_model) diff --git a/modelopt/torch/puzzletron/anymodel/converter/converter.py b/modelopt/torch/puzzletron/anymodel/converter/converter.py new file mode 100644 index 000000000..5fdc92718 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/converter.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import copy +import fnmatch +import json +import os +import shutil +from abc import ABC, abstractmethod +from collections import defaultdict +from pathlib import Path +from typing import Dict, List + +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from transformers import PretrainedConfig +from transformers.integrations.mxfp4 import convert_moe_packed_tensors + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config, save_model_config + +__all__ = ["Converter"] + + +class Converter(ABC): + """Base class for converting HuggingFace models to Puzzletron/AnyModel format.""" + + @staticmethod + def _get_weight_map(input_dir: Path) -> Dict[str, str]: + """Load weight map from checkpoint directory (supports both sharded and single-file models). + + Returns a dict mapping parameter names to their safetensors filenames. + """ + index_path = input_dir / "model.safetensors.index.json" + single_file_path = input_dir / "model.safetensors" + + if index_path.exists(): + # Sharded model + with open(index_path, "r") as f: + index = json.load(f) + return index["weight_map"] + elif single_file_path.exists(): + # Single file model - create a synthetic weight map + data = load_file(single_file_path) + return {name: "model.safetensors" for name in data.keys()} + else: + raise FileNotFoundError( + f"Neither {index_path} nor {single_file_path} found. Cannot determine model format." + ) + + @classmethod + def convert_model_weights( + cls, input_dir: Path, output_dir: Path, descriptor: ModelDescriptor, num_hidden_layers: int + ): + """Convert model weights to subblock format.""" + param_to_file = Converter._get_weight_map(input_dir) + all_param_names = list(param_to_file.keys()) + + # Reverse map: file -> set of params + file_to_params = defaultdict(set) + for name, file in param_to_file.items(): + file_to_params[file].add(name) + + # Determine subblocks needed + subblocks = descriptor.get_weight_groups( + all_param_names, num_hidden_layers=num_hidden_layers + ) + + # Output directory + out_dir = output_dir / "subblocks_safetensors" + os.makedirs(out_dir, exist_ok=True) + + # New weight index + new_index = {"metadata": {"format": "pt"}, "weight_map": {}} + + for subblock, param_names in tqdm(subblocks.items(), desc="Processing subblocks"): + param_files = set(param_to_file[name] for name in param_names) + tensors = {} + + # Load only needed files for this subblock + for file in param_files: + data = load_file(os.path.join(input_dir, file)) + for name in param_names: + if param_to_file[name] == file and name in data: + converted_name = cls.convert_weight_name(name) + # Convert MoE packed tensors if quantized is mxfp4 //gpt-oss-20b + if getattr(cls, "quantized", None) == "mxfp4": + if name.endswith("_blocks"): + converted_name = converted_name.replace("_blocks", "") + tensors[converted_name] = convert_moe_packed_tensors( + data[converted_name + "_blocks"], + data[converted_name + "_scales"], + ) + elif name.endswith("_scales"): + continue + else: + tensors[converted_name] = data[name] + else: + tensors[converted_name] = data[name] + + # Save this subblock + print(f"\n✅ Group: {subblock} ({len(tensors)} layers)") + for layer in tensors.keys(): + print(f" - {layer}") + + subblock_file = f"{subblock}.safetensors" + save_file(tensors, os.path.join(out_dir, subblock_file)) + + # Update index + for new_name in tensors.keys(): + new_index["weight_map"][new_name] = f"subblocks_safetensors/{subblock_file}" + + # Save new index file + with (output_dir / "model.safetensors.index.json").open("w") as f: + json.dump(new_index, f, indent=2) + + print(f"✅ Finished saving subblocks and index to {output_dir}") + + @classmethod + def convert_configs_in_dirs( + cls, + input_dir: Path, + output_dir: Path, + ): + """Convert config and add block_configs.""" + config = load_model_config(input_dir) + + block_configs = cls.create_block_configs_from_main_config(config) + out_config = copy.deepcopy(config) + out_config.block_configs = block_configs + + save_model_config(out_config, output_dir) + return out_config + + @staticmethod + def copy_checkpoint_files(input_dir: Path, output_dir: Path): + """Copy checkpoint files except model weights (which will be converted).""" + ignore_patterns = [ + "model-*.safetensors", + "model.safetensors", + "model.safetensors.index.json", + "subblocks_safetensors", + ] + + def ignore_func(dir, files): + ignored = set() + for pattern in ignore_patterns: + ignored.update(fnmatch.filter(files, pattern)) + return ignored + + shutil.copytree(str(input_dir), str(output_dir), ignore=ignore_func, dirs_exist_ok=True) + + @classmethod + def convert( + cls, + descriptor: ModelDescriptor, + input_dir: Path, + output_dir: Path, + ): + """Convert a HuggingFace model to AnyModel format. + + Args: + descriptor: Model descriptor for the model type. + input_dir: Path to the input HuggingFace checkpoint. + output_dir: Path to the output AnyModel checkpoint. + """ + cls.copy_checkpoint_files(input_dir, output_dir) + config = cls.convert_configs_in_dirs(input_dir, output_dir) + cls.convert_model_weights( + input_dir, output_dir, descriptor=descriptor, num_hidden_layers=config.num_hidden_layers + ) + + @staticmethod + @abstractmethod + def create_block_configs_from_main_config(config: PretrainedConfig) -> List[BlockConfig]: + """Create per-layer BlockConfig list from a HuggingFace model config. + + This method extracts layer-specific parameters (e.g., intermediate_size, + num_key_value_heads) from the main model config and creates a BlockConfig + for each layer. These BlockConfigs enable layer-specific pruning and + modifications during the compression pipeline. + + Args: + config: HuggingFace PretrainedConfig (e.g., LlamaConfig, Qwen2Config) + + Returns: + List of BlockConfig, one per hidden layer. Each BlockConfig contains: + - AttentionConfig: attention settings (no_op, num_key_value_heads) + - FFNConfig: FFN settings (no_op, intermediate_size) + + Example: + For a model with uniform layers (e.g., Llama): + return [BlockConfig(...)] * config.num_hidden_layers + + For a model with heterogeneous layers (e.g., NemotronH with Mamba/Attention): + return [BlockConfig(...) for layer_idx in range(num_layers)] + """ + raise NotImplementedError + + @staticmethod + def convert_weight_name(name: str) -> str: + """ + Convert weight names during checkpoint conversion. + + This method can be overridden by subclasses to apply model-specific weight name + transformations when converting checkpoints from HuggingFace format to Puzzletron format. + + Default implementation returns the name unchanged (identity function). + + Args: + name: Original weight name from HuggingFace checkpoint + + Returns: + Converted weight name for Puzzletron format + + Example: + For Qwen2.5-VL, this converts: + - visual.* → model.visual.* + - model.* → model.language_model.* + """ + return name diff --git a/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py b/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py new file mode 100644 index 000000000..88d490d65 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import inspect +from typing import Callable, Type + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor + +__all__ = ["ConverterFactory"] + + +class ConverterFactory: + """Factory for registering and retrieving Converter classes.""" + + CLASS_MAPPING = {} + + @classmethod + def register(cls, **entries: Type): + """Register converter classes. + + Raises: + KeyError: if entry key is already in type_dict and points to a different class. + """ + for cls_name, cls_type in entries.items(): + if cls_name in cls.CLASS_MAPPING: + ref = cls.CLASS_MAPPING[cls_name] + # If ref and cls_name point to the same class ignore and don't raise an exception. + if cls_type == ref: + continue + raise KeyError( + f"Could not register `{cls_name}`: {cls_type}, " + f"`{cls_name}` is already registered and points to " + f"`{inspect.getmodule(ref).__name__}.{ref.__name__}`" + ) + cls.CLASS_MAPPING[cls_name] = cls_type + + @classmethod + def register_decorator(cls, name: str | None) -> Callable: + """Set up a register decorator. + + Args: + name: If specified, the decorated object will be registered with this name. + + Returns: + Decorator that registers the callable. + """ + + def decorator(cls_type: Type) -> Callable: + """Register the decorated callable.""" + cls_name = name if name is not None else cls_type.__name__ + cls.register(**{cls_name: cls_type}) + return cls_type + + return decorator + + @classmethod + def get(cls, value: str | ModelDescriptor): + """Get a registered converter by name or return the converter if already resolved.""" + if isinstance(value, str): + if value in cls.CLASS_MAPPING: + return cls.CLASS_MAPPING[value] + return value diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py new file mode 100644 index 000000000..cc8e89e34 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model descriptors for defining model-specific properties and layer naming conventions.""" + +from .model_descriptor import * +from .model_descriptor_factory import * diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py new file mode 100644 index 000000000..69af0e66c --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Type + +import torch.nn as nn + +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock + +__all__ = ["ModelDescriptor"] + + +class ModelDescriptor(ABC): + @staticmethod + @abstractmethod + def decoder_layer_cls() -> Type[nn.Module] | List[Type[nn.Module]]: + """Decoder layer class types to patch for heterogeneous config support. + + In most cases this class will hold as attributes both FFN & attention layers. + + Returns: + nn.Module class type or a list if several class types should be patched. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def block_config_to_layer_overrides(block_config: BlockConfig) -> Dict[str, Any]: + """Map between BlockConfig and layer config overrides. + + These overrides are consumed by a specific decoder layer and by the whole model. + Usage can be seen in `deci_x_patcher` under the method `_patched_decoder_layer_init`. + + Example implementation to override the FFN intermediate size of a block: + >>> def block_config_to_layer_overrides(block_config: BlockConfig) -> Dict[str, Any]: + >>> return {"intermediate_size": block_config.ffn.intermediate_size} + """ + raise NotImplementedError + + @staticmethod + def mlp_no_op_post_init(decoder_layer: nn.Module): + """Post-init callback to alter a decoder layer so that FFN/mlp subblock performs as no-op. + + It is recommended to use the utils modules from `no_op.py` to replace layers to dummy + counterparts. + + Example for replacing a layernorm layer with identity: + >>> decoder_layer.post_attention_layernorm = Same() + + Example for replacing an MLP layer with zeroes (zeroes since hidden_states are added to + the residuals hidden_states so a no-op implementation will leave residual the same): + >>> decoder_layer.mlp = MatchingZeros() + + In case the MLP layer to replace returns multiple outputs i.e `hidden_states, _ = self.mlp()`, + use the util method `return_tuple_of_size` to return trailing None values: + >>> decoder_layer.mlp = return_tuple_of_size(MatchingZeros, size=2)() + """ + raise NotImplementedError + + @staticmethod + def attn_no_op_post_init(decoder_layer: nn.Module): + """Post-init callback to alter a decoder layer so that Attention subblock performs as no-op. + + It is recommended to use the utils modules from `no_op.py` to replace layers to dummy + counterparts. + + Example for replacing a layernorm layer with identity: + >>> decoder_layer.post_attention_layernorm = Same() + + Example for replacing an attention layer with zeroes: + >>> decoder_layer.self_attn = MatchingZeros() + + In case the attention layer returns multiple outputs i.e `hidden_states, _ = self.self_attn()`, + use the util method `return_tuple_of_size` to return trailing None values: + >>> decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def init_rotary_embedding(model, runtime): + """Re-initiate the rotary embeddings based on an existing model. + + In puzzletron we initiate a sharded model by first creating a meta model then replacing + to the actual device by loading the state_dict with the real weights. + + Rotary embeddings frequencies are tensor buffers that are created dynamically during init + and are not part of the model state_dict, so cannot be restored after a meta device + initialization. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def input_embedding_name(): + """Return the name of the input embedding layer.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def output_embedding_name(): + """Return the name of the output embedding layer.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def final_norm_name(): + """Return the name of the final normalization layer.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def layer_block_name(index: int): + """Return the name of the decoder layer at the given index.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + """Return predicates for grouping model weights to support subblock checkpointing. + + For every group name return a regex predicate whether a layer name is part of the group. + + Returns: + Dictionary of group name to regex pattern predicate. + """ + raise NotImplementedError + + @staticmethod + def uses_autocast() -> bool: + """Whether this model supports torch.autocast. + + Some models (e.g., Qwen3-VL MoE) have dtype bugs under autocast. + Override and return False for models that do not support autocast. + """ + return True + + @staticmethod + def get_language_model_config(config): + """Get the language model config from a PretrainedConfig. + + For regular LM models, returns the config itself. + For VL/multimodal models with nested configs, override to return the + language model portion (e.g., config.text_config for Qwen-VL). + """ + return config + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + """Create a dummy block to replace a layer for sharded model initialization.""" + return DummyBlock(block_index=block_index) + + @classmethod + def mlp_no_op_supported(cls) -> bool: + """Check whether `mlp_no_op_post_init` is overridden for mlp no-op support.""" + method_name = ModelDescriptor.mlp_no_op_post_init.__name__ + return getattr(cls, method_name) is not getattr(ModelDescriptor, method_name) + + @classmethod + def attn_no_op_supported(cls): + """Check whether `attn_no_op_post_init` is overridden for attention no-op support.""" + method_name = ModelDescriptor.attn_no_op_post_init.__name__ + return getattr(cls, method_name) is not getattr(ModelDescriptor, method_name) + + @classmethod + def get_weight_groups( + cls, layer_names: Iterable[str], num_hidden_layers: int + ) -> Dict[str, List[str]]: + """Group model weights to support the puzzle subblock checkpointing format. + + This method uses the abstract method `layer_name_predicates` by default. + + Args: + layer_names: state_dict layer names of the model. + num_hidden_layers: number of decoder layers in the model. + + Returns: + Dictionary of group names to list of layer names per group, e.g.: + >>> { + ... "embedding": ["model.embed_tokens.weight"], + ... "lm_head": ["lm_head.weight", "model.norm.weight"], + ... "block_0_ffn": ["model.layers.0.mlp.down_proj", ...], + ... "block_0_attention": ["model.layers.0.self_attn.q_proj", ...], + ... } + """ + weight_groups = defaultdict(list) + for name in layer_names: + for group, pattern in cls.layer_name_predicates(num_hidden_layers).items(): + if pattern.match(name): + weight_groups[group].append(name) + break + else: + raise ValueError(f"Couldn't find a match for {name}") + return weight_groups diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py new file mode 100644 index 000000000..23a42da58 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import inspect +from typing import Callable, Type + +from transformers import AutoConfig + +from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor import ModelDescriptor + +__all__ = ["ModelDescriptorFactory"] + +# Map from HuggingFace config.model_type (in checkpoint config.json) to ModelDescriptorFactory name. +# Local to this script; add entries when supporting new model types for auto-detection. +_MODEL_TYPE_TO_DESCRIPTOR = { + "llama": "llama", + "mistral": "mistral_small", + "qwen2": "qwen2", + "qwen3": "qwen3", + "nemotron_h": "nemotron_h", + "nemotron_h_v2": "nemotron_h_v2", + "gpt_oss_20b": "gpt_oss_20b", +} + + +def resolve_descriptor_from_pretrained(pretrained: str | None, trust_remote_code: bool = True): + """Resolve the model descriptor by loading the checkpoint config and mapping model_type.""" + if not pretrained: + raise ValueError("pretrained must be provided") + + config = AutoConfig.from_pretrained(pretrained, trust_remote_code=trust_remote_code) + model_type = getattr(config, "model_type", None) + + if model_type and model_type in _MODEL_TYPE_TO_DESCRIPTOR: + detected = _MODEL_TYPE_TO_DESCRIPTOR[model_type] + print( + f"[resolve_descriptor_from_pretrained] Auto-detected model_type='{model_type}' → descriptor='{detected}'" + ) + return ModelDescriptorFactory.get(detected) + + known = sorted(_MODEL_TYPE_TO_DESCRIPTOR.keys()) + raise ValueError( + f"Cannot auto-detect descriptor for model_type='{model_type}'. " + f"Known model types: {known}. Add this model_type to _MODEL_TYPE_TO_DESCRIPTOR if supported." + ) + + +class ModelDescriptorFactory: + """Factory for registering and retrieving ModelDescriptor classes.""" + + CLASS_MAPPING = {} + + @classmethod + def register(cls, **entries: Type): + """Register model descriptor classes. + + Raises: + KeyError: if entry key is already in type_dict and points to a different class. + """ + for cls_name, cls_type in entries.items(): + if cls_name in cls.CLASS_MAPPING: + ref = cls.CLASS_MAPPING[cls_name] + # If ref and cls_name point to the same class ignore and don't raise an exception. + if cls_type == ref: + continue + raise KeyError( + f"Could not register `{cls_name}`: {cls_type}, " + f"`{cls_name}` is already registered and points to " + f"`{inspect.getmodule(ref).__name__}.{ref.__name__}`" + ) + cls.CLASS_MAPPING[cls_name] = cls_type + + @classmethod + def register_decorator(cls, name: str | None) -> Callable: + """Set up a register decorator. + + Args: + name: If specified, the decorated object will be registered with this name. + + Returns: + Decorator that registers the callable. + """ + + def decorator(cls_type: Type) -> Callable: + """Register the decorated callable.""" + cls_name = name if name is not None else cls_type.__name__ + cls.register(**{cls_name: cls_type}) + return cls_type + + return decorator + + @classmethod + def get(cls, value: str | ModelDescriptor): + """Get a registered model descriptor by name or return the descriptor if already resolved.""" + if isinstance(value, str): + if value in cls.CLASS_MAPPING: + return cls.CLASS_MAPPING[value] + return value diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py b/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py new file mode 100644 index 000000000..a0be9f919 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.llama.llama_converter import LlamaConverter +from modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor import ( + LlamaModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py b/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py new file mode 100644 index 000000000..1f8cf77b5 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Llama converter for AnyModel compression.""" + +from typing import List + +from transformers import LlamaConfig + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) + + +@ConverterFactory.register_decorator("llama") +class LlamaConverter(Converter): + """Converter for Llama models to AnyModel format.""" + + @staticmethod + def create_block_configs_from_main_config(config: LlamaConfig) -> List[BlockConfig]: + """Create uniform block configs for all Llama layers. + + Llama models have uniform architecture across all layers, so we create + the same BlockConfig for each layer. + """ + num_hidden_layers = config.num_hidden_layers + + block_config = BlockConfig( + attention=AttentionConfig(no_op=False, num_key_value_heads=config.num_key_value_heads), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + + block_configs = [block_config] * num_hidden_layers + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py new file mode 100644 index 000000000..fe416e2dd --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Llama model descriptor for AnyModel compression.""" + +import re +from dataclasses import dataclass, field +from typing import Dict, List + +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaRotaryEmbedding, +) + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediateLayerDescriptor, +) + + +@ModelDescriptorFactory.register_decorator("llama") +class LlamaModelDescriptor(ModelDescriptor): + """Model descriptor for Llama models (Llama 2, Llama 3, Llama 3.1, Llama 3.2).""" + + @staticmethod + def decoder_layer_cls(): + return LlamaDecoderLayer + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + return { + "intermediate_size": block_config.ffn.intermediate_size, + "num_key_value_heads": block_config.attention.num_key_value_heads, + } + + @staticmethod + def attn_no_op_post_init(decoder_layer: LlamaDecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: LlamaDecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model: LlamaForCausalLM, runtime): + model.model.rotary_emb = LlamaRotaryEmbedding(model.config, runtime.device) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.o_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class LlamaFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + """Layer descriptor for Llama FFN intermediate pruning.""" + + down_proj_name: str = "mlp.down_proj" + ffn_prefix_name: str = "model.layers.{layer_idx}.mlp" + linear_weight_names: List[str] = field( + default_factory=lambda: ["down_proj", "gate_proj", "up_proj"] + ) diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py b/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py new file mode 100644 index 000000000..aac6f0f20 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.utils import ( + deci_x_patcher, + override_config_with_block_configs, +) diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py b/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py new file mode 100644 index 000000000..aac57af0a --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""No-op modules for replacing layers during pruning.""" + +from functools import cache + +import torch +import torch.nn as nn + + +@cache +def return_tuple_of_size(cls: type[nn.Module], size: int) -> type[nn.Module]: + """Create a wrapper class that returns a tuple of the given size. + + Useful for replacing modules that return multiple outputs (e.g., attention layers + that return (hidden_states, attn_weights)). + + Args: + cls: The base module class to wrap. + size: The size of the tuple to return. + + Returns: + A new class that wraps the base class and returns a tuple of the given size. + + Example: + >>> decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + """ + + class Wrapped(cls): + def forward(self, *args, **kwargs): + result = super().forward(*args, **kwargs) + outputs = [None] * size + outputs[0] = result[0] + return tuple(outputs) + + def extra_repr(self) -> str: + return f"[{cls.__name__}]" + + return Wrapped + + +class MatchingZeros(nn.Module): + """Module that returns zeros matching the input shape. + + Used to replace MLP or attention layers with no-ops. Returns zeros because + the hidden_states are added to the residuals, so a no-op implementation + should leave the residual unchanged. + """ + + def forward(self, hidden_states, *args, **kwargs): + return torch.zeros_like(hidden_states) + + +class Same(nn.Module): + """Module that returns the input unchanged. + + Used to replace normalization layers with identity operations. + """ + + def forward(self, hidden_states, *args, **kwargs): + return hidden_states + + @property + def weight(self): + """Support NemotronH with scoring_activations, when lm_head is called `self.lm_head.weight.dtype`.""" + return torch.empty(0) diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/utils.py b/modelopt/torch/puzzletron/anymodel/puzzformer/utils.py new file mode 100644 index 000000000..93913b8e2 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/puzzformer/utils.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import copy +import inspect +from contextlib import ExitStack, contextmanager +from functools import wraps +from typing import Any, Dict, List + +from transformers import PretrainedConfig + +from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + BlockConfig, + maybe_cast_block_configs, +) + + +def _get_variable_from_stack(names: list[str]) -> Any: + """Search the call stack for a variable with one of the given names.""" + f = inspect.currentframe().f_back + while f: + for name in names: + if name in f.f_locals: + return f.f_locals[name] + f = f.f_back + raise RuntimeError(f"{names} not found in caller stack") + + +@contextmanager +def deci_x_patcher( + model_descriptor: ModelDescriptor, + block_configs: List[BlockConfig | dict] | None = None, +): + """Context manager that patches decoder layer __init__ for heterogeneous per-layer configs. + + This is the core mechanism that enables AnyModel to work with any HuggingFace model. + It patches the decoder layer class(es) to read per-layer block_configs and apply + layer-specific overrides (e.g., different intermediate_size per layer). + + Args: + model_descriptor: The model descriptor that defines which classes to patch + and how to map block_configs to layer overrides. + block_configs: Optional list of BlockConfig (one per layer). If not provided, + will try to read from config.block_configs during model initialization. + + Example: + >>> with deci_x_patcher(LlamaModelDescriptor, block_configs): + ... model = AutoModelForCausalLM.from_config(config) + """ + decoder_layer_classes = model_descriptor.decoder_layer_cls() # Now a list of classes + if not isinstance(decoder_layer_classes, list): + decoder_layer_classes = [decoder_layer_classes] + + orig_inits = [] + for cls in decoder_layer_classes: + orig_inits.append(cls.__init__) + + block_configs = maybe_cast_block_configs(block_configs) + + @wraps(orig_inits[0]) + def _patched_decoder_layer_init(self, config, *args, **kwargs): + _block_configs = block_configs or getattr(config, "block_configs", None) + if _block_configs is None: + return orig_inits[decoder_layer_classes.index(self.__class__)]( + self, config, *args, **kwargs + ) + + _block_configs = maybe_cast_block_configs(_block_configs) + layer_idx = _get_variable_from_stack(["layer_idx", "idx"]) + _block_config = _block_configs[layer_idx] + override_block_config = model_descriptor.block_config_to_layer_overrides(_block_config) + _config = override_config_with_block_configs(config, override_block_config) + orig_inits[decoder_layer_classes.index(self.__class__)](self, _config, *args, **kwargs) + + # Apply no-op post-init + if _block_config.attention.no_op: + if not model_descriptor.attn_no_op_supported(): + raise NotImplementedError( + f"attn no-op not supported for `{model_descriptor.__class__.__name__}`, " + "please implement the method: `attn_no_op_post_init()`" + ) + model_descriptor.attn_no_op_post_init(decoder_layer=self) + + if _block_config.ffn.no_op: + if not model_descriptor.mlp_no_op_supported(): + raise NotImplementedError( + f"mlp no-op not supported for `{model_descriptor.__class__.__name__}`, " + "please implement the method: `mlp_no_op_post_init()`" + ) + model_descriptor.mlp_no_op_post_init(decoder_layer=self) + + with ExitStack() as stack: + # Patch every decoder layer class + for orig_init, cls in zip(orig_inits, decoder_layer_classes): + stack.callback(setattr, cls, "__init__", orig_init) # Restore on exit + cls.__init__ = _patched_decoder_layer_init + yield + + +def override_config_with_block_configs( + config: PretrainedConfig, block_configs: Dict[str, Any] +) -> PretrainedConfig: + """Create a copy of config with block_config overrides applied.""" + _config = copy.deepcopy(config) + # Model initialization requires fails with None in case of no-ops + _config_overrides = {k: v for k, v in block_configs.items() if v is not None} + _config.update(_config_overrides) + return _config From 2099df3af28abb37ef34c74d051ef1809245927f Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 5 Mar 2026 04:08:33 -0800 Subject: [PATCH 02/30] Make any_model conversion working. Signed-off-by: Daniel Korzekwa --- .../nas/plugins/megatron_hooks/base_hooks.py | 380 +++++++++- modelopt/torch/puzzletron/anymodel/README.md | 204 ++++++ .../torch/puzzletron/anymodel/__init__.py | 64 ++ .../model_descriptor_factory.py | 17 +- .../puzzletron/anymodel/models/__init__.py | 24 + .../decilm/deci_lm_hf_code/block_config.py | 97 +-- .../pruning/expert_removal_pruning_mixin.py | 239 +++++++ .../pruning/ffn_intermediate_pruning_mixin.py | 102 +++ .../pruning/kv_heads_pruning_mixin.py | 127 ++++ .../torch/puzzletron/pruning/pruning_ckpts.py | 94 +-- .../torch/puzzletron/pruning/pruning_mixin.py | 73 ++ .../torch/puzzletron/pruning/pruning_utils.py | 647 ++++++++++++++++++ .../puzzletron/tools/checkpoint_utils_hf.py | 152 ++-- .../torch/puzzletron/utils/dummy_modules.py | 75 ++ tests/_test_utils/torch/puzzletron/utils.py | 145 +++- .../llama_3_1_8b_instruct.yaml | 107 +++ .../pruning/attn_pruning.yaml | 16 + .../pruning/ffn_pruning.yaml | 18 + .../pruning/hidden_dim_pruning.yaml | 15 + .../pruning/pruning_defaults.yaml | 33 + .../validate_model_defaults.yaml | 15 + .../validate_solutions_defaults.yaml | 10 + .../llama_3_1_8b_instruct/config.json | 38 + .../tokenizer/special_tokens_map.json | 16 + .../resources/tokenizer/tokenizer.json | 212 ++++++ .../resources/tokenizer/tokenizer_config.json | 13 + .../resources/tokenizer/truncate_tokenizer.py | 62 ++ tests/gpu/torch/puzzletron/test_puzzletron.py | 303 ++++++-- 28 files changed, 3027 insertions(+), 271 deletions(-) create mode 100644 modelopt/torch/puzzletron/anymodel/README.md create mode 100644 modelopt/torch/puzzletron/anymodel/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/__init__.py create mode 100644 modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py create mode 100644 modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py create mode 100644 modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py create mode 100644 modelopt/torch/puzzletron/pruning/pruning_mixin.py create mode 100644 modelopt/torch/puzzletron/pruning/pruning_utils.py create mode 100644 modelopt/torch/puzzletron/utils/dummy_modules.py create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/attn_pruning.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/ffn_pruning.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/hidden_dim_pruning.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/pruning_defaults.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_model_defaults.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_solutions_defaults.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/hf_configs/llama_3_1_8b_instruct/config.json create mode 100644 tests/gpu/torch/puzzletron/resources/tokenizer/special_tokens_map.json create mode 100644 tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer.json create mode 100644 tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer_config.json create mode 100644 tests/gpu/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py diff --git a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py index 56436acfd..7cd721444 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py +++ b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# mypy: ignore-errors """Forward hooks for activation-based importance estimation.""" import gc @@ -26,6 +27,7 @@ from torch import nn import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig # noqa: TC001 from modelopt.torch.puzzletron.tools.logger import aprint from modelopt.torch.puzzletron.tools.robust_json import json_dump @@ -150,7 +152,8 @@ def dump_activations_logs( torch.save(activations_log, activations_log_path) if rank == 0: - args.activation_hooks_kwargs.pop("model") + if args.activation_hooks_kwargs is not None: + args.activation_hooks_kwargs.pop("model", None) json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json") dist.barrier() @@ -822,3 +825,378 @@ def _save_channel_importance_results( aprint(f"Score range: {avg_scores.min():.4f} to {avg_scores.max():.4f}") aprint(f"Score mean: {avg_scores.mean():.4f}") aprint(f"Score std: {avg_scores.std():.4f}") + + +class RemoveExpertsIndependentHook(ForwardHook, ABC): + """Base hook for measuring expert importance in Mixture-of-Experts models. + + This hook measures how much removing each expert affects the model output + by comparing outputs with and without each expert. + """ + + def __init__(self, moe: nn.Module, activation_hooks_kwargs: dict): + """Initialize the hook. + + Args: + moe: The MoE module to analyze + activation_hooks_kwargs: Configuration dict containing block_config + """ + self.moe = moe + block_config: BlockConfig = activation_hooks_kwargs["block_config"] + self.num_local_experts = block_config.ffn.moe.num_local_experts + self.num_experts_per_tok = block_config.ffn.moe.num_experts_per_tok + # tensor of zeros of size num experts + self.diffs = ["mse", "cosine"] + some_param = next(self.moe.parameters()) + self.diffs = { + k: torch.zeros( + size=(self.num_local_experts,), dtype=torch.float32, device=some_param.device + ) + for k in self.diffs + } + self.call_count = 0 + + @abstractmethod + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for measuring expert importance. + + This method is called twice per forward pass: + 1. First call (router_logits=None): Compute original routing and expert outputs + 2. Second call (router_logits provided): Re-run with modified logits (expert disabled) + + Args: + hidden_states: Input tensor of shape (batch, seq_len, hidden_dim) + router_logits: Optional pre-computed router logits. If None, compute from hidden_states. + + Returns: + tuple of (router_logits, routed_experts): + - router_logits: Shape (num_tokens, num_local_experts) + - routed_experts: Shape (num_tokens, hidden_dim) + """ + raise NotImplementedError + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that measures expert importance.""" + hidden_states = args[0] + router_logits, original_routed_out = self.get_router_logits_and_routed_experts( + hidden_states + ) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + original_routed_out = original_routed_out.view(-1, original_routed_out.shape[-1]) + + _, router_indices = torch.topk(router_logits, self.num_experts_per_tok, dim=-1) + self.call_count += 1 + + for i_expert in range(self.num_local_experts): + expert_mask = router_indices == i_expert + is_token_routed_to_this_expert = expert_mask.any(dim=-1) + + num_tokens_displaced = is_token_routed_to_this_expert.sum() + if num_tokens_displaced == 0: + continue + num_total_tokens = is_token_routed_to_this_expert.numel() + + relevant_hidden_states = hidden_states[is_token_routed_to_this_expert, :] + + router_logits_without_i = router_logits.clone() + router_logits_without_i[..., i_expert] = -float("inf") # disable expert i + router_logits_without_i = router_logits_without_i[is_token_routed_to_this_expert, :] + _, routed_out_without_i = self.get_router_logits_and_routed_experts( + relevant_hidden_states, router_logits_without_i + ) + + relevant_tokens_original_out = original_routed_out[is_token_routed_to_this_expert, :] + self.diffs["mse"][i_expert] += ( + nn.functional.mse_loss( + relevant_tokens_original_out, routed_out_without_i, reduction="mean" + ) + * num_tokens_displaced + / num_total_tokens + ) + self.diffs["cosine"][i_expert] += ( + -nn.functional.cosine_similarity( + relevant_tokens_original_out, routed_out_without_i, dim=-1 + ).mean() + * num_tokens_displaced + / num_total_tokens + ) + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert accumulated statistics to dict format.""" + expert_ranks_mse = torch.argsort(self.diffs["mse"]) + expert_ranks_cosine = torch.argsort(self.diffs["cosine"]) + return { + "expert_ranks_mse": expert_ranks_mse.cpu(), + "expert_ranks_cosine": expert_ranks_cosine.cpu(), + "cosine_diffs": (self.diffs["cosine"] / self.call_count).cpu(), + "mse_diffs": (self.diffs["mse"] / self.call_count).cpu(), + } + + def accumulate(self) -> torch.Tensor: + """Return accumulated expert importance scores.""" + return self.diffs["mse"] + + def state_dict(self) -> dict: + """Return the internal state for checkpointing.""" + return { + "diffs_mse": self.diffs["mse"].cpu(), + "diffs_cosine": self.diffs["cosine"].cpu(), + "call_count": self.call_count, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.diffs["mse"] = state_dict["diffs_mse"].to(self.diffs["mse"].device) + self.diffs["cosine"] = state_dict["diffs_cosine"].to(self.diffs["cosine"].device) + self.call_count = state_dict["call_count"] + + +class NemotronHRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): + """Expert removal importance hook for NemotronH models.""" + + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for NemotronH MoE. + + Based on NemotronHMOE forward, uses minimum ops to get router_logits and routed_experts. + """ + orig_shape = hidden_states.shape + # NemotronHMOE.gate forward, copied to extract router_logits + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + if router_logits is None: + router_logits = nn.functional.linear( + hidden_states.type(torch.float32), self.moe.gate.weight.type(torch.float32) + ) + router_logits = router_logits.sigmoid() + router_logits = router_logits + self.moe.gate.e_score_correction_bias.unsqueeze(0) + + topk_indices = self._get_topk_indices_without_correction_bias(router_logits) + topk_weights = router_logits.gather(1, topk_indices) + if self.moe.gate.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.moe.gate.routed_scaling_factor + # Routed experts forward + hidden_states = self.moe.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + return router_logits, hidden_states + + @torch.no_grad() + def _get_topk_indices_without_correction_bias(self, scores: torch.Tensor) -> torch.Tensor: + """Get topk indices without correction bias. + + Same as NemotronHMOE.gate.get_topk_indices but without adding e_score_correction_bias. + """ + group_scores = ( + scores.view( + -1, self.moe.gate.n_group, self.moe.gate.n_routed_experts // self.moe.gate.n_group + ) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.moe.gate.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + -1, self.moe.gate.n_group, self.moe.gate.n_routed_experts // self.moe.gate.n_group + ) + .reshape(-1, self.moe.gate.n_routed_experts) + ) + scores_for_choice = scores.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.moe.gate.top_k, dim=-1, sorted=False)[1] + return topk_indices + + +class RankedChoiceVotingHook(ForwardHook): + """Hook for ranking experts using ranked choice voting algorithm. + + This hook tracks router decisions and uses ranked choice voting to determine + which experts are least important (can be pruned first). + """ + + def __init__(self, router: nn.Module, activation_hooks_kwargs: dict): + """Initialize the hook. + + Args: + router: The router module (typically nn.Linear) + activation_hooks_kwargs: Configuration dict containing block_config + """ + self.router_argsort: list[torch.Tensor] = [] + block_config: BlockConfig = activation_hooks_kwargs["block_config"] + self.top_k = block_config.ffn.moe.num_experts_per_tok + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that records router decisions. + + Args: + module: The router module + args: Tuple with one tensor entry (B, T, I) + output: Router logits of shape (B, T, E) + """ + router_logits = output[0] if isinstance(output, tuple) else output + num_experts = router_logits.shape[-1] + router_argsort = torch.argsort(router_logits, dim=-1, descending=True) + router_argsort = router_argsort.view(-1, num_experts).to(torch.int16).cpu() + self.router_argsort.append(router_argsort) + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert accumulated statistics to dict format using ranked choice voting.""" + router_argsort = torch.concat(self.router_argsort, dim=0) + num_tokens, num_experts = router_argsort.shape + + expert_ranks = torch.full((num_experts,), -1) + expert_counts_at_pruning_time = {} + + expert_kept_per_iteration: list[list[int]] = [] + expert_counts_per_iteration: list[dict[int, int]] = [] + + for rank in range(num_experts): + ids, counts = router_argsort[:, : self.top_k].unique(return_counts=True) + ids = ids.tolist() + counts = counts.tolist() + expert_counts = dict(zip(ids, counts)) + + expert_kept_per_iteration.append(ids) + expert_counts_per_iteration.append(expert_counts) + + least_popular_expert, min_count = min(expert_counts.items(), key=lambda tup: tup[1]) + + expert_ranks[least_popular_expert] = rank + expert_counts_at_pruning_time[least_popular_expert] = min_count + aprint(f"#{rank}: router_argsort shape = {router_argsort.shape}") + router_argsort = router_argsort[router_argsort != least_popular_expert].view( + num_tokens, -1 + ) + + zero_shot_expert_counts = torch.zeros((num_experts,), dtype=torch.long) + for expert_id, expert_counts_val in expert_counts_per_iteration[0].items(): + zero_shot_expert_counts[expert_id] = expert_counts_val + + # Compute zero-shot expert ranks (double argsort converts counts to rank positions) + zero_shot_expert_ranks = torch.argsort(torch.argsort(zero_shot_expert_counts)) + + aprint("Done: Returning hook metadata.") + return { + "expert_ranks": expert_ranks, + "zero_shot_expert_ranks": zero_shot_expert_ranks, + "expert_counts_at_pruning_time": expert_counts_at_pruning_time, + "expert_counts_per_iteration": expert_counts_per_iteration, + "top_k": self.top_k, + } + + def accumulate(self) -> torch.Tensor: + """Return accumulated expert ranks.""" + if not self.router_argsort: + return torch.tensor([]) + router_argsort = torch.concat(self.router_argsort, dim=0) + return router_argsort[:, 0].float() + + def state_dict(self) -> dict: + """Return the internal state for checkpointing.""" + return { + "router_argsort": [tensor.cpu().clone() for tensor in self.router_argsort], + "top_k": self.top_k, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.router_argsort = [tensor.cpu() for tensor in state_dict["router_argsort"]] + self.top_k = state_dict["top_k"] + + def get_progress_info(self) -> dict: + """Get progress information.""" + return { + "num_batches_processed": len(self.router_argsort), + "total_tokens_processed": sum(tensor.shape[0] for tensor in self.router_argsort) + if self.router_argsort + else 0, + } + + +class RankedChoiceVotingHookNemotronH(RankedChoiceVotingHook): + """Ranked choice voting hook for NemotronH models. + + In NemotronH, router_logits is an internal temporary state that never leaves + the forward() function. We reconstruct router_logits from the input hidden_states. + """ + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that reconstructs router logits from hidden states.""" + hidden_states = args[0] + hidden_states = hidden_states.view(-1, module.config.hidden_size) + router_logits = nn.functional.linear( + hidden_states.type(torch.float32), module.weight.type(torch.float32) + ) + super().__call__(module, args, router_logits) + + +class Qwen3VLRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): + """Expert removal importance hook for Qwen3-VL models. + + TODO: Implement get_router_logits_and_routed_experts based on Qwen3-VL MoE forward pass. + """ + + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for Qwen3-VL MoE. + + Note: This is a placeholder implementation. Implement based on Qwen3VLMoeSparseMoe forward. + """ + batch_size = ( + hidden_states.shape[0] * hidden_states.shape[1] + if hidden_states.ndim > 2 + else hidden_states.shape[0] + ) + router_logits_out = torch.zeros( + batch_size, self.num_local_experts, device=hidden_states.device + ) + routed_experts = hidden_states.view(-1, hidden_states.shape[-1]) + return router_logits_out, routed_experts + + +class GptOssRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): + """Expert removal importance hook for GPT-OSS models. + + TODO: Implement get_router_logits_and_routed_experts based on GPT-OSS MoE forward pass. + This is a placeholder implementation that allows the framework to run. + """ + + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for GPT-OSS MoE. + + Note: This is a placeholder implementation. For proper expert scoring, + implement based on GptOssSparseMoeBlock forward pass. + + Args: + hidden_states: Input tensor of shape (batch, seq_len, hidden_dim) + router_logits: Optional pre-computed router logits + + Returns: + tuple of (router_logits, routed_experts): + - router_logits: Shape (num_tokens, num_local_experts) - zeros as placeholder + - routed_experts: Original hidden states (no-op) + """ + batch_size = ( + hidden_states.shape[0] * hidden_states.shape[1] + if hidden_states.ndim > 2 + else hidden_states.shape[0] + ) + router_logits_out = torch.zeros( + batch_size, self.num_local_experts, device=hidden_states.device + ) + routed_experts = hidden_states.view(-1, hidden_states.shape[-1]) + return router_logits_out, routed_experts diff --git a/modelopt/torch/puzzletron/anymodel/README.md b/modelopt/torch/puzzletron/anymodel/README.md new file mode 100644 index 000000000..a8b960165 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/README.md @@ -0,0 +1,204 @@ +# AnyModel Guide + +This guide explains how to add support for new models in the compress pipeline. + +## Convert model + +Convert a HuggingFace model to Puzzletron format. + +Step 1: Create Model Descriptor + +Extend `ModelDescriptor` and implement `layer_name_predicates()` to define regex patterns for grouping weights into subblocks (embeddings, lm_head, block_N_ffn, block_N_attention). + +Key points: + +- Find weight names on the model's HuggingFace page → click "Files info" to see the safetensors structure with all tensor names (example: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct?show_file_info=model.safetensors.index.json)) + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) + +Step 2: Create Converter + +Extend `Converter` and implement `create_block_configs_from_main_config()` to create per-layer BlockConfigs from the HuggingFace config. + +Key points: + +- Import correct HuggingFace config class (e.g., `MistralConfig`, `LlamaConfig`, `Qwen2Config`). Find it in the transformers source: `github.com/huggingface/transformers/tree/main/src/transformers/models//configuration_.py` + +See example: [llama_converter.py](models/llama/llama_converter.py) + +Step 3: Create `models//__init__.py` + +Export descriptor and converter classes: + +```python +from models.._model_descriptor import MyModelDescriptor +from models.._converter import MyConverter +``` + +Step 4: Register in `models/__init__.py` + +Add import to trigger factory registration: + +```python +from models. import * +``` + +## Usage + +```python +from scripts.convert_any_model import convert_model + +convert_model( + input_dir="path/to/hf_checkpoint", + output_dir="path/to/puzzletron_checkpoint", + converter="model_name", +) +``` + +## Compress model + +Run pruning and compression on a Puzzletron model. + +Step 1: Implement ModelDescriptor methods for compression + +Add to your `ModelDescriptor`: + +- `decoder_layer_cls()` - return the decoder layer class(es) to patch for heterogeneous config support +- `block_config_to_layer_overrides()` - map BlockConfig to layer override dict (see [details](#implementing-block_config_to_layer_overrides)) +- `init_rotary_embedding()` - reinitialize rotary embeddings after model loading (see [details](#implementing-init_rotary_embedding)) +- `input_embedding_name()` - return the name of the input embedding layer (see [details](#implementing-path-based-methods)) +- `output_embedding_name()` - return the name of the output embedding layer (see [details](#implementing-path-based-methods)) +- `layer_block_name()` - return the name pattern for decoder layers (see [details](#implementing-path-based-methods)) +- `final_norm_name()` - return the name of the final normalization layer (see [details](#implementing-path-based-methods)) +- `attn_no_op_post_init()` - replace attention sublayers with no-op modules +- `mlp_no_op_post_init()` - replace MLP sublayers with no-op modules + +Step 2: Create FFN Layer Descriptor + +Extend `FFNIntermediateLayerDescriptor` to define model-specific paths for FFN pruning hooks (`down_proj_name`, `ffn_prefix_name`, `linear_weight_names`). Derive values from your model's weight names in `layer_name_predicates()`. + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) → `LlamaFFNIntermediateLayerDescriptor` + +Step 3: Configure YAML files + +Update the main model config YAML: + +- Set `descriptor` to match the name used in `@ModelDescriptorFactory.register_decorator("your_model_name")` +- See example: [llama_3_1_8b_instruct.yaml](../../../../tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml) + +Update pruning YAML files (`ffn_pruning.yaml`, `expert_pruning.yaml`, etc.): + +- Set `pruning_mixin._target_` to the appropriate mixin class +- Set `layer_descriptor._target_` to your layer descriptor class +- Set `hook_class` to the activation hook for scoring +- Set `target_layer` in `activation_hooks_kwargs` to the layer name for hook attachment +- See examples in [configs/llama_3_1_8b_instruct/pruning/](../../../../tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/) + +## End-to-end example + +See [test_compress_model.py](../../../../tests/gpu/torch/puzzletron/test_compress.py) for a complete example that runs both convert and compression steps. + +--- + +## Advanced Topics + +## Pruning Configuration + +### Pruning YAML Structure + +Each pruning type has a YAML config with these key fields: + +```yaml +pruning_mixin: + _target_: pruning._pruning_mixin. + layer_descriptor: + _target_: models.. + +hook_class: ${get_object:utils.activation_hooks.hooks.} +activation_hooks_kwargs: + method: + target_layer: "" # e.g., "mlp.down_proj", "self_attn.o_proj" +``` + +| Field | Description | +|-------|-------------| +| `pruning_mixin._target_` | Mixin class that orchestrates this pruning type | +| `layer_descriptor._target_` | Model-specific class defining layer paths for hooks | +| `hook_class` | Activation hook class for importance scoring | +| `target_layer` | Layer name (relative to decoder block) where hooks attach | + +### Adding a New Hook Class + +1. **Implement the hook** in `modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py`: + - Extend an existing hook base class (e.g., `RemoveExpertsIndependentHook`) + - Implement required methods (e.g., `get_router_logits_and_routed_experts`) + +2. **Register the hook** in the appropriate pruning mixin's `supported_hooks()`: + + For FFN pruning (`pruning/ffn_intermediate_pruning_mixin.py`): + + ```python + def supported_hooks(self) -> List[Type[ActivationsHook]]: + return [IndependentChannelContributionHook, IterativeChannelContributionHook, YourNewHook] + ``` + + For expert removal (`pruning/expert_removal_pruning_mixin.py`): + + ```python + def supported_hooks(self) -> List[Type[ActivationsHook]]: + return [RankedChoiceVotingHook, ..., YourNewHook] + ``` + +3. **Reference in YAML**: + + ```yaml + hook_class: ${get_object:utils.activation_hooks.hooks.YourNewHook} + ``` + +### Pruning Types Reference + +| Type | Mixin | Example Hooks | +|------|-------|---------------| +| FFN intermediate | [`FFNIntermediatePruningMixIn`](../pruning/ffn_intermediate_pruning_mixin.py) | [`IterativeChannelContributionHook`](../../../nas/plugins/megatron_hooks/base_hooks.py), [`IndependentChannelContributionHook`](../../../nas/plugins/megatron_hooks/base_hooks.py) | +| Expert removal | [`ExpertRemovalPruningMixIn`](../pruning/expert_removal_pruning_mixin.py) | [`NemotronHRemoveExpertsIndependentHook`](../../../nas/plugins/megatron_hooks/base_hooks.py), [`Qwen3VLRemoveExpertsIndependentHook`](../../../nas/plugins/megatron_hooks/base_hooks.py) | +| KV heads | [`KVHeadsPruningMixIn`](../pruning/kv_heads_pruning_mixin.py) | [`IndependentKvHeadContributionHook`](../../../nas/plugins/megatron_hooks/base_hooks.py) | + +## Implementing `block_config_to_layer_overrides` + +Maps Puzzletron's [`BlockConfig`](../decilm/deci_lm_hf_code/block_config.py) fields to HuggingFace config attribute names. Only override attributes that change during pruning: + +| BlockConfig Field | HuggingFace Attribute (check `config.json`) | +|-------------------|---------------------------------------------| +| `attention.num_key_value_heads` | `num_key_value_heads` | +| `ffn.intermediate_size` | `intermediate_size` | +| `ffn.moe.num_local_experts` | `num_experts` or `n_routed_experts` (model-specific) | +| `ffn.moe.expert_intermediate_dim` | `moe_intermediate_size` | + +**Tip**: Check the model's `config.json` for exact attribute names - they vary between models. + +See examples: [qwen3_vl](models/qwen3_vl/qwen3_vl_model_descriptor.py), [nemotron_h](models/nemotron_h/nemotron_h_model_descriptor.py) + +--- + +## Implementing path-based methods + +These methods return paths derived from the model's weight names: + +- `input_embedding_name()`, `output_embedding_name()`, `layer_block_name()`, `final_norm_name()` + +Find them on the model's HuggingFace page → "Files info" → safetensors structure (example: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct?show_file_info=model.safetensors.index.json)). + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) + +--- + +## Implementing `init_rotary_embedding` + +Rotary embeddings are computed modules (not saved weights). After model sharding, they need re-initialization on the correct device/dtype. + +Look in `github.com/huggingface/transformers/tree/main/src/transformers/models//modeling_.py` for: + +- `class.*Rotary` — the rotary embedding class name and constructor arguments +- `self.rotary_emb` — the attribute path + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) diff --git a/modelopt/torch/puzzletron/anymodel/__init__.py b/modelopt/torch/puzzletron/anymodel/__init__.py new file mode 100644 index 000000000..e1755a16d --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/__init__.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""AnyModel: Architecture-agnostic model compression for HuggingFace models. + +This module provides a declarative approach to model compression that works with +any HuggingFace model without requiring custom modeling code. Instead of duplicating +HuggingFace modeling classes, AnyModel uses ModelDescriptors that define: + +1. Which decoder layer class(es) to patch for heterogeneous configs +2. How to map BlockConfig to layer-specific overrides +3. Weight name patterns for subblock checkpointing + +Example usage: + >>> from modelopt.torch.puzzletron.anymodel import convert_model + >>> convert_model( + ... input_dir="path/to/hf_checkpoint", + ... output_dir="path/to/anymodel_checkpoint", + ... converter="llama", + ... ) + +Supported models: + - llama: Llama 2, Llama 3, Llama 3.1, Llama 3.2 + - (more to come: qwen2, mistral_small, etc.) +""" + +# Import models to trigger factory registration +from modelopt.torch.puzzletron.anymodel import models # noqa: F401 +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory, convert_model +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer import ( + MatchingZeros, + Same, + deci_x_patcher, + return_tuple_of_size, +) + +__all__ = [ + "Converter", + "ConverterFactory", + "ModelDescriptor", + "ModelDescriptorFactory", + "deci_x_patcher", + "MatchingZeros", + "Same", + "return_tuple_of_size", + "convert_model", +] diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py index 23a42da58..45fe83f47 100644 --- a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py @@ -36,8 +36,21 @@ } -def resolve_descriptor_from_pretrained(pretrained: str | None, trust_remote_code: bool = True): - """Resolve the model descriptor by loading the checkpoint config and mapping model_type.""" +def resolve_descriptor_from_pretrained(pretrained: str | None, trust_remote_code: bool = False): + """Resolve the model descriptor by loading the checkpoint config and mapping model_type. + + Args: + pretrained: Path to a pretrained model checkpoint or HuggingFace model identifier. + trust_remote_code: If True, allows execution of custom code from the model repository. + This is a security risk if the model source is untrusted. Only set to True if you + trust the source of the model. Defaults to False for security. + + Returns: + The resolved ModelDescriptor class for the detected model type. + + Raises: + ValueError: If pretrained is not provided or if the model type cannot be auto-detected. + """ if not pretrained: raise ValueError("pretrained must be provided") diff --git a/modelopt/torch/puzzletron/anymodel/models/__init__.py b/modelopt/torch/puzzletron/anymodel/models/__init__.py new file mode 100644 index 000000000..9928854b5 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/__init__.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Import models to trigger factory registration +from modelopt.torch.puzzletron.anymodel.models.gpt_oss_20b import * +from modelopt.torch.puzzletron.anymodel.models.llama import * +from modelopt.torch.puzzletron.anymodel.models.mistral_small import * +from modelopt.torch.puzzletron.anymodel.models.nemotron_h import * +from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2 import * +from modelopt.torch.puzzletron.anymodel.models.qwen2 import * +from modelopt.torch.puzzletron.anymodel.models.qwen3_8b import * +from modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct import * diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py index d5eebfa35..a7212516a 100644 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py @@ -19,7 +19,7 @@ import warnings from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Optional, Type, Union, get_args, get_origin +from typing import Any, List, Optional, Type, Union, get_args, get_origin @dataclass(frozen=True, kw_only=True) @@ -178,106 +178,51 @@ class Llama4AttentionConfig(BaseDataclass): @dataclass(frozen=True, kw_only=True) class AttentionConfig(SubblockConfig): - n_heads_in_group: Optional[int] = None - window_length: Optional[int] = None - num_sink_tokens: Optional[int] = None - use_prefill_window_in_sink_attention: bool = False - unshifted_sink: bool = False - mamba: Optional[MambaConfig] = None + num_key_value_heads: Optional[int] = None llama4: Optional[Llama4AttentionConfig] = None + mamba: Optional[MambaConfig] = None def __post_init__(self): super().__post_init__() if self.no_op: - assert not self.replace_with_linear assert not self.is_mamba assert not self.is_llama4 - if self.no_op or self.replace_with_linear or self.is_mamba: + if self.no_op or self.is_mamba: for irrelevant_att in [ - "n_heads_in_group", - "window_length", - "num_sink_tokens", - "use_prefill_window_in_sink_attention", - "unshifted_sink", - "attention_chunk_size", - "attn_scale", - "floor_scale", - "attn_temperature_tuning", - "attention_dropout", - "use_qk_norm", + "num_key_value_heads", ]: self._force_setattr(irrelevant_att, None) else: - assert self.n_heads_in_group is not None - - if self.is_sink: - assert not (self.unshifted_sink and self.use_prefill_window_in_sink_attention), ( - "Unshifted sink uses its own kind of explicit masking, not standard window. " - "Set use_prefill_window_in_sink_attention to False." - ) - assert not (self.num_sink_tokens == 0 and not self.unshifted_sink), ( - "Fake sink attention with 0 sink tokens is only supported with unshifted_sink=True" - ) - - if self.is_llama4: - assert not self.is_sink, "Sink not support with Llama4 currently" - assert not self.is_sliding, "Sliding window not support with Llama4 currently" - assert not self.unshifted_sink, "Unshifted sink not support with Llama4 currently" + assert self.num_key_value_heads is not None def to_blockconfig(self) -> "BlockConfig": return BlockConfig(attention=self, ffn=FFNConfig(no_op=True)) @property - def prefill_sliding_window(self) -> Optional[int]: - if self.window_length is not None: - if not self.is_sink or self.use_prefill_window_in_sink_attention: - return self.window_length - return None - - @property - def is_sliding(self) -> bool: - return self.prefill_sliding_window is not None - - @property - def is_sink(self) -> bool: - return (self.window_length is not None) and (self.num_sink_tokens is not None) + def is_llama4(self) -> bool: + return self.llama4 is not None @property def is_mamba(self) -> bool: return self.mamba is not None - @property - def is_llama4(self) -> bool: - return self.llama4 is not None - @dataclass(frozen=True, kw_only=True) class FFNConfig(SubblockConfig): - gated: Optional[bool] = ( - True # Gated Linear Unit e.g. SwiGLU or vanilla MLP (up -> activation -> down) - ) - hidden_act: Optional[str] = "silu" moe: Optional[MoEConfig] = None intermediate_size: Optional[int] = None def __post_init__(self): super().__post_init__() - if self.no_op or self.replace_with_linear: - self._force_setattr("gated", None) - self._force_setattr("hidden_act", None) + if self.no_op: self._force_setattr("moe", None) self._force_setattr("intermediate_size", None) elif self.is_moe: - self._force_setattr("gated", None) - self._force_setattr("hidden_act", None) self._force_setattr("intermediate_size", None) else: - assert self.intermediate_size is not None, ( - "Intermediate size must be provided for an FFN block" - ) - assert self.intermediate_size % 256 == 0, "Intermediate size must be divisible by 256" + assert self.intermediate_size is not None, "Intermediate size must be provided for an FFN block" def to_blockconfig(self) -> "BlockConfig": return BlockConfig(attention=AttentionConfig(no_op=True), ffn=self) @@ -306,3 +251,25 @@ def __post_init__(self): BlockConfig(**block_config) for block_config in self.parallel_blocks ] self._force_setattr("parallel_blocks", initialized_block_configs) + + def to_dict(self) -> dict: + """Convert BlockConfig to a dictionary.""" + return dataclasses.asdict(self) + + +def maybe_cast_block_configs( + block_configs: List[BlockConfig | dict] | None, +) -> List[BlockConfig] | None: + """Cast a list of dicts to BlockConfig objects if needed. + + Args: + block_configs: List of BlockConfig or dict objects, or None. + + Returns: + List of BlockConfig objects, or None if input is None/empty. + """ + if not block_configs: + return block_configs + if isinstance(block_configs[0], dict): + return [BlockConfig(**conf) for conf in block_configs] + return block_configs diff --git a/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py new file mode 100644 index 000000000..96d3489f5 --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from transformers import PretrainedConfig + +from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( + ForwardHook, + GptOssRemoveExpertsIndependentHook, + NemotronHRemoveExpertsIndependentHook, + Qwen3VLRemoveExpertsIndependentHook, + RankedChoiceVotingHook, + RankedChoiceVotingHookNemotronH, +) +from modelopt.torch.puzzletron.pruning.pruning_mixin import LayerDescriptor, PruningMixIn +from modelopt.torch.puzzletron.pruning.pruning_utils import MlpInitMode, _init_moe_module + + +@dataclass +class ExpertRemovalLayerDescriptor(LayerDescriptor): + """ + TODO - Add Shared expert weights in case it's prunable. + TODO - consider removing the segmentation between weight and bias, doesn't seem to affect the pruning algo. + Attributes: + target_name: module name required to register hooks for scoring_activations, can be a regex if start with the prefix `regex:` + moe_prefix_name: moe prefix layer name, should include a placeholder for `layer_idx` to be repeated for all layers. i.e: `model.layers.{layer_idx}.moe` + expert_prefix_name: expert prefix layer name relative to moe_prefix, should include a placeholder for `expert_idx` to be repeated for all experts. i.e: `experts.{expert_idx}` + router_weights: List of the router weight names relative to moe_prefix. + router_biases: List of the router bias names relative to moe_prefix. + expert_weights: List of the expert weight names relative to expert_prefix (for per-expert format). + expert_biases: List of the expert bias names relative to expert_prefix (for per-expert format). + is_fused_experts: If True, experts are stored as single fused tensors with shape [num_experts, ...]. + If False (default), experts are stored as separate tensors per expert. + fused_expert_weights: List of fused expert weight names relative to moe_prefix (for fused format). + e.g., ["experts.gate_up_proj", "experts.down_proj"] + """ + + target_name: str + moe_prefix_name: str + expert_prefix_name: str = "" + router_weights: List[str] = field(default_factory=list) + router_biases: List[str] = field(default_factory=list) + expert_weights: List[str] = field(default_factory=list) + expert_biases: List[str] = field(default_factory=list) + is_fused_experts: bool = False + fused_expert_weights: List[str] = field(default_factory=list) + + def module_name_regex(self) -> str: + return self.target_name + + def moe_prefix(self, layer_idx: int) -> str: + return self.moe_prefix_name.format(layer_idx=layer_idx) + + def expert_prefix(self, layer_idx: int, expert_idx: int) -> str: + _expert_prefix = self.moe_prefix_name + "." + self.expert_prefix_name + return _expert_prefix.format(layer_idx=layer_idx, expert_idx=expert_idx) + + +class ExpertRemovalPruningMixIn(PruningMixIn): + def __init__(self, layer_descriptor: ExpertRemovalLayerDescriptor): + assert isinstance(layer_descriptor, ExpertRemovalLayerDescriptor) + super().__init__(layer_descriptor) + + def supported_hooks(self) -> List[Type[ForwardHook]]: + return [ + RankedChoiceVotingHook, + RankedChoiceVotingHookNemotronH, + NemotronHRemoveExpertsIndependentHook, + Qwen3VLRemoveExpertsIndependentHook, + GptOssRemoveExpertsIndependentHook, + ] + + def prune_single_layer( + self, + layer_idx: int, + parent_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + mlp_init_mode: MlpInitMode, + mlp_init_config: Optional[dict[str, Any]], + keys: dict, + **kwargs, + ) -> Dict[str, torch.Tensor]: + layer_out_state_dict = {} + + child_block_config = new_config.block_configs[layer_idx] + parent_block_config = original_config.block_configs[layer_idx] + + if not parent_block_config.ffn.is_moe: + return layer_out_state_dict + + new_num_experts = child_block_config.ffn.moe.num_local_experts + orig_num_experts = parent_block_config.ffn.moe.num_local_experts + + child_router_keys, new_experts_keys = self._generate_moe_keys(layer_idx, new_num_experts) + parent_router_keys, orig_experts_keys = self._generate_moe_keys(layer_idx, orig_num_experts) + + # Pop parent's router keys from copy list; child-only router keys will be initialized below + for rk in sum(parent_router_keys.values(), []): + if rk in keys: + keys.pop(rk) + for key in sum(orig_experts_keys.values(), []): + if key in keys: + keys.pop(key) + + if self.layer_descriptor.is_fused_experts: + # Fused format: unbundle single tensor [num_experts, ...] into list of per-expert tensors + orig_experts_weights = {} + for name, fused_keys in orig_experts_keys.items(): + fused_tensor = parent_state_dict[fused_keys[0]] # Single fused tensor + orig_experts_weights[name] = [fused_tensor[i] for i in range(orig_num_experts)] + + new_experts_weights = {} + for name, fused_keys in new_experts_keys.items(): + fused_tensor = new_state_dict[fused_keys[0]] # Single fused tensor + new_experts_weights[name] = [fused_tensor[i] for i in range(new_num_experts)] + else: + # Per-expert format: load each expert tensor separately + orig_experts_weights = { + name: [parent_state_dict[key] for key in orig_experts_module_keys] + for name, orig_experts_module_keys in orig_experts_keys.items() + } + new_experts_weights = { + name: [new_state_dict[key] for key in new_experts_module_keys] + for name, new_experts_module_keys in new_experts_keys.items() + } + + orig_router_weights = { + name: [parent_state_dict[key] for key in _module_router_keys] + for name, _module_router_keys in parent_router_keys.items() + } + new_router_weights = { + name: [new_state_dict[key] for key in _module_router_keys] + for name, _module_router_keys in child_router_keys.items() + } + + out_router_weights, out_experts_weights = _init_moe_module( + layer_idx=layer_idx, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + orig_router_weights=orig_router_weights, + orig_experts_weights=orig_experts_weights, + new_router_weights=new_router_weights, + new_experts_weights=new_experts_weights, + orig_num_experts=orig_num_experts, + new_num_experts=new_num_experts, + ) + assert new_experts_keys.keys() == out_experts_weights.keys(), ( + "new_experts_keys and out_experts_weights must have the same keys" + ) + assert child_router_keys.keys() == out_router_weights.keys(), ( + "child_router_keys and out_router_weights must have the same keys" + ) + + for name in child_router_keys.keys(): + layer_out_state_dict.update(zip(child_router_keys[name], out_router_weights[name])) + + if self.layer_descriptor.is_fused_experts: + # Fused format: rebundle list of per-expert tensors into single fused tensor + for name in new_experts_keys.keys(): + fused_key = new_experts_keys[name][0] # Single key for fused tensor + fused_tensor = torch.stack(out_experts_weights[name], dim=0) # [num_experts, ...] + layer_out_state_dict[fused_key] = fused_tensor + else: + # Per-expert format: each expert has its own key + for name in new_experts_keys.keys(): + layer_out_state_dict.update(zip(new_experts_keys[name], out_experts_weights[name])) + + return layer_out_state_dict + + def _generate_moe_keys( + self, layer_idx: int, num_experts: int + ) -> Tuple[Dict[str, List[str]], dict[str, list[str]]]: + """ + Generate MoE weight keys for router and experts. + TODO simplify or better define the data structure of the moe keys returned. + + :return: tuple of router_keys and expert_keys, all are absolute names relative to the model root: + * router_keys structure: + {"weight: [], bias: []"} + * expert_keys structure (per-expert format): + {": []} + i.e: + { + "down_proj.weight": ["model...experts.0.down_proj.weight", ..., "model...experts.N.down_proj.weight"], + ... + } + * expert_keys structure (fused format): + {": []} + i.e: + { + "experts.gate_up_proj": ["model...experts.gate_up_proj"], + "experts.down_proj": ["model...experts.down_proj"], + } + """ + self.layer_descriptor: ExpertRemovalLayerDescriptor + moe_prefix = self.layer_descriptor.moe_prefix(layer_idx) + + router_keys = { + "weight": [ + f"{moe_prefix}.{_weight}" for _weight in self.layer_descriptor.router_weights + ], + "bias": [f"{moe_prefix}.{_bias}" for _bias in self.layer_descriptor.router_biases], + } + + if self.layer_descriptor.is_fused_experts: + # Fused format: single tensor per weight type with shape [num_experts, ...] + experts_module_names = {} + for fused_weight in self.layer_descriptor.fused_expert_weights: + experts_module_names[fused_weight] = [f"{moe_prefix}.{fused_weight}"] + else: + # Per-expert format: separate tensor for each expert + expert_key_names = ( + self.layer_descriptor.expert_weights + self.layer_descriptor.expert_biases + ) + experts_module_names = {} + for key_name in expert_key_names: + experts_module_names[key_name] = [ + f"{self.layer_descriptor.expert_prefix(layer_idx, expert_idx)}.{key_name}" + for expert_idx in range(num_experts) + ] + + return router_keys, experts_module_names diff --git a/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py new file mode 100644 index 000000000..b3d9b8884 --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Type + +import torch +from transformers import PretrainedConfig + +from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( + ForwardHook, + IndependentChannelContributionHook, + IterativeChannelContributionHook, +) +from modelopt.torch.puzzletron.pruning.pruning_mixin import LayerDescriptor, PruningMixIn +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( + MlpInitMode, + _init_mlp_module, +) + + +@dataclass +class FFNIntermediateLayerDescriptor(LayerDescriptor): + down_proj_name: str + ffn_prefix_name: str + linear_weight_names: List[str] = field(default_factory=list) + + def module_name_regex(self) -> str: + return self.down_proj_name + + def ffn_prefix(self, layer_idx: int) -> str: + return self.ffn_prefix_name.format(layer_idx=layer_idx) + + +class FFNIntermediatePruningMixIn(PruningMixIn): + def __init__(self, layer_descriptor: FFNIntermediateLayerDescriptor): + assert isinstance(layer_descriptor, FFNIntermediateLayerDescriptor) + super().__init__(layer_descriptor) + + def supported_hooks(self) -> List[Type[ForwardHook]]: + return [IndependentChannelContributionHook, IterativeChannelContributionHook] + + def prune_single_layer( + self, + layer_idx: int, + parent_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + mlp_init_mode: MlpInitMode, + mlp_init_config: Optional[dict[str, Any]], + keys: dict, + keys_to_remove: dict, + **kwargs, + ) -> Dict[str, torch.Tensor]: + layer_out_state_dict = {} + # Hardcoded strings + mlp_prefix = self.layer_descriptor.ffn_prefix(layer_idx) + mlp_key_names = [ + f"{mlp_prefix}.{name}.weight" for name in self.layer_descriptor.linear_weight_names + ] + mlp_keys = [keys.get(module_name) for module_name in mlp_key_names] + mlp_keys = [k for k in mlp_keys if k is not None] + + for key in mlp_keys: + keys_to_remove[f"{mlp_prefix}.{key.split('.')[-2]}.weight"] = key + + pruned_filters = None + projection_matrix = None + + for mlp_key in mlp_keys: + expanded_dim = 1 if self.layer_descriptor.down_proj_name in mlp_key else 0 + if mlp_key in new_state_dict.keys(): + mlp_module_weight, pruned_filters, projection_matrix = _init_mlp_module( + mlp_init_mode, + mlp_prefix, + expanded_dim, + layer_idx, + new_state_dict[mlp_key], + new_config, + parent_state_dict[mlp_key], + original_config, + mlp_init_config, + pruned_filters, + projection_matrix, + ) + layer_out_state_dict[mlp_key] = mlp_module_weight + + return layer_out_state_dict diff --git a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py new file mode 100644 index 000000000..f93e4b77a --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +from dataclasses import dataclass, field +from typing import Any, List, Optional, Type + +from transformers import PretrainedConfig + +from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( + ForwardHook, + IndependentKvHeadContributionHook, +) +from modelopt.torch.puzzletron.pruning.pruning_mixin import LayerDescriptor, PruningMixIn +from modelopt.torch.puzzletron.pruning.pruning_utils import ( + GQAInitMode, + _init_attention_biases, + _init_attention_weights, +) + + +@dataclass +class KVHeadsLayerDescriptor(LayerDescriptor): + o_proj_name: str + attn_prefix_name: str + qkvo_weight_names: List[str] = field(default_factory=list) + + def module_name_regex(self) -> str: + return self.o_proj_name + + def attn_prefix(self, layer_idx: int) -> str: + return self.attn_prefix_name.format(layer_idx=layer_idx) + + +class KVHeadsPruningMixIn(PruningMixIn): + def __init__(self, layer_descriptor: KVHeadsLayerDescriptor): + assert isinstance(layer_descriptor, KVHeadsLayerDescriptor) + super().__init__(layer_descriptor) + + def supported_hooks(self) -> List[Type[ForwardHook]]: + return [IndependentKvHeadContributionHook] + + def prune_single_layer( + self, + layer_idx: int, + parent_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + gqa_init_mode: GQAInitMode, + mlp_init_config: Optional[dict[str, Any]], + is_original_mha: bool, + keys: dict, + keys_to_remove: dict, + **kwargs, + ): + layer_out_state_dict = {} + + attn_prefix = self.layer_descriptor.attn_prefix(layer_idx) + q_name, k_name, v_name, o_name = [ + f"{attn_prefix}.{proj_name}" for proj_name in self.layer_descriptor.qkvo_weight_names + ] + + head_size = new_config.head_dim + for part in ["weight", "bias"]: + attn_keys = [f"{name}.{part}" for name in [q_name, k_name, v_name, o_name]] + q_key, k_key, v_key, o_key = attn_keys + + # Drop attn keys that don't exist and required to be in the new state_dict + attn_keys = [key for key in attn_keys if key in new_state_dict.keys()] + if len(attn_keys) > 0 and all(key in keys for key in attn_keys): + for key in attn_keys: + keys_to_remove[key] = keys[key] + is_student_and_teacher_have_same_attention_implementation = all( + key in new_state_dict.keys() for key in attn_keys + ) + if is_student_and_teacher_have_same_attention_implementation: + if part == "weight": + wq, wk, wv, wo = _init_attention_weights( + gqa_init_mode=gqa_init_mode, + layer_idx=layer_idx, + new_state_dict=new_state_dict, + new_config=new_config, + original_state_dict=parent_state_dict, + original_config=original_config, + q_key=q_key, + k_key=k_key, + v_key=v_key, + o_key=o_key, + is_original_mha=is_original_mha, + head_size=head_size, + mlp_init_config=mlp_init_config, + ) + layer_out_state_dict[q_key], layer_out_state_dict[k_key] = wq, wk + layer_out_state_dict[v_key], layer_out_state_dict[o_key] = wv, wo + else: + bias_sd = _init_attention_biases( + gqa_init_mode=gqa_init_mode, + layer_idx=layer_idx, + new_state_dict=new_state_dict, + new_config=new_config, + original_state_dict=parent_state_dict, + original_config=original_config, + q_key=q_key, + k_key=k_key, + v_key=v_key, + o_key=o_key, + is_original_mha=is_original_mha, + head_size=head_size, + mlp_init_config=mlp_init_config, + ) + for bias_key, sd_key in zip("qkvo", [q_key, k_key, v_key, o_key]): + if bias_key in bias_sd.keys(): + layer_out_state_dict[sd_key] = bias_sd[bias_key] + + return layer_out_state_dict diff --git a/modelopt/torch/puzzletron/pruning/pruning_ckpts.py b/modelopt/torch/puzzletron/pruning/pruning_ckpts.py index 5a0dfed01..823f42faf 100644 --- a/modelopt/torch/puzzletron/pruning/pruning_ckpts.py +++ b/modelopt/torch/puzzletron/pruning/pruning_ckpts.py @@ -23,14 +23,22 @@ import json import os import time +from typing import Optional from omegaconf import DictConfig -from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory +from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import ExpertRemovalPruningMixIn +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediatePruningMixIn, +) +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsPruningMixIn +from modelopt.torch.puzzletron.pruning.pruning_utils import ( GQAInitMode, HiddenSizeInitMode, LinearInitMode, MlpInitMode, + resolve_pruning_mixin, ) from modelopt.torch.puzzletron.tools.bypassed_training.init_child_from_parent import ( init_child_from_parent, @@ -40,7 +48,7 @@ def launch_ffn_intermediates_prune_ckpt( - cfg: DictConfig, max_save_workers: int | None = None, max_layer_workers: int | None = None + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None ): for intermediate_size in cfg.pruning.intermediate_size_list: dirname = f"ffn_{intermediate_size}_attn_no_op" @@ -54,14 +62,16 @@ def launch_ffn_intermediates_prune_ckpt( model_config_overrides_json = {"ffn": [{"intermediate_size": intermediate_size}]} mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml - output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) # Profile the overall init_child_from_parent call with optimizations mprint("Starting init_child_from_parent...") start_time = time.time() init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, parent_checkpoint_dir=cfg.teacher_dir, - model_config_overrides_json=model_config_overrides_json, + model_config_overrides_dict=model_config_overrides_json, output_checkpoint_dir=output_dir, gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), @@ -83,7 +93,7 @@ def launch_ffn_intermediates_prune_ckpt( def launch_attn_groups_prune_ckpt( - cfg: DictConfig, max_save_workers: int | None = None, max_layer_workers: int | None = None + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None ): for n_heads_in_group in cfg.pruning.n_heads_in_group_list: dirname = f"n_heads_in_group{n_heads_in_group}" @@ -98,14 +108,16 @@ def launch_attn_groups_prune_ckpt( model_config_overrides_json = {"attention": [{"n_heads_in_group": n_heads_in_group}]} mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml - output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) # Profile the overall init_child_from_parent call with optimizations mprint("Starting init_child_from_parent...") start_time = time.time() init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, parent_checkpoint_dir=cfg.teacher_dir, - model_config_overrides_json=model_config_overrides_json, + model_config_overrides_dict=model_config_overrides_json, output_checkpoint_dir=output_dir, gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), @@ -150,17 +162,17 @@ def launch_hidden_dim_prune_ckpt(cfg: DictConfig): else: intermediate_sizes.append(None) - mprint("Teacher config:") + mprint(f"Teacher config:") mprint(f" - hidden_size: {parent_hidden_size}") mprint(f" - intermediate_sizes: {intermediate_sizes}") os.makedirs(os.path.join(cfg.puzzle_dir, "ckpts"), exist_ok=True) for hidden_size in cfg.pruning.hidden_size_list: - mprint("\n######################################################################") + mprint(f"\n######################################################################") mprint(f"Hidden Size = {hidden_size}") - mprint("######################################################################\n") + mprint(f"######################################################################\n") - mprint("Child config:") + mprint(f"Child config:") mprint(f" - hidden_size: {hidden_size}") # Create model config overrides with proper FFN configuration @@ -178,14 +190,16 @@ def launch_hidden_dim_prune_ckpt(cfg: DictConfig): mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml dirname = f"hidden_size_{hidden_size}" - output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) mprint(f"Creating checkpoint with hidden_size={hidden_size}") mprint(f"Model config overrides: {model_config_overrides_json}") init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, parent_checkpoint_dir=cfg.pruning.model_name_or_path, - model_config_overrides_json=model_config_overrides_json, + model_config_overrides_dict=model_config_overrides_json, output_checkpoint_dir=output_dir, gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), @@ -204,9 +218,9 @@ def launch_hidden_dim_prune_ckpt(cfg: DictConfig): def launch_experts_prune_ckpt( cfg: DictConfig, - max_save_workers: int | None = None, - max_layer_workers: int | None = None, - symlink_suffix: str | None = None, + max_save_workers: Optional[int] = None, + max_layer_workers: Optional[int] = None, + symlink_suffix: Optional[str] = None, ): for num_experts in cfg.pruning.num_experts_to_keep_list: dirname = f"num_experts_{num_experts}" @@ -223,14 +237,16 @@ def launch_experts_prune_ckpt( mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml - output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) # Profile the overall init_child_from_parent call with optimizations mprint("Starting init_child_from_parent...") start_time = time.time() init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, parent_checkpoint_dir=cfg.teacher_dir, - model_config_overrides_json=model_config_overrides_json, + model_config_overrides_dict=model_config_overrides_json, output_checkpoint_dir=output_dir, gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), @@ -252,7 +268,7 @@ def launch_experts_prune_ckpt( def launch_moe_ffn_intermediates_prune_ckpt( - cfg: DictConfig, max_save_workers: int | None = None, max_layer_workers: int | None = None + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None ): for intermediate_size in cfg.pruning.intermediate_size_list: dirname = f"moe_ffn_{intermediate_size}_attn_no_op" @@ -269,14 +285,16 @@ def launch_moe_ffn_intermediates_prune_ckpt( } mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml - output_dir = os.path.join(cfg.pruning.pruned_ckpts_outpt_dir, dirname) + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) # Profile the overall init_child_from_parent call with optimizations mprint("Starting init_child_from_parent...") start_time = time.time() init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, parent_checkpoint_dir=cfg.teacher_dir, - model_config_overrides_json=model_config_overrides_json, + model_config_overrides_dict=model_config_overrides_json, output_checkpoint_dir=output_dir, gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), @@ -296,7 +314,11 @@ def launch_moe_ffn_intermediates_prune_ckpt( def launch_prune_ckpt(cfg: DictConfig): - target_layer = cfg.pruning.activation_hooks_kwargs.target_layer + cfg.descriptor = ModelDescriptorFactory.get(cfg.descriptor) + # Resolve pruning_mixin from config (could be string, enum, or PruningMixIn) + cfg.pruning.pruning_mixin = resolve_pruning_mixin(cfg.pruning.pruning_mixin, cfg.descriptor) + pruning_mixin = cfg.pruning.pruning_mixin + # I/O optimization settings - same as FFN pruning max_save_workers = None # Will auto-calculate as min(CPU count, num files) if "PRUNING_SAVE_WORKERS" in os.environ: @@ -307,29 +329,15 @@ def launch_prune_ckpt(cfg: DictConfig): if "PRUNING_LAYER_WORKERS" in os.environ: max_layer_workers = int(os.environ["PRUNING_LAYER_WORKERS"]) - # Log optimization settings (extracted from individual pruning methods) - mprint("Optimization Settings:") - mprint( - f" - I/O workers (max_workers): {'auto-calculate' if max_save_workers is None else max_save_workers}" - ) - mprint( - f" - Layer workers (max_layer_workers): {'auto-calculate' if max_layer_workers is None else max_layer_workers}" - ) - mprint(" (Override with env vars: PRUNING_IO_WORKERS, PRUNING_LAYER_WORKERS)") - - if target_layer == "mlp.down_proj": + if isinstance(pruning_mixin, FFNIntermediatePruningMixIn): launch_ffn_intermediates_prune_ckpt(cfg, max_save_workers, max_layer_workers) - elif target_layer == "self_attn.o_proj": + elif isinstance(pruning_mixin, KVHeadsPruningMixIn): launch_attn_groups_prune_ckpt(cfg, max_save_workers, max_layer_workers) - elif target_layer == "layernorm": - launch_hidden_dim_prune_ckpt(cfg) - elif target_layer == "router": - # Check if we should use symlink suffix for chained pruning - symlink_suffix = getattr(cfg.pruning, "symlink_suffix", None) - launch_experts_prune_ckpt(cfg, max_save_workers, max_layer_workers, symlink_suffix) - elif target_layer == r"regex:experts\.\d+\.down_proj$": - launch_moe_ffn_intermediates_prune_ckpt(cfg, max_save_workers, max_layer_workers) + elif isinstance(pruning_mixin, ExpertRemovalPruningMixIn): + launch_experts_prune_ckpt(cfg, max_save_workers, max_layer_workers) + # elif target_layer == "layernorm": + # launch_hidden_dim_prune_ckpt(cfg) else: raise NotImplementedError( - f"checkpoint pruning is not currently supported for target layer: {target_layer}" + f"checkpoint pruning is not currently supported for pruning mixin: {pruning_mixin.__class__.__name__}" ) diff --git a/modelopt/torch/puzzletron/pruning/pruning_mixin.py b/modelopt/torch/puzzletron/pruning/pruning_mixin.py new file mode 100644 index 000000000..bcb422c4e --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/pruning_mixin.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import re +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Type + +from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ForwardHook + + +class LayerDescriptor: + def module_name_regex(self) -> str: + return "" + + def block_idx_from_module_name(self, module_name: str) -> Optional[int]: + block_idx_match = re.search(r"\.(\d+)\.", module_name) + if block_idx_match: + return int(block_idx_match.group(1)) + return None + + def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]: + target_layer = self.module_name_regex() + if target_layer.startswith("regex:"): + target_layer_regex = target_layer[len("regex:") :] + pattern = re.compile(target_layer_regex) + match_predicate = lambda module_name: pattern.search(module_name) + else: + match_predicate = lambda module_name: module_name.endswith(target_layer) + + module_names_to_hook = [] + for module_name, module in model.named_modules(): + if match_predicate(module_name): + module_names_to_hook.append( + (self.block_idx_from_module_name(module_name), module_name) + ) + return module_names_to_hook + + +class PruningMixIn(ABC): + def __init__(self, layer_descriptor: LayerDescriptor): + self.layer_descriptor = layer_descriptor + + def get_module_names_to_hook(self, model) -> List[Tuple[int, str]]: + return self.layer_descriptor.get_modules_names_to_hook(model) + + @abstractmethod + def supported_hooks(self) -> List[Type[ForwardHook]]: + raise NotImplementedError + + # @abstractmethod + # def prune_single_layer( + # self, + # layer_idx: int, + # parent_state_dict: dict, + # new_state_dict: dict, + # original_config: PretrainedConfig, + # new_config: PretrainedConfig, + # **kwargs + # ): + # raise NotImplementedError diff --git a/modelopt/torch/puzzletron/pruning/pruning_utils.py b/modelopt/torch/puzzletron/pruning/pruning_utils.py new file mode 100644 index 000000000..cea716b63 --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/pruning_utils.py @@ -0,0 +1,647 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import json +import math +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +from transformers import PretrainedConfig + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn + + +class GQAInitMode(Enum): + RandomKV = "RandomKV" + AverageKV = "AverageKV" + FirstKV = "FirstKV" + RandomBlock = "RandomBlock" + CopyAsIs = "CopyAsIs" + Degrouping = "Degrouping" + PruneKVHeads = "PruneKVHeads" + + +class MlpInitMode(Enum): + Random = "Random" + Truncate = "Truncate" + CopyAsIs = "CopyAsIs" + PruneByActivationsLog = "PruneByActivationsLog" + ExpertRemoval = "ExpertRemoval" + ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN" + + +class LinearInitMode(Enum): + Random = "Random" + FromTeacher = "FromTeacher" + + +class HiddenSizeInitMode(Enum): + Random = "Random" + Truncate = "Truncate" + PruneByChannelRanking = "PruneByChannelRanking" + CopyAsIs = "CopyAsIs" + + +def resolve_pruning_mixin( + pruning_mixin, descriptor: Type[ModelDescriptor] +) -> PruningMixIn | List[PruningMixIn]: + """ + Convert pruning_mixin argument to PruningMixIn instance(s). + + Args: + pruning_mixin: Can be a string identifier, PruningMixIn instance, + or a list of any of those types. + descriptor: ModelDescriptor class that provides the pruning_mixins() mapping. + + Returns: + PruningMixIn or List[PruningMixIn] depending on input type. + """ + # Handle list of values recursively + if isinstance(pruning_mixin, list): + return [resolve_pruning_mixin(item, descriptor) for item in pruning_mixin] + + # Handle single value + # If it's already a PruningMixIn, return as is + if isinstance(pruning_mixin, PruningMixIn): + return pruning_mixin + + # Get the pruning mixins mapping from the descriptor + mixins_dict = descriptor.pruning_mixins() + + if isinstance(pruning_mixin, str): + if pruning_mixin not in mixins_dict: + available_methods = list(mixins_dict.keys()) + raise ValueError( + f"Pruning method '{pruning_mixin}' is not supported by {descriptor.__name__}. " + f"Available methods: {available_methods}" + ) + return mixins_dict[pruning_mixin] + + raise ValueError(f"Unsupported pruning_mixin type: {type(pruning_mixin)}") + + +def _init_mlp_module( + mlp_init_mode: Union[MlpInitMode, str], + mlp_prefix: str, + expanded_dim: int, + layer_idx: int, + new_item: torch.Tensor, + new_config: PretrainedConfig, + orig_item: torch.Tensor, + original_config: PretrainedConfig, + mlp_init_config: Optional[dict[str, Any]], + pruned_filters: Optional[torch.Tensor] = None, + projection_matrix: Optional[dict[str, torch.Tensor]] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[dict[str, torch.Tensor]]]: + if isinstance(mlp_init_mode, str): + mlp_init_mode = MlpInitMode(mlp_init_mode) + assert orig_item.ndim == 2, f"{orig_item.ndim=}" + assert new_item.ndim == 2, f"{new_item.ndim=}" + + assert new_config.num_hidden_layers == original_config.num_hidden_layers, ( + f"({new_config.num_hidden_layers=}) != ({original_config.num_hidden_layers=})" + ) + + new_intermediate_size = new_config.block_configs[layer_idx].ffn.intermediate_size + original_intermediate_size = original_config.block_configs[layer_idx].ffn.intermediate_size + + if mlp_init_mode == MlpInitMode.CopyAsIs: + assert new_intermediate_size == original_intermediate_size, ( + f"({new_intermediate_size=}) != ({original_intermediate_size=}), can't be copied as is." + ) + mlp_module_weight = orig_item + + elif mlp_init_mode == MlpInitMode.Random: + mlp_module_weight = new_item + + elif new_intermediate_size == original_intermediate_size: + mlp_module_weight = orig_item + + elif mlp_init_mode in ( + MlpInitMode.Truncate, + MlpInitMode.PruneByActivationsLog, + ): + assert original_intermediate_size >= new_intermediate_size, ( + f"({original_intermediate_size=}) < ({new_intermediate_size=}), can't be truncated." + ) + orig_ffn_size = orig_item.shape[expanded_dim] + new_ffn_size = new_item.shape[expanded_dim] + + if mlp_init_mode == MlpInitMode.Truncate: + truncated_weight = torch.narrow( + orig_item, dim=expanded_dim, start=0, length=new_ffn_size + ) + mlp_module_weight = truncated_weight + + elif mlp_init_mode == MlpInitMode.PruneByActivationsLog: + if pruned_filters is None: + filter_importance = _load_activations_log( + mlp_init_config, module_name=f"{mlp_prefix}.down_proj" + ) + filters_sorted_by_importance = torch.argsort(filter_importance, descending=True) + pruned_filters = filters_sorted_by_importance[:new_ffn_size].to(orig_item.device) + + pruned_weight = torch.index_select(orig_item, dim=expanded_dim, index=pruned_filters) + if mlp_init_config.get("scale_pruned_weights", False) and expanded_dim == 1: + pruned_weight = pruned_weight * (orig_ffn_size / new_ffn_size) + mlp_module_weight = pruned_weight + + elif ( + mlp_init_mode == MlpInitMode.ExpertRemoval + ): # the case of mlp layers of maverick. for now we only support copy as is + assert new_intermediate_size == original_intermediate_size, ( + f"({new_intermediate_size=}) != ({original_intermediate_size=}), can't be copied as is." + ) + mlp_module_weight = orig_item + + else: + raise ValueError(f"Unsupported {mlp_init_mode=}") + + return mlp_module_weight, pruned_filters, projection_matrix + + +def _load_activations_log(mlp_init_config: dict[str, Any], module_name: str) -> torch.Tensor: + _cache_activations_log(mlp_init_config) + module_log = ACTIVATIONS_LOG[module_name] + filter_importance = module_log["score"] + return filter_importance + + +ACTIVATIONS_LOG = dict() + + +def _cache_activations_log(mlp_init_config: dict[str, Any]) -> None: + if len(ACTIVATIONS_LOG) == 0: + assert "activations_log_dir" in mlp_init_config + activations_log_dir = mlp_init_config["activations_log_dir"] + print(f"Loading activations_log from {activations_log_dir}") + # Only load rank_*.pth files to avoid loading hook_states_*.pth checkpoint files + ACTIVATIONS_LOG.update( + { + module_name: module_log + for p in Path(activations_log_dir).glob("rank_*.pth") + for module_name, module_log in torch.load(p).items() + } + ) + + +def _init_attention_weights( + gqa_init_mode, + layer_idx, + new_state_dict, + new_config, + original_state_dict, + q_key, + k_key, + v_key, + o_key, + original_config, + is_original_mha, + head_size, + mlp_init_config, +): + assert new_config.num_attention_heads == original_config.num_attention_heads, ( + f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + ) + num_q_heads = new_config.num_attention_heads + num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads + orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads + + # new_w* are typically randomly initialized + new_wq = new_state_dict[q_key] + new_wk = new_state_dict[k_key] + new_wv = new_state_dict[v_key] + new_wo = new_state_dict[o_key] + + # w* are from the parent model + wq = original_state_dict[q_key] + wk = original_state_dict[k_key] + wv = original_state_dict[v_key] + wo = original_state_dict[o_key] + + if "bias" in k_key: + for tensor in [wq, wk, wv, wo, new_wq, new_wk, new_wv, new_wo]: + assert tensor.ndim == 1 + tensor.unsqueeze_(1) + dim1 = wk.shape[1] # this is the hidden_size in case of matrix weights, and 1 in case of biases + + if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock): + wk, wv = new_wk, new_wv + elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV): + assert orig_num_kv_heads % num_kv_heads == 0, ( + f"({orig_num_kv_heads=}) % ({num_kv_heads=}) != 0" + ) + n_heads_to_aggregate = orig_num_kv_heads // num_kv_heads + + wk = wk.view(-1, n_heads_to_aggregate, head_size, dim1) + wv = wv.view(-1, n_heads_to_aggregate, head_size, dim1) + + if gqa_init_mode == GQAInitMode.AverageKV: + wk = wk.mean(dim=1) + wv = wv.mean(dim=1) + else: + wk = wk[:, 0] + wv = wv[:, 0] + elif gqa_init_mode == GQAInitMode.CopyAsIs: + assert new_wk.shape == wk.shape, f"({new_wk.shape=}) != ({wk.shape=})" + assert new_wv.shape == wv.shape, f"({new_wv.shape=}) != ({wv.shape=})" + assert new_wq.shape == wq.shape, f"({new_wq.shape=}) != ({wq.shape=})" + assert new_wo.shape == wo.shape, f"({new_wo.shape=}) != ({wo.shape=})" + + elif gqa_init_mode == GQAInitMode.Degrouping: + assert not is_original_mha, ( + "Degrouping can only be done on original models that are GQA themselves." + ) + n_groups = num_kv_heads + orig_n_groups = orig_num_kv_heads + assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" + n_repeats = n_groups // orig_n_groups + if n_repeats > 1: + print(f"Degrouping {orig_n_groups} into {n_groups}") + + def degroup_w(w): + w = w.view(orig_n_groups, head_size, dim1) + w = torch.repeat_interleave(w, repeats=n_repeats, dim=0) + w = w.reshape(n_groups * head_size, dim1) + return w + + wk = degroup_w(wk) + wv = degroup_w(wv) + + elif gqa_init_mode == GQAInitMode.PruneKVHeads: + wk = wk.view(orig_num_kv_heads, head_size, dim1) + wv = wv.view(orig_num_kv_heads, head_size, dim1) + wq = wq.view(orig_num_kv_heads, num_q_heads // orig_num_kv_heads, head_size, dim1) + wo = wo.view(dim1, orig_num_kv_heads, num_q_heads // orig_num_kv_heads, head_size) + + o_proj_module_name = o_key.replace(".weight", "") + kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name) + kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True) + kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads] + kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:] + + wk = wk[kv_heads_to_keep] + wv = wv[kv_heads_to_keep] + + reduction_factor = orig_num_kv_heads // num_kv_heads + + prune_via_duplication = False + if prune_via_duplication: + ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads. + wq = wq[kv_heads_to_keep] + wq = torch.repeat_interleave(wq, repeats=reduction_factor, dim=0) + + ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups. + ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise. + wo = wo[:, kv_heads_to_keep] + wo = torch.repeat_interleave(wo, repeats=reduction_factor, dim=1) + wo = wo / reduction_factor + + else: # prune via zeroing out + ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA. + ## We need to interleave them to keep the matching between queries and kv heads. + kv_heads_to_keep = kv_heads_to_keep.tolist() + kv_heads_to_remove = kv_heads_to_remove.tolist() + kv_head_ordering = [] + zero_out_mask = [] + for i_head in range(orig_num_kv_heads): + if i_head % reduction_factor == 0: + kv_head_ordering.append(kv_heads_to_keep.pop(0)) + zero_out_mask.append(False) + else: + kv_head_ordering.append(kv_heads_to_remove.pop(0)) + zero_out_mask.append(True) + + wq = wq[kv_head_ordering] + + ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads. + ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model. + ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training. + ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them. + wo = wo[:, kv_head_ordering] + wo[:, zero_out_mask] = 0.0 + + else: + raise ValueError(f"{gqa_init_mode=} not supported") + + wk = wk.reshape(-1, dim1) + wv = wv.reshape(-1, dim1) + wq = wq.reshape(-1, dim1) + wo = wo.reshape(dim1, -1) + return wq, wk, wv, wo + + +def _init_attention_biases( + gqa_init_mode, + layer_idx, + new_state_dict, + new_config, + original_state_dict, + q_key, + k_key, + v_key, + o_key, + original_config, + is_original_mha, + head_size, + mlp_init_config, +): + assert new_config.num_attention_heads == original_config.num_attention_heads, ( + f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + ) + num_q_heads = new_config.num_attention_heads + n_heads_in_group = new_config.block_configs[layer_idx].attention.n_heads_in_group + orig_n_heads_in_group = original_config.block_configs[layer_idx].attention.n_heads_in_group + num_kv_heads = num_q_heads // n_heads_in_group + orig_num_kv_heads = num_q_heads // orig_n_heads_in_group + + o_proj_bias = new_config.o_proj_bias + attention_bias = new_config.attention_bias + + # If no biases + if not (o_proj_bias or attention_bias): + return {} + + new_bias_sd = {} + bias_sd = {} + # new_w* are typically randomly initialized + if o_proj_bias: + new_bias_sd["o"] = new_state_dict[o_key] + bias_sd["o"] = original_state_dict[o_key] + if attention_bias: + for bias_key, key in zip("qkv", [q_key, k_key, v_key]): + new_bias_sd[bias_key] = new_state_dict[key] + bias_sd[bias_key] = original_state_dict[key] + + # maybe unsqueeze all tensors + for tensor in list(new_bias_sd.values()) + list(bias_sd.values()): + assert tensor.ndim == 1 + tensor.unsqueeze_(1) + + dim1 = 1 # this is the hidden_size in case of matrix weights, and 1 in case of biases + if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock) and attention_bias: + bias_sd["k"] = torch.zeros( + new_bias_sd["k"].shape, dtype=bias_sd["k"].dtype, device=bias_sd["k"].device + ) + bias_sd["v"] = torch.zeros( + new_bias_sd["v"].shape, dtype=bias_sd["v"].dtype, device=bias_sd["v"].device + ) + elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV) and attention_bias: + assert n_heads_in_group % orig_n_heads_in_group == 0, ( + f"({n_heads_in_group=}) % ({orig_n_heads_in_group=}) != 0" + ) + n_heads_to_aggregate = n_heads_in_group // orig_n_heads_in_group + + bias_sd["k"] = bias_sd["k"].view(-1, n_heads_to_aggregate, head_size, dim1) + bias_sd["v"] = bias_sd["v"].view(-1, n_heads_to_aggregate, head_size, dim1) + + if gqa_init_mode == GQAInitMode.AverageKV: + bias_sd["k"] = bias_sd["k"].mean(dim=1) + bias_sd["v"] = bias_sd["v"].mean(dim=1) + else: + bias_sd["k"] = bias_sd["k"][:, 0] + bias_sd["v"] = bias_sd["v"][:, 0] + elif gqa_init_mode == GQAInitMode.CopyAsIs: + for key in bias_sd.keys(): + assert new_bias_sd[key].shape == bias_sd[key].shape, ( + f"({new_bias_sd[key].shape=}) != ({bias_sd[key].shape=})" + ) + + elif gqa_init_mode == GQAInitMode.Degrouping and attention_bias: + assert not is_original_mha, ( + "Degrouping can only be done on original models that are GQA themselves." + ) + n_groups = new_config.num_attention_heads // n_heads_in_group + orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group + assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" + n_repeats = n_groups // orig_n_groups + if n_repeats > 1: + print(f"Degrouping {orig_n_groups} into {n_groups}") + + def degroup_w(w): + w = w.view(orig_n_groups, head_size, dim1) + w = torch.repeat_interleave(w, repeats=n_repeats, dim=0) + w = w.reshape(n_groups * head_size, dim1) + return w + + bias_sd["k"] = degroup_w(bias_sd["k"]) + bias_sd["v"] = degroup_w(bias_sd["v"]) + + elif gqa_init_mode == GQAInitMode.PruneKVHeads: + if o_proj_bias: + o_proj_module_name = o_key.rsplit(".", 1)[0] + else: + # Here we assume that the o_proj layer is called "o_proj" + o_proj_module_name = k_key.rsplit(".", 2)[0] + ".o_proj" + + kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name) + kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True) + kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads] + kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:] + + # view as KV groups + if attention_bias: + bias_sd["k"] = bias_sd["k"].view(orig_num_kv_heads, head_size, dim1) + bias_sd["v"] = bias_sd["v"].view(orig_num_kv_heads, head_size, dim1) + bias_sd["q"] = bias_sd["q"].view( + orig_num_kv_heads, orig_n_heads_in_group, head_size, dim1 + ) + # Keep important KV heads and prune the others + bias_sd["k"] = bias_sd["k"][kv_heads_to_keep] + bias_sd["v"] = bias_sd["v"][kv_heads_to_keep] + if o_proj_bias: + bias_sd["o"] = bias_sd["o"].view( + dim1, orig_num_kv_heads, orig_n_heads_in_group, head_size + ) + + reduction_factor = orig_num_kv_heads // num_kv_heads + + prune_via_duplication = False + if prune_via_duplication: + if attention_bias: + ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads. + bias_sd["q"] = bias_sd["q"][kv_heads_to_keep] + bias_sd["q"] = torch.repeat_interleave( + bias_sd["q"], repeats=reduction_factor, dim=0 + ) + + if o_proj_bias: + ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups. + ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise. + bias_sd["o"] = bias_sd["o"][:, kv_heads_to_keep] + bias_sd["o"] = torch.repeat_interleave( + bias_sd["o"], repeats=reduction_factor, dim=1 + ) + bias_sd["o"] = bias_sd["o"] / reduction_factor + + else: # prune via zeroing out + ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA. + ## We need to interleave them to keep the matching between queries and kv heads. + kv_heads_to_keep = kv_heads_to_keep.tolist() + kv_heads_to_remove = kv_heads_to_remove.tolist() + kv_head_ordering = [] + zero_out_mask = [] + for i_head in range(orig_num_kv_heads): + if i_head % reduction_factor == 0: + kv_head_ordering.append(kv_heads_to_keep.pop(0)) + zero_out_mask.append(False) + else: + kv_head_ordering.append(kv_heads_to_remove.pop(0)) + zero_out_mask.append(True) + + if attention_bias: + bias_sd["q"] = bias_sd["q"][kv_head_ordering] + + if o_proj_bias: + ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads. + ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model. + ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training. + ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them. + bias_sd["o"] = bias_sd["o"][:, kv_head_ordering] + bias_sd["o"][:, zero_out_mask] = 0.0 + + else: + raise ValueError(f"{gqa_init_mode=} not supported") + + if attention_bias: + for bias_key in "qkv": + bias_sd[bias_key] = bias_sd[bias_key].reshape(-1) + if o_proj_bias: + bias_sd["o"] = bias_sd["o"].reshape(-1) + return bias_sd + + +def _init_moe_module( + mlp_init_mode: Union[MlpInitMode, str], + mlp_init_config: Optional[Dict[str, Any]], + layer_idx: int, + orig_router_weights: Dict[str, List[torch.Tensor]], + orig_experts_weights: Dict[str, List[torch.Tensor]], + new_router_weights: Dict[str, List[torch.Tensor]], + new_experts_weights: Dict[str, List[torch.Tensor]], + orig_num_experts: int, + new_num_experts: int, +) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + if isinstance(mlp_init_mode, str): + mlp_init_mode = MlpInitMode(mlp_init_mode) + + if mlp_init_mode != MlpInitMode.ExpertRemoval: + raise ValueError(f"Unsupported {mlp_init_mode=}") + + selected_experts = _select_expert_indices( + mlp_init_config=mlp_init_config, + layer_idx=layer_idx, + orig_num_experts=orig_num_experts, + new_num_experts=new_num_experts, + ) + + # Router: prefer parent tensors when available; if child has bias only, slice from child + result_router_weights: dict[str, list[torch.Tensor]] = {} + for name, new_list in new_router_weights.items(): + result_router_weights[name] = [ + tensor_to_slice[selected_experts] for tensor_to_slice in orig_router_weights[name] + ] + + # Experts: for each name present in the child, take from parent if available, else from child + result_experts_weights: dict[str, list[torch.Tensor]] = {} + for name, new_list in new_experts_weights.items(): + if name in orig_experts_weights: + src_list = orig_experts_weights[name] + else: + src_list = new_list + result_experts_weights[name] = [src_list[i] for i in selected_experts] + + # Validate shapes + assert result_router_weights.keys() == new_router_weights.keys(), ( + "result_router_weights and new_router_weights must have the same keys" + ) + for name in new_router_weights.keys(): + assert len(new_router_weights[name]) == len(result_router_weights[name]) + for new_router_weight, result_router_weight in zip( + new_router_weights[name], result_router_weights[name] + ): + assert new_router_weight.shape == result_router_weight.shape + + assert result_experts_weights.keys() == new_experts_weights.keys(), ( + "result_experts_weights and new_experts_weights must have the same keys" + ) + for name in result_experts_weights.keys(): + assert len(new_experts_weights[name]) == len(result_experts_weights[name]) + for new_expert_weight, result_expert_weight in zip( + new_experts_weights[name], result_experts_weights[name] + ): + assert new_expert_weight.shape == result_expert_weight.shape + + return result_router_weights, result_experts_weights + + +def _select_expert_indices( + *, mlp_init_config: dict[str, Any], layer_idx: int, orig_num_experts: int, new_num_experts: int +) -> list[int]: + expert_scores = _load_expert_scores(mlp_init_config, layer_idx) + assert len(expert_scores) == orig_num_experts + selected_experts = sorted( + range(orig_num_experts), + key=lambda i: expert_scores[i] if not math.isnan(expert_scores[i]) else float("inf"), + reverse=mlp_init_config.get("higher_is_better", True), + )[:new_num_experts] + return selected_experts + + +def _load_expert_scores( + mlp_init_config: Optional[dict[str, Any]], layer_idx: int +) -> list[list[int | float]]: + assert mlp_init_config is not None + if "expert_scores_file" in mlp_init_config: + expert_scores_file = mlp_init_config["expert_scores_file"] + with open(expert_scores_file, "r") as f: + expert_scores = json.load(f) + elif "activations_log_dir" in mlp_init_config: + _cache_activations_log(mlp_init_config) + # Use layer_prefix_template from pruning config, or fall back to legacy nemotron_h format + # TODO - get from descriptors + layer_prefix_template = mlp_init_config.get( + "layer_prefix_template", "backbone.layers.{layer_idx}." + ) + layer_prefix = layer_prefix_template.format(layer_idx=layer_idx) + candidate_layer_keys = [ + key for key in ACTIVATIONS_LOG.keys() if key.startswith(layer_prefix) + ] + if len(candidate_layer_keys) == 0: + raise ValueError(f"No layer keys found for {layer_prefix=}. {ACTIVATIONS_LOG.keys()=}") + elif len(candidate_layer_keys) > 1: + if "layer_suffix" not in mlp_init_config: + raise ValueError( + f"Multiple candidate layer keys found for {layer_prefix=}, you must specify a layer_suffix in the mlp_init_config. {candidate_layer_keys=}" + ) + layer_suffix = mlp_init_config["layer_suffix"] + layer_key = f"{layer_prefix}{layer_suffix}" + else: + layer_key = candidate_layer_keys[0] + layer_log = ACTIVATIONS_LOG[layer_key] + + expert_scores_key = mlp_init_config.get("expert_scores_key", "expert_ranks") + if expert_scores_key not in layer_log: + raise ValueError( + f"Expert scores key {expert_scores_key=} not found in {layer_log.keys()=}" + ) + expert_scores = layer_log[expert_scores_key] + else: + raise ValueError(f"Unsupported {mlp_init_config=}") + return expert_scores diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index f52c12d26..ad8ccfba2 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -14,11 +14,13 @@ # limitations under the License. # mypy: ignore-errors -"""Provides utilities for loading and saving PyTorch model checkpoints in the Hugging Face format, +""" +Provides utilities for loading and saving PyTorch model checkpoints in the Hugging Face format, particularly for DeciLM models. """ import concurrent.futures +import dataclasses import fcntl import os import shutil @@ -31,9 +33,12 @@ import torch from safetensors.torch import save_file as safe_save_file +from transformers import AutoConfig, PretrainedConfig, PreTrainedModel +from transformers.dynamic_module_utils import get_class_from_dynamic_module from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from modelopt.torch.puzzletron.decilm import deci_lm_hf_code +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import maybe_cast_block_configs from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM from modelopt.torch.puzzletron.tools.common import infer_weights_dtype @@ -69,7 +74,8 @@ def load_checkpoint( model_config_overrides: dict | None = None, ignore_unexpected_config_keys: bool = False, ) -> DeciLMForCausalLM: - """Unlike AutoModelForCausalLM.from_pretrained, the models loaded by this function use your + """ + Unlike AutoModelForCausalLM.from_pretrained, the models loaded by this function use your local repo code, not the code inside the checkpoint. """ from modelopt.torch.puzzletron.tools.checkpoint_utils import ( @@ -99,20 +105,35 @@ def load_checkpoint( return model +def force_cache_dynamic_modules(config: PretrainedConfig, checkpoint_dir: Path | str): + has_remote_code = ( + hasattr(config, "auto_map") + and isinstance(config.auto_map, dict) + and "AutoConfig" in config.auto_map.keys() + ) + if has_remote_code: + for class_reference in config.auto_map.values(): + _ = get_class_from_dynamic_module(class_reference, checkpoint_dir) + + def load_model_config( checkpoint_dir: Path | str, model_config_overrides: Mapping | None = None, ignore_unexpected_config_keys: bool = False, -) -> DeciLMConfig: +): if not isinstance(checkpoint_dir, Path): checkpoint_dir = Path(checkpoint_dir) if model_config_overrides is None: model_config_overrides = {} - config, unused_kwargs = DeciLMConfig.from_pretrained( - checkpoint_dir, return_unused_kwargs=True, **model_config_overrides + config, unused_kwargs = AutoConfig.from_pretrained( + checkpoint_dir, trust_remote_code=True, return_unused_kwargs=True, **model_config_overrides ) + if hasattr(config, "block_configs"): + config.block_configs = maybe_cast_block_configs(config.block_configs) + + force_cache_dynamic_modules(config, checkpoint_dir) if not ignore_unexpected_config_keys: if unused_kwargs: @@ -121,74 +142,65 @@ def load_model_config( return config -def save_checkpoint(model: DeciLMForCausalLM, checkpoint_dir: Path | str) -> None: - _save_checkpoint(model.config, model.state_dict(), checkpoint_dir) +def save_checkpoint( + model: PreTrainedModel, + checkpoint_dir: Path | str, + descriptor: "ModelDescriptor", +) -> None: + _save_checkpoint(model.config, model.state_dict(), checkpoint_dir, descriptor) def _save_checkpoint( - model_config: DeciLMConfig, + model_config: PretrainedConfig, state_dict: dict[str, torch.Tensor], checkpoint_dir: Path | str, + descriptor: "ModelDescriptor", max_workers: int | None = None, # Now optional - will auto-calculate if None ) -> None: - mprint("=== Starting _save_checkpoint detailed profiling ===") - total_start_time = time.time() + from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor if not isinstance(checkpoint_dir, Path): checkpoint_dir = Path(checkpoint_dir) - # Phase 1: Create directory and save config - phase1_start_time = time.time() checkpoint_dir.mkdir(parents=True, exist_ok=True) - model_config.save_pretrained(checkpoint_dir) - phase1_time = time.time() - phase1_start_time - mprint(f"Phase 1 - Directory creation and config save: {phase1_time:.2f}s") - # Phase 2: Save subblocks (main model weights) with auto-calculated worker count - phase2_start_time = time.time() - save_subblocks( - state_dict, - checkpoint_dir, - multi_threaded=True, - max_workers=max_workers, # Will auto-calculate if None + # Phase 1: Save config + save_model_config(model_config, checkpoint_dir) + + # Phase 2: Build weight map using descriptor and write index + subblock_keys = descriptor.get_weight_groups( + layer_names=state_dict.keys(), + num_hidden_layers=model_config.num_hidden_layers, ) - phase2_time = time.time() - phase2_start_time - mprint(f"Phase 2 - Save subblocks (model weights): {phase2_time:.2f}s") - # Phase 3: Save safetensors index - phase3_start_time = time.time() - save_safetensors_index(model_config, checkpoint_dir) - phase3_time = time.time() - phase3_start_time - mprint(f"Phase 3 - Save safetensors index: {phase3_time:.2f}s") + weight_map = {} + for subblock, layer_keys in subblock_keys.items(): + weight_map_entries = { + key: f"subblocks_safetensors/{subblock}.safetensors" for key in layer_keys + } + weight_map.update(weight_map_entries) + + # Write index + index = {"metadata": {"format": "pt"}, "weight_map": weight_map} + index_path = checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME + index_json = json_dumps(index) + _write_file_process_safe(index_json, index_path) - # Phase 4: Copy HF code - phase4_start_time = time.time() - copy_deci_lm_hf_code(checkpoint_dir) - phase4_time = time.time() - phase4_start_time - mprint(f"Phase 4 - Copy HF code: {phase4_time:.2f}s") + # Handle tie_word_embeddings - don't save lm_head.weight if it's tied to embed_tokens + if getattr(model_config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict: + lm_head_weight_name = f"{descriptor.output_embedding_name()}.weight" + state_dict = {k: v for k, v in state_dict.items() if k != lm_head_weight_name} + weight_map = {k: v for k, v in weight_map.items() if k != lm_head_weight_name} - total_time = time.time() - total_start_time - mprint(f"=== _save_checkpoint completed in {total_time:.2f}s ===") - mprint( - f"Breakdown: Config {phase1_time:.1f}s + Subblocks {phase2_time:.1f}s + " - f"Index {phase3_time:.1f}s + HF code {phase4_time:.1f}s" - ) - mprint( - f"Save percentage breakdown: Config {phase1_time / total_time * 100:.1f}% + " - f"Subblocks {phase2_time / total_time * 100:.1f}% + " - f"Index {phase3_time / total_time * 100:.1f}% + " - f"HF code {phase4_time / total_time * 100:.1f}%" + # Phase 3: Save subblocks + save_subblocks( + state_dict, + checkpoint_dir, + weight_map=weight_map, + multi_threaded=True, + max_workers=max_workers, ) - # Performance metrics - if phase2_time > 0: - subblocks_percentage = phase2_time / total_time * 100 - actual_workers = max_workers if max_workers else "auto" - mprint( - f"I/O optimization: Subblocks were {subblocks_percentage:.1f}% of total save time " - f"(max_workers={actual_workers})" - ) - def split_checkpoint_to_subblocks(checkpoint_dir: Path | str) -> None: from modelopt.torch.puzzletron.tools.checkpoint_utils import ( @@ -210,6 +222,7 @@ def split_checkpoint_to_subblocks(checkpoint_dir: Path | str) -> None: def save_subblocks( state_dict: dict[str, torch.Tensor], checkpoint_dir: Path | str, + weight_map: dict[str, str] | None = None, multi_threaded: bool = True, max_workers: int | None = None, # Now optional - will auto-calculate if None ) -> None: @@ -219,14 +232,15 @@ def save_subblocks( if not isinstance(checkpoint_dir, Path): checkpoint_dir = Path(checkpoint_dir) - # Step 1: Build weight map + # Step 1: Build weight map (use provided or build from state_dict) weight_map_start_time = time.time() - weight_map = _build_safetensors_weight_map( - state_dict=state_dict, - non_layer_module_to_file_type=NON_LAYER_MODULE_TO_FILE_TYPE, - module_within_layer_to_file_type=MODULE_WITHIN_LAYER_TO_FILE_TYPE, - layers_module_name=LAYERS_MODULE_NAME, - ) + if weight_map is None: + weight_map = _build_safetensors_weight_map( + state_dict=state_dict, + non_layer_module_to_file_type=NON_LAYER_MODULE_TO_FILE_TYPE, + module_within_layer_to_file_type=MODULE_WITHIN_LAYER_TO_FILE_TYPE, + layers_module_name=LAYERS_MODULE_NAME, + ) weight_name_to_filename = {k: checkpoint_dir / v for k, v in weight_map.items()} weight_map_time = time.time() - weight_map_start_time mprint(f" Step 1 - Build weight map: {weight_map_time:.2f}s ({len(weight_map)} mappings)") @@ -323,6 +337,7 @@ def save_safetensors_index( model_config: DeciLMConfig, checkpoint_dir: Path | str, ) -> None: + """Save safetensors index for DeciLM models (legacy function).""" mprint("=== Starting save_safetensors_index profiling ===") index_start_time = time.time() @@ -372,7 +387,8 @@ def _write_file_process_safe( path: Path | str, write_fn: Callable[[Any, BinaryIO], None] = _write_text, ) -> None: - """Write a file in a multi-process safe way. + """ + Write a file in a multi-process safe way. If another process tries to write the same file using this method, the current process "gives up" and assumes that the matter is being taken care of by another process. @@ -435,13 +451,19 @@ def _build_safetensors_weight_map( return weight_map -# Not really needed -def save_model_config(model_config: DeciLMConfig, checkpoint_dir: Path | str) -> None: +def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str) -> None: + if hasattr(model_config, "block_configs"): + model_config.block_configs = [ + dataclasses.asdict(conf) if dataclasses.is_dataclass(conf) else conf + for conf in model_config.block_configs + ] model_config.save_pretrained(checkpoint_dir) def copy_deci_lm_hf_code(output_dir: Path | str) -> None: - """Copy the deci_lm_hf_code directory to the output directory.""" + """ + Copy the deci_lm_hf_code directory to the output directory. + """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) code_dir = Path(deci_lm_hf_code.__file__).parent diff --git a/modelopt/torch/puzzletron/utils/dummy_modules.py b/modelopt/torch/puzzletron/utils/dummy_modules.py new file mode 100644 index 000000000..c9eaa2bc6 --- /dev/null +++ b/modelopt/torch/puzzletron/utils/dummy_modules.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig +from typing_extensions import override + + +class DummyModule(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.register_load_state_dict_post_hook(self.load_state_dict_post_hook) + + @staticmethod + def load_state_dict_post_hook( + module: torch.nn.Module, + incompatible_keys: torch.nn.modules.module._IncompatibleKeys, + ) -> None: + incompatible_keys.missing_keys.clear() + incompatible_keys.unexpected_keys.clear() + + +class DummyBlock(DummyModule): + def __init__(self, block_index: int): + super().__init__() + self.block_index = block_index + + @override + def forward( + self, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, None]: + return x + + +class DummyWTE(DummyModule): + def __init__(self, hidden_size: int, dtype: Optional[torch.dtype] = None): + super().__init__() + self.n_embd = hidden_size + self.dtype = dtype + + @override + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + B, T = input_ids.shape + result = torch.ones((B, T, self.n_embd), dtype=self.dtype, device=input_ids.device) + return result + + +class DummyLMHead(DummyModule): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.vocab_size = config.vocab_size + + @override + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, C = x.shape + result = torch.ones((B, T, self.vocab_size), dtype=x.dtype, device=x.device) + return result diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py index 6c9feecd0..4779ee1f3 100644 --- a/tests/_test_utils/torch/puzzletron/utils.py +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -19,26 +19,38 @@ import torch from datasets import Dataset, DatasetDict -from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers +# Path to HF configs relative to this file +# HF configs are in tests/gpu/torch/puzzletron/resources/hf_configs +HF_CONFIGS_DIR = ( + Path(__file__).parent.parent.parent.parent / "gpu/torch/puzzletron/resources/hf_configs" +) + def setup_test_model_and_data( - project_root_path: Path, tmp_path: Path, rank: int + project_root_path: Path, + tmp_path: Path, + rank: int, + hf_config_name: str, + hybrid_override_pattern: str | None = None, ) -> tuple[Path, Path, Path]: """ - Setup the test model and data for the puzzletron NAS search. + Setup the test model and data for the compress NAS search. Args: project_root_path (Path): the root path of the project tmp_path (Path): the temporary path to use for the test rank (int): the rank of the process + hf_config_name (str): Name of the HF config directory (e.g., "llama_3_1_8b_instruct") + hybrid_override_pattern (str): For NemotronH models, the layer type pattern Returns: tuple[Path, Path, Path]: - the puzzle_dir, llama_checkpoint_path, dataset_path + the puzzle_dir, hf_checkpoint_path, dataset_path """ # Register Hydra custom resolvers (needed for config resolution) @@ -46,8 +58,8 @@ def setup_test_model_and_data( # The inputs for the nas.convert() step. # - puzzle_dir = tmp_path - llama_checkpoint_path = puzzle_dir / "input_model/llama" + puzzle_dir = tmp_path / hf_config_name + hf_checkpoint_path = puzzle_dir / f"hf_models/{hf_config_name}" dataset_path = puzzle_dir / "dummy_dataset" if rank == 0: @@ -55,74 +67,133 @@ def setup_test_model_and_data( setup_puzzle_dir(puzzle_dir) save_dummy_dataset(dataset_path) - # Create a small Llama model + # Create a small HF model tokenizer = create_tokenizer(project_root_path) - create_and_save_small_llama_model( - llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer + create_and_save_small_hf_model( + output_path=str(hf_checkpoint_path), + vocab_size=tokenizer.vocab_size, + tokenizer=tokenizer, + hf_config_name=hf_config_name, + hybrid_override_pattern=hybrid_override_pattern, ) dist.barrier() return ( puzzle_dir, - llama_checkpoint_path, + hf_checkpoint_path, dataset_path, ) -def create_and_save_small_llama_model( - output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase +def create_and_save_small_hf_model( + output_path: str, + vocab_size: int, + tokenizer: PreTrainedTokenizerBase, + hf_config_name: str, + hybrid_override_pattern: str | None = None, ): """ - Create and save a small Llama model for testing the conversion pipeline. - This mimics having a real Llama checkpoint that needs to be converted. + Create and save a small HuggingFace model for testing the conversion pipeline. + Uses real HuggingFace config to preserve model-specific settings (like tie_word_embeddings), + but shrinks size parameters for fast testing. + + Args: + output_path: Where to save the model + vocab_size: Vocabulary size (should match tokenizer) + tokenizer: Tokenizer to save alongside the model + hf_config_name: Name of the config directory under resources/hf_configs/ + e.g., "llama_3_1_8b_instruct", "llama_3_2_3b_instruct", or "qwen2_5_7b_instruct" + hybrid_override_pattern: For NemotronH models, the layer type pattern (e.g., "*-" for Attention+MLP, + "M-" for Mamba+MLP). Must match num_hidden_layers. None for non-NemotronH models. """ os.makedirs(output_path, exist_ok=True) - # Create a minimal Llama config (small for testing) + # Load real HuggingFace config (preserves tie_word_embeddings, rope_scaling, etc.) + config_path = HF_CONFIGS_DIR / hf_config_name + config = AutoConfig.from_pretrained(config_path, local_files_only=True, trust_remote_code=True) + + # Override size-related params to make it small for testing # Note: intermediate_size must be divisible by 256 per DeciLM config requirements # Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility - llama_config = LlamaConfig( - vocab_size=vocab_size, - hidden_size=256, # 32 heads times 8 head_dim = 256 (matches bypass config expectations) - intermediate_size=512, # Must be divisible by 256 - num_hidden_layers=2, - num_attention_heads=32, # Matches original test - num_key_value_heads=8, # GQA: 32÷4=8 (matches original n_heads_in_group=4) - max_position_embeddings=512, - rms_norm_eps=1e-5, - rope_theta=10000.0, - attention_bias=False, - hidden_act="silu", - tie_word_embeddings=False, - ) - # Create and save the Llama model - model = LlamaForCausalLM(llama_config) + # VL models have nested configs (text_config, vision_config) + if hf_config_name == "qwen3-vl-30b-a3b-instruct": + config.text_config.vocab_size = vocab_size + config.text_config.hidden_size = 256 + config.text_config.intermediate_size = 512 + config.text_config.num_hidden_layers = 2 + config.text_config.num_attention_heads = 32 + config.text_config.num_key_value_heads = 8 + config.text_config.num_experts = 16 # Reduce from 128 + config.text_config.moe_intermediate_size = 256 + config.text_config.max_position_embeddings = 512 + config.vision_config.depth = 2 # Reduce from 27 + config.vision_config.hidden_size = 256 + config.vision_config.intermediate_size = 512 + config.vision_config.out_hidden_size = 256 + # TODO: this is hack, redesign converter to not read config.num_hidden_layers directly. + # set top-level num_hidden_layers for converter compatibility + config.num_hidden_layers = config.text_config.num_hidden_layers + else: + # Regular models have flat config + config.vocab_size = vocab_size + config.hidden_size = 256 + config.intermediate_size = 512 + config.num_hidden_layers = 2 + config.num_attention_heads = 32 + config.num_key_value_heads = 8 + config.max_position_embeddings = 512 + + # Fix layer_types to match num_hidden_layers (newer transformers validates this) + if hasattr(config, "layer_types") and config.layer_types is not None: + config.layer_types = config.layer_types[:2] + + # Fix rope_scaling to be consistent with max_position_embeddings + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + config.rope_scaling["original_max_position_embeddings"] = 256 + + # NemotronH requires hybrid_override_pattern to match num_hidden_layers + if hasattr(config, "hybrid_override_pattern") and hybrid_override_pattern is not None: + config.hybrid_override_pattern = hybrid_override_pattern + + # Set seed for reproducible weight initialization + torch.manual_seed(42) + + # Create and save the model + # TODO: Consider using AutoModel.from_config instead. + if hf_config_name == "qwen3-vl-30b-a3b-instruct": + from transformers import Qwen3VLMoeForConditionalGeneration + + model = Qwen3VLMoeForConditionalGeneration._from_config(config) + else: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + model.to(dtype=torch.bfloat16).save_pretrained(output_path) # Save tokenizer tokenizer.save_pretrained(output_path) # Save config - llama_config.save_pretrained(output_path) + config.save_pretrained(output_path) def create_tokenizer(project_root_path: Path) -> PreTrainedTokenizerBase: """ - Create a tokenizer for the Llama model. + Create a tokenizer for the model. """ - tokenizer_path = project_root_path / "tests/_test_utils/torch/puzzletron/resources/tokenizer" + tokenizer_path = project_root_path / "tests/gpu/torch/puzzletron/resources/tokenizer" tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) return tokenizer -def setup_puzzle_dir(puzzle_dir: str): +def setup_puzzle_dir(puzzle_dir: str | Path): """ Setup puzzle directory by removing existing directory and creating a new one. """ - if Path(puzzle_dir).exists(): + puzzle_dir = Path(puzzle_dir) + if puzzle_dir.exists(): shutil.rmtree(puzzle_dir) - Path(puzzle_dir).mkdir(parents=True, exist_ok=True) + puzzle_dir.mkdir(parents=True, exist_ok=True) def save_dummy_dataset(dataset_path: Path | str): diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml new file mode 100644 index 000000000..65ca64ef4 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml @@ -0,0 +1,107 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +descriptor: llama + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/attn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/attn_pruning.yaml new file mode 100644 index 000000000..01886607e --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/ffn_pruning.yaml new file mode 100644 index 000000000..cad6fcf3e --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/ffn_pruning.yaml @@ -0,0 +1,18 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/hidden_dim_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/hidden_dim_pruning.yaml new file mode 100644 index 000000000..407c835d8 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/pruning_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/pruning_defaults.yaml new file mode 100644 index 000000000..b24ea1b7c --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/pruning_defaults.yaml @@ -0,0 +1,33 @@ +defaults: + - /validate_model_defaults + +descriptor: ${descriptor} +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 100 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_model_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_model_defaults.yaml new file mode 100644 index 000000000..9dabef741 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_model_defaults.yaml @@ -0,0 +1,15 @@ +block_size: 8192 +bos_rate: 0.5 +data_column: conversation +val_dataset_name: train +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_solutions_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_solutions_defaults.yaml new file mode 100644 index 000000000..ec1390237 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/tests/gpu/torch/puzzletron/resources/hf_configs/llama_3_1_8b_instruct/config.json b/tests/gpu/torch/puzzletron/resources/hf_configs/llama_3_1_8b_instruct/config.json new file mode 100644 index 000000000..0bb6fd75b --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/hf_configs/llama_3_1_8b_instruct/config.json @@ -0,0 +1,38 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": [ + 128001, + 128008, + 128009 + ], + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.42.3", + "use_cache": true, + "vocab_size": 128256 +} diff --git a/tests/gpu/torch/puzzletron/resources/tokenizer/special_tokens_map.json b/tests/gpu/torch/puzzletron/resources/tokenizer/special_tokens_map.json new file mode 100644 index 000000000..02ee80b61 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/tokenizer/special_tokens_map.json @@ -0,0 +1,16 @@ +{ + "bos_token": { + "content": "<|begin_of_text|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "<|eot_id|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer.json b/tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer.json new file mode 100644 index 000000000..83592e249 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer.json @@ -0,0 +1,212 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [], + "normalizer": null, + "pre_tokenizer": { + "type": "Sequence", + "pretokenizers": [ + { + "type": "Split", + "pattern": { + "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + }, + "behavior": "Isolated", + "invert": false + }, + { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": false + } + ] + }, + "post_processor": { + "type": "Sequence", + "processors": [ + { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": false, + "use_regex": true + }, + { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 1 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + } + ], + "special_tokens": { + "<|begin_of_text|>": { + "id": "<|begin_of_text|>", + "ids": [ + 100 + ], + "tokens": [ + "<|begin_of_text|>" + ] + } + } + } + ] + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true, + "use_regex": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": true, + "vocab": { + "!": 0, + "\"": 1, + "#": 2, + "$": 3, + "%": 4, + "&": 5, + "'": 6, + "(": 7, + ")": 8, + "*": 9, + "+": 10, + ",": 11, + "-": 12, + ".": 13, + "/": 14, + "0": 15, + "1": 16, + "2": 17, + "3": 18, + "4": 19, + "5": 20, + "6": 21, + "7": 22, + "8": 23, + "9": 24, + ":": 25, + ";": 26, + "<": 27, + "=": 28, + ">": 29, + "?": 30, + "@": 31, + "A": 32, + "B": 33, + "C": 34, + "D": 35, + "E": 36, + "F": 37, + "G": 38, + "H": 39, + "I": 40, + "J": 41, + "K": 42, + "L": 43, + "M": 44, + "N": 45, + "O": 46, + "P": 47, + "Q": 48, + "R": 49, + "S": 50, + "T": 51, + "U": 52, + "V": 53, + "W": 54, + "X": 55, + "Y": 56, + "Z": 57, + "[": 58, + "\\": 59, + "]": 60, + "^": 61, + "_": 62, + "`": 63, + "a": 64, + "b": 65, + "c": 66, + "d": 67, + "e": 68, + "f": 69, + "g": 70, + "h": 71, + "i": 72, + "j": 73, + "k": 74, + "l": 75, + "m": 76, + "n": 77, + "o": 78, + "p": 79, + "q": 80, + "r": 81, + "s": 82, + "t": 83, + "u": 84, + "v": 85, + "w": 86, + "x": 87, + "y": 88, + "z": 89, + "{": 90, + "|": 91, + "}": 92, + "~": 93, + "¡": 94, + "¢": 95, + "£": 96, + "¤": 97, + "¥": 98, + "¦": 99, + "<|begin_of_text|>": 100, + "<|eot_id|>": 101 + }, + "merges": [] + } +} diff --git a/tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer_config.json b/tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer_config.json new file mode 100644 index 000000000..754d9e8db --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/tokenizer/tokenizer_config.json @@ -0,0 +1,13 @@ +{ + "bos_token": "<|begin_of_text|>", + "chat_template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n", + "clean_up_tokenization_spaces": true, + "eos_token": "<|eot_id|>", + "extra_special_tokens": {}, + "model_input_names": [ + "input_ids", + "attention_mask" + ], + "model_max_length": 131072, + "tokenizer_class": "PreTrainedTokenizer" +} diff --git a/tests/gpu/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py b/tests/gpu/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py new file mode 100644 index 000000000..aedcae4ab --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/tokenizer/truncate_tokenizer.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script was used to truncate the tokenizer.json file from Llama 3.1 8B model +to keep only the top 100 most common tokens. +""" + +import json + +# Path to your original and new tokenizer.json +in_path = "./tokenizer.json" +out_path = "./tokenizer_truncated.json" + +# How many top tokens to keep +NUM_TO_KEEP = 100 + +with open(in_path, encoding="utf-8") as f: + tokenizer_data = json.load(f) + +# Get and sort the original vocab by index (frequency proxy) +orig_vocab = tokenizer_data["model"]["vocab"] + +# Sort tokens by their original index (lowest index = assumed most common/important) +sorted_tokens = sorted(orig_vocab.items(), key=lambda item: item[1]) + +# Keep the top N tokens +tokens_to_keep = [tok for tok, idx in sorted_tokens[:NUM_TO_KEEP]] + +# Re-index the selected tokens: 0..N-1 +small_vocab = {tok: i for i, tok in enumerate(tokens_to_keep)} +tokenizer_data["model"]["vocab"] = small_vocab + +# Update vocab size +if "vocab_size" in tokenizer_data["model"]: + tokenizer_data["model"]["vocab_size"] = len(small_vocab) + +# Optionally remove merges if present and unneeded (mostly for BPE/WordPiece) +if "merges" in tokenizer_data["model"]: + tokenizer_data["model"]["merges"] = [] + +# Remove added_tokens if not needed +if "added_tokens" in tokenizer_data: + tokenizer_data["added_tokens"] = [] + +# Write out the truncated tokenizer.json +with open(out_path, "w", encoding="utf-8") as f: + json.dump(tokenizer_data, f, indent=2, ensure_ascii=False) + +print(f"Truncated tokenizer saved to: {out_path}") diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index faf72f749..23a4b61c2 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -13,19 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from datetime import timedelta from functools import partial from pathlib import Path +import pytest import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.puzzletron.utils import setup_test_model_and_data import modelopt.torch.utils.distributed as dist -from modelopt.torch.puzzletron import puzzletron -from modelopt.torch.puzzletron.decilm.converters.convert_llama3_to_decilm import ( - convert_llama3_to_decilm, -) +from modelopt.torch.puzzletron.anymodel import convert_model # The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) # using a one-click command. @@ -33,91 +32,279 @@ # Note: Bypass is disabled now in the test. -def test_puzzletron(project_root_path: Path, tmp_path: Path): +@pytest.mark.parametrize( + ( + "hf_config_name", + "converter", + "hydra_config_subdir", + "hybrid_override_pattern", + "has_moe_layers", + ), + [ + ("llama_3_1_8b_instruct", "llama", "llama_3_1_8b_instruct", None, False), + ("llama_3_2_3b_instruct", "llama", "llama_3_1_8b_instruct", None, False), + ("qwen2_5_7b_instruct", "qwen2", "qwen2_5_7b_instruct", None, False), + ( + "mistral-small-24b-instruct-2501", + "mistral_small", + "mistral-small-24b-instruct-2501", + None, + False, + ), + ("qwen3-8b", "qwen3", "qwen3-8b", None, False), + ("qwen3-vl-30b-a3b-instruct", "qwen3_vl", "qwen3-vl-30b-a3b-instruct", None, True), + ("nemotron-nano-12b-v2", "nemotron_h_v2", "nemotron-nano-12b-v2", "*-", False), + ( + "nemotron-3-nano-30b-a3b-base-bf16", + "nemotron_h", + "nemotron-3-nano-30b-a3b-base-bf16", + "*E", + True, + ), + ("gpt-oss-20b", "gpt_oss_20b", "gpt-oss-20b", None, True), + ], +) +def test_puzzletron( + project_root_path: Path, + tmp_path: Path, + hf_config_name: str, + converter: str, + hydra_config_subdir: str, + hybrid_override_pattern: str, + has_moe_layers: bool, +): spawn_multiprocess_job( - size=min(torch.cuda.device_count(), 2), # assertions configured for atmost 2 GPUs - job=partial(_test_puzzletron_multiprocess_job, project_root_path, tmp_path), + size=torch.cuda.device_count(), + job=partial( + _test_puzzletron_multiprocess_job, + project_root_path, + tmp_path, + hf_config_name, + converter, + hydra_config_subdir, + hybrid_override_pattern, + has_moe_layers, + ), backend="nccl", ) def _test_puzzletron_multiprocess_job( - project_root_path: Path, tmp_path: Path, rank: int, size: int + project_root_path: Path, + tmp_path: Path, + hf_config_name: str, + converter: str, + hydra_config_subdir: str, + hybrid_override_pattern: str, + has_moe_layers: bool, + rank: int, + size: int, ): dist.setup(timeout=timedelta(10)) + # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank + puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank, hf_config_name, hybrid_override_pattern + ) + hydra_config_dir = ( # noqa: F841 + project_root_path / f"tests/gpu/torch/puzzletron/resources/configs/{hydra_config_subdir}" ) - hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" - hydra_config_name = "Llama-3_1-8B-ffn-pruning" - # Convert the Llama model to DeciLM model. + # Convert the model using AnyModel converter. if rank == 0: - convert_llama3_to_decilm( - input_dir=llama_checkpoint_path, - output_dir=puzzle_dir / "ckpts/teacher", + convert_model( + input_dir=str(hf_checkpoint_path), + output_dir=str(puzzle_dir / "ckpts/teacher"), + converter=converter, ) dist.barrier() - # Compress the model using a one-click approach - puzzletron.puzzletron( - str(hydra_config_dir), hydra_config_name, str(puzzle_dir), str(dataset_path) - ) + # TODO commented for the duration of merging process from dkorzekwa/any_model to feature/puzzletron + # # Compress the model using a one-click approach + # puzzletron.puzzletron( + # str(hydra_config_dir), hydra_config_subdir, str(puzzle_dir), str(dataset_path) + # ) - # - # Check assertions - # - # assertions for the score_pruning_activations step 1 - _assert_score_pruning_activations(puzzle_dir) - if rank == 0: - # assertions for the pruning_ckpts step 2 - assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + # # + # # Check assertions + # # + # if rank == 0: + # if has_moe_layers: + # # assertions for the score_pruning_activations step 1 (MoE models only) + # rank_filepath = ( + # f"pruning/pruning_scores/expert_removal/10samples_diverse_mini/rank_{rank}.pth" + # ) + # assert (puzzle_dir / rank_filepath).is_file(), f"Expected {rank_filepath} to exist" - # assertions for the build_library_and_stats step 4 + # # assertions for the pruning_ckpts step 2 + # assert (puzzle_dir / "ckpts/num_experts_8").exists() - assert (puzzle_dir / "replacement_library.json").is_file() - assert (puzzle_dir / "subblock_stats.json").is_file() + # # assertions for the mip_and_realize_models step 6 + # # Find the MIP solution directory dynamically (e.g., stats_num_local_experts_*) + # mip_solutions_dir = puzzle_dir / "mip/puzzle_solutions" + # solution_dirs = [ + # d + # for d in mip_solutions_dir.iterdir() + # if d.is_dir() and d.name.startswith("stats_num_local_experts_") + # ] + # assert len(solution_dirs) == 1, ( + # f"Expected exactly one stats_num_local_experts_* directory, found: {[d.name for d in solution_dirs]}" + # ) + # solution_dir = solution_dirs[0] - # assertions for the scoring step 5 - solution_0_filepath = ( - puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" - ) + # solution_0_ckpt_config_path = ( + # solution_dir / "solutions--checkpoints/solution_0/config.json" + # ) + # assert solution_0_ckpt_config_path.exists() + # assert (solution_dir / "solutions.json").exists() - assert solution_0_filepath.exists() + # # Validate lm_loss + # _assert_lm_loss(puzzle_dir, hf_config_name) + # else: + # # assertions for the score_pruning_activations step 1 (FFN pruning) + # _assert_score_pruning_activations(puzzle_dir, hf_config_name) - # assertions for the mip_and_realize_models step 6 - solution_0_ckpt_config_path = ( - puzzle_dir - / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" - ) + # # assertions for the pruning_ckpts step 2 + # assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + + # # assertions for the mip_and_realize_models step 6 + # _assert_mip_solutions(puzzle_dir, hf_config_name) - assert solution_0_ckpt_config_path.exists() - assert (puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json").exists() + # # assertions for the build_library_and_stats step 4 + # assert (puzzle_dir / "replacement_library.json").is_file() + # assert (puzzle_dir / "subblock_stats.json").is_file() + + # # assertions for the scoring step 5 + # solution_0_filepath = ( + # puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + # ) + # assert solution_0_filepath.exists() dist.cleanup() + print( + f"PYTEST SUMMARY: test_puzzletron({hf_config_name}) test has finished successfully. " + f"Puzzle directory: {puzzle_dir}" + ) + + +# Expected pruning activation values per model +# Each model has a list of (score, channels) tuples for each FFN layer +EXPECTED_PRUNING_VALUES = { + "llama_3_1_8b_instruct": [ + {"score": 73, "channels": 95}, + {"score": 440, "channels": 174}, + ], + "llama_3_2_3b_instruct": [ + {"score": 79, "channels": 95}, + {"score": 428, "channels": 174}, + ], + "qwen2_5_7b_instruct": [ + {"score": 96, "channels": 433}, + {"score": 485, "channels": 105}, + ], + # Mistral Small 24B + "mistral-small-24b-instruct-2501": [ + {"score": 73, "channels": 95}, + {"score": 431, "channels": 174}, + ], + # Qwen3 8B + "qwen3-8b": [ + {"score": 208, "channels": 51}, + {"score": 475, "channels": 266}, + ], + # NemotronH with pattern "*-" has only 1 FFN layer (the "-" layer) + "nemotron-nano-12b-v2": [ + {"score": 70, "channels": 509}, + ], + # Note: nemotron-3-nano-30b-a3b-base-bf16 uses MoE expert pruning, not FFN pruning + # so it doesn't have EXPECTED_PRUNING_VALUES +} + -def _assert_score_pruning_activations(puzzle_dir: Path): +# Expected lm_loss values per model +EXPECTED_LM_LOSS = { + "llama_3_1_8b_instruct": 4.706878662109375, + "llama_3_2_3b_instruct": 4.816886901855469, + "qwen2_5_7b_instruct": 4.778186798095703, + "nemotron-nano-12b-v2": 4.79390811920166, + "mistral-small-24b-instruct-2501": 4.709150314331055, + "qwen3-8b": 4.733874320983887, + "gpt-oss-20b": 4.689250946044922, + "nemotron-3-nano-30b-a3b-base-bf16": 4.741103172302246, + "qwen3-vl-30b-a3b-instruct": 4.65625, +} + + +def _assert_score_pruning_activations(puzzle_dir: Path, hf_config_name: str): """Assertions for the score_pruning_activations step 1.""" rank = dist.rank() - size = dist.size() rank_filepath = f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" assert (puzzle_dir / rank_filepath).is_file() pruning_scores = torch.load(puzzle_dir / rank_filepath) layer_names = list(pruning_scores.keys()) - assert len(layer_names) == 2 // size - - if size == 1 or rank == 0: - # Check specific values for layer 0 - layer_0 = pruning_scores[layer_names[0]] - assert layer_0["score"][0].item() == 371 - assert layer_0["channels_importance_ascending"][0].item() == 140 - - if size == 1 or rank == 1: - # Check specific values for layer 1 - layer_1 = pruning_scores[layer_names[1 if size == 1 else 0]] - assert layer_1["score"][0].item() == 269 - assert layer_1["channels_importance_ascending"][0].item() == 366 + expected = EXPECTED_PRUNING_VALUES[hf_config_name] + size = dist.size() + + if expected is not None: + # In multi-GPU: layers are distributed across ranks + # Each rank processes len(expected) // size layers + expected_layers_per_rank = len(expected) // size + assert len(layer_names) == expected_layers_per_rank, ( + f"Expected {expected_layers_per_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}" + ) + # Check each layer's values + for i, layer_name in enumerate(layer_names): + layer_data = pruning_scores[layer_name] + # Calculate global layer index from rank and local index + global_idx = rank * expected_layers_per_rank + i + assert layer_data["score"][0].item() == expected[global_idx]["score"] + assert ( + layer_data["channels_importance_ascending"][0].item() + == expected[global_idx]["channels"] + ) + else: + # Print values for new models - update EXPECTED_PRUNING_VALUES with these + print(f"\n=== PRUNING VALUES for {hf_config_name} (num_layers={len(layer_names)}) ===") + print(f'"{hf_config_name}": [') + for layer_name in layer_names: + layer_data = pruning_scores[layer_name] + score = layer_data["score"][0].item() + channels = layer_data["channels_importance_ascending"][0].item() + print(f' {{"score": {score}, "channels": {channels}}},') + print("],") + print("===") + + +def _assert_lm_loss(puzzle_dir: Path, hf_config_name: str): + """Validate lm_loss for a model solution.""" + solution_0_path = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) + with open(solution_0_path) as f: + validation = json.load(f) + + actual_lm_loss = validation["lm_loss"]["avg"] + expected_lm_loss = EXPECTED_LM_LOSS.get(hf_config_name) + if expected_lm_loss is not None: + assert abs(actual_lm_loss - expected_lm_loss) < 0.01, ( + f"lm_loss mismatch: expected {expected_lm_loss}, got {actual_lm_loss}" + ) + else: + # Print value for new models - update EXPECTED_LM_LOSS with this + print(f"\n=== LM_LOSS for {hf_config_name} ===") + print(f'"{hf_config_name}": {actual_lm_loss},') + print("===") + + +def _assert_mip_solutions(puzzle_dir: Path, hf_config_name: str): + """Assertions for the mip_and_realize_models step.""" + mip_dir = puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB" + + assert (mip_dir / "solutions.json").exists() + assert (mip_dir / "solutions--checkpoints/solution_0/config.json").exists() + + # Validate lm_loss + _assert_lm_loss(puzzle_dir, hf_config_name) From eb5cf8ab36abe5c583cd9863f3d4748248d79480 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 5 Mar 2026 05:48:21 -0800 Subject: [PATCH 03/30] Update child_init.py with anymodel version Signed-off-by: Daniel Korzekwa --- .../tools/bypassed_training/child_init.py | 704 ++++-------------- 1 file changed, 128 insertions(+), 576 deletions(-) diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py index 3981b62e3..b30e7eefa 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py @@ -14,7 +14,7 @@ # limitations under the License. # mypy: ignore-errors -"""TODO Add description. Analyze this code, why is it so long and complex? Can it be simplified?""" +"""Core logic for creating pruned child model state dicts from parent models. Used by init_child_from_parent.""" import concurrent.futures import dataclasses @@ -22,12 +22,11 @@ import os import re import time -from collections.abc import Callable from copy import deepcopy from enum import Enum from functools import partial from pathlib import Path -from typing import Any +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from typeguard import check_type @@ -39,41 +38,23 @@ _is_dataclass_type, ) from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch.puzzletron.pruning.pruning_utils import ( + ACTIVATIONS_LOG, + GQAInitMode, + HiddenSizeInitMode, + LinearInitMode, + MlpInitMode, + _cache_activations_log, + _init_attention_biases, + _init_attention_weights, + _init_mlp_module, + _init_moe_module, + _load_activations_log, + _load_expert_scores, + _select_expert_indices, +) from modelopt.torch.puzzletron.tools.logger import aprint, mprint - -class GQAInitMode(Enum): - RandomKV = "RandomKV" - AverageKV = "AverageKV" - FirstKV = "FirstKV" - RandomBlock = "RandomBlock" - CopyAsIs = "CopyAsIs" - Degrouping = "Degrouping" - PruneKVHeads = "PruneKVHeads" - - -class MlpInitMode(Enum): - Random = "Random" - Truncate = "Truncate" - CopyAsIs = "CopyAsIs" - PruneByActivationsLog = "PruneByActivationsLog" - ExpertRemoval = "ExpertRemoval" - ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN" - MoEChannelPruning = "MoEChannelPruning" - - -class LinearInitMode(Enum): - Random = "Random" - FromTeacher = "FromTeacher" - - -class HiddenSizeInitMode(Enum): - Random = "Random" - Truncate = "Truncate" - PruneByChannelRanking = "PruneByChannelRanking" - CopyAsIs = "CopyAsIs" - - IgnoreFn = Callable[[str], bool] default_ignore_fn: IgnoreFn = lambda _: False @@ -87,25 +68,52 @@ def print(s: str) -> None: def _process_single_layer( layer_idx: int, + pruning_mixin, + descriptor, parent_state_dict: dict, new_state_dict: dict, original_config: DeciLMConfig, new_config: DeciLMConfig, gqa_init_mode: GQAInitMode, mlp_init_mode: MlpInitMode, - mlp_init_config: dict[str, Any] | None, + mlp_init_config: Optional[dict[str, Any]], linear_init_mode: LinearInitMode, ignored_keys: set, keys: dict, is_original_mha: bool, head_size: int, hidden_size: int, -) -> tuple[dict[str, torch.Tensor], dict[str, str]]: - """Process a single layer in parallel. Returns (layer_state_dict, keys_to_remove). +) -> Tuple[Dict[str, torch.Tensor], Dict[str, str]]: + """ + Process a single layer in parallel. Returns (layer_state_dict, keys_to_remove). Thread-safe function for parallel layer processing. """ - layer_out_state_dict = {} keys_to_remove = {} + layer_out_state_dict = {} + + # Delegate to pruning_mixin if available + if pruning_mixin is not None: + _layer_out = pruning_mixin.prune_single_layer( + layer_idx=layer_idx, + parent_state_dict=parent_state_dict, + new_state_dict=new_state_dict, + original_config=original_config, + new_config=new_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + ignored_keys=ignored_keys, + keys=keys, + is_original_mha=is_original_mha, + head_size=head_size, + hidden_size=hidden_size, + keys_to_remove=keys_to_remove, + ) + layer_out_state_dict.update(_layer_out) + return layer_out_state_dict, keys_to_remove + + # Legacy inline processing (fallback when no pruning_mixin) parent_block_config = original_config.block_configs[layer_idx] child_block_config = new_config.block_configs[layer_idx] @@ -119,13 +127,13 @@ def _process_single_layer( o_key = f"{attn_prefix}.o_proj.{part}" attn_keys = [q_key, k_key, v_key, o_key] # Drop attn keys that don't exist and required to be in the new state_dict - attn_keys = [key for key in attn_keys if key in new_state_dict] + attn_keys = [key for key in attn_keys if key in new_state_dict.keys()] if len(attn_keys) > 0 and all(key in keys for key in attn_keys): for key in attn_keys: keys_to_remove[key] = keys[key] if all(key not in ignored_keys for key in attn_keys): is_student_and_teacher_have_same_attention_implementation = all( - key in new_state_dict for key in attn_keys + key in new_state_dict.keys() for key in attn_keys ) if is_student_and_teacher_have_same_attention_implementation: if part == "weight": @@ -168,7 +176,7 @@ def _process_single_layer( else: linear_attn_key = f"{attn_prefix}.linear_attn.weight" - is_student_attn_replaced_with_linear = linear_attn_key in new_state_dict + is_student_attn_replaced_with_linear = linear_attn_key in new_state_dict.keys() if is_student_attn_replaced_with_linear: if linear_init_mode == LinearInitMode.Random: layer_out_state_dict[linear_attn_key] = new_state_dict[linear_attn_key] @@ -180,7 +188,7 @@ def _process_single_layer( raise ValueError(f"Unknown {linear_init_mode=}") else: # student attn random init - for new_key in new_state_dict: + for new_key in new_state_dict.keys(): if attn_prefix in new_key: layer_out_state_dict[new_key] = new_state_dict[new_key] @@ -190,7 +198,7 @@ def _process_single_layer( mlp_prefix = f"model.layers.{layer_idx}.mlp" linear_mlp_key = f"{mlp_prefix}.linear_mlp.weight" - is_student_mlp_replaced_with_linear = linear_mlp_key in new_state_dict + is_student_mlp_replaced_with_linear = linear_mlp_key in new_state_dict.keys() if is_student_mlp_replaced_with_linear: if linear_init_mode == LinearInitMode.Random: layer_out_state_dict[linear_mlp_key] = new_state_dict[linear_mlp_key] @@ -312,7 +320,7 @@ def _process_single_layer( ]: key_possibly_missing_in_student = f".{layer_idx}.{key_possibly_missing_in_student}" is_key_missing_from_student = ( - len([k for k in new_state_dict if key_possibly_missing_in_student in k]) == 0 + len([k for k in new_state_dict.keys() if key_possibly_missing_in_student in k]) == 0 ) if is_key_missing_from_student: for k in list(keys.keys()): @@ -324,6 +332,8 @@ def _process_single_layer( @torch.no_grad() def create_child_state_dict( + pruning_mixin, + descriptor, original_state_dict: dict, new_state_dict: dict, original_config: DeciLMConfig, @@ -331,12 +341,12 @@ def create_child_state_dict( gqa_init_mode: GQAInitMode, ignore_fn: IgnoreFn = default_ignore_fn, mlp_init_mode: MlpInitMode = MlpInitMode.CopyAsIs, - mlp_init_config: dict[str, Any] | None = None, - owned_block_indexes: set[int] | None = None, + mlp_init_config: Optional[dict[str, Any]] = None, + owned_block_indexes: Optional[set[int]] = None, linear_init_mode: LinearInitMode = LinearInitMode.Random, hidden_size_init_mode: HiddenSizeInitMode = HiddenSizeInitMode.CopyAsIs, - channel_importance_path: str | None = None, - max_layer_workers: int | None = None, # Now optional - will auto-calculate if None + channel_importance_path: Optional[str] = None, + max_layer_workers: Optional[int] = None, # Now optional - will auto-calculate if None ): mprint("=== Starting create_child_state_dict with optimizations ===") total_start_time = time.time() @@ -371,34 +381,40 @@ def create_child_state_dict( else: out_state_dict[key] = tensor - original_n_heads_in_group_per_layer = [ - b.attention.n_heads_in_group for b in original_config.block_configs + # Get language model config for LM-specific attributes (VL models have nested config) + original_lm_config = descriptor.get_language_model_config(original_config) + new_lm_config = descriptor.get_language_model_config(new_config) + + # Check if original model is MHA (all layers have num_key_value_heads == num_attention_heads) + original_num_kv_heads_per_layer = [ + b.attention.num_key_value_heads for b in original_config.block_configs ] - is_original_mha = set(original_n_heads_in_group_per_layer) == {1} - is_same_hidden_size = original_config.hidden_size == new_config.hidden_size - head_size = new_config.head_dim - orig_head_size = original_config.head_dim + num_attention_heads = original_lm_config.num_attention_heads + is_original_mha = all(kv == num_attention_heads for kv in original_num_kv_heads_per_layer) + is_same_hidden_size = original_lm_config.hidden_size == new_lm_config.hidden_size + head_size = _get_head_dim(new_lm_config) + orig_head_size = _get_head_dim(original_lm_config) assert head_size == orig_head_size, f"head_size {head_size} != orig_head_size {orig_head_size}" # Allow different hidden sizes for pruning if not is_same_hidden_size: - assert new_config.hidden_size <= original_config.hidden_size, ( - f"New hidden size ({new_config.hidden_size}) must be <= original ({original_config.hidden_size})" + assert new_lm_config.hidden_size <= original_lm_config.hidden_size, ( + f"New hidden size ({new_lm_config.hidden_size}) must be <= original ({original_lm_config.hidden_size})" ) assert hidden_size_init_mode != HiddenSizeInitMode.CopyAsIs, ( "Cannot copy as is when hidden sizes differ" ) - hidden_size = original_config.hidden_size + hidden_size = original_lm_config.hidden_size - ignored_keys = set([key for key in original_state_dict if ignore_fn(key)]) + ignored_keys = set([key for key in original_state_dict.keys() if ignore_fn(key)]) for key in ignored_keys: aprint(f"Ignoring key {key} and taking its init from new_state_dict") out_state_dict[key] = new_state_dict[key] keys = { match.group(1) if (match := re.search(r"(h\.\d+\..*)", key)) is not None else key: key - for key in original_state_dict + for key in original_state_dict.keys() } setup_time = time.time() - setup_start_time mprint(f"Phase 1 - Setup and memory pre-allocation: {setup_time:.2f}s") @@ -409,6 +425,8 @@ def create_child_state_dict( # Prepare arguments for parallel processing process_layer_partial = partial( _process_single_layer, + pruning_mixin=pruning_mixin, + descriptor=descriptor, parent_state_dict=original_state_dict, new_state_dict=new_state_dict, original_config=original_config, @@ -489,6 +507,7 @@ def create_child_state_dict( original_state_dict, new_config, original_config, + descriptor, hidden_size_init_mode, channel_importance_path, owned_block_indexes, @@ -527,7 +546,7 @@ def _generate_moe_keys(layer_idx: int, num_experts: int) -> tuple[str, dict[str, def _concatenate_experts_into_dense_ffn( original_state_dict: dict[str, torch.Tensor], - mlp_init_config: dict | None, + mlp_init_config: Optional[dict], hidden_size: int, layer_idx: int, child_block_config: BlockConfig, @@ -585,7 +604,8 @@ def _concatenate_experts_into_dense_ffn( "concat_dims and experts_weights must have the same keys" ) concat_routed_state_dict = { - name: torch.cat(experts_weights[name], dim=concat_dims[name]) for name in concat_dims + name: torch.cat(experts_weights[name], dim=concat_dims[name]) + for name in concat_dims.keys() } # turn the shared expert into a normal FFN. concatenate the pruned routed experts if needed. @@ -645,16 +665,16 @@ def _verify_state_dicts_match( def _init_mlp( *, - mlp_init_mode: MlpInitMode | str, + mlp_init_mode: Union[MlpInitMode, str], layer_idx: int, original_config: DeciLMConfig, - mlp_init_config: dict[str, Any] | None, + mlp_init_config: Optional[dict[str, Any]], original_state_dict: dict, new_state_dict: dict, new_config: DeciLMConfig, keys: dict[str, str], ignored_keys: set[str], - expert_idx: int | None = None, + expert_idx: Optional[int] = None, ) -> dict[str, torch.Tensor]: out_state_dict = {} @@ -679,10 +699,12 @@ def _init_mlp( projection_matrix = None for mlp_key in mlp_keys: expanded_dim = 1 if "down_proj" in mlp_key else 0 - if mlp_key in new_state_dict: + if mlp_key in new_state_dict.keys(): mlp_module_weight, pruned_filters, projection_matrix = _init_mlp_module( mlp_init_mode, + mlp_prefix, expanded_dim, + layer_idx, new_state_dict[mlp_key], new_config, original_state_dict[mlp_key], @@ -690,7 +712,6 @@ def _init_mlp( mlp_init_config, pruned_filters, projection_matrix, - mlp_prefix, ) out_state_dict[mlp_key] = mlp_module_weight else: @@ -698,128 +719,6 @@ def _init_mlp( return out_state_dict -def _init_mlp_module( - mlp_init_mode: MlpInitMode | str, - expanded_dim: int, - new_item: torch.Tensor, - new_config: DeciLMConfig, - orig_item: torch.Tensor, - original_config: DeciLMConfig, - mlp_init_config: dict[str, Any] | None, - pruned_filters: torch.Tensor | None = None, - projection_matrix: dict[str, torch.Tensor] | None = None, - mlp_prefix: str | None = None, -) -> tuple[torch.Tensor, torch.Tensor | None, dict[str, torch.Tensor] | None]: - if isinstance(mlp_init_mode, str): - mlp_init_mode = MlpInitMode(mlp_init_mode) - assert orig_item.ndim == 2, f"{orig_item.ndim=}" - assert new_item.ndim == 2, f"{new_item.ndim=}" - - assert new_config.num_hidden_layers == original_config.num_hidden_layers, ( - f"({new_config.num_hidden_layers=}) != ({original_config.num_hidden_layers=})" - ) - - orig_ffn_size = orig_item.shape[expanded_dim] - new_ffn_size = new_item.shape[expanded_dim] - - if mlp_init_mode == MlpInitMode.CopyAsIs: - assert new_ffn_size == orig_ffn_size, ( - f"({new_ffn_size=}) != ({orig_ffn_size=}), can't be copied as is." - ) - mlp_module_weight = orig_item - - elif mlp_init_mode == MlpInitMode.Random: - mlp_module_weight = new_item - - elif new_ffn_size == orig_ffn_size: - mlp_module_weight = orig_item - - elif mlp_init_mode in ( - MlpInitMode.Truncate, - MlpInitMode.PruneByActivationsLog, - MlpInitMode.MoEChannelPruning, - ): - assert new_ffn_size <= orig_ffn_size, ( - f"({new_ffn_size=}) > ({orig_ffn_size=}), can't be truncated." - ) - - if mlp_init_mode == MlpInitMode.Truncate: - truncated_weight = torch.narrow( - orig_item, dim=expanded_dim, start=0, length=new_ffn_size - ) - mlp_module_weight = truncated_weight - - elif mlp_init_mode in (MlpInitMode.PruneByActivationsLog, MlpInitMode.MoEChannelPruning): - if pruned_filters is None: - filter_importance = _load_activations_log( - mlp_init_config, module_name=f"{mlp_prefix}.down_proj" - ) - filters_sorted_by_importance = torch.argsort(filter_importance, descending=True) - pruned_filters = filters_sorted_by_importance[:new_ffn_size].to(orig_item.device) - - pruned_weight = torch.index_select(orig_item, dim=expanded_dim, index=pruned_filters) - if mlp_init_config.get("scale_pruned_weights", False) and expanded_dim == 1: - pruned_weight = pruned_weight * (orig_ffn_size / new_ffn_size) - mlp_module_weight = pruned_weight - - elif ( - mlp_init_mode == MlpInitMode.ExpertRemoval - ): # the case of mlp layers of maverick. for now we only support copy as is - assert new_ffn_size == orig_ffn_size, ( - f"({new_ffn_size=}) != ({orig_ffn_size=}), can't be copied as is." - ) - mlp_module_weight = orig_item - - else: - raise ValueError(f"Unsupported {mlp_init_mode=}") - - return mlp_module_weight, pruned_filters, projection_matrix - - -def _init_moe_module( - *, - mlp_init_mode: MlpInitMode | str, - mlp_init_config: dict[str, Any] | None, - layer_idx: int, - orig_router_weight: torch.Tensor, - orig_experts_weights: dict[str, list[torch.Tensor]], - new_router_weight: torch.Tensor, - new_experts_weights: dict[str, list[torch.Tensor]], -) -> tuple[torch.Tensor, torch.Tensor | None, dict[str, torch.Tensor] | None]: - if isinstance(mlp_init_mode, str): - mlp_init_mode = MlpInitMode(mlp_init_mode) - - if mlp_init_mode == MlpInitMode.ExpertRemoval: - result_router_weight, result_experts_weights = _prune_experts_by_score( - mlp_init_config=mlp_init_config, - layer_idx=layer_idx, - orig_router_weight=orig_router_weight, - orig_experts_weights=orig_experts_weights, - new_num_experts=new_router_weight.shape[0], - ) - else: - raise ValueError(f"Unsupported {mlp_init_mode=}") - - assert result_router_weight.shape == new_router_weight.shape - assert result_experts_weights.keys() == new_experts_weights.keys(), ( - "result_experts_weights and new_experts_weights must have the same keys" - ) - assert all( - len(new_experts_weights[name]) == len(result_experts_weights[name]) - for name in result_experts_weights.keys() - ) - assert all( - all( - new_expert_weight.shape == result_expert_weight.shape - for new_expert_weight, result_expert_weight in zip( - new_experts_weights[name], result_experts_weights[name] - ) - ) - for name in result_experts_weights.keys() - ) - return result_router_weight, result_experts_weights - - def _prune_experts_by_score( *, mlp_init_config: dict[str, Any], @@ -848,377 +747,6 @@ def _prune_experts_by_score( return result_router_weight, result_experts_weights -def _load_expert_scores(mlp_init_config: dict[str, Any] | None) -> list[list[int | float]]: - assert mlp_init_config is not None - if "expert_scores_file" in mlp_init_config: - expert_scores_file = mlp_init_config["expert_scores_file"] - with open(expert_scores_file) as f: - expert_scores = json.load(f) - elif "activations_log_dir" in mlp_init_config: - _cache_activations_log(mlp_init_config) - num_layers = len(ACTIVATIONS_LOG) - expert_scores = [] - for layer_idx in range(num_layers): - router_name = f"model.layers.{layer_idx}.mlp.router" - expert_scores.append(ACTIVATIONS_LOG[router_name]["expert_ranks"]) - expert_scores = torch.stack(expert_scores) - expert_scores = expert_scores.tolist() - else: - raise ValueError(f"Unsupported {mlp_init_config=}") - return expert_scores - - -ACTIVATIONS_LOG = dict() - - -def _cache_activations_log(mlp_init_config: dict[str, Any]) -> None: - if len(ACTIVATIONS_LOG) == 0: - assert "activations_log_dir" in mlp_init_config - activations_log_dir = mlp_init_config["activations_log_dir"] - ACTIVATIONS_LOG.update( - { - module_name: module_log - for p in Path(activations_log_dir).glob("rank*.pth") - for module_name, module_log in torch.load(p).items() - } - ) - - -def _load_activations_log(mlp_init_config: dict[str, Any], module_name: str) -> torch.Tensor: - _cache_activations_log(mlp_init_config) - module_log = ACTIVATIONS_LOG[module_name] - filter_importance = module_log["score"] - return filter_importance - - -def _init_attention_weights( - gqa_init_mode, - layer_idx, - new_state_dict, - new_config, - original_state_dict, - q_key, - k_key, - v_key, - o_key, - original_config, - is_original_mha, - head_size, - mlp_init_config, -): - assert new_config.num_attention_heads == original_config.num_attention_heads, ( - f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" - ) - num_q_heads = new_config.num_attention_heads - n_heads_in_group = new_config.block_configs[layer_idx].attention.n_heads_in_group - orig_n_heads_in_group = original_config.block_configs[layer_idx].attention.n_heads_in_group - num_kv_heads = num_q_heads // n_heads_in_group - orig_num_kv_heads = num_q_heads // orig_n_heads_in_group - - # new_w* are typically randomly initialized - new_wq = new_state_dict[q_key] - new_wk = new_state_dict[k_key] - new_wv = new_state_dict[v_key] - new_wo = new_state_dict[o_key] - - # w* are from the parent model - wq = original_state_dict[q_key] - wk = original_state_dict[k_key] - wv = original_state_dict[v_key] - wo = original_state_dict[o_key] - - if "bias" in k_key: - for tensor in [wq, wk, wv, wo, new_wq, new_wk, new_wv, new_wo]: - assert tensor.ndim == 1 - tensor.unsqueeze_(1) - dim1 = wk.shape[1] # this is the hidden_size in case of matrix weights, and 1 in case of biases - - if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock): - wk, wv = new_wk, new_wv - elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV): - assert n_heads_in_group % orig_n_heads_in_group == 0, ( - f"({n_heads_in_group=}) % ({orig_n_heads_in_group=}) != 0" - ) - n_heads_to_aggregate = n_heads_in_group // orig_n_heads_in_group - - wk = wk.view(-1, n_heads_to_aggregate, head_size, dim1) - wv = wv.view(-1, n_heads_to_aggregate, head_size, dim1) - - if gqa_init_mode == GQAInitMode.AverageKV: - wk = wk.mean(dim=1) - wv = wv.mean(dim=1) - else: - wk = wk[:, 0] - wv = wv[:, 0] - elif gqa_init_mode == GQAInitMode.CopyAsIs: - assert new_wk.shape == wk.shape, f"({new_wk.shape=}) != ({wk.shape=})" - assert new_wv.shape == wv.shape, f"({new_wv.shape=}) != ({wv.shape=})" - assert new_wq.shape == wq.shape, f"({new_wq.shape=}) != ({wq.shape=})" - assert new_wo.shape == wo.shape, f"({new_wo.shape=}) != ({wo.shape=})" - - elif gqa_init_mode == GQAInitMode.Degrouping: - assert not is_original_mha, ( - "Degrouping can only be done on original models that are GQA themselves." - ) - n_groups = new_config.num_attention_heads // n_heads_in_group - orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group - assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" - n_repeats = n_groups // orig_n_groups - if n_repeats > 1: - print(f"Degrouping {orig_n_groups} into {n_groups}") - - def degroup_w(w): - w = w.view(orig_n_groups, head_size, dim1) - w = torch.repeat_interleave(w, repeats=n_repeats, dim=0) - w = w.reshape(n_groups * head_size, dim1) - return w - - wk = degroup_w(wk) - wv = degroup_w(wv) - - elif gqa_init_mode == GQAInitMode.PruneKVHeads: - wk = wk.view(orig_num_kv_heads, head_size, dim1) - wv = wv.view(orig_num_kv_heads, head_size, dim1) - wq = wq.view(orig_num_kv_heads, orig_n_heads_in_group, head_size, dim1) - wo = wo.view(dim1, orig_num_kv_heads, orig_n_heads_in_group, head_size) - - o_proj_module_name = o_key.replace(".weight", "") - kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name) - kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True) - kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads] - kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:] - - wk = wk[kv_heads_to_keep] - wv = wv[kv_heads_to_keep] - - reduction_factor = orig_num_kv_heads // num_kv_heads - - prune_via_duplication = False - if prune_via_duplication: - ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads. - wq = wq[kv_heads_to_keep] - wq = torch.repeat_interleave(wq, repeats=reduction_factor, dim=0) - - ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups. - ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise. - wo = wo[:, kv_heads_to_keep] - wo = torch.repeat_interleave(wo, repeats=reduction_factor, dim=1) - wo = wo / reduction_factor - - else: # prune via zeroing out - ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA. - ## We need to interleave them to keep the matching between queries and kv heads. - kv_heads_to_keep = kv_heads_to_keep.tolist() - kv_heads_to_remove = kv_heads_to_remove.tolist() - kv_head_ordering = [] - zero_out_mask = [] - for i_head in range(orig_num_kv_heads): - if i_head % reduction_factor == 0: - kv_head_ordering.append(kv_heads_to_keep.pop(0)) - zero_out_mask.append(False) - else: - kv_head_ordering.append(kv_heads_to_remove.pop(0)) - zero_out_mask.append(True) - - wq = wq[kv_head_ordering] - - ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads. - ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model. - ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training. - ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them. - wo = wo[:, kv_head_ordering] - wo[:, zero_out_mask] = 0.0 - - else: - raise ValueError(f"{gqa_init_mode=} not supported") - - wk = wk.reshape(-1, dim1) - wv = wv.reshape(-1, dim1) - wq = wq.reshape(-1, dim1) - wo = wo.reshape(dim1, -1) - return wq, wk, wv, wo - - -def _init_attention_biases( - gqa_init_mode, - layer_idx, - new_state_dict, - new_config: DeciLMConfig, - original_state_dict, - q_key, - k_key, - v_key, - o_key, - original_config, - is_original_mha, - head_size, - mlp_init_config, -): - assert new_config.num_attention_heads == original_config.num_attention_heads, ( - f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" - ) - num_q_heads = new_config.num_attention_heads - n_heads_in_group = new_config.block_configs[layer_idx].attention.n_heads_in_group - orig_n_heads_in_group = original_config.block_configs[layer_idx].attention.n_heads_in_group - num_kv_heads = num_q_heads // n_heads_in_group - orig_num_kv_heads = num_q_heads // orig_n_heads_in_group - - o_proj_bias = new_config.o_proj_bias - attention_bias = new_config.attention_bias - - # If no biases - if not (o_proj_bias or attention_bias): - return {} - - new_bias_sd = {} - bias_sd = {} - # new_w* are typically randomly initialized - if o_proj_bias: - new_bias_sd["o"] = new_state_dict[o_key] - bias_sd["o"] = original_state_dict[o_key] - if attention_bias: - for bias_key, key in zip("qkv", [q_key, k_key, v_key]): - new_bias_sd[bias_key] = new_state_dict[key] - bias_sd[bias_key] = original_state_dict[key] - - # maybe unsqueeze all tensors - for tensor in list(new_bias_sd.values()) + list(bias_sd.values()): - assert tensor.ndim == 1 - tensor.unsqueeze_(1) - - dim1 = 1 # this is the hidden_size in case of matrix weights, and 1 in case of biases - if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock) and attention_bias: - bias_sd["k"] = torch.zeros( - new_bias_sd["k"].shape, dtype=bias_sd["k"].dtype, device=bias_sd["k"].device - ) - bias_sd["v"] = torch.zeros( - new_bias_sd["v"].shape, dtype=bias_sd["v"].dtype, device=bias_sd["v"].device - ) - elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV) and attention_bias: - assert n_heads_in_group % orig_n_heads_in_group == 0, ( - f"({n_heads_in_group=}) % ({orig_n_heads_in_group=}) != 0" - ) - n_heads_to_aggregate = n_heads_in_group // orig_n_heads_in_group - - bias_sd["k"] = bias_sd["k"].view(-1, n_heads_to_aggregate, head_size, dim1) - bias_sd["v"] = bias_sd["v"].view(-1, n_heads_to_aggregate, head_size, dim1) - - if gqa_init_mode == GQAInitMode.AverageKV: - bias_sd["k"] = bias_sd["k"].mean(dim=1) - bias_sd["v"] = bias_sd["v"].mean(dim=1) - else: - bias_sd["k"] = bias_sd["k"][:, 0] - bias_sd["v"] = bias_sd["v"][:, 0] - elif gqa_init_mode == GQAInitMode.CopyAsIs: - for key in bias_sd: - assert new_bias_sd[key].shape == bias_sd[key].shape, ( - f"({new_bias_sd[key].shape=}) != ({bias_sd[key].shape=})" - ) - - elif gqa_init_mode == GQAInitMode.Degrouping and attention_bias: - assert not is_original_mha, ( - "Degrouping can only be done on original models that are GQA themselves." - ) - n_groups = new_config.num_attention_heads // n_heads_in_group - orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group - assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" - n_repeats = n_groups // orig_n_groups - if n_repeats > 1: - print(f"Degrouping {orig_n_groups} into {n_groups}") - - def degroup_w(w): - w = w.view(orig_n_groups, head_size, dim1) - w = torch.repeat_interleave(w, repeats=n_repeats, dim=0) - w = w.reshape(n_groups * head_size, dim1) - return w - - bias_sd["k"] = degroup_w(bias_sd["k"]) - bias_sd["v"] = degroup_w(bias_sd["v"]) - - elif gqa_init_mode == GQAInitMode.PruneKVHeads: - if o_proj_bias: - o_proj_module_name = o_key.rsplit(".", 1)[0] - else: - # Here we assume that the o_proj layer is called "o_proj" - o_proj_module_name = k_key.rsplit(".", 2)[0] + ".o_proj" - - kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name) - kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True) - kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads] - kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:] - - # view as KV groups - if attention_bias: - bias_sd["k"] = bias_sd["k"].view(orig_num_kv_heads, head_size, dim1) - bias_sd["v"] = bias_sd["v"].view(orig_num_kv_heads, head_size, dim1) - bias_sd["q"] = bias_sd["q"].view( - orig_num_kv_heads, orig_n_heads_in_group, head_size, dim1 - ) - # Keep important KV heads and prune the others - bias_sd["k"] = bias_sd["k"][kv_heads_to_keep] - bias_sd["v"] = bias_sd["v"][kv_heads_to_keep] - if o_proj_bias: - bias_sd["o"] = bias_sd["o"].view( - dim1, orig_num_kv_heads, orig_n_heads_in_group, head_size - ) - - reduction_factor = orig_num_kv_heads // num_kv_heads - - prune_via_duplication = False - if prune_via_duplication: - if attention_bias: - ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads. - bias_sd["q"] = bias_sd["q"][kv_heads_to_keep] - bias_sd["q"] = torch.repeat_interleave( - bias_sd["q"], repeats=reduction_factor, dim=0 - ) - - if o_proj_bias: - ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups. - ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise. - bias_sd["o"] = bias_sd["o"][:, kv_heads_to_keep] - bias_sd["o"] = torch.repeat_interleave( - bias_sd["o"], repeats=reduction_factor, dim=1 - ) - bias_sd["o"] = bias_sd["o"] / reduction_factor - - else: # prune via zeroing out - ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA. - ## We need to interleave them to keep the matching between queries and kv heads. - kv_heads_to_keep = kv_heads_to_keep.tolist() - kv_heads_to_remove = kv_heads_to_remove.tolist() - kv_head_ordering = [] - zero_out_mask = [] - for i_head in range(orig_num_kv_heads): - if i_head % reduction_factor == 0: - kv_head_ordering.append(kv_heads_to_keep.pop(0)) - zero_out_mask.append(False) - else: - kv_head_ordering.append(kv_heads_to_remove.pop(0)) - zero_out_mask.append(True) - - if attention_bias: - bias_sd["q"] = bias_sd["q"][kv_head_ordering] - - if o_proj_bias: - ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads. - ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model. - ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training. - ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them. - bias_sd["o"] = bias_sd["o"][:, kv_head_ordering] - bias_sd["o"][:, zero_out_mask] = 0.0 - - else: - raise ValueError(f"{gqa_init_mode=} not supported") - - if attention_bias: - for bias_key in "qkv": - bias_sd[bias_key] = bias_sd[bias_key].reshape(-1) - if o_proj_bias: - bias_sd["o"] = bias_sd["o"].reshape(-1) - return bias_sd - - def _init_linear_attn( parent_state_dict: dict[str, torch.Tensor], parent_config: DeciLMConfig, @@ -1226,13 +754,15 @@ def _init_linear_attn( v_key: str, o_key: str, ) -> torch.Tensor: - """Init a linear layer that operates like an attention layer that assigns score 1 to the current token + """ + Init a linear layer that operates like an attention layer that assigns score 1 to the current token and score 0 to all others: out = (Wo @ Wv) @ x """ n_embd = parent_config.hidden_size - head_size = parent_config.head_dim - n_heads_in_group = parent_config.block_configs[layer_idx].attention.n_heads_in_group - n_kv_heads = parent_config.num_attention_heads // n_heads_in_group + head_size = _get_head_dim(parent_config) + # Get num_kv_heads from config, compute n_heads_in_group + n_kv_heads = parent_config.block_configs[layer_idx].attention.num_key_value_heads + n_heads_in_group = parent_config.num_attention_heads // n_kv_heads wv = parent_state_dict[v_key] wv = wv.view(n_kv_heads, head_size, n_embd) @@ -1245,7 +775,9 @@ def _init_linear_attn( def _init_linear_mlp(teacher_mlp_state_dict: dict[str, torch.Tensor]) -> torch.Tensor: - """A linear layer that does (W_down @ W_up) @ x, ignoring W_gate.""" + """ + A linear layer that does (W_down @ W_up) @ x, ignoring W_gate. + """ if "linear_mlp.weight" in teacher_mlp_state_dict: # if the teacher itself is a linear layer return teacher_mlp_state_dict["linear_mlp.weight"] @@ -1314,9 +846,10 @@ def _parse_model_config_overrides( model_config_overrides_json: str | dict | Path | list[dict], n_layer: int, ) -> list[dict[str, Any]]: - """Example model_config_overrides_json: + """ + example model_config_overrides_dict: { - "attention": [{"n_heads_in_group": 2}], + "attention": [{"num_key_value_heads": 4}], "ffn": [{"intermediate_size": 14336}] } """ @@ -1362,18 +895,24 @@ def _apply_hidden_size_pruning( original_state_dict: dict[str, torch.Tensor], new_config: DeciLMConfig, original_config: DeciLMConfig, + descriptor, hidden_size_init_mode: HiddenSizeInitMode, - channel_importance_path: str | None = None, - owned_block_indexes: list[int] | None = None, + channel_importance_path: Optional[str] = None, + owned_block_indexes: Optional[list[int]] = None, ) -> dict[str, torch.Tensor]: - """Apply hidden size pruning to all layers that depend on hidden_size. + """ + Apply hidden size pruning to all layers that depend on hidden_size. This includes embeddings, layer norms, and any linear layers that haven't been handled yet. """ if isinstance(hidden_size_init_mode, str): hidden_size_init_mode = HiddenSizeInitMode(hidden_size_init_mode) - original_hidden_size = original_config.hidden_size - new_hidden_size = new_config.hidden_size + # Get language model config (for VL models this extracts the nested config) + original_lm_config = descriptor.get_language_model_config(original_config) + new_lm_config = descriptor.get_language_model_config(new_config) + + original_hidden_size = original_lm_config.hidden_size + new_hidden_size = new_lm_config.hidden_size if hidden_size_init_mode == HiddenSizeInitMode.CopyAsIs: return out_state_dict @@ -1381,7 +920,7 @@ def _apply_hidden_size_pruning( # Load channel ranking if needed if hidden_size_init_mode == HiddenSizeInitMode.PruneByChannelRanking: if channel_importance_path is not None: - with open(channel_importance_path) as f: + with open(channel_importance_path, "r") as f: channel_ranking = json.load(f)["channel_importance_ranking"] else: raise ValueError( @@ -1574,10 +1113,12 @@ def _prune_hidden_size_dimension( original_tensor: torch.Tensor, new_hidden_size: int, hidden_size_init_mode: HiddenSizeInitMode, - channel_ranking: list[int] | None = None, + channel_ranking: Optional[list[int]] = None, dim: int = -1, ) -> torch.Tensor: - """Prune a tensor along the specified dimension to match the new hidden size.""" + """ + Prune a tensor along the specified dimension to match the new hidden size. + """ original_size = original_tensor.shape[dim] if hidden_size_init_mode == HiddenSizeInitMode.Random: @@ -1627,3 +1168,14 @@ def _prune_hidden_size_dimension( else: raise ValueError(f"Unsupported hidden_size_init_mode: {hidden_size_init_mode}") + + +def _get_head_dim(config) -> int: + """Get head dimension from config in a model-agnostic way. + + Some models like Llama have `head_dim` as a direct attribute, while others + like Qwen2 don't. This helper computes it from hidden_size and num_attention_heads. + """ + if hasattr(config, "head_dim") and config.head_dim is not None: + return config.head_dim + return config.hidden_size // config.num_attention_heads From c9de41ce2a1d46c0fdd5c828e8ecf6e8a33d1816 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 5 Mar 2026 06:31:37 -0800 Subject: [PATCH 04/30] fix attention pruning Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/pruning/pruning_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/puzzletron/pruning/pruning_utils.py b/modelopt/torch/puzzletron/pruning/pruning_utils.py index cea716b63..cdd6a2bf7 100644 --- a/modelopt/torch/puzzletron/pruning/pruning_utils.py +++ b/modelopt/torch/puzzletron/pruning/pruning_utils.py @@ -366,10 +366,10 @@ def _init_attention_biases( f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" ) num_q_heads = new_config.num_attention_heads - n_heads_in_group = new_config.block_configs[layer_idx].attention.n_heads_in_group - orig_n_heads_in_group = original_config.block_configs[layer_idx].attention.n_heads_in_group - num_kv_heads = num_q_heads // n_heads_in_group - orig_num_kv_heads = num_q_heads // orig_n_heads_in_group + num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads + orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads + n_heads_in_group = num_q_heads // num_kv_heads + orig_n_heads_in_group = num_q_heads // orig_num_kv_heads o_proj_bias = new_config.o_proj_bias attention_bias = new_config.attention_bias From 3c1bc1facc60e30a98adaa988cc17fb77a075a11 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 5 Mar 2026 06:42:11 -0800 Subject: [PATCH 05/30] Add trust_remote_code to load_model_config (default to false) Signed-off-by: Daniel Korzekwa --- .../puzzletron/tools/checkpoint_utils_hf.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index ad8ccfba2..bcdab7627 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -120,7 +120,21 @@ def load_model_config( checkpoint_dir: Path | str, model_config_overrides: Mapping | None = None, ignore_unexpected_config_keys: bool = False, + trust_remote_code: bool = False, ): + """Load model configuration from a checkpoint directory. + + Args: + checkpoint_dir: Path to the checkpoint directory (e.g. containing config.json). + model_config_overrides: Optional mapping of config overrides. + ignore_unexpected_config_keys: If True, ignore unexpected config keys. + trust_remote_code: If True, allows execution of custom code from the model repository. + This is a security risk if the model source is untrusted. Only set to True if you + trust the source of the model. Defaults to False for security. + + Returns: + Loaded model configuration (PretrainedConfig). + """ if not isinstance(checkpoint_dir, Path): checkpoint_dir = Path(checkpoint_dir) @@ -128,7 +142,10 @@ def load_model_config( model_config_overrides = {} config, unused_kwargs = AutoConfig.from_pretrained( - checkpoint_dir, trust_remote_code=True, return_unused_kwargs=True, **model_config_overrides + checkpoint_dir, + trust_remote_code=trust_remote_code, + return_unused_kwargs=True, + **model_config_overrides, ) if hasattr(config, "block_configs"): config.block_configs = maybe_cast_block_configs(config.block_configs) From 83571360c2a2202dd5521387e6059e943f52400f Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 5 Mar 2026 09:52:47 -0800 Subject: [PATCH 06/30] Make activation scoring working Signed-off-by: Daniel Korzekwa --- .../activation_hooks/utils.py | 121 ++++------- .../score_pruning_activations.py | 2 +- modelopt/torch/puzzletron/puzzletron.py | 26 ++- .../torch/puzzletron/tools/robust_json.py | 5 + .../tools/sharded_checkpoint_utils.py | 205 +++++++++++++----- .../torch/puzzletron/tools/validate_model.py | 193 ++++++++--------- .../utils/validate_runtime_pipeline.py | 94 ++++++-- tests/gpu/torch/puzzletron/test_puzzletron.py | 51 ++--- 8 files changed, 405 insertions(+), 292 deletions(-) diff --git a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py index ab7eed2ac..1b1485c71 100644 --- a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py +++ b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py @@ -15,84 +15,57 @@ # mypy: ignore-errors """Provides a function to register activation hooks for a model. -Activation hooks are used to compute activation scores for pruning. -""" +Activation hooks are used to compute activation scores for pruning.""" -import re +from typing import Type -from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( - ForwardHook, - IndependentChannelContributionHook, - IndependentKvHeadContributionHook, - IterativeChannelContributionHook, - LayerNormContributionHook, -) -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ForwardHook as ActivationsHook +from modelopt.torch.puzzletron.tools.logger import aprint def register_activation_hooks( - model: DeciLMForCausalLM, activation_hooks_kwargs: dict -) -> tuple[dict[str, ForwardHook], type[ForwardHook]]: - hook_class_map = { - "mlp.down_proj": { - "independent": IndependentChannelContributionHook, - "iterative": IterativeChannelContributionHook, - }, - "self_attn.o_proj": { - "independent_kv_head_contribution": IndependentKvHeadContributionHook, - }, - r"regex:experts\.\d+\.down_proj$": { # For MoE - "independent": IndependentChannelContributionHook, - }, - # TODO: maybe this is too generic, and we should have it specifically for - # input_layernorm and post_attention_layernorm; now it might select qk_norms - "layernorm": { - "layer_norm_contribution": LayerNormContributionHook, - }, - } - - activation_hooks = {} - target_layer = activation_hooks_kwargs.get("target_layer", "mlp.c_proj") - - if target_layer.startswith("regex:"): - target_layer_regex = target_layer[len("regex:") :] - pattern = re.compile(target_layer_regex) - - def match_predicate(module_name, module): - return pattern.search(module_name) - else: - - def match_predicate(module_name, module): - return module_name.endswith(target_layer) - - target_layer_hooks_map = hook_class_map.get(target_layer) - if target_layer_hooks_map is None: - raise ValueError(f"no hook classes found for: {target_layer}") - - hook_class = target_layer_hooks_map.get(activation_hooks_kwargs["method"]) - if hook_class is None: - raise ValueError(f"Unknown hook class: {hook_class}") - - if target_layer == "block": - pattern = re.compile(r"^transformer\.h\.\d+$") - - def match_predicate(module_name, module): - return pattern.match(module_name) - + model, + activation_hooks_kwargs: dict, + pruning_mixin, + hook_class: Type[ActivationsHook], +) -> dict[str, ActivationsHook]: + """Register activation hooks using the pruning mixin approach. + + Args: + model: The model to register hooks on. + activation_hooks_kwargs: Keyword arguments passed to hook constructors. + pruning_mixin: The pruning mixin that defines which modules to hook. + hook_class: The hook class to instantiate for each module. + + Returns: + Dictionary mapping module names to hook instances. + """ activation_hooks_kwargs["model"] = model - for module_name, module in model.named_modules(): - if match_predicate(module_name, module): - block_config = None - if block_idx_match := re.search(r"\.(\d+)\.", module_name): - block_idx = int(block_idx_match.group(1)) - block_config = model.config.block_configs[block_idx] - curr_activation_hooks_kwargs = { - **activation_hooks_kwargs, - "block_config": block_config, - } - - hook = hook_class(module, curr_activation_hooks_kwargs) - module.register_forward_hook(hook) - activation_hooks[module_name] = hook - return activation_hooks, hook_class + if hook_class not in pruning_mixin.supported_hooks(): + raise ValueError( + f"Hook class not supported for {pruning_mixin.__class__.__name__}, " + f"must be in {pruning_mixin.supported_hooks()}" + ) + + module_names_to_hook = pruning_mixin.get_module_names_to_hook(model) + activation_hooks = dict() + for block_idx, module_name in module_names_to_hook: + block_config = None + if block_idx is not None: + block_config = model.config.block_configs[block_idx] + curr_activation_hooks_kwargs = { + **activation_hooks_kwargs, + "block_config": block_config, + } + + module = model.get_submodule(module_name) + hook = hook_class(module, curr_activation_hooks_kwargs) + module.register_forward_hook(hook) + activation_hooks[module_name] = hook + + if len(activation_hooks) == 0: + raise ValueError("couldn't find any hooks") + + aprint(f"Found the following hooks: {activation_hooks.keys()}") + return activation_hooks diff --git a/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py b/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py index ef5e5e9ad..c043c20d5 100644 --- a/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py +++ b/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py @@ -138,4 +138,4 @@ def launch_score_activations(cfg: DictConfig): mprint("Starting pruning activation scoring...") # The checkpoint manager inside validate_model handles all progress tracking - validate_model(args=cfg.pruning, pipeline_parallel=True) + validate_model(args=cfg.pruning) diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py index 1051fdbaf..0d9ac068f 100644 --- a/modelopt/torch/puzzletron/puzzletron.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -15,6 +15,7 @@ """This module provides the main compression function for a model using MIP-based NAS search algorithm.""" +import hydra from omegaconf import DictConfig import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations @@ -51,24 +52,25 @@ def puzzletron( f"dataset_path={dataset_path}", ], ) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) # Step 1: score_pruning_activations (distributed processing) score_pruning_activations.launch_score_activations(hydra_cfg) - # Step 2: pruning_ckpts (single process) - if dist.is_master(): - pruning_ckpts.launch_prune_ckpt(hydra_cfg) - dist.barrier() + # # Step 2: pruning_ckpts (single process) + # if dist.is_master(): + # pruning_ckpts.launch_prune_ckpt(hydra_cfg) + # dist.barrier() - # Step 4: build_library_and_stats (single process) - if dist.is_master(): - build_library_and_stats.launch_build_library_and_stats(hydra_cfg) - dist.barrier() + # # Step 4: build_library_and_stats (single process) + # if dist.is_master(): + # build_library_and_stats.launch_build_library_and_stats(hydra_cfg) + # dist.barrier() - # Step 5: calc_one_block_scores (distributed processing) - scoring.launch_scoring(hydra_cfg) + # # Step 5: calc_one_block_scores (distributed processing) + # scoring.launch_scoring(hydra_cfg) - # Step 6: mip_and_realize_models (distributed processing) - mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) + # # Step 6: mip_and_realize_models (distributed processing) + # mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) return hydra_cfg diff --git a/modelopt/torch/puzzletron/tools/robust_json.py b/modelopt/torch/puzzletron/tools/robust_json.py index dbb561b82..3397de639 100644 --- a/modelopt/torch/puzzletron/tools/robust_json.py +++ b/modelopt/torch/puzzletron/tools/robust_json.py @@ -50,8 +50,13 @@ def default(self, o): # User-defined function in main — fallback to just the name return o.__name__ return f"{o.__module__}.{o.__qualname__}" + if inspect.isclass(o): + return f"{o.__module__}.{o.__qualname__}" if isinstance(o, datetime.timedelta): return str(o) + # Fallback for arbitrary objects: return their class path + if hasattr(o, "__class__") and hasattr(o.__class__, "__module__"): + return f"{o.__class__.__module__}.{o.__class__.__qualname__}" return super().default(o) diff --git a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py index 1cb5e8489..1cf02dc93 100644 --- a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py @@ -14,22 +14,30 @@ # limitations under the License. # mypy: ignore-errors -"""Provides utilities for distributed loading, saving, and manipulation of +""" +Provides utilities for distributed loading, saving, and manipulation of large language model checkpoints across multiple GPUs/processes. + +Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations. """ import json from collections.abc import Iterable, Mapping from pathlib import Path -from typing import Literal, cast +from types import SimpleNamespace +from typing import Literal, Type, cast import numpy as np import torch import torch.distributed import torch.nn as nn +import transformers +from huggingface_hub import split_torch_state_dict_into_shards from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME from transformers.utils.hub import cached_file, get_checkpoint_shard_files from typing_extensions import override @@ -43,23 +51,18 @@ ) from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config, load_state_dict from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.utils.dummy_modules import ( + DummyBlock, + DummyLMHead, + DummyModule, + DummyWTE, +) from modelopt.torch.puzzletron.utils.utils import EmptyInitOnDevice -class DummyModule(nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.register_load_state_dict_post_hook(self.load_state_dict_post_hook) - - @staticmethod - def load_state_dict_post_hook( - module: torch.nn.Module, incompatible_keys: torch.nn.modules.module._IncompatibleKeys - ) -> None: - incompatible_keys.missing_keys.clear() - incompatible_keys.unexpected_keys.clear() +class DeciLMDummyBlock(DummyModule): + """Dummy block for DeciLM models (used by replacement_library).""" - -class DummyBlock(DummyModule): def __init__(self, config: DeciLMConfig, block_index: int): super().__init__() self.config = config @@ -73,7 +76,9 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor | tuple[torc return x, None -class DummyWTE(DummyModule): +class DeciLMDummyWTE(DummyModule): + """Dummy word token embedding for DeciLM models (used by replacement_library).""" + def __init__(self, config: DeciLMConfig, dtype: torch.dtype | None = None): super().__init__() self.n_embd = config.get_hidden_size() @@ -86,7 +91,9 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor: return result -class DummyLMHead(DummyModule): +class DeciLMDummyLMHead(DummyModule): + """Dummy LM head for DeciLM models (used by replacement_library).""" + def __init__(self, config: DeciLMConfig): super().__init__() self.vocab_size = config.vocab_size @@ -98,24 +105,44 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return result -def create_local_shard_(model: DeciLMForCausalLM, owned_block_indexes: set[int]): - all_block_indexes = set(range(len(model.model.layers))) +def set_submodule(model: nn.Module, module_name: str, new_submodule: nn.Module) -> None: + """Set a submodule on a model by dotted path.""" + parts = module_name.split(".") + parent_path = ".".join(parts[:-1]) + attr = parts[-1] + parent_module = model.get_submodule(parent_path) if parent_path else model + setattr(parent_module, attr, new_submodule) + + +def create_local_shard_(model, owned_block_indexes: set[int], descriptor, runtime): + all_block_indexes = set(range(model.config.num_hidden_layers)) has_first_block = 0 in owned_block_indexes has_last_block = max(all_block_indexes) in owned_block_indexes unowned_block_indexes = all_block_indexes - owned_block_indexes for block_index in unowned_block_indexes: - model.model.layers[block_index] = cast( - "DeciLMDecoderLayer", DummyBlock(model.config, block_index) + decoder_layer_name = descriptor.layer_block_name(block_index) + decoder_layer = model.get_submodule(decoder_layer_name) + set_submodule( + model, + decoder_layer_name, + descriptor.create_dummy_block(decoder_layer, block_index=block_index), ) - if not has_first_block: - model.set_input_embeddings(DummyWTE(model.config)) + # If we have the last block with tied embeddings, keep embed_tokens so lm_head works. + # load_sharded_state_dict will load embed_tokens.weight from the first shard's checkpoint file, + # and since they're tied, lm_head.weight gets populated too. + if not has_first_block and not (has_last_block and model.config.tie_word_embeddings): + set_submodule( + model, + descriptor.input_embedding_name(), + DummyWTE(model.config.hidden_size, dtype=runtime.dtype), + ) if not has_last_block: - model.model.set_final_layer_norm(nn.Identity()) + set_submodule(model, descriptor.final_norm_name(), nn.Identity()) if not (model.config.tie_word_embeddings and has_first_block): - model.set_output_embeddings(DummyLMHead(model.config)) + set_submodule(model, descriptor.output_embedding_name(), DummyLMHead(model.config)) return model @@ -130,42 +157,74 @@ def create_dummy_model( rope_cls = rope_type_to_class[model_config.position_embedding_type] model.model.rotary_emb = rope_cls(config=model.config) - model.model.set_input_embeddings(DummyWTE(model.config, dtype)) + model.model.set_input_embeddings(DeciLMDummyWTE(model.config, dtype)) model.model.set_final_layer_norm(nn.Identity()) - model.set_output_embeddings(DummyLMHead(model.config)) + model.set_output_embeddings(DeciLMDummyLMHead(model.config)) for block_index in range(model_config.get_num_hidden_layers()): - model.model.layers[block_index] = DummyBlock(model.config, block_index) + model.model.layers[block_index] = DeciLMDummyBlock(model.config, block_index) return model +def _get_model_class_from_config(config: PretrainedConfig): + """ + Get the model class from config.architectures field. + Works for any model registered in transformers (CausalLM, VL models, etc.). + Falls back to AutoModelForCausalLM if architectures is not available. + """ + if hasattr(config, "architectures") and config.architectures: + model_class_name = config.architectures[0] + if hasattr(transformers, model_class_name): + return getattr(transformers, model_class_name) + mprint( + f"Warning: {model_class_name} not found in transformers, falling back to AutoModelForCausalLM" + ) + return AutoModelForCausalLM + + def load_and_shard_model( + descriptor, checkpoint_path: str | Path, owned_block_indexes: set[int] | Literal["auto"] = "auto", - model_config: DeciLMConfig | None = None, - model_config_overrides: Mapping | None = None, - model_dtype: torch.dtype = torch.bfloat16, -) -> DeciLMForCausalLM: + model_config: PretrainedConfig | None = None, +): checkpoint_path = Path(checkpoint_path) - with torch.device(dist.local_rank()): + runtime = SimpleNamespace( + device=torch.device(dist.local_rank()), + dtype=torch.bfloat16, + global_rank=dist.rank(), + world_size=dist.size(), + is_main_process=dist.is_master(), + is_last_process=dist.is_last_process(), + use_autocast=True, # Default: use autocast; descriptor can override + ) + + with runtime.device: if model_config is None: - model_config = load_model_config( - checkpoint_path, model_config_overrides, ignore_unexpected_config_keys=True - ) + model_config = load_model_config(checkpoint_path) if owned_block_indexes == "auto": owned_block_indexes = set( - np.array_split(np.arange(model_config.get_num_hidden_layers()), dist.size())[ - dist.rank() + np.array_split(np.arange(model_config.num_hidden_layers), runtime.world_size)[ + runtime.global_rank ] ) mprint("Initializing model shards") - model_shard = create_sharded_model( - model_config=model_config, - owned_block_indexes=owned_block_indexes, - ) + # Pass block_configs explicitly so patcher works for VL models where + # decoder layers receive nested config (e.g., text_config) without block_configs + from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + + with deci_x_patcher( + model_descriptor=descriptor, block_configs=getattr(model_config, "block_configs", None) + ): + model_shard = create_sharded_model( + runtime=runtime, + descriptor=descriptor, + model_config=model_config, + owned_block_indexes=owned_block_indexes, + ) if (checkpoint_path / SAFE_WEIGHTS_NAME).exists() or ( checkpoint_path / SAFE_WEIGHTS_INDEX_NAME @@ -178,27 +237,47 @@ def load_and_shard_model( shard_state_dict = load_sharded_state_dict( model_name_or_path=str(checkpoint_path), keys_to_load=shard_keys, - device=torch.device(dist.local_rank()), + device=runtime.device, ) new_names = set(shard_state_dict.keys()) mprint(f"{new_names=}") - model_shard.load_state_dict(shard_state_dict, assign=True) + # strict=False: allows missing lm_head.weight when tie_word_embeddings=True (e.g., Llama 3.2 3B) + model_shard.load_state_dict(shard_state_dict, strict=False, assign=True) del shard_state_dict - if model_config.tie_word_embeddings and (0 in owned_block_indexes): - # re-tie the weights in case the connection was severed + # Re-tie weights after load_state_dict with assign=True, which severs the tie. + # Needed on first rank (owns embed_tokens) and last rank (owns lm_head). + has_first_block = 0 in owned_block_indexes + has_last_block = (model_config.num_hidden_layers - 1) in owned_block_indexes + if model_config.tie_word_embeddings and (has_first_block or has_last_block): model_shard.tie_weights() + + # On the last rank with tied embeddings, we kept embed_tokens in create_local_shard_() + # just to load the weight and tie it to lm_head. Now replace it with a dummy so it + # doesn't interfere with the pipeline forward pass (only rank 0 should run embed_tokens). + if model_config.tie_word_embeddings and has_last_block and not has_first_block: + set_submodule( + model_shard, + descriptor.input_embedding_name(), + DummyWTE(model_config.hidden_size, dtype=runtime.dtype), + ) else: mprint("Loading state_dict in main process") - state_dict = load_state_dict(checkpoint_path) if dist.is_master() else None + state_dict = load_state_dict(checkpoint_path) if runtime.is_main_process else None mprint("Distributing model to shards") load_state_dict_to_shards(model_shard=model_shard, loaded_state_dict=state_dict) del state_dict - model_shard.type(model_dtype) + descriptor.init_rotary_embedding(model_shard, runtime) + + model_shard.type(runtime.dtype) + + # Configure autocast based on model descriptor (some models like Qwen3-VL MoE + # have dtype bugs under autocast) + runtime.use_autocast = descriptor.uses_autocast() params_on_meta_device = [ param_name @@ -206,14 +285,16 @@ def load_and_shard_model( if param.device == torch.device("meta") ] assert len(params_on_meta_device) == 0, ( - f"[global_rank={dist.rank()}] Couldn't load params {params_on_meta_device}" + f"[global_rank={runtime.global_rank}] Couldn't load params {params_on_meta_device}" ) return model_shard def create_sharded_model( - model_config: DeciLMConfig, + runtime, + descriptor, + model_config: PretrainedConfig, owned_block_indexes: set[int], device: str | torch.device | None = "meta", dtype: torch.dtype | None = torch.float32, @@ -224,14 +305,24 @@ def create_sharded_model( dist.barrier() with EmptyInitOnDevice(device="meta", dtype=dtype): - model = DeciLMForCausalLM(model_config) - create_local_shard_(model=model, owned_block_indexes=owned_block_indexes) + # Get model class from config.architectures (works for CausalLM, VL models, etc.) + model_class = _get_model_class_from_config(model_config) + # AutoModelForCausalLM uses from_config(); concrete model classes use _from_config() + if model_class is AutoModelForCausalLM: + model = model_class.from_config(model_config, trust_remote_code=True) + else: + model = model_class._from_config(model_config) + create_local_shard_( + model=model, + owned_block_indexes=owned_block_indexes, + descriptor=descriptor, + runtime=runtime, + ) if device != torch.device("meta"): local_shard_state_dict = { k: torch.empty_like(v, device=device) for k, v in model.state_dict().items() } - model.load_state_dict(local_shard_state_dict, assign=True) return model @@ -288,7 +379,9 @@ def load_state_dict_to_shards( def save_sharded_model( model_shard: torch.nn.Module | dict[str, torch.Tensor], out_path: str | Path ): - """out_path is usually output_checkpoint_path / "model.safetensors" """ + """ + out_path is usually output_checkpoint_path / "model.safetensors" + """ dist.barrier() if isinstance(model_shard, torch.nn.Module): @@ -346,7 +439,9 @@ def load_sharded_state_dict( keys_to_load: Iterable[str] | None = None, device: torch.device | str = "cpu", ) -> dict[str, torch.Tensor]: - """keys_to_load: entire state_dict if None, else partial state_dict containing only these keys""" + """ + keys_to_load: entire state_dict if None, else partial state_dict containing only these keys + """ shard_paths = _resolve_shard_paths(model_name_or_path) # print(f"shard_paths: {shard_paths}") partial_state_dict = {} diff --git a/modelopt/torch/puzzletron/tools/validate_model.py b/modelopt/torch/puzzletron/tools/validate_model.py index 6c3dc3640..cb8eb996d 100644 --- a/modelopt/torch/puzzletron/tools/validate_model.py +++ b/modelopt/torch/puzzletron/tools/validate_model.py @@ -12,42 +12,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Provides a function to validate a model. Runs a model forward pass on a dataset and calculates +# mypy: ignore-errors +""" +Provides a function to validate a model. Runs a model forward pass on a dataset and calculates the loss, and optionally registers hooks to capture the inputs and the outputs of pytorch modules that are used for activation scoring for pruning. TODO: Consider moving this a separate module dedicated for scoring + +Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations. """ import textwrap from pathlib import Path +from typing import Type import torch from omegaconf import DictConfig from torch import nn from torch.utils.data import DataLoader -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - PreTrainedModel, - PreTrainedTokenizerBase, -) +from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.activation_scoring.activation_hooks.utils import ( register_activation_hooks, ) -from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_checkpoint +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import Same from modelopt.torch.puzzletron.tools.logger import aprint, mprint -from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import ( + load_and_shard_model, + set_submodule, +) from modelopt.torch.puzzletron.utils.data.dataloaders import create_validation_dataloader -from modelopt.torch.puzzletron.utils.parsing import simple_parse_args_string +from modelopt.torch.puzzletron.utils.parsing import ( + simple_parse_args_string, # noqa: F401 (kept for backwards compat) +) from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( HiddenStatesAndLMHead, calculate_losses_pipeline, ) -from modelopt.torch.puzzletron.utils.validation import calculate_losses """ Two goals: @@ -70,7 +77,6 @@ def validate_model( tokenizer: PreTrainedTokenizerBase | None = None, target_hidden_states_per_batch: list[torch.Tensor] | None = None, return_hidden_states: bool = False, - pipeline_parallel: bool = False, calculate_full_score_ablations: bool = False, val_dataloader: DataLoader | None = None, ) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: @@ -79,86 +85,80 @@ def validate_model( Args: args: Configuration object containing the following attributes: - Model Configuration attributes: - - - ``model_name_or_path`` (str): Path to model checkpoint or HuggingFace model name. - Required unless model is passed directly. - - ``model_dtype`` (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). - - ``autocast_dtype`` (str or torch.dtype): Autocast data type for mixed precision. - - Dataset Configuration attributes: - - - ``dataset_path`` (str): Path to the validation dataset. - - ``tokenizer_name`` (str, optional): Tokenizer name/path. Uses model_name_or_path if not specified. - - ``data_column`` (str): Column name in dataset containing text data. - - ``block_size`` (int): Maximum sequence length for tokenization. - - ``eval_samples`` (int, optional): Number of samples to evaluate. Uses all if None. - - ``val_dataset_name`` (str): Name of validation dataset split. - - ``source_datasets_to_discard`` (list[str], optional): List of source datasets to exclude. - - ``load_dataset_fn`` (callable, optional): Custom function to load the dataset. - - Data Processing attributes: - - - ``micro_batch_size`` (int): Batch size for evaluation. - - ``seed`` (int): Random seed for reproducibility. - - ``shuffle_seed`` (int, optional): Seed for shuffling data. Uses seed if None. - - ``varlen`` (bool): Enable variable-length sequences. - - ``bos_rate`` (float): Rate of adding BOS token. - - ``fim_rate`` (float): Fill-in-the-middle rate for code completion tasks. - - ``fim_spm_rate`` (float): SPM-based fill-in-the-middle rate. - - Activation Hooks attributes: - - - ``activations_log_dir`` (str, optional): Directory to log activation scores. - If provided, hooks will be registered to capture activations. - - ``activation_hooks_kwargs`` (str or dict, optional): Arguments for activation hooks. - If string, comma-separated format: "arg1=val1,arg2=val2". - - Execution Options attributes: - - - ``calc_losses_on_cpu`` (bool): Calculate losses on CPU to avoid OOM. Very slow, not recommended. - - ``write_results`` (bool): Write validation results to file. + Model Configuration: + - model_name_or_path (str): Path to model checkpoint or HuggingFace model name. + Required unless model is passed directly. + - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration: + - dataset_path (str): Path to the validation dataset. + - tokenizer_name (str, optional): Tokenizer name/path. Uses model_name_or_path if not specified. + - data_column (str): Column name in dataset containing text data. + - block_size (int): Maximum sequence length for tokenization. + - eval_samples (int, optional): Number of samples to evaluate. Uses all if None. + - val_dataset_name (str): Name of validation dataset split. + - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. + - load_dataset_fn (callable, optional): Custom function to load the dataset. + + Data Processing: + - micro_batch_size (int): Batch size for evaluation. + - seed (int): Random seed for reproducibility. + - shuffle_seed (int, optional): Seed for shuffling data. Uses seed if None. + - varlen (bool): Enable variable-length sequences. + - bos_rate (float): Rate of adding BOS token. + - fim_rate (float): Fill-in-the-middle rate for code completion tasks. + - fim_spm_rate (float): SPM-based fill-in-the-middle rate. + + Activation Hooks: + - activations_log_dir (str, optional): Directory to log activation scores. If provided, + hooks will be registered to capture activations. + - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. + If string, comma-separated format: "arg1=val1,arg2=val2". + + Execution Options: + - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. Very slow, not recommended. + - write_results (bool): Write validation results to file. model: Pre-loaded model. If None, will be loaded from args.model_name_or_path. tokenizer: Pre-loaded tokenizer. If None, will be loaded based on args. target_hidden_states_per_batch: Target hidden states for pipeline parallel evaluation. return_hidden_states: Whether to return hidden states from the model. - pipeline_parallel: Enable pipeline parallelism for large models. calculate_full_score_ablations: Calculate comprehensive teacher similarity scores. - False calculates only a small suite for efficiency. + False calculates only a small suite for efficiency. val_dataloader: Pre-created validation dataloader. If None, will be created from args. Returns: A tuple containing: - - losses: Dictionary mapping loss names to loss statistics (avg, per_sample). - hidden_states_per_batch: Hidden states and LM head outputs if return_hidden_states is True, else None. - Returns (None, None) if not on master rank. """ + descriptor = ModelDescriptorFactory.get(args.descriptor) + if val_dataloader is None: val_dataloader = prepare_dataloader(args, tokenizer) if dist.is_master() else None validation_full_iters = ( args.eval_samples // args.micro_batch_size ) # model pipeline, single data rank - model = prepare_model(args, model, pipeline_parallel) + model = prepare_model(args, descriptor=descriptor, model=model) just_model_forward = False checkpoint_manager = None activation_hooks = None if args.activations_log_dir is not None: - activation_hooks_kwargs = ( - simple_parse_args_string(args.activation_hooks_kwargs) - if isinstance(args.activation_hooks_kwargs, str) - else args.activation_hooks_kwargs - ) + activation_hooks_kwargs = args.activation_hooks_kwargs or {} activation_hooks_kwargs["validation_full_iters"] = validation_full_iters + hook_class = args.hook_class - # Create activation hooks first - activation_hooks, hook_class = register_activation_hooks( - model=model, activation_hooks_kwargs=activation_hooks_kwargs + # Create activation hooks using pruning mixin + activation_hooks = register_activation_hooks( + model=model, + activation_hooks_kwargs=activation_hooks_kwargs, + hook_class=hook_class, + pruning_mixin=args.pruning_mixin, ) # Create checkpoint manager with hooks @@ -181,26 +181,23 @@ def validate_model( else: mprint("No checkpoint found, starting fresh") just_model_forward = True - model.lm_head = nn.Identity() - - if not pipeline_parallel: - losses, hidden_states_per_batch = calculate_losses( - model=model, - dataloader=val_dataloader, - checkpoint_manager=checkpoint_manager, - ) - else: - losses, hidden_states_per_batch = calculate_losses_pipeline( - stitched_model=model, - dataloader=val_dataloader, - target_hidden_states_per_batch=target_hidden_states_per_batch, - return_hidden_states=return_hidden_states, - calculate_full_score_ablations=calculate_full_score_ablations, - calc_on_cpu=args.calc_losses_on_cpu, - just_model_forward=just_model_forward, - checkpoint_manager=checkpoint_manager, - autocast_dtype=getattr(torch, args.autocast_dtype.strip("torch.")), - ) + set_submodule(model, descriptor.output_embedding_name(), Same()) + + losses, hidden_states_per_batch = calculate_losses_pipeline( + stitched_model=model, + dataloader=val_dataloader, + target_hidden_states_per_batch=target_hidden_states_per_batch, + return_hidden_states=return_hidden_states, + calculate_full_score_ablations=calculate_full_score_ablations, + calc_on_cpu=args.calc_losses_on_cpu, + just_model_forward=just_model_forward, + checkpoint_manager=checkpoint_manager, + autocast_dtype=getattr( + torch, getattr(args, "autocast_dtype", "torch.bfloat16").strip("torch.") + ), + descriptor=descriptor, + use_autocast=descriptor.uses_autocast(), + ) if losses is not None: avg_losses = {loss_name: loss_log["avg"] for loss_name, loss_log in losses.items()} @@ -224,31 +221,13 @@ def validate_model( def prepare_model( - args: DictConfig, model: PreTrainedModel | None = None, pipeline_parallel: bool = False + args: DictConfig, + descriptor: Type[ModelDescriptor], + model: PreTrainedModel | None = None, ) -> nn.Module: if model is None: assert args.model_name_or_path is not None - if pipeline_parallel: - model = load_and_shard_model( - args.model_name_or_path, - model_config_overrides={"block_size": args.block_size}, - model_dtype=getattr(torch, args.model_dtype.strip("torch.")), - ) - else: - try: - model = load_checkpoint( - args.model_name_or_path, - model_config_overrides={"block_size": args.block_size}, - ignore_unexpected_config_keys=True, - ) - model.to("cuda") - except FileNotFoundError: - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - torch_dtype="auto", - device_map="auto", - trust_remote_code=True, - ) + model = load_and_shard_model(descriptor=descriptor, checkpoint_path=args.model_name_or_path) model.eval() return model diff --git a/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py b/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py index db1e8f2ce..90fea13c5 100644 --- a/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py +++ b/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Model evaluation utilities for models split across multiple GPUs in pipeline-parallel mode. +""" +Model evaluation utilities for models split across multiple GPUs in pipeline-parallel mode. Coordinates forward passes and loss computation through model shards distributed across GPUs using sewing_kit's StitchedModule framework. Relies on validation.py for core loss computation. @@ -22,16 +23,18 @@ """ # mypy: ignore-errors +import traceback +from contextlib import nullcontext +from typing import Type + import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm import modelopt.torch.utils.distributed as dist -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( - DeciLMForCausalLM, - LMHead, -) +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import LMHead from modelopt.torch.puzzletron.sewing_kit import ( ExternalTarget, InputArgs, @@ -51,6 +54,23 @@ from modelopt.torch.puzzletron.utils.validation import _organize_outputs, calculate_batch_outputs +def _log_forward_error(e: Exception, rank: int, batch_idx: int, num_batches: int) -> None: + """Log detailed error info for distributed forward pass failures. + + When one rank crashes during distributed forward, others may hang waiting for communication. + This logging helps diagnose which rank failed and why. + """ + error_msg = ( + f"\n{'=' * 60}\n" + f"[Rank {rank}] ERROR in stitched_model forward (batch {batch_idx}/{num_batches})\n" + f"Error: {type(e).__name__}: {e}\n" + f"{'=' * 60}\n" + f"{traceback.format_exc()}" + f"{'=' * 60}\n" + ) + print(error_msg, flush=True) + + class HiddenStatesAndLMHead(list): def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Tensor): super().__init__(hidden_states) @@ -59,7 +79,7 @@ def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Ten @torch.no_grad() def calculate_losses_pipeline( - stitched_model: StitchedModule | DeciLMForCausalLM, + stitched_model: StitchedModule, dataloader: DataLoader | None, target_hidden_states_per_batch: HiddenStatesAndLMHead | None = None, return_hidden_states: bool = False, @@ -68,8 +88,11 @@ def calculate_losses_pipeline( just_model_forward: bool = False, checkpoint_manager=None, autocast_dtype: torch.dtype = torch.bfloat16, + descriptor: Type[ModelDescriptor] = None, + use_autocast: bool = True, ) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: - """Do model forward on each batch and calculate LM loss. + """ + Do model forward on each batch and calculate LM loss. Optionally also calculate kl_div loss and other metrics from given target_hidden_states_per_batch. Optionally return hidden states per batch. Does not support data-parallel. @@ -87,8 +110,8 @@ def calculate_losses_pipeline( target_hidden_states_per_batch: list[torch.Tensor], returned if return_hidden_states=True """ - if isinstance(stitched_model, DeciLMForCausalLM): - stitched_model = perform_pipeline_stitches(stitched_model) + if not isinstance(stitched_model, StitchedModule): + stitched_model = perform_pipeline_stitches(stitched_model, descriptor) params = list(stitched_model.parameters()) model_device = params[0].device if params else "cpu" @@ -145,14 +168,24 @@ def calculate_losses_pipeline( stitched_model.eval() - with torch.autocast(device_type="cuda", dtype=autocast_dtype): + # Use autocast for mixed precision, or nullcontext if disabled + # (some models like Qwen3-VL MoE have dtype bugs under autocast) + autocast_ctx = ( + torch.autocast(device_type="cuda", dtype=autocast_dtype) if use_autocast else nullcontext() + ) + with autocast_ctx: + fake_input_ids = fake_tensor(1, seq_len, dtype=torch.long, device=model_device) for i_batch in progress_bar: if dist.is_master(): input_ids = all_input_ids[i_batch].to(model_device) else: - input_ids = fake_tensor(1, seq_len, dtype=torch.long) + input_ids = fake_input_ids - output = stitched_model({}, {}, input_ids) + try: + output = stitched_model({}, {}, input_ids) + except Exception as e: + _log_forward_error(e, dist.rank(), i_batch, num_batches) + raise if dist.is_last_process(): logits = output.captured_outputs.get("model_output") @@ -183,6 +216,16 @@ def calculate_losses_pipeline( outputs.append(batch_outputs) + # Free GPU memory after processing each batch + del logits, hidden_states, targets + if target_hidden_states is not None: + del target_hidden_states + if target_logits is not None: + del target_logits + + # Free output tensor memory on all ranks + del output + # Update checkpoint progress periodically if checkpoint_manager: checkpoint_manager.update_progress(i_batch + 1, num_batches) @@ -200,13 +243,28 @@ def calculate_losses_pipeline( return losses, hidden_states_per_batch -def perform_pipeline_stitches(model: DeciLMForCausalLM) -> StitchedModule: +def perform_pipeline_stitches( + model, + descriptor: Type[ModelDescriptor], +) -> StitchedModule: + """Create pipeline stitches for distributed model evaluation. + + Args: + model: The model to stitch (any HuggingFace model with AnyModel descriptor). + descriptor: ModelDescriptor for layer naming. + """ target = ModuleTarget("module", model) stitcher = Needle() + num_layers = model.config.num_hidden_layers + is_real_block = np.flatnonzero( - [not isinstance(block, DummyBlock) for block in model.model.layers] + [ + not isinstance(model.get_submodule(descriptor.layer_block_name(i)), DummyBlock) + for i in range(num_layers) + ] ) + first_block, last_block = is_real_block.min(), is_real_block.max() if dist.rank() != 0: @@ -216,7 +274,7 @@ def perform_pipeline_stitches(model: DeciLMForCausalLM) -> StitchedModule: name="activations", adapter=lambda x: InputArgs(x) ), target.input( - name=f"model.layers.{first_block}", + name=descriptor.layer_block_name(first_block), reducer=InputReducer( lambda acc, override, orig, *args: override + orig.drop_args(0) ), @@ -226,17 +284,17 @@ def perform_pipeline_stitches(model: DeciLMForCausalLM) -> StitchedModule: if not dist.is_last_process(): # send activations to next rank stitcher.stitch( - target.output(f"model.layers.{last_block}"), + target.output(descriptor.layer_block_name(last_block)), RemoteTarget(peer_rank=dist.rank() + 1).value(name="activations"), ) else: # register model output stitcher.stitch( - target.output(name="lm_head"), + target.output(name=descriptor.output_embedding_name()), ExternalTarget().output("model_output"), ) stitcher.stitch( - target.output(name="model.norm"), + target.output(name=descriptor.final_norm_name()), ExternalTarget().output("hidden_states"), ) diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index 23a4b61c2..585567715 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -24,6 +24,7 @@ from _test_utils.torch.puzzletron.utils import setup_test_model_and_data import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron import puzzletron from modelopt.torch.puzzletron.anymodel import convert_model # The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) @@ -42,26 +43,26 @@ ), [ ("llama_3_1_8b_instruct", "llama", "llama_3_1_8b_instruct", None, False), - ("llama_3_2_3b_instruct", "llama", "llama_3_1_8b_instruct", None, False), - ("qwen2_5_7b_instruct", "qwen2", "qwen2_5_7b_instruct", None, False), - ( - "mistral-small-24b-instruct-2501", - "mistral_small", - "mistral-small-24b-instruct-2501", - None, - False, - ), - ("qwen3-8b", "qwen3", "qwen3-8b", None, False), - ("qwen3-vl-30b-a3b-instruct", "qwen3_vl", "qwen3-vl-30b-a3b-instruct", None, True), - ("nemotron-nano-12b-v2", "nemotron_h_v2", "nemotron-nano-12b-v2", "*-", False), - ( - "nemotron-3-nano-30b-a3b-base-bf16", - "nemotron_h", - "nemotron-3-nano-30b-a3b-base-bf16", - "*E", - True, - ), - ("gpt-oss-20b", "gpt_oss_20b", "gpt-oss-20b", None, True), + # ("llama_3_2_3b_instruct", "llama", "llama_3_1_8b_instruct", None, False), + # ("qwen2_5_7b_instruct", "qwen2", "qwen2_5_7b_instruct", None, False), + # ( + # "mistral-small-24b-instruct-2501", + # "mistral_small", + # "mistral-small-24b-instruct-2501", + # None, + # False, + # ), + # ("qwen3-8b", "qwen3", "qwen3-8b", None, False), + # ("qwen3-vl-30b-a3b-instruct", "qwen3_vl", "qwen3-vl-30b-a3b-instruct", None, True), + # ("nemotron-nano-12b-v2", "nemotron_h_v2", "nemotron-nano-12b-v2", "*-", False), + # ( + # "nemotron-3-nano-30b-a3b-base-bf16", + # "nemotron_h", + # "nemotron-3-nano-30b-a3b-base-bf16", + # "*E", + # True, + # ), + # ("gpt-oss-20b", "gpt_oss_20b", "gpt-oss-20b", None, True), ], ) def test_puzzletron( @@ -106,7 +107,7 @@ def _test_puzzletron_multiprocess_job( puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank, hf_config_name, hybrid_override_pattern ) - hydra_config_dir = ( # noqa: F841 + hydra_config_dir = ( project_root_path / f"tests/gpu/torch/puzzletron/resources/configs/{hydra_config_subdir}" ) @@ -120,10 +121,10 @@ def _test_puzzletron_multiprocess_job( dist.barrier() # TODO commented for the duration of merging process from dkorzekwa/any_model to feature/puzzletron - # # Compress the model using a one-click approach - # puzzletron.puzzletron( - # str(hydra_config_dir), hydra_config_subdir, str(puzzle_dir), str(dataset_path) - # ) + # Compress the model using a one-click approach + puzzletron.puzzletron( + str(hydra_config_dir), hydra_config_subdir, str(puzzle_dir), str(dataset_path) + ) # # # # Check assertions From 6cc219492c1e267274cb8097f368576b38a19e68 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 5 Mar 2026 09:55:17 -0800 Subject: [PATCH 07/30] Comment all tested models aside of llama_3_1_8b_instruct Signed-off-by: Daniel Korzekwa --- tests/gpu/torch/puzzletron/test_puzzletron.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index 23a4b61c2..3a5d9a8ce 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -42,26 +42,26 @@ ), [ ("llama_3_1_8b_instruct", "llama", "llama_3_1_8b_instruct", None, False), - ("llama_3_2_3b_instruct", "llama", "llama_3_1_8b_instruct", None, False), - ("qwen2_5_7b_instruct", "qwen2", "qwen2_5_7b_instruct", None, False), - ( - "mistral-small-24b-instruct-2501", - "mistral_small", - "mistral-small-24b-instruct-2501", - None, - False, - ), - ("qwen3-8b", "qwen3", "qwen3-8b", None, False), - ("qwen3-vl-30b-a3b-instruct", "qwen3_vl", "qwen3-vl-30b-a3b-instruct", None, True), - ("nemotron-nano-12b-v2", "nemotron_h_v2", "nemotron-nano-12b-v2", "*-", False), - ( - "nemotron-3-nano-30b-a3b-base-bf16", - "nemotron_h", - "nemotron-3-nano-30b-a3b-base-bf16", - "*E", - True, - ), - ("gpt-oss-20b", "gpt_oss_20b", "gpt-oss-20b", None, True), + # ("llama_3_2_3b_instruct", "llama", "llama_3_1_8b_instruct", None, False), + # ("qwen2_5_7b_instruct", "qwen2", "qwen2_5_7b_instruct", None, False), + # ( + # "mistral-small-24b-instruct-2501", + # "mistral_small", + # "mistral-small-24b-instruct-2501", + # None, + # False, + # ), + # ("qwen3-8b", "qwen3", "qwen3-8b", None, False), + # ("qwen3-vl-30b-a3b-instruct", "qwen3_vl", "qwen3-vl-30b-a3b-instruct", None, True), + # ("nemotron-nano-12b-v2", "nemotron_h_v2", "nemotron-nano-12b-v2", "*-", False), + # ( + # "nemotron-3-nano-30b-a3b-base-bf16", + # "nemotron_h", + # "nemotron-3-nano-30b-a3b-base-bf16", + # "*E", + # True, + # ), + # ("gpt-oss-20b", "gpt_oss_20b", "gpt-oss-20b", None, True), ], ) def test_puzzletron( From ee4e1e355e6772504a42e2d4e03f99ec9bfd4727 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 5 Mar 2026 09:57:29 -0800 Subject: [PATCH 08/30] Delete not needed decilm test Signed-off-by: Daniel Korzekwa --- ..._convert_llama3_config_to_decilm_config.py | 50 ------------------- 1 file changed, 50 deletions(-) delete mode 100644 tests/gpu/torch/puzzletron/decilm/converters/test_convert_llama3_config_to_decilm_config.py diff --git a/tests/gpu/torch/puzzletron/decilm/converters/test_convert_llama3_config_to_decilm_config.py b/tests/gpu/torch/puzzletron/decilm/converters/test_convert_llama3_config_to_decilm_config.py deleted file mode 100644 index 4b1ea0b41..000000000 --- a/tests/gpu/torch/puzzletron/decilm/converters/test_convert_llama3_config_to_decilm_config.py +++ /dev/null @@ -1,50 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from pathlib import Path - -from _test_utils.torch.puzzletron.utils import create_and_save_small_llama_model, create_tokenizer - -from modelopt.torch.puzzletron.decilm.converters.convert_llama3_to_decilm import ( - convert_llama3_to_decilm, -) - - -def test_convert_llama3_config_to_decilm_config(project_root_path: Path, tmp_path: Path): - tokenizer = create_tokenizer(project_root_path) - llama_checkpoint_path = tmp_path / "llama_checkpoint" - create_and_save_small_llama_model( - llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer - ) - - # Convert the Llama model to a DeciLM model - decilm_checkpoint_path = tmp_path / "decilm_checkpoint" - convert_llama3_to_decilm( - input_dir=llama_checkpoint_path, - output_dir=decilm_checkpoint_path, - ) - - # Assert that the converted config has the correct number of block_configs - config_path = decilm_checkpoint_path / "config.json" - assert config_path.exists(), f"Config file not found at {config_path}" - - with open(config_path) as f: - decilm_config = json.load(f) - - # Verify block_configs exists and has the correct length - assert "block_configs" in decilm_config, "block_configs not found in converted config" - actual_num_block_configs = len(decilm_config["block_configs"]) - assert actual_num_block_configs == 2 From 449b52390eb159192b0e3b57c8680a934304f972 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 5 Mar 2026 10:11:18 -0800 Subject: [PATCH 09/30] Fix broken tests Signed-off-by: Daniel Korzekwa --- tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py | 4 ++-- tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py index c409da28b..23e3b70d5 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py @@ -41,7 +41,7 @@ def _test_nas_convert_ffn_pruning_multiprocess_job( dist.setup(timeout=timedelta(10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank + project_root_path, tmp_path, rank, "llama_3_1_8b_instruct" ) hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" hydra_config_name = "Llama-3_1-8B-ffn-pruning" @@ -97,7 +97,7 @@ def _test_nas_convert_attn_pruning_multiprocess_job( dist.setup(timeout=timedelta(10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank + project_root_path, tmp_path, rank, "llama_3_1_8b_instruct" ) hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" hydra_config_name = "Llama-3_1-8B-attn-pruning" diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py index a1258c1d0..b0691c90e 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py @@ -40,7 +40,7 @@ def _test_nas_search_multiprocess_job( dist.setup(timeout=timedelta(10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank + project_root_path, tmp_path, rank, "llama_3_1_8b_instruct" ) hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" hydra_config_name = "Llama-3_1-8B-ffn-pruning" From fb27bba0298c558e9088e5217300baee178e02f4 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 5 Mar 2026 10:20:40 -0800 Subject: [PATCH 10/30] Update puzzletron_nas_pluging to any_model version Signed-off-by: Daniel Korzekwa --- .../nas/plugins/puzzletron_nas_plugin.py | 46 +++++++++++++------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py index 5e1eace93..bd11837d7 100644 --- a/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py +++ b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py @@ -13,14 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Puzzletron NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146). +""" +Puzzletron NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146). -It is used by mtn.convert() to convert a model from HF format to DeciLM format + do pruning scoring +It is used by mtn.convert() to convert a model from HF format to Puzzletron heterogeneous format + do pruning scoring and save pruned checkpoints, and by mtn.search() to perform the MIP-based NAS search. """ +import datetime from pathlib import Path +import hydra +import torch from torch import nn import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models @@ -39,15 +43,14 @@ from modelopt.torch.opt.searcher import BaseSearcher, SearchStateDict from modelopt.torch.puzzletron import build_library_and_stats from modelopt.torch.puzzletron.activation_scoring import score_pruning_activations -from modelopt.torch.puzzletron.decilm.converters.convert_llama3_to_decilm import ( - convert_llama3_to_decilm, -) +from modelopt.torch.puzzletron.anymodel.converter import ConverterFactory +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir from modelopt.torch.puzzletron.tools.logger import mprint class PuzzletronModel(nn.Module): - pass # No model implementation is needed for the puzzletron mode + pass # No model implementation is needed for the compress mode class PuzzletronConfig(ModeloptBaseConfig): @@ -90,7 +93,7 @@ class PuzzletronConfig(ModeloptBaseConfig): def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> ConvertReturnType: - """1. Convert the model from HF format to DeciLM format. + """1. Convert the model from HF format to AnyModel format. 2. Score the pruning activations. 3. Prune the model and save pruned checkpoints @@ -111,14 +114,24 @@ def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> Conv f"dataset_path={config.dataset_path}", ], ) + # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) - # Convert Llama3 model to DeciLM model - # TODO: Make it generic, do not call convert_llama3_to_decilm directly. + # Convert HuggingFace model to Puzzletron heterogeneous format (generic, uses descriptor from config) if dist.is_master(): - mprint("Puzzletron Progress 2/8: converting model from HF to DeciLM (single-gpu)") + mprint( + "Puzzletron Progress 2/8: converting model to Puzzletron heterogeneous format (single-gpu)" + ) hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable - convert_llama3_to_decilm( - input_dir=config.input_model_path, + + # Get descriptor and converter from the hydra config + descriptor_name = hydra_cfg.descriptor + descriptor = ModelDescriptorFactory.get(descriptor_name) + converter = ConverterFactory.get(descriptor_name) + + converter.convert( + descriptor=descriptor, + input_dir=Path(config.input_model_path), output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, ) dist.barrier() @@ -141,7 +154,7 @@ def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> Conv def restore_puzzletron_model( model: nn.Module, config: PuzzletronConfig, metadata: MetadataDict ) -> nn.Module: - """Restore is not needed for the puzzletron mode as we are not saving any model state""" + """Restore is not needed for the compress mode as we are not saving any model state""" return model @@ -162,6 +175,7 @@ def config_class(self) -> type[ModeloptBaseConfig]: @property def search_algorithm(self) -> type[BaseSearcher]: """Return the associated searcher implementation.""" + return PuzzletronSearcher @property @@ -178,7 +192,7 @@ def restore(self) -> RestoreEntrypoint: def export_mode(self) -> str | None: """The mode that corresponds to the export mode. For now, this will be a no-op as there is no modelopt's concept of search space defined - for the puzzletron algorithm. + for the compress algorithm. """ return "export_nas" @@ -188,7 +202,7 @@ class PuzzletronSearcher(BaseSearcher): @property def default_state_dict(self) -> SearchStateDict: - """Not needed for the puzzletron mode as we are not saving any model state""" + """Not needed for the compress mode as we are not saving any model state""" return {} def run_search(self) -> None: @@ -201,6 +215,8 @@ def run_search(self) -> None: f"dataset_path={self.model.dataset_path}", ], ) + # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) # Build_library_and_stats (single process) if dist.is_master(): From b350f8226d3da6023b1adfd8d855294fa559dd8e Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 5 Mar 2026 10:36:19 -0800 Subject: [PATCH 11/30] Correct test resources used by tests. Signed-off-by: Daniel Korzekwa --- .../nas/plugins/test_nas_convert.py | 12 +- .../puzzletron/nas/plugins/test_nas_search.py | 6 +- .../llama_3_1_8b_instruct-attn-pruning.yaml | 107 ++++++++++++++++++ 3 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct-attn-pruning.yaml diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py index 23e3b70d5..4d2294d66 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py @@ -43,8 +43,10 @@ def _test_nas_convert_ffn_pruning_multiprocess_job( puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank, "llama_3_1_8b_instruct" ) - hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" - hydra_config_name = "Llama-3_1-8B-ffn-pruning" + hydra_config_dir = ( + project_root_path / "tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct" + ) + hydra_config_name = "llama_3_1_8b_instruct" # # Run the mnt.convert() step @@ -99,8 +101,10 @@ def _test_nas_convert_attn_pruning_multiprocess_job( puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank, "llama_3_1_8b_instruct" ) - hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" - hydra_config_name = "Llama-3_1-8B-attn-pruning" + hydra_config_dir = ( + project_root_path / "tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct" + ) + hydra_config_name = "llama_3_1_8b_instruct-attn-pruning" # # Run the mnt.convert() step diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py index b0691c90e..c34f449d8 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py @@ -42,8 +42,10 @@ def _test_nas_search_multiprocess_job( puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank, "llama_3_1_8b_instruct" ) - hydra_config_dir = project_root_path / "tests/_test_utils/torch/puzzletron/resources/configs" - hydra_config_name = "Llama-3_1-8B-ffn-pruning" + hydra_config_dir = ( + project_root_path / "tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct" + ) + hydra_config_name = "llama_3_1_8b_instruct" # # Run the mnt.convert() step diff --git a/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct-attn-pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct-attn-pruning.yaml new file mode 100644 index 000000000..02c73aca6 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct-attn-pruning.yaml @@ -0,0 +1,107 @@ +defaults: + - pruning: attn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +descriptor: llama + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} From fafe5a381ffd73c7bb49fc2abf1023cc2932d1e9 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 5 Mar 2026 10:43:34 -0800 Subject: [PATCH 12/30] Disable puzzletron tests (will be enabled after all any_model logic is merged) Signed-off-by: Daniel Korzekwa --- tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py | 3 +++ tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py index 4d2294d66..e2373676d 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py @@ -18,6 +18,7 @@ from functools import partial from pathlib import Path +import pytest import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.puzzletron.utils import setup_test_model_and_data @@ -27,6 +28,7 @@ from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel +@pytest.mark.skip(reason="Temporarily disabled") def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), @@ -85,6 +87,7 @@ def _test_nas_convert_ffn_pruning_multiprocess_job( dist.cleanup() +@pytest.mark.skip(reason="Temporarily disabled") def test_nas_convert_attn_pruning(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py index c34f449d8..e39f1e1cb 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py @@ -17,6 +17,7 @@ from functools import partial from pathlib import Path +import pytest import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.puzzletron.utils import setup_test_model_and_data @@ -26,6 +27,7 @@ from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel +@pytest.mark.skip(reason="Temporarily disabled") def test_nas_search(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), From c7178525e4c870df9c61e1fe7fea5639f9f9ca7f Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 6 Mar 2026 00:22:38 -0800 Subject: [PATCH 13/30] Comment out not implemented models. Signed-off-by: Daniel Korzekwa --- .../torch/puzzletron/anymodel/models/__init__.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/puzzletron/anymodel/models/__init__.py b/modelopt/torch/puzzletron/anymodel/models/__init__.py index 9928854b5..f2119059f 100644 --- a/modelopt/torch/puzzletron/anymodel/models/__init__.py +++ b/modelopt/torch/puzzletron/anymodel/models/__init__.py @@ -14,11 +14,11 @@ # limitations under the License. # Import models to trigger factory registration -from modelopt.torch.puzzletron.anymodel.models.gpt_oss_20b import * +# from modelopt.torch.puzzletron.anymodel.models.gpt_oss_20b import * from modelopt.torch.puzzletron.anymodel.models.llama import * -from modelopt.torch.puzzletron.anymodel.models.mistral_small import * -from modelopt.torch.puzzletron.anymodel.models.nemotron_h import * -from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2 import * -from modelopt.torch.puzzletron.anymodel.models.qwen2 import * -from modelopt.torch.puzzletron.anymodel.models.qwen3_8b import * -from modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct import * +# from modelopt.torch.puzzletron.anymodel.models.mistral_small import * +# from modelopt.torch.puzzletron.anymodel.models.nemotron_h import * +# from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2 import * +# from modelopt.torch.puzzletron.anymodel.models.qwen2 import * +# from modelopt.torch.puzzletron.anymodel.models.qwen3_8b import * +# from modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct import * From 030f126459c1390cc98aa12db1efec2a2c574d8f Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 6 Mar 2026 00:45:34 -0800 Subject: [PATCH 14/30] format python docs Signed-off-by: Daniel Korzekwa --- .../puzzletron/anymodel/model_descriptor/model_descriptor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py index 69af0e66c..0fd9149ec 100644 --- a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py @@ -65,6 +65,7 @@ def mlp_no_op_post_init(decoder_layer: nn.Module): Example for replacing an MLP layer with zeroes (zeroes since hidden_states are added to the residuals hidden_states so a no-op implementation will leave residual the same): + >>> decoder_layer.mlp = MatchingZeros() In case the MLP layer to replace returns multiple outputs i.e `hidden_states, _ = self.mlp()`, From 70df0df2575fe3064d9730a9f2a562a38ee7cd32 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 6 Mar 2026 01:28:38 -0800 Subject: [PATCH 15/30] Use trust_remote_code in force_cache_dynamic_modules() Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index bcdab7627..0f5bba2cb 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -105,13 +105,15 @@ def load_checkpoint( return model -def force_cache_dynamic_modules(config: PretrainedConfig, checkpoint_dir: Path | str): +def force_cache_dynamic_modules( + config: PretrainedConfig, checkpoint_dir: Path | str, trust_remote_code: bool = False +): has_remote_code = ( hasattr(config, "auto_map") and isinstance(config.auto_map, dict) and "AutoConfig" in config.auto_map.keys() ) - if has_remote_code: + if has_remote_code and trust_remote_code: for class_reference in config.auto_map.values(): _ = get_class_from_dynamic_module(class_reference, checkpoint_dir) @@ -150,7 +152,7 @@ def load_model_config( if hasattr(config, "block_configs"): config.block_configs = maybe_cast_block_configs(config.block_configs) - force_cache_dynamic_modules(config, checkpoint_dir) + force_cache_dynamic_modules(config, checkpoint_dir, trust_remote_code=trust_remote_code) if not ignore_unexpected_config_keys: if unused_kwargs: From ecd953eccbc844000c3b2e9ba7ff9f708350e5a1 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 6 Mar 2026 01:51:54 -0800 Subject: [PATCH 16/30] Fix anymodel pruning Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/puzzletron.py | 8 +- .../init_child_from_parent.py | 127 +++++++----------- 2 files changed, 54 insertions(+), 81 deletions(-) diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py index 0d9ac068f..94a1de57e 100644 --- a/modelopt/torch/puzzletron/puzzletron.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -57,10 +57,10 @@ def puzzletron( # Step 1: score_pruning_activations (distributed processing) score_pruning_activations.launch_score_activations(hydra_cfg) - # # Step 2: pruning_ckpts (single process) - # if dist.is_master(): - # pruning_ckpts.launch_prune_ckpt(hydra_cfg) - # dist.barrier() + # Step 2: pruning_ckpts (single process) + if dist.is_master(): + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + dist.barrier() # # Step 4: build_library_and_stats (single process) # if dist.is_master(): diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py index 46e403c5f..74ddb8d95 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py @@ -14,15 +14,22 @@ # limitations under the License. # mypy: ignore-errors -"""TODO Add description""" +"""Initialize child models from parent models using AnyModel approach with deci_x_patcher.""" import json import time +from pathlib import Path +from typing import Optional import torch import yaml +from transformers import AutoModelForCausalLM -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( GQAInitMode, HiddenSizeInitMode, @@ -31,85 +38,37 @@ create_child_state_dict, update_model_config, ) -from modelopt.torch.puzzletron.tools.checkpoint_utils import ( - copy_tokenizer, - load_model_config, - load_state_dict, -) +from modelopt.torch.puzzletron.tools.checkpoint_utils import copy_tokenizer, load_state_dict from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import ( _save_checkpoint, copy_deci_lm_hf_code, + load_model_config, ) from modelopt.torch.puzzletron.tools.logger import mprint - -""" - -Usage example - remove all/some routed experts: -=============================================== - -PARENT_DIR=".../meta-llama/Llama-4-Scout-17B-16E-Instruct--deci-hf" - -MLP_INIT_MODE="ConcatExpertsIntoDenseFFN" - -## remove all routed experts, turn the shared expert into a dense FFN -# OUTPUT_DIR="/.../micro_scout/Scout-remove-routed-experts" -# MODEL_CONFIG_OVERRIDES_JSON=' -# { -# "ffn": [ -# { -# "moe": null, -# "intermediate_size": 14336, -# "gated": true, -# "hidden_act": "silu" -# } -# ] -# } -# ' - -## concat the shared expert with one routed expert into a dense FFN -OUTPUT_DIR=".../scratch/micro_scout/Scout-ConcatExpertsIntoDenseFFN-concat-shared-and-3-routed" -MODEL_CONFIG_OVERRIDES_JSON=' -{ - "ffn": [ - { - "moe": null, - "intermediate_size": 14336, - "gated": true, - "hidden_act": "silu" - } - ] -} -' - -echo "" -echo "MODEL_CONFIG_OVERRIDES_JSON:" -echo "${MODEL_CONFIG_OVERRIDES_JSON}" - -python -m modelopt.torch.puzzletron.tools.bypassed_training.init_child_from_parent \ - --parent_checkpoint_dir="$PARENT_DIR" \ - --model_config_overrides_json="$MODEL_CONFIG_OVERRIDES_JSON" \ - --output_checkpoint_dir="$OUTPUT_DIR" \ - --mlp_init_mode="$MLP_INIT_MODE" \ - --mlp_init_config_yaml="$MLP_INIT_CONFIG_YAML" -""" +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import _get_model_class_from_config def init_child_from_parent( + descriptor: ModelDescriptor, + pruning_mixin, parent_checkpoint_dir: str, - model_config_overrides_json: str, + model_config_overrides_dict: dict, output_checkpoint_dir: str, gqa_init_mode: GQAInitMode, mlp_init_mode: MlpInitMode, - mlp_init_config_yaml: str | None, + mlp_init_config_yaml: Optional[str], linear_init_mode: LinearInitMode, - hidden_size_init_mode: HiddenSizeInitMode | None = None, - channel_importance_path: str | None = None, - max_workers: int | None = None, # Auto-calculate optimal workers if None - max_layer_workers: int | None = None, # Auto-calculate optimal workers if None + hidden_size_init_mode: Optional[HiddenSizeInitMode] = None, + channel_importance_path: Optional[str] = None, + max_workers: Optional[int] = None, # Auto-calculate optimal workers if None + max_layer_workers: Optional[int] = None, # Auto-calculate optimal workers if None ) -> None: - """Init child models from parent models in the style of bypass training, + """ + Init child models from parent models in the style of bypass training, but without having to run the entire bypass pipeline. + Uses AnyModel approach with deci_x_patcher for heterogeneous layer configurations. + I/O Optimization Parameters: - max_workers: Number of threads for parallel file I/O (default: auto-calculate min(CPU count, num files)) - max_layer_workers: Number of threads for parallel layer processing (default: auto-calculate min(CPU count, num layers)) @@ -123,16 +82,16 @@ def init_child_from_parent( "We do not support random init of any subblock in this script to avoid initializing the student model" ) + descriptor = ModelDescriptorFactory.get(descriptor) + copy_tokenizer(parent_checkpoint_dir, output_checkpoint_dir) parent_model_config = load_model_config(parent_checkpoint_dir) parent_state_dict = load_state_dict(parent_checkpoint_dir) - # Parse the model config overrides - if isinstance(model_config_overrides_json, str): - model_config_overrides_dict = json.loads(model_config_overrides_json) - else: - model_config_overrides_dict = model_config_overrides_json + # Parse JSON if string + if isinstance(model_config_overrides_dict, str): + model_config_overrides_dict = json.loads(model_config_overrides_dict) # Separate global config overrides from block-level overrides global_config_overrides = {} @@ -146,7 +105,7 @@ def init_child_from_parent( # Load child model config with global overrides child_model_config = load_model_config( - checkpoint_dir=parent_checkpoint_dir, + parent_checkpoint_dir, model_config_overrides=global_config_overrides, ignore_unexpected_config_keys=True, ) @@ -159,12 +118,23 @@ def init_child_from_parent( ) with torch.device("meta"): - child_model = DeciLMForCausalLM(child_model_config) + # Pass block_configs explicitly so patcher works for VL models where + # decoder layers receive nested config (e.g., text_config) without block_configs + with deci_x_patcher( + model_descriptor=descriptor, block_configs=child_model_config.block_configs + ): + model_class = _get_model_class_from_config(child_model_config) + # AutoModelForCausalLM uses from_config(); concrete model classes use _from_config() + if model_class is AutoModelForCausalLM: + child_model = model_class.from_config(child_model_config, trust_remote_code=True) + else: + child_model = model_class._from_config(child_model_config) + child_state_dict_with_meta_tensors = child_model.state_dict() mlp_init_config = ( yaml.safe_load(mlp_init_config_yaml) - if isinstance(mlp_init_config_yaml, str) is None + if isinstance(mlp_init_config_yaml, str) else mlp_init_config_yaml ) @@ -172,6 +142,8 @@ def init_child_from_parent( mprint("Starting create_child_state_dict...") start_time = time.time() child_state_dict = create_child_state_dict( + pruning_mixin=pruning_mixin, + descriptor=descriptor, original_state_dict=parent_state_dict, new_state_dict=child_state_dict_with_meta_tensors, original_config=parent_model_config, @@ -182,7 +154,7 @@ def init_child_from_parent( linear_init_mode=linear_init_mode, hidden_size_init_mode=hidden_size_init_mode or HiddenSizeInitMode.CopyAsIs, channel_importance_path=channel_importance_path, - max_layer_workers=max_layer_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, ) create_child_state_dict_time = time.time() - start_time mprint(f"create_child_state_dict completed in {create_child_state_dict_time:.2f} seconds") @@ -196,7 +168,8 @@ def init_child_from_parent( child_model_config, child_state_dict, output_checkpoint_dir, - max_workers=max_workers, # Will auto-calculate if None + descriptor, + max_workers=max_workers, ) save_checkpoint_time = time.time() - start_time mprint(f"_save_checkpoint completed in {save_checkpoint_time:.2f} seconds") @@ -207,7 +180,7 @@ def init_child_from_parent( total_core_time = create_child_state_dict_time + save_checkpoint_time actual_layer_workers = max_layer_workers if max_layer_workers else "auto" actual_io_workers = max_workers if max_workers else "auto" - mprint("\n=== PROFILING SUMMARY ===") + mprint(f"\n=== PROFILING SUMMARY ===") mprint( f"create_child_state_dict: {create_child_state_dict_time:.2f}s ({create_child_state_dict_time / total_core_time * 100:.1f}%)" ) @@ -216,4 +189,4 @@ def init_child_from_parent( ) mprint(f"Total core processing: {total_core_time:.2f}s") mprint(f"Optimizations: I/O workers={actual_io_workers}, Layer workers={actual_layer_workers}") - mprint("=========================\n") + mprint(f"=========================\n") From ee8f538e31c92444efeff2b370a08b06b9e73b4b Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 6 Mar 2026 03:41:19 -0800 Subject: [PATCH 17/30] Fix buid docs issue. Signed-off-by: Daniel Korzekwa --- .../puzzletron/anymodel/model_descriptor/model_descriptor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py index 0fd9149ec..73d56d201 100644 --- a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py @@ -61,6 +61,7 @@ def mlp_no_op_post_init(decoder_layer: nn.Module): counterparts. Example for replacing a layernorm layer with identity: + >>> decoder_layer.post_attention_layernorm = Same() Example for replacing an MLP layer with zeroes (zeroes since hidden_states are added to @@ -70,6 +71,7 @@ def mlp_no_op_post_init(decoder_layer: nn.Module): In case the MLP layer to replace returns multiple outputs i.e `hidden_states, _ = self.mlp()`, use the util method `return_tuple_of_size` to return trailing None values: + >>> decoder_layer.mlp = return_tuple_of_size(MatchingZeros, size=2)() """ raise NotImplementedError @@ -82,13 +84,16 @@ def attn_no_op_post_init(decoder_layer: nn.Module): counterparts. Example for replacing a layernorm layer with identity: + >>> decoder_layer.post_attention_layernorm = Same() Example for replacing an attention layer with zeroes: + >>> decoder_layer.self_attn = MatchingZeros() In case the attention layer returns multiple outputs i.e `hidden_states, _ = self.self_attn()`, use the util method `return_tuple_of_size` to return trailing None values: + >>> decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() """ raise NotImplementedError From 0ad6d924bedb36038d9a0f7635b8007344b01600 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 6 Mar 2026 05:35:06 -0800 Subject: [PATCH 18/30] Merging build_library_and_stats Signed-off-by: Daniel Korzekwa --- .../puzzletron/build_library_and_stats.py | 9 +++- modelopt/torch/puzzletron/puzzletron.py | 8 ++-- .../build_replacement_library.py | 33 +++++++++++--- .../calc_subblock_params_and_memory.py | 4 +- .../subblock_stats/calc_subblock_stats.py | 45 ++++++++++++++----- modelopt/torch/puzzletron/utils/utils.py | 33 ++++++-------- 6 files changed, 87 insertions(+), 45 deletions(-) diff --git a/modelopt/torch/puzzletron/build_library_and_stats.py b/modelopt/torch/puzzletron/build_library_and_stats.py index 5f04f6049..31cebdf6b 100644 --- a/modelopt/torch/puzzletron/build_library_and_stats.py +++ b/modelopt/torch/puzzletron/build_library_and_stats.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unified command that runs build_replacement_library followed by calc_subblock_stats. +""" +Unified command that runs build_replacement_library followed by calc_subblock_stats. This script combines the functionality of both commands into a single workflow: 1. First, it builds the replacement library for the puzzle @@ -28,17 +29,21 @@ all the same configuration parameters for both build_replacement_library and calc_subblock_stats. """ +import hydra from omegaconf import DictConfig from modelopt.torch.puzzletron.replacement_library.build_replacement_library import ( launch_build_replacement_library, ) from modelopt.torch.puzzletron.subblock_stats.calc_subblock_stats import launch_calc_subblock_stats +from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.utils.parsing import format_global_config def launch_build_library_and_stats(cfg: DictConfig) -> None: - """Launch both build_replacement_library and calc_subblock_stats in sequence. + """ + Launch both build_replacement_library and calc_subblock_stats in sequence. Args: cfg: Hydra configuration containing settings for both commands diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py index 94a1de57e..87d90fdd9 100644 --- a/modelopt/torch/puzzletron/puzzletron.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -62,10 +62,10 @@ def puzzletron( pruning_ckpts.launch_prune_ckpt(hydra_cfg) dist.barrier() - # # Step 4: build_library_and_stats (single process) - # if dist.is_master(): - # build_library_and_stats.launch_build_library_and_stats(hydra_cfg) - # dist.barrier() + # Step 4: build_library_and_stats (single process) + if dist.is_master(): + build_library_and_stats.launch_build_library_and_stats(hydra_cfg) + dist.barrier() # # Step 5: calc_one_block_scores (distributed processing) # scoring.launch_scoring(hydra_cfg) diff --git a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py index 1618aceaf..aec10e03b 100644 --- a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py @@ -12,17 +12,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""This module constructs the replacement library JSON files from a puzzle directory containing +""" +This module constructs the replacement library JSON files from a puzzle directory containing multiple trained model checkpoints. It analyzes checkpoints to extract unique block and subblock configurations, builds a library of available replacements, and generates solutions for layer replacement in compressed models. The resulting replacement library can then be used by ReplacementLibrary to efficiently load models with mixed teacher/student layers. + +Standard Puzzle Usage: +====================== +python -m modelopt.torch.puzzletron.replacement_library.build_replacement_library PUZZLE_DIR + +Teacher checkpoint dir is assumed to be inside PUZZLE_DIR/ckpts/teacher (symlink is recommended) +though you can supply an explicit --teacher_checkpoint_dir. + +--add_ffn_no_ops and --add_attention_no_ops are optional (default True), + + +Untrained puzzle run (with bypass): +=================================== +The subblock that doesn't interest you in the checkpoint should be no_op. + """ # mypy: ignore-errors import json from pathlib import Path -from typing import Any +from typing import Any, Type import pandas as pd from omegaconf import DictConfig @@ -57,7 +73,8 @@ def build_replacement_library( add_ffn_no_ops: bool = True, add_attention_no_ops: bool = True, ) -> None: - """For normal puzzle runs, use default values. + """ + For normal puzzle runs, use default values. For advanced use cases, see the Usage section. """ master_puzzle_dir = Path(master_puzzle_dir) @@ -90,7 +107,9 @@ def build_replacement_library( def launch_build_replacement_library(cfg: DictConfig) -> None: - """Launch the build replacement library function with Hydra configuration.""" + """ + Launch the build replacement library function with Hydra configuration. + """ mprint(f"Building replacement library for puzzle directory: {cfg.puzzle_dir}") mprint(f"Teacher directory: {cfg.teacher_dir}") mprint( @@ -113,8 +132,8 @@ def infer_teacher_dir( teacher_checkpoint_dir = Path(master_puzzle_dir) / CHECKPOINTS_DIR_NAME / "teacher" if not teacher_checkpoint_dir.exists(): raise ValueError( - "You must either provide the --teacher_checkpoint_dir argument, or create a link to the " - "teacher dir under '{PUZZLE_DIR}/ckpts'." + f"You must either provide the --teacher_checkpoint_dir argument, or create a link to the " + f"teacher dir under '{{PUZZLE_DIR}}/ckpts'." ) teacher_checkpoint_dir = Path(teacher_checkpoint_dir).resolve().absolute() return teacher_checkpoint_dir @@ -362,7 +381,7 @@ def _add_no_op_subblock_rows( def _get_rows_with_no_op_subblock( subblocks_df: pd.DataFrame, no_op_subblock: str -) -> tuple[pd.DataFrame, type[AttentionConfig] | type[FFNConfig]]: +) -> tuple[pd.DataFrame, Type[AttentionConfig] | Type[FFNConfig]]: other_subblock = "ffn" if no_op_subblock == "attention" else "attention" subblock_cls = AttentionConfig if no_op_subblock == "attention" else FFNConfig no_op_subblock_config = subblock_cls(no_op=True) diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py index 2e8630bc9..88081d177 100644 --- a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py @@ -189,7 +189,7 @@ def calculate_attention_memory( ): seq_len = min(seq_len, attention_chunk_size) - kv_dim = calculate_kv_dim(attention_config.n_heads_in_group, n_head, n_embd) + kv_dim = calculate_kv_dim(attention_config.num_key_value_heads, n_head, n_embd) total_num_tokens = seq_len * (batch_size + prefill_queue_size) kv_cache_size = total_num_tokens * kv_dim query_prefill_size = seq_len * n_embd if allocate_prefill_query else 0 @@ -208,7 +208,7 @@ def calculate_attention_params( n_embd: int, n_head: int, ) -> int: - kv_dim = calculate_kv_dim(attention_config.n_heads_in_group, n_head, n_embd) + kv_dim = calculate_kv_dim(attention_config.num_key_value_heads, n_head, n_embd) return ( n_embd * n_embd * 2 # Wq + Wo + n_embd * kv_dim # Wk + Wv diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py index 07597eb5c..2db0bc391 100644 --- a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py @@ -19,11 +19,10 @@ import dataclasses import json import os -from collections.abc import Iterable from functools import partial from itertools import product from pathlib import Path -from typing import TypeVar +from typing import Iterable, Optional, Type, TypeVar os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" @@ -33,6 +32,10 @@ from omegaconf import DictConfig, ListConfig, OmegaConf from tqdm import tqdm +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( AttentionConfig, BlockConfig, @@ -56,6 +59,15 @@ # Type variable for dataclasses T_DataClass = TypeVar("T_DataClass") +""" +Usage: +python -m modelopt.torch.puzzletron.subblock_stats.calc_subblock_stats PUZZLE_DIR [ --benchmark_iterations 1000 ] + +--benchmark_iterations=None (the default) means that the code won't use infery to benchmark runtime, + only memory stats will be calculated. If you want to benchmark runtime, run inside an infery-llm docker. + +""" + def calculate_subblock_stats( calc_subblock_stats_config: DictConfig, @@ -69,7 +81,7 @@ def calculate_subblock_stats( n_embd: int, n_head: int, vocab_size: int, - benchmark_iterations: int | None, + benchmark_iterations: Optional[int], use_cuda_graph: bool, weights_dtype: torch.dtype, activations_dtype: torch.dtype, @@ -181,6 +193,7 @@ def calculate_subblock_stats( ) if is_calc_runtime: + pass # TODO: fix # from puzzle_tools.calc_subblock_runtime import measure_non_block_runtime_ms # non_block_runtime_ms, embedding_runtime_ms, lm_head_runtime_ms = \ @@ -206,17 +219,21 @@ def calculate_subblock_stats( def launch_calc_subblock_stats(cfg: DictConfig) -> None: - """Launch the calc subblock stats function with Hydra configuration.""" + """ + Launch the calc subblock stats function with Hydra configuration. + """ mprint(f"Calculating subblock stats for puzzle directory: {cfg.puzzle_dir}") mprint(f"Teacher directory: {cfg.teacher_dir}") mprint( f"Calc subblock stats config: {format_global_config(cfg.calc_subblock_stats, title='Calc subblock stats')}" ) + descriptor = ModelDescriptorFactory.get(cfg.descriptor) calculate_subblock_stats_for_puzzle_dir( cfg.calc_subblock_stats, master_puzzle_dir=cfg.puzzle_dir, teacher_dir=cfg.teacher_dir, + descriptor=descriptor, model_hidden_sizes=cfg.calc_subblock_stats.get("model_hidden_sizes", OmegaConf.create([])), ffn_hidden_sizes=cfg.calc_subblock_stats.get("ffn_hidden_sizes", OmegaConf.create([])), batch_sizes=cfg.calc_subblock_stats.batch_sizes, @@ -224,7 +241,7 @@ def launch_calc_subblock_stats(cfg: DictConfig) -> None: generation_seq_len=cfg.calc_subblock_stats.generation_seq_len, num_active_tokens_override=cfg.calc_subblock_stats.get("num_active_tokens_override", None), prefill_queue_size=cfg.calc_subblock_stats.prefill_queue_size, - allocate_prefill_query=cfg.calc_subblock_stats.allocate_prefill_query, + allocate_prefill_query=cfg.calc_subblock_stats.get("allocate_prefill_query", False), benchmark_iterations=cfg.calc_subblock_stats.get("benchmark_iterations", None), merge_with_existing_stats=cfg.calc_subblock_stats.merge_with_existing_stats, subblock_stats_filename=cfg.calc_subblock_stats.subblock_stats_filename, @@ -236,6 +253,7 @@ def calculate_subblock_stats_for_puzzle_dir( calc_subblock_stats_config: DictConfig, master_puzzle_dir: Path | str, teacher_dir: Path | str, + descriptor: Type[ModelDescriptor], model_hidden_sizes: ListConfig, ffn_hidden_sizes: ListConfig, batch_sizes: Iterable[int] = (1, 8, 16, 32, 64, 128, 256), @@ -268,6 +286,8 @@ def calculate_subblock_stats_for_puzzle_dir( Path(teacher_dir) if teacher_dir is not None else master_puzzle_dir / "ckpts" / "teacher" ) model_config = load_model_config(teacher_dir) + # Get language model config for LM-specific attributes (VL models have nested config) + lm_config = descriptor.get_language_model_config(model_config) subblock_configs = _load_subblock_configs(master_puzzle_dir, ffn_hidden_sizes, model_config) subblock_stats_file = master_puzzle_dir / subblock_stats_filename @@ -299,7 +319,7 @@ def calculate_subblock_stats_for_puzzle_dir( ] model_hidden_sizes = model_hidden_sizes + [ - model_config.hidden_size + lm_config.hidden_size ] # add a teacher model hidden size for batch_size, ( weights_dtype, @@ -323,8 +343,8 @@ def calculate_subblock_stats_for_puzzle_dir( generation_seq_len=generation_seq_len, prefill_queue_size=prefill_queue_size, n_embd=model_hidden_size, - n_head=model_config.num_attention_heads, - vocab_size=model_config.vocab_size, + n_head=lm_config.num_attention_heads, + vocab_size=lm_config.vocab_size, benchmark_iterations=curr_benchmark_iterations, use_cuda_graph=True, weights_dtype=weights_dtype, @@ -445,7 +465,7 @@ def _load_subblock_configs_from_replacement_library( return subblock_configs -T_DataClass: TypeVar = type[dataclasses.dataclass] +T_DataClass: TypeVar = Type[dataclasses.dataclass] def _dataclass_from_dict( @@ -483,7 +503,7 @@ def add_int8_runtime_estimates(subblock_stats: list[dict]) -> None: if (subblock_config := curr_subblock.get("subblock_config")) is not None: if hasattr(subblock_config, "__dataclass_fields__"): subblock_config = dataclasses.asdict(subblock_config) - is_attention = subblock_config.get("n_heads_in_group", None) is not None + is_attention = subblock_config.get("num_key_value_heads", None) is not None runtime_factor = attention_factor if is_attention else ffn_factor for stat_name, stat_value in bf16_subblock.items(): if "runtime" in stat_name: @@ -512,7 +532,10 @@ def _find_corresponding_bf16_stats(args: dict, subblock_stats: list[dict]) -> di stats for stats in subblock_stats if all( - [stats["args"][key] == corresponding_bf16_args[key] for key in corresponding_bf16_args] + [ + stats["args"][key] == corresponding_bf16_args[key] + for key in corresponding_bf16_args.keys() + ] ) ] if len(matching_bf16_stats) == 0: diff --git a/modelopt/torch/puzzletron/utils/utils.py b/modelopt/torch/puzzletron/utils/utils.py index d56aab0bd..77a13609a 100644 --- a/modelopt/torch/puzzletron/utils/utils.py +++ b/modelopt/torch/puzzletron/utils/utils.py @@ -28,24 +28,21 @@ ) -def calculate_kv_dim(n_heads_in_group: int, n_head: int, n_embd: int) -> int: +def calculate_kv_dim(num_key_value_heads: int, n_head: int, n_embd: int) -> int: """Calculate the key-value dimension for grouped-query attention. - TODO: Consider a better place for this function. - Args: - n_heads_in_group: Number of attention heads per key-value group. + num_key_value_heads: Number of key-value heads. n_head: Total number of attention heads. n_embd: Embedding dimension. Returns: - Combined dimension for key and value tensors (2 * n_kv_heads * head_size). + Combined dimension for key and value tensors (2 * num_key_value_heads * head_size). """ - if n_heads_in_group is None: + if num_key_value_heads is None: return 0 - n_kv_heads = n_head // n_heads_in_group head_size = n_embd // n_head - kv_dim = 2 * n_kv_heads * head_size + kv_dim = 2 * num_key_value_heads * head_size return kv_dim @@ -53,7 +50,6 @@ def raise_unknown_subblock_config_error(subblock_config: Any) -> None: """Raise an error for invalid subblock configuration types. TODO: Consider a better place for this function. - Args: subblock_config: The invalid subblock configuration object. @@ -69,7 +65,6 @@ def sizeof_dtype(dtype: torch.dtype) -> int | float: """Return the size in bytes of the given data type. TODO: Consider a better place for this function. - Args: dtype: PyTorch data type or custom type string (e.g., 'nvfp4'). @@ -125,10 +120,10 @@ def solution_to_str(block_configs: list[dict[str, Any] | BlockConfig]) -> str: def block_config_to_str(block_config: BlockConfig | dict[str, Any] | None) -> str | None: - """Convert a BlockConfig to a human-readable string representation. + """ + Convert a BlockConfig to a human-readable string representation. TODO: Consider a better place for this function. - Args: block_config: BlockConfig dataclass or dict containing attention and ffn configs. @@ -153,7 +148,6 @@ def subblock_config_to_str( """Convert a subblock config (FFN, Attention, Mamba, or MoE) to string. TODO: Consider a better place for this function. - Args: subblock_config: FFNConfig, AttentionConfig dataclass or dict. subblock_name: Name of subblock ('ffn', 'attention', 'mamba', 'moe'). @@ -161,7 +155,7 @@ def subblock_config_to_str( Returns: Formatted string showing subblock type and key parameters (e.g., intermediate_size, - n_heads_in_group), or None if input is None. + num_key_value_heads), or None if input is None. """ if subblock_config is None: return None @@ -194,8 +188,8 @@ def subblock_config_to_str( intermediate_size = subblock_config["intermediate_size"] rep += f" intermediate_{intermediate_size}".ljust(8) elif subblock_name == "attention": - n_heads_in_group = subblock_config["n_heads_in_group"] - rep += f" gqa_{n_heads_in_group}".ljust(8) + num_key_value_heads = subblock_config["num_key_value_heads"] + rep += f" kv_heads_{num_key_value_heads}".ljust(8) elif subblock_name == "mamba": mamba_num_heads = subblock_config["mamba"]["num_heads"] mamba_head_dim = subblock_config["mamba"]["head_dim"] @@ -216,7 +210,8 @@ def subblock_config_to_str( class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): def __init__(self, device=None, dtype=None): - """Create tensors with given device and dtype and don't run initialization + """ + Create tensors with given device and dtype and don't run initialization (but instead use "empty tensors", i.e. uninitialized memory). device: `torch.device` to work with @@ -225,8 +220,8 @@ def __init__(self, device=None, dtype=None): Example:: with EmptyInitOnDevice("cuda", dtype=torch.bfloat16): model = LLaMA(model_config) - model.load_state_dict(torch.load("llama-lit/7B/lit-llama.pth")) - """ + model.load_state_dict(torch.load("llama-lit/7B/lit-llama.pth"))""" + self.device = device self.dtype = dtype From 995eb1a5eeb4e1eda61fd0150da569d00f2f1d12 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 6 Mar 2026 05:57:57 -0800 Subject: [PATCH 19/30] Merging anymodel: calc_one_block_scores Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/puzzletron.py | 4 +- .../replacement_library.py | 103 ++++++++---- ...validate_puzzle_with_multi_replacements.py | 155 ++++++++++-------- 3 files changed, 161 insertions(+), 101 deletions(-) diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py index 87d90fdd9..262df7648 100644 --- a/modelopt/torch/puzzletron/puzzletron.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -67,8 +67,8 @@ def puzzletron( build_library_and_stats.launch_build_library_and_stats(hydra_cfg) dist.barrier() - # # Step 5: calc_one_block_scores (distributed processing) - # scoring.launch_scoring(hydra_cfg) + # Step 5: calc_one_block_scores (distributed processing) + scoring.launch_scoring(hydra_cfg) # # Step 6: mip_and_realize_models (distributed processing) # mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_library.py b/modelopt/torch/puzzletron/replacement_library/replacement_library.py index bf6cc6636..7935fea4a 100644 --- a/modelopt/torch/puzzletron/replacement_library/replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/replacement_library.py @@ -12,23 +12,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Replacement library for efficiently loading and managing layer-replaced DeciLM models. +""" +Replacement library for efficiently loading and managing layer-replaced DeciLM models. - Uses replacement_utils for parsing, sorting, and analyzing layer replacement configurations """ # mypy: ignore-errors +import copy import json import re +import tempfile from pathlib import Path +from typing import List, Optional -import numpy as np import torch from immutabledict import immutabledict from lru import LRU +from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from torch import nn +from transformers import PretrainedConfig, PreTrainedModel import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.converter.converter import Converter from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, @@ -51,9 +57,11 @@ init_module_with_state_dict, load_model_config, ) +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_model_config from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import ( create_dummy_model, is_in_safetensors_format, + load_and_shard_model, load_sharded_state_dict, ) @@ -62,8 +70,10 @@ class ReplacementLibrary: def __init__( self, replacement_library_path: str | Path, - model_config_overrides: dict | None = None, + descriptor, + model_config_overrides: Optional[dict] = None, ): + self.descriptor = descriptor self.replacement_library = self._load_replacement_library(replacement_library_path) self._ensure_all_checkpoints_are_split() self.model_config_overrides = ( @@ -114,42 +124,77 @@ def n_layer(self) -> int: def model_config(self) -> DeciLMConfig: if self._model_config is None: self._model_config = load_model_config( - self.get_arbitrary_checkpoint_dir(), self.model_config_overrides + self.get_arbitrary_checkpoint_dir(), + self.model_config_overrides, + ignore_unexpected_config_keys=True, ) return self._model_config def create_model_config(self, layer_replacements: list[dict]): block_configs, _ = extract_block_configs_and_locations(layer_replacements) - model_config = self.model_config.set_block_configs(block_configs) + model_config = copy.deepcopy(self.model_config) + model_config.block_configs = block_configs + model_config.num_hidden_layers = len(block_configs) return model_config - def load_model(self, layer_replacements: list[dict]) -> DeciLMForCausalLM: - block_configs, block_locations = extract_block_configs_and_locations(layer_replacements) - model_config = self.model_config.set_block_configs(block_configs) + def _get_arbitrary_block_checkpoint_paths(self): + checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir()) + subblocks_dir = checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME + non_block_paths = [p for p in subblocks_dir.glob("*.safetensors") if "block_" not in p.name] + return non_block_paths + + def create_index_file_from_weights(self, weight_paths: List[str]): + weight_map = {} + for weight_path in weight_paths: + weight_path = Path(weight_path) + with safe_open(str(weight_path), framework="pt", device="cpu") as f: + for tensor_name in f.keys(): + weight_map[tensor_name] = f"{SAFETENSORS_SUBBLOCKS_DIR_NAME}/{weight_path.name}" + index = {"metadata": {"format": "pt"}, "weight_map": weight_map} + return index + + def prepare_tmp_checkpoint_dir( + self, + tmpdir: Path, + model_config: PretrainedConfig, + layer_replacements: List[dict], + ): + arbitrary_checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir()) - owned_block_indexes = _get_owned_block_indexes(model_config.get_num_hidden_layers()) - model = create_dummy_model(model_config, self.dtype) + weight_paths = self._get_arbitrary_block_checkpoint_paths() + for layer_replacement in layer_replacements: + weight_paths += layer_replacement["weight_paths"] - is_first_shard = 0 in owned_block_indexes - if is_first_shard and not isinstance(model.model.get_input_embeddings(), nn.Embedding): - model.set_input_embeddings(self.get_embedding()) + weights_index = self.create_index_file_from_weights(weight_paths) + index_path = tmpdir / "model.safetensors.index.json" + with index_path.open("w", encoding="utf-8") as out: + json.dump(weights_index, out, indent=2, sort_keys=True) - is_last_shard = model_config.get_num_hidden_layers() - 1 in owned_block_indexes - if is_last_shard and not isinstance(model.model.get_output_embeddings(), nn.Linear): - model.model.set_final_layer_norm(self.get_ln_f()) - model.set_output_embeddings(self.get_lm_head()) + Converter.copy_checkpoint_files(arbitrary_checkpoint_dir, tmpdir) + save_model_config(model_config, tmpdir) - active_blocks = [] - for block_idx in owned_block_indexes: - layer_replacement, block_idx_in_replacement = block_locations[block_idx] - block = self.get_block(layer_replacement, block_idx_in_replacement) - model.model.layers[block_idx] = block - active_blocks.append(block) + # create symlinks inside tmpdir + subblocks_dir = tmpdir / SAFETENSORS_SUBBLOCKS_DIR_NAME + subblocks_dir.mkdir(exist_ok=True) + for weight_path in weight_paths: + link_path = subblocks_dir / weight_path.name + link_path.symlink_to(weight_path) - self._move_inactive_blocks_to_cpu(active_blocks) + def load_model( + self, + layer_replacements: list[dict], + ) -> PreTrainedModel: + """Load model using AnyModel approach with temporary checkpoint directory.""" + model_config = self.create_model_config(layer_replacements) + with tempfile.TemporaryDirectory(prefix="replacement_solution_") as tmpdir: + tmpdir = Path(tmpdir) + self.prepare_tmp_checkpoint_dir( + tmpdir, model_config=model_config, layer_replacements=layer_replacements + ) + model = load_and_shard_model(descriptor=self.descriptor, checkpoint_path=tmpdir) return model - def load_checkpoint(self, checkpoint_dir: str | Path) -> DeciLMForCausalLM: + def load_checkpoint(self, checkpoint_dir: str | Path) -> PreTrainedModel: checkpoint_dir = Path(checkpoint_dir).resolve() layer_replacements = self._locate_replacements_of_entire_checkpoint(checkpoint_dir) model = self.load_model(layer_replacements) @@ -221,7 +266,7 @@ def _load_layer_replacement(self, layer_replacement: dict) -> nn.ModuleList: if len(state_dict) > 0: block_indices = [ int(re.findall(r"^model\.layers\.(\d+)\.", param_name)[0]) - for param_name in state_dict + for param_name in state_dict.keys() ] assert sorted(set(block_indices)) == list( range(min(block_indices), max(block_indices) + 1) @@ -239,7 +284,9 @@ def _load_layer_replacement(self, layer_replacement: dict) -> nn.ModuleList: } dtype = infer_weights_dtype(state_dict) - model_config = self.model_config.set_block_configs(layer_replacement["child_block_configs"]) + model_config = copy.deepcopy(self.model_config) + model_config.block_configs = layer_replacement["child_block_configs"] + model_config.num_hidden_layers = len(layer_replacement["child_block_configs"]) module_list = nn.ModuleList( [ @@ -316,7 +363,7 @@ def _get_arbitrary_non_block_param(self, param_name: str) -> torch.Tensor: partial_state_dict = load_sharded_state_dict(checkpoint_dir, [param_name]) return partial_state_dict[param_name] - non_block_pth_path = checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME / "non_block.pth" + non_block_pth_path = checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME / f"non_block.pth" assert non_block_pth_path.exists(), _error_message_ensure_split(checkpoint_dir) non_block_state_dict = torch.load(non_block_pth_path) return non_block_state_dict[param_name] diff --git a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py index 4e3266df4..7311e35e5 100644 --- a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py @@ -21,9 +21,11 @@ # mypy: ignore-errors import json +import shutil import warnings from functools import partial from pathlib import Path +from typing import Optional import torch from omegaconf import DictConfig @@ -31,6 +33,8 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.converter import Converter +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.replacement_library.replacement_library import ReplacementLibrary from modelopt.torch.puzzletron.replacement_library.replacement_utils import parse_layer_replacement @@ -40,15 +44,15 @@ copy_tokenizer, ) from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import ( - copy_deci_lm_hf_code, save_checkpoint, save_safetensors_index, ) +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model from modelopt.torch.puzzletron.tools.validation_utils import ( validate_model_and_extract_hidden_states, validate_model_with_teacher_similarity_metrics, ) -from modelopt.torch.puzzletron.utils.parsing import get_nested_key +from modelopt.torch.puzzletron.utils.parsing import get_nested_key, parse_path from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import perform_pipeline_stitches """ @@ -68,62 +72,57 @@ def validate_puzzle_solutions(args: DictConfig) -> None: Args: args: Configuration object containing the following attributes: - Puzzle Configuration (Required) attributes: - - - ``replacement_library_path`` (Path): Path to the replacement library JSON file. - - ``solutions_path`` (Path): Path to puzzle solutions JSON file or directory containing solution files. - - ``solutions_to_validate`` (list[int], optional): Indices of specific solutions to validate. - Validates all solutions if None. - - ``sort_solutions_by`` (str, optional): JSON field path to sort solutions by before validation. - - ``bigger_is_better`` (bool): If True, sort solutions in descending order. Used with sort_solutions_by. - - ``skip_validation`` (bool): If True, skip model validation and only save models if requested. - - ``save_models`` (bool): If True, save realized model checkpoints for each solution. - - Teacher/Tokenizer Configuration attributes: - - - ``teacher_dir`` (Path, optional): Path to teacher model directory. Auto-inferred if not provided. - - ``tokenizer_name`` (str, optional): Tokenizer name/path. Uses teacher_dir if not specified. - - Model Configuration (Required if skip_validation=False) attributes: - - - ``model_dtype`` (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). - - ``autocast_dtype`` (str or torch.dtype): Autocast data type for mixed precision. - - Dataset Configuration (Required if skip_validation=False) attributes: - - - ``dataset_path`` (str): Path to the validation dataset. - - ``data_column`` (str): Column name in dataset containing text data. - - ``block_size`` (int): Maximum sequence length for tokenization. - - ``eval_samples`` (int, optional): Number of samples to evaluate. - - ``val_dataset_name`` (str): Name of validation dataset split. - - ``source_datasets_to_discard`` (list[str], optional): List of source datasets to exclude. - - ``load_dataset_fn`` (callable, optional): Custom function to load the dataset. - - Data Processing (Required if skip_validation=False) attributes: - - - ``micro_batch_size`` (int): Batch size for evaluation. - - ``seed`` (int): Random seed for reproducibility. - - ``shuffle_seed`` (int, optional): Seed for shuffling data. - - ``varlen`` (bool): Enable variable-length sequences. - - ``bos_rate`` (float): Rate of adding BOS token. - - ``fim_rate`` (float): Fill-in-the-middle rate for code completion tasks. - - ``fim_spm_rate`` (float): SPM-based fill-in-the-middle rate. - - Output Configuration attributes: - - - ``output_dir`` (Path, optional): Directory to save validation results. - Auto-generated from solutions_path if not provided. - - Execution Options (Optional if skip_validation=False) attributes: - - - ``calc_losses_on_cpu`` (bool): Calculate losses on CPU to avoid OOM. - - ``write_results`` (bool): Write validation results to file. - - ``activations_log_dir`` (str, optional): Directory to log activation scores. - - ``activation_hooks_kwargs`` (str or dict, optional): Arguments for activation hooks. + Puzzle Configuration (Required): + - replacement_library_path (Path): Path to the replacement library JSON file. + - solutions_path (Path): Path to puzzle solutions JSON file or directory containing solution files. + - solutions_to_validate (list[int], optional): Indices of specific solutions to validate. + Validates all solutions if None. + - sort_solutions_by (str, optional): JSON field path to sort solutions by before validation. + - bigger_is_better (bool): If True, sort solutions in descending order. Used with sort_solutions_by. + - skip_validation (bool): If True, skip model validation and only save models if requested. + - save_models (bool): If True, save realized model checkpoints for each solution. + + Teacher/Tokenizer Configuration: + - teacher_dir (Path, optional): Path to teacher model directory. Auto-inferred if not provided. + - tokenizer_name (str, optional): Tokenizer name/path. Uses teacher_dir if not specified. + + Model Configuration (Required if skip_validation=False): + - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration (Required if skip_validation=False): + - dataset_path (str): Path to the validation dataset. + - data_column (str): Column name in dataset containing text data. + - block_size (int): Maximum sequence length for tokenization. + - eval_samples (int, optional): Number of samples to evaluate. + - val_dataset_name (str): Name of validation dataset split. + - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. + - load_dataset_fn (callable, optional): Custom function to load the dataset. + + Data Processing (Required if skip_validation=False): + - micro_batch_size (int): Batch size for evaluation. + - seed (int): Random seed for reproducibility. + - shuffle_seed (int, optional): Seed for shuffling data. + - varlen (bool): Enable variable-length sequences. + - bos_rate (float): Rate of adding BOS token. + - fim_rate (float): Fill-in-the-middle rate for code completion tasks. + - fim_spm_rate (float): SPM-based fill-in-the-middle rate. + + Output Configuration: + - output_dir (Path, optional): Directory to save validation results. + Auto-generated from solutions_path if not provided. + + Execution Options (Optional if skip_validation=False): + - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. + - write_results (bool): Write validation results to file. + - activations_log_dir (str, optional): Directory to log activation scores. + - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. Returns: None. Saves validation results and optionally model checkpoints to disk. """ + descriptor = ModelDescriptorFactory.get(args.descriptor) + puzzle_solutions = load_puzzle_solutions( args.solutions_path, args.sort_solutions_by, args.bigger_is_better ) @@ -143,29 +142,41 @@ def validate_puzzle_solutions(args: DictConfig) -> None: else args.solutions_path.with_name(f"{args.solutions_path.stem}--validation") ) - replacement_library = ReplacementLibrary(args.replacement_library_path) + replacement_library = ReplacementLibrary( + args.replacement_library_path, + descriptor=descriptor, + model_config_overrides={"use_cache": False}, + ) teacher_hidden_states = None if (args.teacher_dir is not None) and (not args.skip_validation): - teacher_model = replacement_library.load_checkpoint(args.teacher_dir) + teacher_model = load_and_shard_model( + checkpoint_path=args.teacher_dir, descriptor=descriptor + ) teacher_model.cuda(dist.local_rank()) - stitched_model = perform_pipeline_stitches(teacher_model) + stitched_model = perform_pipeline_stitches(teacher_model, descriptor=descriptor) teacher_hidden_states = validate_model_and_extract_hidden_states( args, stitched_model, tokenizer, output_dir, model_name="teacher", - pipeline_parallel=True, val_dataloader=val_dataloader, ) + # Properly release CUDA memory after teacher validation + teacher_model.cpu() + stitched_model.cpu() + torch.cuda.empty_cache() + torch.cuda.synchronize() + dist.barrier() + for i_solution, puzzle_solution in tqdm( list(zip(args.solutions_to_validate, puzzle_solutions)), desc="Validating solutions" ): layer_replacements = _extract_layer_replacements_from_puzzle_solution(puzzle_solution) - # realizable_as_symlinks = can_realize_as_symlinks(layer_replacements) - realizable_as_symlinks = False + realizable_as_symlinks = can_realize_as_symlinks(layer_replacements) + # realizable_as_symlinks = False model_config = replacement_library.create_model_config(layer_replacements) if (args.save_models and not realizable_as_symlinks) or (not args.skip_validation): model = replacement_library.load_model(layer_replacements) @@ -177,24 +188,21 @@ def validate_puzzle_solutions(args: DictConfig) -> None: / f"solution_{i_solution}" ) - model_config.dtype = args.model_dtype - model_config.architectures = ["DeciLMForCausalLM"] + model_config.dtype = getattr(args, "model_dtype", "torch.bfloat16") + Converter.copy_checkpoint_files(args.teacher_dir, checkpoint_dir) if realizable_as_symlinks: if dist.is_master(): - save_checkpoint_as_symlinks( - layer_replacements, model_config, checkpoint_dir, replacement_library - ) - else: - save_checkpoint(model, checkpoint_dir) + # save_checkpoint_as_symlinks is currently not supported + pass + save_checkpoint(model, checkpoint_dir, descriptor) copy_tokenizer(args.tokenizer_name, checkpoint_dir) - copy_deci_lm_hf_code(checkpoint_dir) dist.barrier() if not args.skip_validation: model.cuda(dist.local_rank()) - stitched_model = perform_pipeline_stitches(model) + stitched_model = perform_pipeline_stitches(model, descriptor=descriptor) validate_model_with_teacher_similarity_metrics( args, stitched_model, @@ -203,10 +211,15 @@ def validate_puzzle_solutions(args: DictConfig) -> None: output_dir, model_name=f"solution_{i_solution}", extra_payload={"i_solution": i_solution, "puzzle_solution": puzzle_solution}, - pipeline_parallel=True, val_dataloader=val_dataloader, ) + # Properly release CUDA memory after solution validation + model.cpu() + stitched_model.cpu() + torch.cuda.empty_cache() + torch.cuda.synchronize() + dist.barrier() @@ -278,7 +291,7 @@ def _extract_layer_replacements_from_puzzle_solution( def load_puzzle_solutions( solutions_path: Path, - sort_solutions_by: str | None, + sort_solutions_by: Optional[str], bigger_is_better: bool, ) -> list[dict]: assert solutions_path.exists(), f"{solutions_path=} does not exist" From 34081c9efc7c7b914b90f41306e71c06bfa145e7 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 6 Mar 2026 05:58:25 -0800 Subject: [PATCH 20/30] Mering any_model: calc_one_block_scores Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/tools/validation_utils.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/puzzletron/tools/validation_utils.py b/modelopt/torch/puzzletron/tools/validation_utils.py index 697977cda..d7197e8ab 100644 --- a/modelopt/torch/puzzletron/tools/validation_utils.py +++ b/modelopt/torch/puzzletron/tools/validation_utils.py @@ -21,7 +21,7 @@ # mypy: ignore-errors from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional, Union import torch from omegaconf import DictConfig, OmegaConf @@ -44,8 +44,7 @@ def validate_model_and_extract_hidden_states( tokenizer: PreTrainedTokenizerBase, output_dir: str | Path, model_name: str, - extra_payload: dict[str, Any] | None = None, - pipeline_parallel: bool = False, + extra_payload: Optional[dict[str, Any]] = None, val_dataloader=None, ) -> list[torch.Tensor | LowMemorySparseTensor]: mprint(f""" @@ -60,7 +59,6 @@ def validate_model_and_extract_hidden_states( model, tokenizer, return_hidden_states=True, - pipeline_parallel=pipeline_parallel, val_dataloader=val_dataloader, ) if dist.is_last_process(): @@ -77,8 +75,7 @@ def validate_model_with_teacher_similarity_metrics( target_hidden_states_per_batch: list[torch.Tensor], output_dir: str | Path, model_name: str, - extra_payload: dict[str, Any] | None = None, - pipeline_parallel: bool = False, + extra_payload: Optional[dict[str, Any]] = None, calculate_full_score_ablations: bool = False, val_dataloader=None, ) -> None: @@ -95,7 +92,6 @@ def validate_model_with_teacher_similarity_metrics( model, tokenizer, target_hidden_states_per_batch=target_hidden_states_per_batch, - pipeline_parallel=pipeline_parallel, calculate_full_score_ablations=calculate_full_score_ablations, val_dataloader=val_dataloader, ) From 47414d50c38ecc9d165f22624edb83230cdf1b87 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 9 Mar 2026 09:22:00 -0700 Subject: [PATCH 21/30] Clarify readme and avoid reusing the same reference in llama_converter. Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/anymodel/README.md | 4 ++-- .../torch/puzzletron/anymodel/models/llama/llama_converter.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/puzzletron/anymodel/README.md b/modelopt/torch/puzzletron/anymodel/README.md index a8b960165..85393deec 100644 --- a/modelopt/torch/puzzletron/anymodel/README.md +++ b/modelopt/torch/puzzletron/anymodel/README.md @@ -1,6 +1,6 @@ # AnyModel Guide -This guide explains how to add support for new models in the compress pipeline. +This guide explains how to add support for new models in the Puzzletron pipeline. ## Convert model @@ -96,7 +96,7 @@ Update pruning YAML files (`ffn_pruning.yaml`, `expert_pruning.yaml`, etc.): ## End-to-end example -See [test_compress_model.py](../../../../tests/gpu/torch/puzzletron/test_compress.py) for a complete example that runs both convert and compression steps. +See [test_puzzletron.py](../../../../tests/gpu/torch/puzzletron/test_puzzletron.py) for a complete example that runs both convert and compression steps. --- diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py b/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py index 1f8cf77b5..5d3f47e03 100644 --- a/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py +++ b/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py @@ -16,6 +16,7 @@ """Llama converter for AnyModel compression.""" +import copy from typing import List from transformers import LlamaConfig @@ -46,5 +47,5 @@ def create_block_configs_from_main_config(config: LlamaConfig) -> List[BlockConf ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), ).to_dict() - block_configs = [block_config] * num_hidden_layers + block_configs = [copy.deepcopy(block_config) for _ in range(num_hidden_layers)] return block_configs From a8305d8a295a8d6556de75de0710137ed832c39c Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 9 Mar 2026 09:36:42 -0700 Subject: [PATCH 22/30] Fix tied-embedding handling before writing the safetensors index. Signed-off-by: Daniel Korzekwa --- .../torch/puzzletron/tools/checkpoint_utils_hf.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 0f5bba2cb..3c3b54830 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -199,18 +199,18 @@ def _save_checkpoint( } weight_map.update(weight_map_entries) - # Write index + # Handle tie_word_embeddings - remove from state_dict and weight_map BEFORE writing index + output_emb_weight_name = f"{descriptor.output_embedding_name()}.weight" + if getattr(model_config, "tie_word_embeddings", False) and output_emb_weight_name in state_dict: + state_dict = {k: v for k, v in state_dict.items() if k != output_emb_weight_name} + weight_map = {k: v for k, v in weight_map.items() if k != output_emb_weight_name} + + # Write index (now without tied embedding) index = {"metadata": {"format": "pt"}, "weight_map": weight_map} index_path = checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME index_json = json_dumps(index) _write_file_process_safe(index_json, index_path) - # Handle tie_word_embeddings - don't save lm_head.weight if it's tied to embed_tokens - if getattr(model_config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict: - lm_head_weight_name = f"{descriptor.output_embedding_name()}.weight" - state_dict = {k: v for k, v in state_dict.items() if k != lm_head_weight_name} - weight_map = {k: v for k, v in weight_map.items() if k != lm_head_weight_name} - # Phase 3: Save subblocks save_subblocks( state_dict, From 68421a5766903d27d1f80994c4ac8d3e84cf084a Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 9 Mar 2026 09:51:19 -0700 Subject: [PATCH 23/30] =?UTF-8?q?Fix=20NaN=20ranking=20currently=20selects?= =?UTF-8?q?=20NaNs=20as=20=E2=80=9Cbest=E2=80=9D=20experts=20by=20default.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/pruning/pruning_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/puzzletron/pruning/pruning_utils.py b/modelopt/torch/puzzletron/pruning/pruning_utils.py index cdd6a2bf7..82ba675c9 100644 --- a/modelopt/torch/puzzletron/pruning/pruning_utils.py +++ b/modelopt/torch/puzzletron/pruning/pruning_utils.py @@ -596,10 +596,15 @@ def _select_expert_indices( ) -> list[int]: expert_scores = _load_expert_scores(mlp_init_config, layer_idx) assert len(expert_scores) == orig_num_experts + higher_is_better = mlp_init_config.get("higher_is_better", True) selected_experts = sorted( range(orig_num_experts), - key=lambda i: expert_scores[i] if not math.isnan(expert_scores[i]) else float("inf"), - reverse=mlp_init_config.get("higher_is_better", True), + key=lambda i: ( + expert_scores[i] + if not math.isnan(expert_scores[i]) + else (float("-inf") if higher_is_better else float("inf")) + ), + reverse=higher_is_better, )[:new_num_experts] return selected_experts From d6b8028f6fb27010133278eef28566c5fa5c85d8 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 9 Mar 2026 11:11:05 -0700 Subject: [PATCH 24/30] Code clean up. Signed-off-by: Daniel Korzekwa --- .../model_descriptor/model_descriptor_factory.py | 4 +--- .../anymodel/models/llama/llama_converter.py | 16 +++++++++------- .../puzzletron/anymodel/puzzformer/__init__.py | 6 ++++++ 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py index 45fe83f47..badbe2b0e 100644 --- a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py @@ -36,7 +36,7 @@ } -def resolve_descriptor_from_pretrained(pretrained: str | None, trust_remote_code: bool = False): +def resolve_descriptor_from_pretrained(pretrained: str, trust_remote_code: bool = False): """Resolve the model descriptor by loading the checkpoint config and mapping model_type. Args: @@ -51,8 +51,6 @@ def resolve_descriptor_from_pretrained(pretrained: str | None, trust_remote_code Raises: ValueError: If pretrained is not provided or if the model type cannot be auto-detected. """ - if not pretrained: - raise ValueError("pretrained must be provided") config = AutoConfig.from_pretrained(pretrained, trust_remote_code=trust_remote_code) model_type = getattr(config, "model_type", None) diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py b/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py index 5d3f47e03..5a0686ecc 100644 --- a/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py +++ b/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py @@ -16,7 +16,6 @@ """Llama converter for AnyModel compression.""" -import copy from typing import List from transformers import LlamaConfig @@ -42,10 +41,13 @@ def create_block_configs_from_main_config(config: LlamaConfig) -> List[BlockConf """ num_hidden_layers = config.num_hidden_layers - block_config = BlockConfig( - attention=AttentionConfig(no_op=False, num_key_value_heads=config.num_key_value_heads), - ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), - ).to_dict() - - block_configs = [copy.deepcopy(block_config) for _ in range(num_hidden_layers)] + block_configs = [ + BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=config.num_key_value_heads + ), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + for _ in range(num_hidden_layers) + ] return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py b/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py index aac6f0f20..3af98d57f 100644 --- a/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py +++ b/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py @@ -13,6 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Utilities for patching and transforming HuggingFace models to work with AnyModel. + +Provides no-op modules for layer replacement and patching utilities for heterogeneous +per-layer configurations. +""" + from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( MatchingZeros, Same, From ecd2341ce7d95b4a7162fa64c9cd26b25a0116d4 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 10 Mar 2026 01:14:15 -0700 Subject: [PATCH 25/30] Code clean up. Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/anymodel/README.md | 2 +- .../torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/puzzletron/anymodel/README.md b/modelopt/torch/puzzletron/anymodel/README.md index 85393deec..9dea9d45f 100644 --- a/modelopt/torch/puzzletron/anymodel/README.md +++ b/modelopt/torch/puzzletron/anymodel/README.md @@ -46,7 +46,7 @@ from models. import * ## Usage ```python -from scripts.convert_any_model import convert_model +from modelopt.torch.puzzletron.anymodel import convert_model convert_model( input_dir="path/to/hf_checkpoint", diff --git a/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py index bd11837d7..e5025dea7 100644 --- a/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py +++ b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py @@ -50,7 +50,7 @@ class PuzzletronModel(nn.Module): - pass # No model implementation is needed for the compress mode + pass # No model implementation is needed for the puzzletron mode class PuzzletronConfig(ModeloptBaseConfig): @@ -154,7 +154,7 @@ def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> Conv def restore_puzzletron_model( model: nn.Module, config: PuzzletronConfig, metadata: MetadataDict ) -> nn.Module: - """Restore is not needed for the compress mode as we are not saving any model state""" + """Restore is not needed for the puzzletron mode as we are not saving any model state""" return model @@ -192,7 +192,7 @@ def restore(self) -> RestoreEntrypoint: def export_mode(self) -> str | None: """The mode that corresponds to the export mode. For now, this will be a no-op as there is no modelopt's concept of search space defined - for the compress algorithm. + for the puzzletron algorithm. """ return "export_nas" @@ -202,7 +202,7 @@ class PuzzletronSearcher(BaseSearcher): @property def default_state_dict(self) -> SearchStateDict: - """Not needed for the compress mode as we are not saving any model state""" + """Not needed for the puzzletron mode as we are not saving any model state""" return {} def run_search(self) -> None: From f9d845d4954edf85c439038d4103d5ee8ff5fee0 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 10 Mar 2026 01:19:09 -0700 Subject: [PATCH 26/30] code clean up Signed-off-by: Daniel Korzekwa --- tests/_test_utils/torch/puzzletron/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py index 4779ee1f3..07d1565f4 100644 --- a/tests/_test_utils/torch/puzzletron/utils.py +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -39,7 +39,7 @@ def setup_test_model_and_data( hybrid_override_pattern: str | None = None, ) -> tuple[Path, Path, Path]: """ - Setup the test model and data for the compress NAS search. + Setup the test model and data for the puzzletron NAS search. Args: project_root_path (Path): the root path of the project From 934ab2fc1d4ff4b53cb08fda54c8b57fba831d60 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 10 Mar 2026 04:44:39 -0700 Subject: [PATCH 27/30] code clean up Signed-off-by: Daniel Korzekwa --- .../tools/bypassed_training/init_child_from_parent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py index 74ddb8d95..36e41c4b6 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py @@ -52,7 +52,7 @@ def init_child_from_parent( descriptor: ModelDescriptor, pruning_mixin, parent_checkpoint_dir: str, - model_config_overrides_dict: dict, + model_config_overrides_dict: dict | str, output_checkpoint_dir: str, gqa_init_mode: GQAInitMode, mlp_init_mode: MlpInitMode, From dcb9e02ddbbb9ab32cd21bdf8ac9a071ffddb211 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 10 Mar 2026 04:53:04 -0700 Subject: [PATCH 28/30] remove not needed comment Signed-off-by: Daniel Korzekwa --- .../replacement_library/build_replacement_library.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py index aec10e03b..0f5ecd215 100644 --- a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py @@ -29,10 +29,6 @@ --add_ffn_no_ops and --add_attention_no_ops are optional (default True), -Untrained puzzle run (with bypass): -=================================== -The subblock that doesn't interest you in the checkpoint should be no_op. - """ # mypy: ignore-errors From 176a4358fe993ecd10ffa6f8041d0de7df1ba22d Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 10 Mar 2026 09:01:31 -0700 Subject: [PATCH 29/30] Fix a broken test_puzzletron test on 2 gpus. Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/sewing_kit/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py index 25ee8c9ea..19c1bd6c8 100644 --- a/modelopt/torch/puzzletron/sewing_kit/utils.py +++ b/modelopt/torch/puzzletron/sewing_kit/utils.py @@ -291,6 +291,7 @@ def create(cls, data: Tensor) -> MyFakeTensor: def fake_tensor(*args, **kwargs) -> Tensor: dtype: Optional[torch.dtype] = kwargs.get("dtype") use_meta = kwargs.get("use_meta", False) + device = kwargs.get("device", "meta") if len(args) == 1 and isinstance(args[0], Tensor): if use_meta: @@ -298,7 +299,7 @@ def fake_tensor(*args, **kwargs) -> Tensor: else: fake_tensor = MyFakeTensor.create(args[0]) else: - fake_tensor = torch.empty(*args, dtype=dtype, device="meta") + fake_tensor = torch.empty(*args, dtype=dtype, device=device) if not use_meta: fake_tensor = MyFakeTensor.create(fake_tensor) From 58a42cab8fc11188706acc5af80335e2a5f3a8fc Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 12 Mar 2026 15:47:03 -0700 Subject: [PATCH 30/30] Fix tox -e build-docs issues Signed-off-by: Daniel Korzekwa --- modelopt/torch/puzzletron/tools/validate_model.py | 1 + .../tools/validate_puzzle_with_multi_replacements.py | 6 ++---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/puzzletron/tools/validate_model.py b/modelopt/torch/puzzletron/tools/validate_model.py index 8461c6a5c..4a300fcd0 100644 --- a/modelopt/torch/puzzletron/tools/validate_model.py +++ b/modelopt/torch/puzzletron/tools/validate_model.py @@ -128,6 +128,7 @@ def validate_model( A tuple containing: - losses: Dictionary mapping loss names to loss statistics (avg, per_sample). - hidden_states_per_batch: Hidden states and LM head outputs if return_hidden_states is True, else None. + Returns (None, None) if not on master rank. """ descriptor = ModelDescriptorFactory.get(args.descriptor) diff --git a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py index 7311e35e5..d253c9445 100644 --- a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py @@ -75,8 +75,7 @@ def validate_puzzle_solutions(args: DictConfig) -> None: Puzzle Configuration (Required): - replacement_library_path (Path): Path to the replacement library JSON file. - solutions_path (Path): Path to puzzle solutions JSON file or directory containing solution files. - - solutions_to_validate (list[int], optional): Indices of specific solutions to validate. - Validates all solutions if None. + - solutions_to_validate (list[int], optional): Indices of specific solutions to validate. Validates all solutions if None. - sort_solutions_by (str, optional): JSON field path to sort solutions by before validation. - bigger_is_better (bool): If True, sort solutions in descending order. Used with sort_solutions_by. - skip_validation (bool): If True, skip model validation and only save models if requested. @@ -109,8 +108,7 @@ def validate_puzzle_solutions(args: DictConfig) -> None: - fim_spm_rate (float): SPM-based fill-in-the-middle rate. Output Configuration: - - output_dir (Path, optional): Directory to save validation results. - Auto-generated from solutions_path if not provided. + - output_dir (Path, optional): Directory to save validation results. Auto-generated from solutions_path if not provided. Execution Options (Optional if skip_validation=False): - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM.