From 13378ff6620b7e125aa0b2457411fc9fa0aab208 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 11 Mar 2026 01:08:21 -0700 Subject: [PATCH 1/3] Add gpt-oss model Signed-off-by: Daniel Korzekwa --- .../puzzletron/anymodel/models/__init__.py | 2 +- .../anymodel/models/gpt_oss/__init__.py | 22 + .../models/gpt_oss/gpt_oss_converter.py | 74 +++ .../gpt_oss/gpt_oss_model_descriptor.py | 236 ++++++++ .../models/gpt_oss/gpt_oss_pruned_to_mxfp4.py | 549 ++++++++++++++++++ .../nas/plugins/test_nas_convert.py | 1 - .../puzzletron/nas/plugins/test_nas_search.py | 2 - .../configs/gpt-oss-20b/gpt-oss-20b.yaml | 108 ++++ .../gpt-oss-20b/pruning/expert_removal.yaml | 20 + .../gpt-oss-20b/pruning/pruning_defaults.yaml | 35 ++ .../gpt-oss-20b/validate_model_defaults.yaml | 15 + .../validate_solutions_defaults.yaml | 10 + .../hf_configs/gpt-oss-20b/config.json | 76 +++ 13 files changed, 1146 insertions(+), 4 deletions(-) create mode 100644 modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py create mode 100644 modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py create mode 100644 tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/gpt-oss-20b.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/pruning/expert_removal.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/pruning/pruning_defaults.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/validate_model_defaults.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/validate_solutions_defaults.yaml create mode 100644 tests/gpu/torch/puzzletron/resources/hf_configs/gpt-oss-20b/config.json diff --git a/modelopt/torch/puzzletron/anymodel/models/__init__.py b/modelopt/torch/puzzletron/anymodel/models/__init__.py index 1f3fb477b..34d7ce5e5 100644 --- a/modelopt/torch/puzzletron/anymodel/models/__init__.py +++ b/modelopt/torch/puzzletron/anymodel/models/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. # Import models to trigger factory registration -# from modelopt.torch.puzzletron.anymodel.models.gpt_oss_20b import * +from modelopt.torch.puzzletron.anymodel.models.gpt_oss import * from modelopt.torch.puzzletron.anymodel.models.llama import * from modelopt.torch.puzzletron.anymodel.models.mistral_small import * from modelopt.torch.puzzletron.anymodel.models.nemotron_h import * diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py new file mode 100644 index 000000000..9f72b8dd7 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""GPT-OSS model support for AnyModel.""" + +from .gpt_oss_converter import GptOssConverter +from .gpt_oss_model_descriptor import GptOssModelDescriptor diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py new file mode 100644 index 000000000..3e7371aae --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""GPT-OSS-20B converter for AnyModel compression.""" + +from typing import List + +from transformers import PretrainedConfig + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, + MoEConfig, +) + + +@ConverterFactory.register_decorator("gpt_oss") +class GptOssConverter(Converter): + """Converter for GPT-OSS models to AnyModel format. + + GPT-OSS is a pure MoE model with 32/128 experts per layer and 4/16 active experts. + All layers use MoE FFN (no standard dense FFN layers). + """ + + quantized = "mxfp4" + + @staticmethod + def create_block_configs_from_main_config(config: PretrainedConfig) -> List[BlockConfig]: + """Create block configs for GPT-OSS layers. + + GPT-OSS uses MoE for all FFN layers with: + - 32/128 local experts (num_local_experts) + - 4/16 active experts per token (experts_per_token) + - No dense/standard FFN layers + """ + num_hidden_layers = config.num_hidden_layers + num_local_experts = config.num_local_experts + experts_per_token = config.experts_per_token + intermediate_size = config.intermediate_size + + block_configs = [] + for layer_idx in range(num_hidden_layers): + block_config = BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=config.num_key_value_heads + ), + ffn=FFNConfig( + no_op=False, + intermediate_size=None, # MoE doesn't use this field + moe=MoEConfig( + num_local_experts=num_local_experts, + num_experts_per_tok=experts_per_token, + expert_intermediate_dim=intermediate_size, + ), + ), + ).to_dict() + block_configs.append(block_config) + + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py new file mode 100644 index 000000000..c77a4547f --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""GPT-OSS model descriptor for AnyModel compression.""" + +import re +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Type + +import torch.nn as nn +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssDecoderLayer, GptOssRotaryEmbedding + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import ( + ExpertRemovalLayerDescriptor, + ExpertRemovalPruningMixIn, +) + +# Expert removal is supported for unquantized models (test models). +# Production models use MXFP4 quantized MoE with combined tensors +# (gate_up_proj_blocks, down_proj_blocks), which is not yet supported. +from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn +from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock + + +@ModelDescriptorFactory.register_decorator("gpt_oss") +class GptOssModelDescriptor(ModelDescriptor): + """Model descriptor for GPT-OSS (pure MoE model).""" + + _DECODER_LAYER_CLS: Type[nn.Module] = None + + @classmethod + def create_dummy_block(cls, original_layer: GptOssDecoderLayer, block_index: int) -> nn.Module: + dummy_block = DummyBlock(block_index=block_index) + # Required by `GptOssModel.forward`. + dummy_block.attention_type = original_layer.attention_type + return dummy_block + + @staticmethod + def decoder_layer_cls(): + """Get the decoder layer class for GPT-OSS models. + + GPT-OSS is a standard transformers model in recent versions. + Import directly from transformers.models.gpt_oss.modeling_gpt_oss. + """ + return GptOssDecoderLayer + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + """Map BlockConfig to layer constructor overrides.""" + override_kwargs = {} + + if block_config.attention.num_key_value_heads is not None: + override_kwargs["num_key_value_heads"] = block_config.attention.num_key_value_heads + + if block_config.ffn.moe is not None: + override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim + override_kwargs["num_local_experts"] = block_config.ffn.moe.num_local_experts + override_kwargs["num_experts_per_tok"] = block_config.ffn.moe.num_experts_per_tok + + return override_kwargs + + @staticmethod + def attn_no_op_post_init(decoder_layer): + """Replace attention sublayers with no-op modules.""" + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer): + """Replace MLP sublayers with no-op modules. + + Note: GPT-OSS MoE layers return (hidden_states, router_scores), so we need + to return a tuple of 2 values. + """ + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def init_rotary_embedding(model, runtime): + """Initialize rotary embeddings on the correct device.""" + # GPT-OSS uses RoPE with YARN scaling + + model.model.rotary_emb = GptOssRotaryEmbedding( + config=model.config, + device=runtime.device, + ) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + """Define regex patterns for grouping weights into subblocks.""" + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + """FFN is MoE in GPT-OSS with MXFP4 quantization.""" + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\." + r"(post_attention_layernorm\.weight" + r"|mlp\.router\.weight" + r"|mlp\.router\.bias" + r"|mlp\.experts\.(gate_up_proj|down_proj)(_(bias|blocks|scales))?)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\." + r"(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.q_proj\.bias" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.k_proj\.bias" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.v_proj\.bias" + r"|self_attn\.o_proj\.weight" + r"|self_attn\.o_proj\.bias" + r"|self_attn\.sinks)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update( + **build_ffn_predicates(), + **build_attention_predicates(), + ) + + return layer_name_patterns + + @staticmethod + def pruning_mixins() -> Dict[str, PruningMixIn]: + """Return available pruning mixins for GPT-OSS. + + Note: Expert removal works for unquantized models (test models). + Production models use MXFP4 quantization which is not yet supported. + """ + return {"expert_removal": ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())} + + +@dataclass +class GptOssExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): + """ + GPT-OSS MoE layer descriptor for expert removal. + + Note: This only works for unquantized models (e.g., test models). + Production GPT-OSS models use MXFP4 quantization with fused experts + (_blocks, _scales, _bias), which requires a different approach. + + Structure: + - Router: mlp.router with .weight and .bias + - Experts: mlp.experts.{idx}.{gate_up_proj,down_proj} with .weight and .bias + """ + + target_name: str = "mlp" + moe_prefix_name: str = "model.layers.{layer_idx}.mlp" + expert_prefix_name: str = "experts" + + # Router has both weight and bias + router_weights: List[str] = field(default_factory=lambda: ["router.weight"]) + router_biases: List[str] = field(default_factory=lambda: ["router.bias"]) + + # Fused format: experts stored as single tensors + is_fused_experts: bool = True + + # Fused format: single tensors containing all experts (test models) + fused_expert_weights: List[str] = field( + default_factory=lambda: [ + "experts.gate_up_proj", + "experts.gate_up_proj_bias", + "experts.down_proj", + "experts.down_proj_bias", + ] + ) + + # Not used for fused format, but kept for compatibility + expert_weights: List[str] = field(default_factory=lambda: ["gate_up_proj", "down_proj"]) + expert_biases: List[str] = field( + default_factory=lambda: ["gate_up_proj_bias", "down_proj_bias"] + ) + + def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]: + target_class_name = "GptOssTopKRouter" + + module_names_to_hook = [] + for module_name, module in model.named_modules(): + if ( + module_name.endswith(self.target_name) + and module.__class__.__name__ == target_class_name + ): + module_names_to_hook.append( + (self.block_idx_from_module_name(module_name), module_name) + ) + return module_names_to_hook diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py new file mode 100644 index 000000000..4bab0b985 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py @@ -0,0 +1,549 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/tree/aa457edc3d64d81530159cd3a182932320c78f8c + +# MIT License +# +# Copyright (c) 2020 EleutherAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +""" +Create a HuggingFace checkpoint with MXFP4 MoE weights from the original gpt-oss-120b model. + +This script: +1. Copies non-MoE weights from the student model (trained attention, embeddings, etc.) +2. Extracts MoE expert weights from the original gpt-oss-120b in MXFP4 format +3. Deduces expert mappings by comparing weights +4. Outputs a new pruned (heterogeneous) checkpoint with PACKED MXFP4 expert weights +""" + +import argparse +import json +import os +import shutil +from typing import Any, Dict, List, Optional, TextIO, Tuple + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from tqdm import tqdm +from transformers.integrations.mxfp4 import convert_moe_packed_tensors + + +def deduce_experts_for_layer( + layer: int, + original_path: str, + original_index: Dict, + student_path: str, +) -> Tuple[List[int], int, int]: + """ + Deduce which original experts match the student experts by comparing weights. + + Compares dequantized MXFP4 weights from the original model against the student + model's BF16 weights using L2 distance. Finds the best 1-to-1 matching. + + Args: + layer: Layer index + original_path: Path to original model + original_index: Original model's safetensors index + student_path: Path to student model + num_student_experts: Number of experts in student model (if None, auto-detect) + + Returns: + Tuple of (expert_indices, num_student_experts, num_original_experts) + """ + # Load original tensors + orig_tensors = load_layer_tensors(original_path, layer, original_index) + mlp1_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] + mlp1_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] + mlp2_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] + mlp2_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_scales"] + + num_original_experts = mlp1_blocks.shape[0] + + # Load student tensors + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + student_ffn = os.path.join(student_subblocks, f"block_{layer}_ffn.safetensors") + if not os.path.exists(student_ffn): + print(f"FFN file not found at {student_ffn} - fallback to no_op") + return [], 0, num_original_experts + + student_experts = {} + with safe_open(student_ffn, framework="pt") as f: + for key in f.keys(): + if "experts" in key or "router" in key: + student_experts[key] = f.get_tensor(key) + + # Auto-detect number of student experts + num_student_experts = student_experts[f"model.layers.{layer}.mlp.experts.gate_up_proj"].size(0) + print( + f" Layer {layer}: Comparing {num_student_experts} student experts against {num_original_experts} original experts" + ) + + # Pre-dequantize all original experts once (optimization) + print(f" Pre-dequantizing {num_original_experts} original experts...") + deqexpert_mlp1 = convert_moe_packed_tensors(mlp1_blocks, mlp1_scales).cpu() + deqexpert_mlp2 = convert_moe_packed_tensors(mlp2_blocks, mlp2_scales).cpu() + original_experts_dequant = [] + for orig_idx in range(num_original_experts): + original_experts_dequant.append( + {"up": deqexpert_mlp1[orig_idx], "down": deqexpert_mlp2[orig_idx]} + ) + + # For each student expert, find best matching original expert + experts_to_keep = [] + used_original_indices = set() + + # Number of values to use for quick comparison (tune this) + quick_compare_size = 8 + # Number of candidates to keep for full comparison + top_k_candidates = min(10, num_original_experts) + + for student_idx in range(num_student_experts): + # Get student expert weights + prefix = f"model.layers.{layer}.mlp" + student_up = student_experts.get(f"{prefix}.experts.gate_up_proj")[student_idx] # type: ignore[index] + student_down = student_experts.get(f"{prefix}.experts.down_proj")[student_idx] # type: ignore[index] + + # if student_gate is None or student_up is None or student_down is None: + if student_up is None or student_down is None: + raise ValueError( + f"Missing student expert weights for layer {layer} expert {student_idx}" + ) + + # Step 1: Quick filtering using first N values + candidate_scores = [] + for orig_idx in range(num_original_experts): + if orig_idx in used_original_indices: + continue + + orig_expert = original_experts_dequant[orig_idx] + + up_quick = ( + ( + orig_expert["up"].flatten()[:quick_compare_size] + - student_up.float().flatten()[:quick_compare_size] + ) + .pow(2) + .mean() + .sqrt() + ) + down_quick = ( + ( + orig_expert["down"].flatten()[:quick_compare_size] + - student_down.float().flatten()[:quick_compare_size] + ) + .pow(2) + .mean() + .sqrt() + ) + + quick_score = (up_quick + down_quick) / 2.0 + candidate_scores.append((orig_idx, quick_score.item())) + + # Step 2: Get top-k candidates based on quick comparison + candidate_scores.sort(key=lambda x: x[1]) + top_candidates = [idx for idx, _ in candidate_scores[:top_k_candidates]] + + # Step 3: Full comparison only on top candidates + best_match_idx = None + best_match_score = float("inf") + + for orig_idx in top_candidates: + orig_expert = original_experts_dequant[orig_idx] + + # Full comparison across all values + up_diff = (orig_expert["up"] - student_up.float()).pow(2).mean().sqrt() + down_diff = (orig_expert["down"] - student_down.float()).pow(2).mean().sqrt() + + score = (up_diff + down_diff) / 2.0 + + if score < best_match_score: + best_match_score = score + best_match_idx = orig_idx + + if best_match_idx is None: + raise ValueError( + f"Could not find match for student expert {student_idx} in layer {layer}" + ) + + experts_to_keep.append(best_match_idx) + used_original_indices.add(best_match_idx) + print( + f" Student expert {student_idx} -> Original expert {best_match_idx} (RMSE: {best_match_score:.6f})" + ) + + return experts_to_keep, num_student_experts, num_original_experts + + +def load_original_index(path: str) -> Dict[str, Any]: + """Load the original model's safetensors index.""" + with open(path, "r") as f: + return json.load(f) + + +def load_layer_tensors(original_path: str, layer: int, index: Dict) -> Dict[str, torch.Tensor]: + """Load all MoE-related tensors for a layer, potentially from multiple files.""" + keys_to_load = [ + f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks", + f"model.layers.{layer}.mlp.experts.gate_up_proj_scales", + f"model.layers.{layer}.mlp.experts.gate_up_proj_bias", + f"model.layers.{layer}.mlp.experts.down_proj_blocks", + f"model.layers.{layer}.mlp.experts.down_proj_scales", + f"model.layers.{layer}.mlp.experts.down_proj_bias", + f"model.layers.{layer}.mlp.router.weight", # Router weight + f"model.layers.{layer}.mlp.router.bias", # Router bias + ] + + # Group by file + file_to_keys = {} + for key in keys_to_load: + if key in index["weight_map"]: + filename = index["weight_map"][key] + if filename not in file_to_keys: + file_to_keys[filename] = [] + file_to_keys[filename].append(key) + + # Load from each file + tensors = {} + for filename, keys in file_to_keys.items(): + filepath = os.path.join(original_path, filename) + with safe_open(filepath, framework="pt") as f: + for key in keys: + tensors[key] = f.get_tensor(key) + + return tensors + + +def copy_non_moe_weights(student_path: str, output_path: str, num_layers: int) -> Dict[str, str]: + """ + Copy non-MoE weights from student model. + Returns weight_map for the new index. + """ + weight_map = {} + subblocks_dir = os.path.join(output_path, "subblocks_safetensors") + os.makedirs(subblocks_dir, exist_ok=True) + + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + + # Copy embeddings + src_emb = os.path.join(student_subblocks, "embeddings.safetensors") + dst_emb = os.path.join(subblocks_dir, "embeddings.safetensors") + shutil.copy2(src_emb, dst_emb) + with safe_open(src_emb, framework="pt") as f: + for key in f.keys(): + weight_map[key] = "subblocks_safetensors/embeddings.safetensors" + + # Copy lm_head + src_head = os.path.join(student_subblocks, "lm_head.safetensors") + dst_head = os.path.join(subblocks_dir, "lm_head.safetensors") + shutil.copy2(src_head, dst_head) + with safe_open(src_head, framework="pt") as f: + for key in f.keys(): + weight_map[key] = "subblocks_safetensors/lm_head.safetensors" + + # Copy attention blocks + for layer in range(num_layers): + src_attn = os.path.join(student_subblocks, f"block_{layer}_attention.safetensors") + dst_attn = os.path.join(subblocks_dir, f"block_{layer}_attention.safetensors") + shutil.copy2(src_attn, dst_attn) + with safe_open(src_attn, framework="pt") as f: + for key in f.keys(): + weight_map[key] = f"subblocks_safetensors/block_{layer}_attention.safetensors" + + return weight_map + + +def process_single_layer( + layer: int, + original_path: str, + original_index: Dict, + student_path: str, + output_path: str, + experts_to_keep: List[int], +) -> Tuple[Dict[str, str], List[str]]: + """ + Process a single layer - loads tensors from potentially multiple files. + Returns (weight_map, verification_errors). + """ + weight_map = {} + verification_errors = [] + subblocks_dir = os.path.join(output_path, "subblocks_safetensors") + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + + # Load all tensors for this layer (may come from multiple files) + orig_tensors = load_layer_tensors(original_path, layer, original_index) + + # Load student FFN file + student_ffn = os.path.join(student_subblocks, f"block_{layer}_ffn.safetensors") + + tensors_to_save = {} + student_tensors = {} + + with safe_open(student_ffn, framework="pt") as f: + for key in f.keys(): + tensor = f.get_tensor(key) + if "experts" not in key and "router" not in key: + # Copy norm weights + tensors_to_save[key] = tensor + + # Get router from original model, sliced to kept experts + orig_router_weight = orig_tensors[f"model.layers.{layer}.mlp.router.weight"] + orig_router_bias = orig_tensors[f"model.layers.{layer}.mlp.router.bias"] + + kept_indices_tensor = torch.tensor(experts_to_keep, dtype=torch.long) + sliced_router_weight = orig_router_weight[kept_indices_tensor] + sliced_router_bias = orig_router_bias[kept_indices_tensor] + + tensors_to_save[f"model.layers.{layer}.mlp.router.weight"] = sliced_router_weight + tensors_to_save[f"model.layers.{layer}.mlp.router.bias"] = sliced_router_bias + + # Get MoE tensors + mlp1_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] + mlp1_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] + mlp2_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] + mlp2_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_scales"] + mlp1_bias = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_bias"] + mlp2_bias = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_bias"] + + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] = mlp1_blocks[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] = mlp1_scales[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_bias"] = mlp1_bias[ + kept_indices_tensor + ] + + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] = mlp2_blocks[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_scales"] = mlp2_scales[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_bias"] = mlp2_bias[ + kept_indices_tensor + ] + + # Save the FFN file + output_file = os.path.join(subblocks_dir, f"block_{layer}_ffn.safetensors") + save_file(tensors_to_save, output_file) + + # Build weight map + for key in tensors_to_save.keys(): + weight_map[key] = f"subblocks_safetensors/block_{layer}_ffn.safetensors" + + return weight_map, verification_errors + + +def copy_config_files(student_path: str, output_path: str): + """Copy configuration files from student model and update config.json.""" + files_to_copy = [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "chat_template.jinja", + ] + + # Also copy transformers compatibility files + if os.path.exists(student_path): + for f in os.listdir(student_path): + if f.startswith("transformers_"): + files_to_copy.append(f) + + for filename in files_to_copy: + src = os.path.join(student_path, filename) + dst = os.path.join(output_path, filename) + + # Try student path first + if os.path.exists(src): + try: + shutil.copy2(src, dst) + continue + except PermissionError: + pass + + # If we get here, file doesn't exist or permission denied + if not os.path.exists(dst): + print(f" Warning: Could not copy {filename}") + + # Update config.json for DeciGptOssForCausalLM with MXFP4 + src_config = os.path.join(student_path, "config.json") + if not os.path.exists(src_config): + raise FileNotFoundError(f"config.json not found at {src_config}") + + with open(src_config, "r") as f: + config = json.load(f) # type: ignore[arg-type] + + # Set architecture to DeciGptOssForCausalLM for MXFP4 support + config["architectures"] = ["DeciGptOssForCausalLM"] + + # Add quantization_config so vllm calls _load_weights_mxfp4 + config["quantization_config"] = { + "quant_method": "mxfp4", + "modules_to_not_convert": [ + "model.layers.*.self_attn", + "model.layers.*.mlp.router", + "model.embed_tokens", + "lm_head", + ], + } + + dst_config = os.path.join(output_path, "config.json") + with open(dst_config, "w") as f: + json.dump(config, f, indent=2) # type: ignore[arg-type] + + +def main(): + parser = argparse.ArgumentParser(description="Create MXFP4 checkpoint from student model") + parser.add_argument( + "--student-path", type=str, required=True, help="Path to student model checkpoint" + ) + parser.add_argument( + "--original-path", + type=str, + required=True, + help="Path to original gpt-oss-120b model with MXFP4 weights", + ) + parser.add_argument( + "--output-path", type=str, required=True, help="Output path for the new checkpoint" + ) + parser.add_argument("--num-layers", type=int, default=36, help="Number of transformer layers") + args = parser.parse_args() + + print(f"Creating MXFP4 checkpoint...") + print(f" Student model: {args.student_path}") + print(f" Original model: {args.original_path}") + print(f" Output: {args.output_path}") + + # Load original model index + original_index = load_original_index( + os.path.join(args.original_path, "model.safetensors.index.json") + ) + + print("\nDeducing expert mappings by comparing weights...") + experts_to_keep = [] + layer_statistics = [] # Store (num_student, num_original) for each layer + + for layer in range(args.num_layers): + layer_experts, num_student, num_original = deduce_experts_for_layer( + layer, + args.original_path, + original_index, + args.student_path, + ) + experts_to_keep.append(layer_experts) + layer_statistics.append((num_student, num_original)) + + # Print statistics + print(f"\n{'=' * 70}") + print("EXPERT DEDUCTION STATISTICS") + print(f"{'=' * 70}") + print(f"{'Layer':<8} {'Student Experts':<18} {'Original Experts':<18} {'Kept %':<10}") + print(f"{'-' * 70}") + + total_student = 0 + total_original = 0 + for layer, (num_student, num_original) in enumerate(layer_statistics): + percentage = (num_student / num_original * 100) if num_original > 0 else 0 + print(f"{layer:<8} {num_student:<18} {num_original:<18} {percentage:<10.2f}") + total_student += num_student + total_original += num_original + + print(f"{'-' * 70}") + avg_percentage = (total_student / total_original * 100) if total_original > 0 else 0 + print(f"{'TOTAL':<8} {total_student:<18} {total_original:<18} {avg_percentage:<10.2f}") + print(f"{'=' * 70}") + print(f"\n Deduced experts_to_keep mapping for {len(experts_to_keep)} layers") + + # Create output directory + os.makedirs(args.output_path, exist_ok=True) + os.makedirs(os.path.join(args.output_path, "subblocks_safetensors"), exist_ok=True) + + # Copy config files + print("Copying configuration files...") + copy_config_files(args.student_path, args.output_path) + + # Save experts_to_keep.json + experts_to_keep_output = os.path.join(args.output_path, "experts_to_keep.json") + with open(experts_to_keep_output, "w") as f: + json.dump(experts_to_keep, f, indent=2) + print(f" Saved experts_to_keep mapping to {experts_to_keep_output}") + + # Copy non-MoE weights (embeddings, attention, lm_head) + print("Copying non-MoE weights...") + weight_map = copy_non_moe_weights(args.student_path, args.output_path, args.num_layers) + + # Load weights per layer (handles multi-file loading) + print(f"Processing {args.num_layers} layers...") + + all_verification_errors = [] + + # Process each layer + for layer in tqdm(range(args.num_layers), desc="Processing layers"): + if len(experts_to_keep[layer]) == 0: + print(f"Layer {layer} has no experts to keep - ffn->no_op") + continue + layer_weight_map, layer_errors = process_single_layer( + layer, + args.original_path, + original_index, + args.student_path, + args.output_path, + experts_to_keep[layer], + ) + weight_map.update(layer_weight_map) + all_verification_errors.extend(layer_errors) + + # Calculate total size + total_size = 0 + subblocks_dir = os.path.join(args.output_path, "subblocks_safetensors") + for filename in os.listdir(subblocks_dir): + filepath = os.path.join(subblocks_dir, filename) + total_size += os.path.getsize(filepath) + + # Create model.safetensors.index.json + index = {"metadata": {"total_size": total_size}, "weight_map": weight_map} + + index_path = os.path.join(args.output_path, "model.safetensors.index.json") + with open(index_path, "w") as f: + json.dump(index, f, indent=2) + + print(f"\nCheckpoint created successfully at: {args.output_path}") + print(f"Total size: {total_size / 1e9:.2f} GB") + + +if __name__ == "__main__": + main() diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py index e2373676d..317043da2 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py @@ -28,7 +28,6 @@ from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel -@pytest.mark.skip(reason="Temporarily disabled") def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py index e39f1e1cb..c34f449d8 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py @@ -17,7 +17,6 @@ from functools import partial from pathlib import Path -import pytest import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.puzzletron.utils import setup_test_model_and_data @@ -27,7 +26,6 @@ from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel -@pytest.mark.skip(reason="Temporarily disabled") def test_nas_search(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), diff --git a/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/gpt-oss-20b.yaml b/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/gpt-oss-20b.yaml new file mode 100644 index 000000000..04e5fe90a --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/gpt-oss-20b.yaml @@ -0,0 +1,108 @@ +defaults: + - pruning: expert_removal # TODO: Note: Works for unquantized test models, not MXFP4 quantized production models + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +descriptor: gpt_oss + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true # TODO: Works for unquantized test models + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + - stats.num_local_experts: 48 # teacher has: 2 layers * 32 experts = 64 total experts + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/pruning/expert_removal.yaml b/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/pruning/expert_removal.yaml new file mode 100644 index 000000000..9377497f8 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/pruning/expert_removal.yaml @@ -0,0 +1,20 @@ +defaults: + - pruning_defaults + +eval_samples: 10 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss_20b.gpt_oss_20b_model_descriptor.GptOss20bExpertRemovalLayerDescriptor + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.GptOssRemoveExpertsIndependentHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +# num_experts_to_keep must be >= num_experts_per_tok (can't route to more experts than exist) +# Test model: num_local_experts=32, num_experts_per_tok=4 +num_experts_to_keep_list: [24, 16, 8] +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks_mse" + layer_prefix_template: "model.layers.{layer_idx}." diff --git a/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/pruning/pruning_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/pruning/pruning_defaults.yaml new file mode 100644 index 000000000..441577b05 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/pruning/pruning_defaults.yaml @@ -0,0 +1,35 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 10 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning (not used for MoE expert removal) +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} + diff --git a/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/validate_model_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/validate_model_defaults.yaml new file mode 100644 index 000000000..9dabef741 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/validate_model_defaults.yaml @@ -0,0 +1,15 @@ +block_size: 8192 +bos_rate: 0.5 +data_column: conversation +val_dataset_name: train +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/validate_solutions_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/validate_solutions_defaults.yaml new file mode 100644 index 000000000..ec1390237 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/tests/gpu/torch/puzzletron/resources/hf_configs/gpt-oss-20b/config.json b/tests/gpu/torch/puzzletron/resources/hf_configs/gpt-oss-20b/config.json new file mode 100644 index 000000000..f2d75de66 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/hf_configs/gpt-oss-20b/config.json @@ -0,0 +1,76 @@ +{ + "architectures": [ + "GptOssForCausalLM" + ], + "attention_bias": true, + "attention_dropout": 0.0, + "eos_token_id": 2, + "experts_per_token": 4, + "head_dim": 64, + "hidden_act": "silu", + "hidden_size": 2880, + "initial_context_length": 4096, + "initializer_range": 0.02, + "intermediate_size": 2880, + "layer_types": [ + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention" + ], + "max_position_embeddings": 131072, + "model_type": "gpt_oss", + "num_attention_heads": 64, + "num_experts_per_tok": 4, + "num_hidden_layers": 24, + "num_key_value_heads": 8, + "num_local_experts": 32, + "output_router_logits": false, + "pad_token_id": 0, + "quantization_config": { + "modules_to_not_convert": [ + "model.layers.*.self_attn", + "model.layers.*.mlp.router", + "model.embed_tokens", + "lm_head" + ], + "quant_method": "mxfp4" + }, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "beta_fast": 32.0, + "beta_slow": 1.0, + "factor": 32.0, + "original_max_position_embeddings": 4096, + "rope_type": "yarn", + "truncate": false + }, + "rope_theta": 150000, + "router_aux_loss_coef": 0.9, + "sliding_window": 128, + "swiglu_limit": 7.0, + "tie_word_embeddings": false, + "transformers_version": "4.55.0.dev0", + "use_cache": true, + "vocab_size": 201088 +} \ No newline at end of file From 47ca0e37cd890e0ca8396b30f8b5f4b88dd2c8ff Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 11 Mar 2026 01:10:24 -0700 Subject: [PATCH 2/3] Add comments about a broken test Signed-off-by: Daniel Korzekwa --- tests/gpu/torch/puzzletron/test_puzzletron.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index ca620eb68..6d5202f83 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -62,6 +62,7 @@ "*E", True, ), + # GPT-OSS test fails, @TODO: fixing in progress # ("gpt-oss-20b", "gpt_oss_20b", "gpt-oss-20b", None, True), ], ) From 96112f70649acfd1abc39e052417b73f097c2dec Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 12 Mar 2026 08:48:32 -0700 Subject: [PATCH 3/3] Fix a broken gptoss test Signed-off-by: Daniel Korzekwa --- .../gpt-oss-20b/pruning/expert_removal.yaml | 15 +++++++-------- tests/gpu/torch/puzzletron/test_puzzletron.py | 3 +-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/pruning/expert_removal.yaml b/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/pruning/expert_removal.yaml index 9377497f8..a8925a59c 100644 --- a/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/pruning/expert_removal.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/gpt-oss-20b/pruning/expert_removal.yaml @@ -3,18 +3,17 @@ defaults: eval_samples: 10 activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} + pruning_mixin: _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn layer_descriptor: - _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss_20b.gpt_oss_20b_model_descriptor.GptOss20bExpertRemovalLayerDescriptor - -hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.GptOssRemoveExpertsIndependentHook} + _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_model_descriptor.GptOssExpertRemovalLayerDescriptor + target_name: "mlp.router" +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.RankedChoiceVotingHook} activation_hooks_kwargs: # Additional kwargs to pass to the hook init -# num_experts_to_keep must be >= num_experts_per_tok (can't route to more experts than exist) -# Test model: num_local_experts=32, num_experts_per_tok=4 -num_experts_to_keep_list: [24, 16, 8] +num_experts_to_keep_list: [24, 16, 8] # num_experts in teacher is 128 mlp_init_mode: "ExpertRemoval" mlp_init_config_yaml: - expert_scores_key: "expert_ranks_mse" - layer_prefix_template: "model.layers.{layer_idx}." + expert_scores_key: "expert_ranks" + layer_prefix_template: "model.layers.{layer_idx}.mlp.router" diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index 6d5202f83..3e462e7da 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -62,8 +62,7 @@ "*E", True, ), - # GPT-OSS test fails, @TODO: fixing in progress - # ("gpt-oss-20b", "gpt_oss_20b", "gpt-oss-20b", None, True), + ("gpt-oss-20b", "gpt_oss", "gpt-oss-20b", None, True), ], ) def test_puzzletron(