Skip to content

[Calibration] Add MoE Calibration Context #1596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/multimodal_vision/llama4_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from transformers import Llama4ForConditionalGeneration, Llama4Processor

from llmcompressor import oneshot
from llmcompressor.modeling import prepare_for_calibration
from llmcompressor.modeling import replace_modules_for_calibration
from llmcompressor.modifiers.quantization import GPTQModifier

# Select model and load it.
Expand All @@ -14,7 +14,7 @@
# This change allows compatibility with vllm.
# To apply your own custom module for experimentation, consider updating
# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py
model = prepare_for_calibration(model)
model = replace_modules_for_calibration(model)

DATASET_ID = "neuralmagic/calibration"
NUM_CALIBRATION_SAMPLES = 512
Expand Down
4 changes: 2 additions & 2 deletions examples/quantization_w4a4_fp4/llama4_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from transformers import Llama4ForConditionalGeneration, Llama4Processor

from llmcompressor import oneshot
from llmcompressor.modeling import prepare_for_calibration
from llmcompressor.modeling import replace_modules_for_calibration
from llmcompressor.modifiers.quantization import QuantizationModifier

# Select model and load it.
Expand All @@ -14,7 +14,7 @@
# This change allows compatibility with vllm.
# To apply your own custom module for experimentation, consider updating
# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py
model = prepare_for_calibration(model)
model = replace_modules_for_calibration(model)

DATASET_ID = "neuralmagic/calibration"
NUM_CALIBRATION_SAMPLES = 20
Expand Down
86 changes: 86 additions & 0 deletions examples/quantization_w4a4_fp4/qwen_30b_a3b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation

MODEL_ID = "Qwen/Qwen3-30B-A3B"

# Load model.
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)


DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples
NUM_CALIBRATION_SAMPLES = 200
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights to fp4 with per group 16 via ptq
# * calibrate a global_scale for activations, which will be used to
# quantize activations to fp4 on the fly
recipe = QuantizationModifier(
targets="Linear", scheme="NVFP4", ignore=["lm_head", "re:.*mlp.gate$"]
)

# Apply quantization.
# We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock`
# during calibration
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
calibrate_moe_context=True,
)


print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")


# Save to disk in compressed-tensors format.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
4 changes: 2 additions & 2 deletions examples/quantizing_moe/deepseek_r1_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from llmcompressor.modeling import prepare_for_calibration
from llmcompressor.modeling import replace_modules_for_calibration
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot

Expand All @@ -20,7 +20,7 @@
model_id, torch_dtype="auto", config=config
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = prepare_for_calibration(model)
model = replace_modules_for_calibration(model)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
Expand Down
10 changes: 10 additions & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ class DatasetArguments(CustomDatasetArguments):
default=512,
metadata={"help": "Number of samples to use for one-shot calibration"},
)
calibrate_moe_context: bool = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this basically always be on? Is there ever a case where a user shouldn't use this?

Copy link
Collaborator Author

@dsikka dsikka Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when you want to call prepare_for_calibration and want to permanently change the module definition, as opposed to only the duration of calibration

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, but there's no conflict between prepare_for_calibration and calibrate_moe_context, right? I think it'd also look a little confusing to calibrate an MoE model, but explicitly call

prepare_for_calibration(my_moe_model)
oneshot(my_moe_model, calibrate_moe_context=False)

Copy link
Collaborator Author

@dsikka dsikka Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to just remove it and always run with calibrate_moe_context=True?

We require prepare_for_calibration to be explicitly applied, the idea was to be the same with this moe calibration context being enabled

But yes, there is no conflict between the two. I think we can technically run deepseek with both but I haven't tested it with the context

default=False,
metadata={
"help": "If during calibration, the MoE context should be enabled "
"for the given model. This usually involves updating all MoE modules "
"in the model for the duration of calibration. See moe_context under "
"modeling/prepare.py for a list of supported MoEs and their updated "
"module definitions"
},
)
shuffle_calibration_samples: Optional[bool] = field(
default=True,
metadata={
Expand Down
7 changes: 6 additions & 1 deletion src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ def apply_recipe_modifiers(
user_pipeline = self.dataset_args.pipeline
modifiers = session.lifecycle.recipe.modifiers
pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline)
pipeline(self.model, calibration_dataloader, self.dataset_args)
pipeline(
self.model,
calibration_dataloader,
self.dataset_args,
)

session.finalize()

Expand Down Expand Up @@ -227,6 +231,7 @@ def oneshot(
overwrite_cache: bool = False,
preprocessing_num_workers: Optional[int] = None,
min_tokens_per_module: Optional[float] = None,
calibrate_moe_context: bool = False,
# Miscellaneous arguments
output_dir: Optional[str] = None,
log_dir: Optional[str] = "sparse_logs",
Expand Down
10 changes: 5 additions & 5 deletions src/llmcompressor/modeling/deepseek_v3.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import torch
from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE

__all__ = ["DeepseekV3MoECalibrate"]
from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
DeepseekV3MoE as OriginalDeepseekV3MoE,
)


class DeepseekV3MoECalibrate(torch.nn.Module):
"""
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
"""

def __init__(self, config: DeepseekV3Config, original: DeepseekV3MoE):
def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE):
super().__init__()
self.config = config
self.experts = original.experts
Expand Down Expand Up @@ -49,5 +49,5 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


def replace(config: DeepseekV3Config, module: DeepseekV3MoE):
def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE):
return DeepseekV3MoECalibrate(config=config, original=module)
2 changes: 0 additions & 2 deletions src/llmcompressor/modeling/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

