-
Notifications
You must be signed in to change notification settings - Fork 176
[Calibration] Add MoE Calibration Context #1596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e3021ef
7310fed
fe6f316
04324d8
423c94f
6af3ffc
1a3dd30
079c71f
aee670f
3b9e2c2
a7af9ca
ea93089
945007c
528cdc8
d7039a1
9c96183
a5a42bd
8fc840f
732b2ea
3e860d6
331cca3
fc08e41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from datasets import load_dataset | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from llmcompressor import oneshot | ||
from llmcompressor.modifiers.quantization import QuantizationModifier | ||
from llmcompressor.utils import dispatch_for_generation | ||
|
||
MODEL_ID = "Qwen/Qwen3-30B-A3B" | ||
|
||
# Load model. | ||
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
|
||
|
||
DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
DATASET_SPLIT = "train_sft" | ||
|
||
# Select number of samples | ||
NUM_CALIBRATION_SAMPLES = 200 | ||
MAX_SEQUENCE_LENGTH = 2048 | ||
|
||
# Load dataset and preprocess. | ||
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
ds = ds.shuffle(seed=42) | ||
|
||
|
||
def preprocess(example): | ||
return { | ||
"text": tokenizer.apply_chat_template( | ||
example["messages"], | ||
tokenize=False, | ||
) | ||
} | ||
|
||
|
||
ds = ds.map(preprocess) | ||
|
||
|
||
# Tokenize inputs. | ||
def tokenize(sample): | ||
return tokenizer( | ||
sample["text"], | ||
padding=False, | ||
max_length=MAX_SEQUENCE_LENGTH, | ||
truncation=True, | ||
add_special_tokens=False, | ||
) | ||
|
||
|
||
ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
||
# Configure the quantization algorithm and scheme. | ||
# In this case, we: | ||
# * quantize the weights to fp4 with per group 16 via ptq | ||
# * calibrate a global_scale for activations, which will be used to | ||
# quantize activations to fp4 on the fly | ||
recipe = QuantizationModifier( | ||
targets="Linear", scheme="NVFP4", ignore=["lm_head", "re:.*mlp.gate$"] | ||
) | ||
|
||
# Apply quantization. | ||
# We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock` | ||
# during calibration | ||
oneshot( | ||
model=model, | ||
dataset=ds, | ||
recipe=recipe, | ||
max_seq_length=MAX_SEQUENCE_LENGTH, | ||
num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
calibrate_moe_context=True, | ||
) | ||
|
||
|
||
print("\n\n") | ||
print("========== SAMPLE GENERATION ==============") | ||
dispatch_for_generation(model) | ||
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") | ||
output = model.generate(input_ids, max_new_tokens=100) | ||
print(tokenizer.decode(output[0])) | ||
print("==========================================\n\n") | ||
|
||
|
||
# Save to disk in compressed-tensors format. | ||
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4" | ||
model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
tokenizer.save_pretrained(SAVE_DIR) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -117,6 +117,16 @@ class DatasetArguments(CustomDatasetArguments): | |
default=512, | ||
metadata={"help": "Number of samples to use for one-shot calibration"}, | ||
) | ||
calibrate_moe_context: bool = field( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this basically always be on? Is there ever a case where a user shouldn't use this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when you want to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, but there's no conflict between prepare_for_calibration(my_moe_model)
oneshot(my_moe_model, calibrate_moe_context=False) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be better to just remove it and always run with We require But yes, there is no conflict between the two. I think we can technically run deepseek with both but I haven't tested it with the context |
||
default=False, | ||
metadata={ | ||
"help": "If during calibration, the MoE context should be enabled " | ||
"for the given model. This usually involves updating all MoE modules " | ||
"in the model for the duration of calibration. See moe_context under " | ||
"modeling/prepare.py for a list of supported MoEs and their updated " | ||
"module definitions" | ||
}, | ||
) | ||
shuffle_calibration_samples: Optional[bool] = field( | ||
default=True, | ||
metadata={ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,20 +3,53 @@ | |
|
||
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 | ||
from llmcompressor.modeling.llama4 import replace as replace_llama4 | ||
from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE | ||
from llmcompressor.utils.helpers import patch_attr | ||
|
||
__all__ = ["prepare_for_calibration"] | ||
__all__ = ["replace_modules_for_calibration"] | ||
|
||
# ---------------------- module replacements; permanent ------------------------- | ||
replacements = { | ||
"DeepseekV3MoE": replace_deepseekv3, | ||
"Llama4TextMoe": replace_llama4, | ||
} | ||
|
||
|
||
def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel: | ||
def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: | ||
for name, module in model.named_modules(): | ||
cls_name = module.__class__.__name__ | ||
if cls_name in replacements: | ||
new_module = replacements[cls_name](config=model.config, module=module) | ||
replace_module(model, name, new_module) | ||
|
||
return model | ||
|
||
|
||
# ------------------- module replacements; during calibration -------------------- | ||
|
||
|
||
def update_qwen3_moe(model, stack): | ||
for module in model.modules(): | ||
cls_name = module.__class__.__name__ | ||
if cls_name == "Qwen3MoeDecoderLayer": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you use something like this pattern for matching? This way things don't break if the parent's structure changes, and we can also share matching logic between replacements
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if we want to use patch_attr in order to follow the other context set-up functionality, we need both the parent and the child so it would still require setting "mlp" - I think replace_modules finds the parent for you when replacing the module. We could expand patch_attr I guess to follow that potentiallly |
||
# Optionally update the model.config to pass in other arguments | ||
stack.enter_context( | ||
patch_attr( | ||
module, | ||
"mlp", | ||
replace_Qwen3MoE(config=model.config, module=module.mlp), | ||
) | ||
) | ||
|
||
|
||
moe_context = { | ||
"Qwen3MoeForCausalLM": update_qwen3_moe, | ||
} | ||
dsikka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def moe_calibration_context(model: PreTrainedModel, stack): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you show what it's like to pass additional calibration options ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I left a small comment but see other comment below. |
||
# Temporarily updates the MoE modules within the context | ||
# Once the context exists, parameter updates persist | ||
cls_name = model.__class__.__name__ | ||
if cls_name in moe_context: | ||
moe_context.get(cls_name)(model, stack) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# coding=utf-8 | ||
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. | ||
# All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May want to add the HF copyright and a note about the amendments There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added the HF copyright - do you hava reference for what type of note should be made about amendments? |
||
from transformers.models import Qwen3MoeConfig | ||
from transformers.models.qwen3_moe.modeling_qwen3_moe import ( | ||
Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock, | ||
) | ||
|
||
|
||
class Qwen3MoeSparseMoeBlock(torch.nn.Module): | ||
def __init__( | ||
self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock | ||
): | ||
super().__init__() | ||
self.num_experts = config.num_experts | ||
self.top_k = config.top_k | ||
self.norm_topk_prob = config.norm_topk_prob | ||
|
||
# gating | ||
self.gate = original.gate | ||
self.experts = original.experts | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
batch_size, sequence_length, hidden_dim = hidden_states.shape | ||
hidden_states = hidden_states.view(-1, hidden_dim) | ||
# router_logits: (batch * sequence_length, n_experts) | ||
router_logits = self.gate(hidden_states) | ||
|
||
routing_weights = torch.nn.functional.softmax( | ||
router_logits, dim=1, dtype=torch.float | ||
) | ||
routing_weights, selected_experts = torch.topk( | ||
routing_weights, self.top_k, dim=-1 | ||
) | ||
if self.norm_topk_prob: # only diff with mixtral sparse moe block! | ||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True) | ||
# we cast back to the input dtype | ||
routing_weights = routing_weights.to(hidden_states.dtype) | ||
final_hidden_states = torch.zeros( | ||
(batch_size * sequence_length, hidden_dim), | ||
dtype=hidden_states.dtype, | ||
device=hidden_states.device, | ||
) | ||
|
||
# One hot encode the selected experts to create an expert mask | ||
# this will be used to easily index which expert is going to be sollicitated | ||
expert_mask = torch.nn.functional.one_hot( | ||
selected_experts, num_classes=self.num_experts | ||
).permute(2, 1, 0) | ||
|
||
for expert_idx in range(len(self.experts)): | ||
expert_layer = self.experts[expert_idx] | ||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||
# Index the correct hidden states and compute the expert hidden state for | ||
# the current expert. We need to make sure to multiply the output hidden | ||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2) | ||
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) | ||
expert_output = expert_layer(current_state) | ||
current_hidden_states = expert_output * routing_weights[top_x, idx, None] | ||
# However `index_add_` only support torch tensors for indexing so we'll use | ||
# the `top_x` tensor here. | ||
final_hidden_states.index_add_( | ||
0, top_x, current_hidden_states.to(hidden_states.dtype) | ||
) | ||
|
||
final_hidden_states = final_hidden_states.reshape( | ||
batch_size, sequence_length, hidden_dim | ||
) | ||
return final_hidden_states, router_logits | ||
|
||
|
||
def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock): | ||
return Qwen3MoeSparseMoeBlock(config=config, original=module) |
Uh oh!
There was an error while loading. Please reload this page.