Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -931,8 +931,11 @@ def forward(self, inputs):
if self._if_quant:
# Check if the input tensor is contiguous
# Non-contiguous tensors will generate incorrect FP4 quantization results
# DISABLED: This check causes illegal memory access in distributed training
# The tensor appears to be corrupted upstream, before reaching the quantizer
# TODO: Investigate tensor corruption in attention mechanism
if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous():
inputs.data = inputs.data.contiguous()
inputs = inputs.contiguous()
if self.fake_quant:
outputs = self._fake_quantize(inputs)
elif not self._dequantize:
Expand Down
152 changes: 44 additions & 108 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def real_quant_module_get_extra_state(self) -> dict:
def quant_module_get_extra_state(self) -> dict:
"""Populating the extra_state when state_dict() is called.

quantizer_state, real_quantizer_state, and q_tensor_state are usually stored
quantizer_state, real_quantizer_state, and q_tensor_state used to be stored (before 0.29)
with in the modelopt_state metadata where the keys are the full module name. The issue
is that NeMo-MCore model's full module name can change
if pipeline-parallelism (PP) and expert-parallelism (EP)
Expand All @@ -94,8 +94,9 @@ def quant_module_get_extra_state(self) -> dict:
which avoids the need to store the full module name.
"""
extra_state = {}

is_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False
is_enabled = any(
isinstance(child, TensorQuantizer) and child.is_enabled for child in self.children()
)

if not is_enabled:
return extra_state
Expand All @@ -109,7 +110,6 @@ def quant_module_get_extra_state(self) -> dict:

# Handle real_quantizer_state and q_tensor_state
extra_state.update(real_quant_module_get_extra_state(self))

return extra_state


Expand Down Expand Up @@ -219,6 +219,10 @@ def _register_extra_state_callbacks(model: torch.nn.Module):
quant_module_get_extra_state,
quant_module_set_extra_state,
)
if HAS_TE and isinstance(module, TEDotProductAttention):
# A hack to set the dtype and device for DotProductAttention
# to be used in _QuantTEDotProductAttention.modelopt_post_restore()
_QuantTEDotProductAttention.set_dtype(module, name, model)

for name, module in model.named_modules():
if isinstance(module, MegatronModule):
Expand Down Expand Up @@ -612,57 +616,24 @@ def _setup(self):
self.k_bmm_quantizer = TensorQuantizer()
self.v_bmm_quantizer = TensorQuantizer()

def _calibrate_quantizers(self):
"""Calibrate quantizers with minimal dummy tensors."""
# Get device and dtype from the parent module's parameters
param = next(iter(self.parameters()), None)
device = param.device if param is not None else torch.device("cuda")
dtype = param.dtype if param is not None else torch.float16

# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
batch_size = 1
seq_len = 1

# Get dimensions from config
num_heads = self.config.num_attention_heads
head_dim = (
self.config.kv_channels
if hasattr(self.config, "kv_channels")
else self.config.hidden_size // num_heads
)

# Determine tensor format (default to sbhd if not specified)
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False)
qkv_format = "bshd" if apply_rope_fusion else "sbhd"

if qkv_format == "sbhd":
dummy_tensor = torch.randn(
seq_len, batch_size, num_heads, head_dim, device=device, dtype=dtype
)
else:
dummy_tensor = torch.randn(
batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype
)

# Calibrate each quantizer
quantizers = [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]

for _, quantizer in quantizers:
if quantizer is not None and quantizer.is_enabled():
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
quantizer.reset_amax()
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)

def forward(self, query, key, value, *args, **kwargs):
"""Apply post-RoPE quantization to KV cache.

TEDotProductAttention receives Q, K, V after RoPE is applied,
so we quantize them directly for KV cache quantization.
"""
# Ensure tensors are contiguous before quantization
# This is a safety measure for potential non-contiguous tensor views
# from TE or Megatron operations with tensor parallelism
def materialize_if_needed(tensor):
if tensor is not None and hasattr(tensor, 'is_contiguous') and not tensor.is_contiguous():
return tensor.contiguous()
return tensor