from llmcompressor.utils.dev import skip_weights_initialize

__all__ = ["SequentialLlama4TextMoe"]


class SequentialLlama4TextMoe(torch.nn.Module):
def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe):
Expand Down
37 changes: 35 additions & 2 deletions src/llmcompressor/modeling/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,53 @@

from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
from llmcompressor.modeling.llama4 import replace as replace_llama4
from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE
from llmcompressor.utils.helpers import patch_attr

__all__ = ["prepare_for_calibration"]
__all__ = ["replace_modules_for_calibration"]

# ---------------------- module replacements; permanent -------------------------
replacements = {
"DeepseekV3MoE": replace_deepseekv3,
"Llama4TextMoe": replace_llama4,
}


def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
for name, module in model.named_modules():
cls_name = module.__class__.__name__
if cls_name in replacements:
new_module = replacements[cls_name](config=model.config, module=module)
replace_module(model, name, new_module)

return model


# ------------------- module replacements; during calibration --------------------


def update_qwen3_moe(model, stack):
for module in model.modules():
cls_name = module.__class__.__name__
if cls_name == "Qwen3MoeDecoderLayer":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use something like this pattern for matching? This way things don't break if the parent's structure changes, and we can also share matching logic between replacements

    for name, module in model.named_modules():
        cls_name = module.__class__.__name__
        if cls_name in replacements:
            new_module = replacements[cls_name](module)
            replace_module(model, name, new_module)

Copy link
Collaborator Author

@dsikka dsikka Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we want to use patch_attr in order to follow the other context set-up functionality, we need both the parent and the child so it would still require setting "mlp" - I think replace_modules finds the parent for you when replacing the module.

We could expand patch_attr I guess to follow that potentiallly

# Optionally update the model.config to pass in other arguments
stack.enter_context(
patch_attr(
module,
"mlp",
replace_Qwen3MoE(config=model.config, module=module.mlp),
)
)


moe_context = {
"Qwen3MoeForCausalLM": update_qwen3_moe,
}


def moe_calibration_context(model: PreTrainedModel, stack):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you show what it's like to pass additional calibration options (moe_calibrate_all_experts moe_calibrate_gated_acts), if these are still options we want to supply research/users

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a small comment but see other comment below.

# Temporarily updates the MoE modules within the context
# Once the context exists, parameter updates persist
cls_name = model.__class__.__name__
if cls_name in moe_context:
moe_context.get(cls_name)(model, stack)
87 changes: 87 additions & 0 deletions src/llmcompressor/modeling/qwen3_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# coding=utf-8
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May want to add the HF copyright and a note about the amendments

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the HF copyright - do you hava reference for what type of note should be made about amendments?

from transformers.models import Qwen3MoeConfig
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock,
)


class Qwen3MoeSparseMoeBlock(torch.nn.Module):
def __init__(
self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock
):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.top_k
self.norm_topk_prob = config.norm_topk_prob

# gating
self.gate = original.gate
self.experts = original.experts

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

routing_weights = torch.nn.functional.softmax(
router_logits, dim=1, dtype=torch.float
)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)

for expert_idx in range(len(self.experts)):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
expert_output = expert_layer(current_state)
current_hidden_states = expert_output * routing_weights[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)

final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim
)
return final_hidden_states, router_logits


def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock):
return Qwen3MoeSparseMoeBlock(config=config, original=module)
9 changes: 8 additions & 1 deletion src/llmcompressor/pipelines/basic/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
from typing import TYPE_CHECKING, Union

import torch
Expand All @@ -6,6 +7,7 @@
from torch.utils.data.dataloader import DataLoader

from llmcompressor.core import LifecycleCallbacks
from llmcompressor.modeling.prepare import moe_calibration_context
from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch
from llmcompressor.pipelines.registry import CalibrationPipeline
from llmcompressor.pytorch.utils.helpers import tensors_to_device
Expand Down Expand Up @@ -42,7 +44,12 @@ def __call__(

LifecycleCallbacks.calibration_epoch_start()

with calibration_forward_context(model):
with contextlib.ExitStack() as stack:
stack.enter_context(calibration_forward_context(model))

if dataset_args is not None and dataset_args.calibrate_moe_context:
moe_calibration_context(model, stack)

for batch in tqdm.tqdm(dataloader, desc="Calibrating"):
batch = apply_pad_mask_to_batch(batch)
batch = tensors_to_device(batch, model_device)
Expand Down
Loading