Skip to content

Commit 55b9106

Browse files
Add changes for 0.33.1 release
Signed-off-by: Keval Morabia <[email protected]>
1 parent 7a27f2a commit 55b9106

File tree

10 files changed

+64
-28
lines changed

10 files changed

+64
-28
lines changed

CHANGELOG.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Model Optimizer Changelog (Linux)
22
=================================
33

4-
0.33 (2025-07-xx)
4+
0.33 (2025-07-14)
55
^^^^^^^^^^^^^^^^^
66

77
**Backward Breaking Changes**
@@ -20,7 +20,8 @@ Model Optimizer Changelog (Linux)
2020
- Add per node calibration support in ONNX quantization.
2121
- ModelOpt now supports quantization of tensor-parallel sharded Huggingface transformer models. This requires ``transformers>=4.52.0``.
2222
- Support quantization of FSDP2 wrapped models and add FSDP2 support in the ``llm_qat`` example.
23-
- Add NeMo 2 Simplified Flow examples for quantization aware training/distillation (QAT/QAD), speculative decoding, pruning & distilllation.
23+
- Add NeMo 2 Simplified Flow examples for quantization aware training/distillation (QAT/QAD), speculative decoding, pruning & distillation.
24+
- Fix a Qwen3 MOE model export issue.
2425

2526
0.31 (2025-06-04)
2627
^^^^^^^^^^^^^^^^^

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,14 +410,16 @@ def get_onnx_bytes_and_metadata(
410410
)
411411
with torch.inference_mode(), autocast, quantizer_context:
412412
if not dynamo_export or Version(torch.__version__) >= Version("2.6"):
413+
additional_kwargs = {}
414+
if not dynamo_export and Version(torch.__version__) >= Version("2.8"):
415+
additional_kwargs["dynamic_axes"] = dynamic_axes
413416
torch.onnx.export(
414417
model,
415418
dummy_input,
416419
onnx_save_path,
417420
input_names=input_names,
418421
output_names=output_names,
419422
opset_version=onnx_opset,
420-
dynamic_axes=dynamic_axes,
421423
dynamo=dynamo_export,
422424
)
423425
else: # torch < 2.6 with dynamo export

modelopt/torch/export/unified_export_hf.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import collections.abc
1919
import json
20+
import re
2021
import tempfile
2122
import warnings
2223
from collections import defaultdict
@@ -97,7 +98,12 @@ def _output_hook(module, input, output):
9798
handles = []
9899
model_type = type(model).__name__.lower()
99100

101+
fused_linears = {}
102+
module_names = set()
103+
100104
for name, module in model.named_modules():
105+
module_names.add(name)
106+
101107
# For MoE models update pre_quant_scale to average pre_quant_scale amongst experts
102108
if is_moe(module) and ("awq" in quantization_format):
103109
# update_experts_avg_prequant_scale(module)
@@ -151,6 +157,7 @@ def _output_hook(module, input, output):
151157
]:
152158
# Fuse modules that have the same input
153159
preprocess_linear_fusion(modules)
160+
fused_linears[modules[0].name] = [module.name for module in modules]
154161

