[None][feat] Qwen-Image: load pre-quantized ModelOpt NVFP4/FP8 checkpoints#15470
[None][feat] Qwen-Image: load pre-quantized ModelOpt NVFP4/FP8 checkpoints#15470jingyu-ml wants to merge 1 commit into
Conversation
…oints
Enable VisualGen to run Qwen-Image from a statically pre-quantized ModelOpt
checkpoint (NVFP4/FP8), and add the offline example + configs. Previously only
dynamic quantization (BF16 -> NVFP4 at load) worked; pointing --model at a
ModelOpt-exported NVFP4 checkpoint failed during weight loading.
transformer_qwen_image.py:
- Honor the checkpoint's quantization `ignore` list: clear quant_config on the
excluded Linear modules before create_weights() so they build the unquantized
method (ModelOpt stores those layers -- embedders, proj_out, norm_out,
time_text_embed, first/last blocks -- in BF16). get_quant_method() selects the
method purely from module.quant_config.
- Relax the strict weight-key check for FP8/NVFP4 helper buffers that are
derived at load time and never serialized (alpha, inv_input_scale, kv_scales,
inv_kv_scales).
Both changes are backward compatible with the dynamic-quant and BF16 paths.
Add examples/visual_gen/models/qwen_image.py and qwen-image-{fp4,bf16}-1gpu.yaml;
document static-checkpoint support in README and visual-generation.md; add a
unit test for the exclusion logic.
Validated on GB200 (sm_100): a static NVFP4 checkpoint loads (729/729 weights)
and renders a 1328x1328 image on par with BF16; qwen registry tests pass (7/7).
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
📝 WalkthroughWalkthroughAdds support for loading pre-quantized ModelOpt FP8/NVFP4 checkpoints into ChangesQwen-Image static quantization loading and examples
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/unittest/_torch/visual_gen/test_qwen_image_registry.py (1)
68-109: ⚡ Quick winCoverage is currently insufficient for the second loader behavior.
This test covers excluded-layer
quant_configclearing, but it does not exercise the new strict key-check relaxation for derived suffixes (.alpha,.inv_input_scale,.kv_scales,.inv_kv_scales) added inload_weights(). Please add a focused unit test intests/unittest/_torch/visual_gen/test_qwen_image_registry.py(or a dedicatedtest_qwen_image_transformer_load.py) that verifies:
- missing only derived keys does not raise in static mode, and
- missing a real weight key still does raise.
As per coding guidelines, tests under
tests/**should provide actionable coverage assessment for changed behavior.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/visual_gen/test_qwen_image_registry.py` around lines 68 - 109, Add a new focused unit test (either in the existing test_qwen_image_registry.py file or in a dedicated test_qwen_image_transformer_load.py file) that validates the strict key-check relaxation in the load_weights() method for derived suffix keys. The test should verify two scenarios: first, loading a checkpoint with missing only derived keys like .alpha, .inv_input_scale, .kv_scales, and .inv_kv_scales in static quantization mode should not raise an error, and second, loading a checkpoint with missing actual weight keys should still raise an error as expected. This ensures the key-check relaxation only applies to the derived suffixes and not to real weight keys.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/visual_gen/models/qwen_image/transformer_qwen_image.py`:
- Around line 889-909: The method `_clear_quant_config_on_excluded_layers()`
documents that it must run before any `create_weights()` call, but currently
runs in `load_weights()` which is too late. If
`skip_create_weights_in_init=False` allows excluded Linear modules to create
weights during `__init__`, clearing the quant_config later in `load_weights()`
won't switch them back to unquantized since `create_weights()` becomes a no-op.
Move the call to `_clear_quant_config_on_excluded_layers()` to run earlier in
the initialization sequence, such as right after the parent class `__init__`
completes or before weights can be created, to enforce the documented
precondition and ensure excluded layers use the correct unquantized
configuration.
---
Nitpick comments:
In `@tests/unittest/_torch/visual_gen/test_qwen_image_registry.py`:
- Around line 68-109: Add a new focused unit test (either in the existing
test_qwen_image_registry.py file or in a dedicated
test_qwen_image_transformer_load.py file) that validates the strict key-check
relaxation in the load_weights() method for derived suffix keys. The test should
verify two scenarios: first, loading a checkpoint with missing only derived keys
like .alpha, .inv_input_scale, .kv_scales, and .inv_kv_scales in static
quantization mode should not raise an error, and second, loading a checkpoint
with missing actual weight keys should still raise an error as expected. This
ensures the key-check relaxation only applies to the derived suffixes and not to
real weight keys.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 173a44f7-5ae5-44ca-897f-f445dda939bb
📒 Files selected for processing (7)
docs/source/models/visual-generation.mdexamples/visual_gen/README.mdexamples/visual_gen/configs/qwen-image-bf16-1gpu.yamlexamples/visual_gen/configs/qwen-image-fp4-1gpu.yamlexamples/visual_gen/models/qwen_image.pytensorrt_llm/_torch/visual_gen/models/qwen_image/transformer_qwen_image.pytests/unittest/_torch/visual_gen/test_qwen_image_registry.py
| def _clear_quant_config_on_excluded_layers(self) -> None: | ||
| """Build the checkpoint's excluded layers as unquantized. | ||
|
|
||
| ModelOpt keeps some layers in high precision (the ``ignore`` list in | ||
| the checkpoint's ``quantization_config`` -- e.g. img_in / txt_in / | ||
| proj_out / norm_out / time_text_embed and the first/last transformer | ||
| blocks), storing plain BF16 weights with no scales for them. | ||
| ``get_quant_method()`` selects the Linear method purely from | ||
| ``module.quant_config``, so clear it on the excluded Linear modules to | ||
| fall back to the unquantized method. Must run before any | ||
| ``create_weights()`` call, since parent modules (e.g. ``MLP``) may | ||
| materialize their child Linears eagerly. | ||
| """ | ||
| quant_config = self.model_config.quant_config | ||
| if quant_config is None or quant_config.quant_algo is None: | ||
| return | ||
| for name, module in self.named_modules(): | ||
| if (isinstance(module, Linear) | ||
| and quant_config.is_module_excluded_from_quantization(name)): | ||
| module.quant_config = None | ||
|
|
There was a problem hiding this comment.
Enforce the “before create_weights” precondition to avoid stale quantized layouts.
_clear_quant_config_on_excluded_layers() runs in load_weights(), but if any excluded Linear already created weights during __init__ (e.g., skip_create_weights_in_init=False), clearing quant_config on Line 908 won’t switch it back to unquantized because later create_weights() is a no-op. That can still mis-handle BF16 excluded layers.
Suggested fix
def _clear_quant_config_on_excluded_layers(self) -> None:
@@
- for name, module in self.named_modules():
- if (isinstance(module, Linear)
- and quant_config.is_module_excluded_from_quantization(name)):
- module.quant_config = None
+ already_materialized: list[str] = []
+ for name, module in self.named_modules():
+ if not isinstance(module, Linear):
+ continue
+ if not quant_config.is_module_excluded_from_quantization(name):
+ continue
+ if getattr(module, "_weights_created", False):
+ already_materialized.append(name)
+ module.quant_config = None
+
+ if already_materialized:
+ raise RuntimeError(
+ "Excluded layers already materialized quantized weights before quant_config "
+ f"clear: {already_materialized[:5]}. Construct with "
+ "skip_create_weights_in_init=True for static quantized checkpoints."
+ )Also applies to: 919-924
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/visual_gen/models/qwen_image/transformer_qwen_image.py`
around lines 889 - 909, The method `_clear_quant_config_on_excluded_layers()`
documents that it must run before any `create_weights()` call, but currently
runs in `load_weights()` which is too late. If
`skip_create_weights_in_init=False` allows excluded Linear modules to create
weights during `__init__`, clearing the quant_config later in `load_weights()`
won't switch them back to unquantized since `create_weights()` becomes a no-op.
Move the call to `_clear_quant_config_on_excluded_layers()` to run earlier in
the initialization sequence, such as right after the parent class `__init__`
completes or before weights can be created, to enforce the documented
precondition and ensure excluded layers use the correct unquantized
configuration.
Qwen-Image NVFP4 latency sweep — batch × resolution × GEMM backend (GB200 / sm_100)Follow-up benchmarks for this PR. Per-image latency, median, 50 steps, TL;DR
1024×1024
1328×1328
1536×1536
NVFP4 (best backend) vs BF16 — speedup (×, <1 = NVFP4 faster)
Note: cutedsl backend omitted — its CuteDSL NVFP4 JIT build is impractical for this model (>8 min/shape). W4A4 NVFP4 has no trtllm-gen cubin backend; |
What does this PR do?
Type of change: New feature / bug fix
Enables TensorRT-LLM VisualGen to run Qwen-Image from a pre-quantized
ModelOpt checkpoint (static NVFP4 / FP8), and adds the missing offline
example + configs. Previously Qwen-Image only supported dynamic quantization
(quantize a BF16 checkpoint at load); pointing
--modelat a ModelOpt-exportedNVFP4 checkpoint failed during weight loading.
Loader fix (
transformer_qwen_image.py):ignorelist. ModelOpt keeps the embedders(
img_in/txt_in),proj_out,norm_out,time_text_embed, and thefirst/last transformer blocks in BF16.
get_quant_method()selects theLinear method purely from
module.quant_config, so excludedLinears nowhave it cleared (→ unquantized method) before
create_weights(). Withoutthis they were built as NVFP4 and their BF16 weights failed to load.
at load time and never serialized (
alpha,inv_input_scale,kv_scales,inv_kv_scales).Both changes are backward compatible: the dynamic BF16→NVFP4 path
(
dynamic_weight_quant=True, noignorelist) is unaffected.Examples / docs:
examples/visual_gen/models/qwen_image.py— offline text-to-image example(mirrors
flux1.py).examples/visual_gen/configs/qwen-image-fp4-1gpu.yaml,qwen-image-bf16-1gpu.yaml.are now supported for Qwen-Image.
Usage
Output samples
NVFP4 image quality is on par with BF16 (1328×1328, 50 steps, seed 42, same prompt):
Qwen/Qwen-Image)Testing
On 1× GB200 (sm_100), TRT-LLM release container:
New unit test
test_static_quant_excludes_high_precision_layers+ existingtests/unittest/_torch/visual_gen/test_qwen_image_registry.py— 7 passed.End-to-end: a ModelOpt static NVFP4 checkpoint loads (729/729 weights) and
produces a prompt-faithful 1328² image, visually on par with BF16 (same seed).
Latency on 1× GB200 (1328², 50 steps, torch.compile on, median per image):
NVFP4 gives a ~3× smaller transformer (13 GB vs ~40 GB BF16) and is ~12 %
faster per image in the batched regime (batch ≥ 2) at equal quality. The
batch-1 case is slower only because NVFP4's per-call FP4 activation-quant +
graph-break overhead isn't amortized there. Among the W4A4 NVFP4 GEMM
backends,
cublasltis fastest; there is no trtllm-gen cubin for W4A4 (thatpath is W4A8, needing a different checkpoint).
Before your PR is "Ready for review"