query = materialize_if_needed(query)
key = materialize_if_needed(key)
value = materialize_if_needed(value)
Comment on lines +633 to +635
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this if we are calling inputs = inputs.contiguous() in TensorQuantize forward?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO these lines may not be necessary


# Quantize Q, K, V
query = self.q_bmm_quantizer(query)
key = self.k_bmm_quantizer(key)
Expand All @@ -672,44 +643,19 @@ def forward(self, query, key, value, *args, **kwargs):

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
"""Create a sharded state dictionary for distributed checkpointing."""
state_dict = self.state_dict(prefix='', keep_vars=True)
sharded_state_dict = {}

# First add non-quantizer parameters
for k, v in self.state_dict(prefix="", keep_vars=True).items():
if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k:
sharded_state_dict[prefix + k] = v

# Process _amax in bmm_quantizers
for name, quantizer in [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]:
if hasattr(quantizer, "_amax") and quantizer._amax is not None:
amax_key = f"{prefix}{name}._amax"
sharded_state_dict[amax_key] = quantizer._amax

# Process other quantizer parameters in bmm_quantizers
quantizer_state_dict = {
k: v
for k, v in self.state_dict(prefix="", keep_vars=True).items()
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k
}

if quantizer_state_dict:
sharded_state_dict.update(
**make_sharded_tensors_for_checkpoint(
quantizer_state_dict, prefix, {}, sharded_offsets
)
)

tmp = make_sharded_tensors_for_checkpoint(state_dict, prefix, {}, sharded_offsets)
for k, v in tmp.items():
sharded_state_dict[k] = v.data
return sharded_state_dict


def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
"""Handle loading state dict for quantizers."""
for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]:
full_prefix = f"{prefix}{quantizer_name}."
amax_key = f"{prefix}{quantizer_name}._amax"
amax_key = f"{full_prefix}_amax"

# If amax is in state_dict, rename it to the format expected by TensorQuantizer
if amax_key in state_dict:
Expand All @@ -727,37 +673,27 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):

def modelopt_post_restore(self, name=""):
"""Restore quantizer states after model loading."""
super().modelopt_post_restore(name)

def _check_unsupported_states(quantizer):
"""Check for unsupported quantizer states and warn if found."""
if not hasattr(quantizer, "state_dict"):
return

for k in quantizer.state_dict():
if k not in ["_amax", "_pre_quant_scale"]:
warnings.warn(
f"Restore of {k} for {name} is not supported. The restore of this layer might be "
f"incorrect. Please implement a custom restore for {k}."
)

calibration_needed = False

for quantizer_name, quantizer in [
("q_bmm_quantizer", self.q_bmm_quantizer),
("k_bmm_quantizer", self.k_bmm_quantizer),
("v_bmm_quantizer", self.v_bmm_quantizer),
]:
if not hasattr(self, quantizer_name) or not quantizer.is_enabled():
continue
for tq in [self.q_bmm_quantizer, self.k_bmm_quantizer, self.v_bmm_quantizer]:
# TODO: Add support for non-scalar states such as
# Affine KVCache bias vector which is per head per channel
assert all(v.numel() == 1 for v in tq.state_dict().values()), (
"Only scalar states are KV Cache/BMM Quantizers"
)
# Should have been set in the `megatron_replace_quant_module_hook`
assert hasattr(self, "device") and hasattr(self, "dtype")
self.to(device=self.device, dtype=self.dtype)

_check_unsupported_states(quantizer)
@staticmethod
def set_dtype(module: "TEDotProductAttention", name, model: torch.nn.Module):
"""Set the dtype for the module from any parameter in the model.

if not hasattr(quantizer, "_amax") or quantizer._amax is None:
calibration_needed = True
DotProductAttention does not have any parameters, so lets get the parameter from the parent module.
"""
parent = model.get_submodule(name.rsplit(".", 1)[0]) if "." in name else model
param = next(iter(parent.parameters()))
module.dtype = param.dtype
module.device = param.device