155162
# Fuse layernorms
156163
if (
@@ -161,6 +168,29 @@ def _output_hook(module, input, output):
161168
# Pre quant scale of modules is already updated to avg_pre_quant_scale
162169
fuse_prequant_layernorm(output_to_layernorm[tensor], modules)
163170

171+
# The dummy forward may not be able to activate all the experts.
172+
# Process experts by naming rules like experts.0, experts.1, etc.
173+
for name, modules_fused in fused_linears.items():
174+
if re.search(r"experts?\.\d+", name):
175+
expert_id = 0
176+
while True:
177+
new_expert_name = re.sub(r"(experts?\.)\d+", rf"\g<1>{expert_id}", name, count=1)
178+
if new_expert_name in fused_linears:
179+
expert_id += 1
180+
continue
181+
if new_expert_name not in module_names:
182+
break
183+
184+
new_expert_modules = []
185+
for name_fused in modules_fused:
186+
new_expert_name = re.sub(r"(experts?\.)\d+", rf"\g<1>{expert_id}", name_fused)
187+
assert new_expert_name in module_names
188+
new_expert_modules.append(model.get_submodule(new_expert_name))
189+
190+
preprocess_linear_fusion(new_expert_modules)
191+
192+
expert_id += 1
193+
164194

165195
def _export_hf_checkpoint(
166196
model: nn.Module, dtype: torch.dtype | None = None

modelopt/torch/nas/modules/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ def get_sliced_tensor_by_slices(
4040
tensor_sliced = tensor
4141
for i, _ in enumerate(slices):
4242
if sum(not isinstance(s, slice) for s in slices) < 2:
43-
tensor_sliced = tensor_sliced[slices]
43+
tensor_sliced = tensor_sliced[tuple(slices)]
4444
break
45-
tensor_sliced = tensor_sliced[slices[: i + 1]]
45+
tensor_sliced = tensor_sliced[tuple(slices[: i + 1])]
4646
slices[i] = slice(None) # replace with a vanilla slice ("[:]") for next slicing iteration
4747

4848
# return sliced, contiguous tensor

modelopt/torch/quantization/qtensor/base_qtensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def pack_real_quantize_weight(module, force_quantize: bool = False):
173173

174174
with SequentialQuantizer.convert_to_single_quantizer(module), torch.no_grad():
175175
for _, m in module.named_modules():
176-
if hasattr(m, "weight") and m.weight.is_meta:
176+
if hasattr(m, "weight") and (m.weight is None or m.weight.is_meta):
177177
continue
178178
if (
179179
hasattr(m, "weight_quantizer")

modelopt/torch/quantization/tensor_quant.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import modelopt.torch.quantization.triton as triton_kernel
2626

2727
from .config import QuantizerAttributeConfig
28-
from .export_onnx import export_fp4, export_fp8, export_int8, export_mxfp8
2928
from .extensions import get_cuda_ext, get_cuda_ext_fp8, get_cuda_ext_mx
3029

3130
mx_format_map = {
@@ -325,6 +324,8 @@ def symbolic(
325324
trt_high_precision_dtype=None,
326325
):
327326
"""ONNX symbolic function."""
327+
from .export_onnx import export_int8
328+
328329
return export_int8(
329330
g, inputs, amax, num_bits, unsigned, narrow_range, trt_high_precision_dtype
330331
)
@@ -395,6 +396,8 @@ class ScaledE4M3Function(Function):
395396
@symbolic_helper.parse_args("v", "t", "t", "i", "i", "s")
396397
def symbolic(g, inputs, amax=None, bias=None, E=4, M=3, trt_high_precision_dtype=None): # noqa: N803
397398
"""ONNX symbolic function."""
399+
from .export_onnx import export_fp8
400+
398401
return export_fp8(g, inputs, amax, trt_high_precision_dtype)
399402

400403
@staticmethod
@@ -475,6 +478,8 @@ def symbolic(
475478
onnx_quantizer_type="dynamic",
476479
):
477480
"""ONNX symbolic function."""
481+
from .export_onnx import export_fp4, export_mxfp8
482+
478483
if num_bits == (2, 1) and scale_bits == (4, 3):
479484
return export_fp4(
480485
g,
@@ -643,6 +648,8 @@ def symbolic(
643648
trt_high_precision_dtype=None,
644649
):
645650
"""ONNX symbolic function."""
651+
from .export_onnx import export_int8
652+
646653
return export_int8(
647654
g, inputs, amax, num_bits, unsigned, narrow_range, trt_high_precision_dtype
648655
)

setup.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
# Package configuration ############################################################################
2424
name = "nvidia-modelopt"
2525
version = os.environ.get(
26-
"SETUPTOOLS_SCM_PRETEND_VERSION", "0.33.0" if platform.system() == "Linux" else "0.27.0"
26+
"SETUPTOOLS_SCM_PRETEND_VERSION", "0.33.1" if platform.system() == "Linux" else "0.27.0"
2727
)
2828
packages = setuptools.find_namespace_packages(include=["modelopt*"])
2929
package_dir = {"": "."}
@@ -56,11 +56,13 @@
5656
"cppimport",
5757
"cupy-cuda12x; platform_machine != 'aarch64' and platform_system != 'Darwin'",
5858
"ml_dtypes", # for bfloat16 conversion
59-
"onnx>=1.18.0",
6059
"onnx-graphsurgeon",
60+
"onnx>=1.18.0",
61+
"onnxconverter-common",
6162
"onnxruntime~=1.22.0 ; platform_machine == 'aarch64' or platform_system == 'Darwin'",
6263
"onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501
6364
"onnxruntime-gpu==1.20.0; platform_system == 'Windows'",
65+
"onnxscript", # For test_onnx_dynamo_export unit test
6466
"onnxsim ; python_version < '3.12' and platform_machine != 'aarch64'",
6567
"polygraphy>=0.49.22",
6668
],
@@ -82,13 +84,12 @@
8284
# testing
8385
"dev-test": [
8486
"coverage",
85-
"onnxscript", # For test_onnx_dynamo_export unit test
8687
"pytest",
8788
"pytest-cov",
8889
"pytest-timeout",
8990
"timm",
90-
"tox",
91-
"tox-current-env>=0.0.12", # Incompatible with tox==4.18.0
91+
"tox>4.18",
92+
"tox-current-env>=0.0.12",
9293
],
9394
# docs
9495
"dev-docs": [

tests/_test_utils/torch_model/transformers_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
LlamaConfig,
2828
LlamaForCausalLM,
2929
T5Config,
30-
T5Model,
30+
T5ForConditionalGeneration,
3131
T5Tokenizer,
3232
)
3333

@@ -50,7 +50,7 @@ def get_tiny_llama(**config_kwargs) -> LlamaForCausalLM:
5050
return tiny_llama
5151

5252

53-
def get_tiny_t5(**config_kwargs) -> T5Model:
53+
def get_tiny_t5(**config_kwargs) -> T5ForConditionalGeneration:
5454
kwargs = {
5555
"vocab_size": 32,
5656
"d_model": 32,
@@ -63,7 +63,7 @@ def get_tiny_t5(**config_kwargs) -> T5Model:
6363
"decoder_start_token_id": 0,
6464
}
6565
kwargs.update(**config_kwargs)
66-
t5_model = T5Model(T5Config(**kwargs))
66+
t5_model = T5ForConditionalGeneration(T5Config(**kwargs))
6767

6868
return t5_model
6969

tests/unit/torch/quantization/test_quant_rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def test_fake_quant_per_channel(self, original_cls, bidirectional):
211211

212212
out1 = quant_rnn_object(test_input)[0]
213213
out2 = rnn_object_original(test_input)[0]
214-
assert torch.allclose(out1, out2)
214+
assert torch.allclose(out1, out2, atol=1e-5)
215215

216216
@pytest.mark.parametrize(
217217
("original_cls", "bidirectional"),

tests/unit/torch/trace/test_symbol.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,6 @@
1919
from modelopt.torch.trace import RobustTracer, Symbol, SymMap
2020
from modelopt.torch.trace.modules.nn import get_conv_sym_info, get_linear_sym_info
2121

22-
try:
23-
import megatron # noqa: F401
24-
import transformer_engine # noqa: F401
25-
26-
SKIP = True
27-
except ImportError:
28-
SKIP = False
29-
3022

3123
def test_symbol_cls():
3224
sym = Symbol(elastic_dims={1, 2}, cl_type=Symbol.CLType.INCOMING)
@@ -117,11 +109,7 @@ def assert_num_symbols():
117109
assert_num_symbols()
118110

119111

120-
@pytest.mark.skipif(SKIP, reason="This cpu unit test will fail on GPU with Megatron/TE installed!")
121112
def test_sym_map_registry():
122-
# NOTE: If running with transformer_engine or megatron-core installed, this test will fail.
123-
# Ignoring this error for now, as it will only be there if running CPU tests on a GPU machine
124-
# with the above packages installed.
125113
mods_in_registry = {
126114
nn.Linear,
127115
nn.BatchNorm1d,
@@ -151,6 +139,13 @@ def test_sym_map_registry():
151139
except ImportError:
152140
pass
153141

142+
try:
143+
from megatron.core.models.gpt import GPTModel
144+
145+
mods_in_registry.add(GPTModel)
146+
except ImportError:
147+
pass
148+
154149
not_a_leaf = {nn.Sequential}
155150
dependent_registry = set()
156151

0 commit comments

Comments
 (0)