if calibration_needed:
self._calibrate_quantizers()


@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"})
Expand Down
21 changes: 20 additions & 1 deletion modelopt/torch/quantization/tensor_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,29 @@


def _fp8_eager(x, amax=None):
"""Eager mode implementation of FP8 E4M3 fake quantization.

Args:
x: Input tensor.
amax: Absolute max value for scaling. If None, only dtype conversion is performed.

Returns:
Fake-quantized tensor in original dtype.
"""
dtype = x.dtype

if amax is not None:
scale = 448.0 / (amax.to(torch.float32))
scale_inv = 1 / scale
x = x.to(torch.float32) * scale
# Clamp to FP8 E4M3 range to prevent NaN/Inf during conversion
x = torch.clamp(x, min=-448.0, max=448.0)

x = x.to(torch.float8_e4m3fn)

if amax is not None:
x = x.to(torch.float32) * scale_inv

return x.to(dtype)


Expand All @@ -76,7 +91,11 @@ def scaled_e4m3_impl(
return fp8_eager(inputs, amax)

cuda_ext_fp8 = get_cuda_ext_fp8(raise_if_failed=False)
if cuda_ext_fp8 is None:
# NOTE: CUDA extension disabled due to bug with GQA/MQA (singleton KV head dimension)
# and tensor parallelism. The fake_e4m3fy() kernel produces corrupted output for
# tensors with shape [seq_len, 1, head_dim] when TP > 1.
# Using eager fallback until kernel is fixed.
if cuda_ext_fp8 is None:
return fp8_eager(inputs, amax)

with torch.cuda.device(
Expand Down
10 changes: 10 additions & 0 deletions tests/_test_utils/torch/megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,23 @@ def convert_maybe_fp8(v):
f"{k} v:{v}, s[k]: {state_dict_test[k]}"
)

model_test.train()
logits_test = forward_fn(model_test)

logits_diff = (logits_test - logits_ref) / logits_ref
assert torch.allclose(logits_ref, logits_test), (
f"diff: {logits_diff.max()} ref: {logits_ref}, test: {logits_test}"
)

# Test backward pass on model_test
loss = logits_test.sum()
loss.backward()

# Assert that trainable parameters have gradients computed
for name, param in model_test.named_parameters():
if param.requires_grad:
assert param.grad is not None, f"Parameter {name} has no gradient computed"


def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_moe_model):
"""Copy weights from TEGrouped MoE model to sequential MoE model."""
Expand Down
10 changes: 5 additions & 5 deletions tests/gpu/torch/quantization/plugins/test_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,11 +836,7 @@ def forward_fn(model):

# Quantize the reference model
model_ref = mtq.quantize(model_ref, config, forward_fn)

# CRITICAL: model_test must also be quantized with the same config
# Otherwise it won't have the KV cache quantizer keys when loading state dict
model_test = mtq.quantize(model_test, config, forward_fn)

Comment on lines 838 to -843
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaix-nv this is an incorrect unit test. This completely breaks the modelopt resume workflow (that is resume requires an ModelOpt un-modified model).


# Verify KV cache quantizers were created
kv_quantizers_found = False
for name, module in model_ref.named_modules():
Expand All @@ -851,6 +847,10 @@ def forward_fn(model):

assert kv_quantizers_found, "No KV cache quantizers found in quantized model"

# CRITICAL: model_test must also be quantized with the same config
# Otherwise it won't have the KV cache quantizer keys when loading state dict
# model_test = mtq.quantize(model_test, config, forward_fn)

# Test sharded state dict save/load
sharded_state_dict_test_helper(
tmp_path,
Expand Down
Loading