Skip to content

[None][feat] Qwen-Image: load pre-quantized ModelOpt NVFP4/FP8 checkpoints#15470

Open
jingyu-ml wants to merge 1 commit into
NVIDIA:mainfrom
jingyu-ml:feat/qwen-image-modelopt-nvfp4-load
Open

[None][feat] Qwen-Image: load pre-quantized ModelOpt NVFP4/FP8 checkpoints#15470
jingyu-ml wants to merge 1 commit into
NVIDIA:mainfrom
jingyu-ml:feat/qwen-image-modelopt-nvfp4-load

Conversation

@jingyu-ml

@jingyu-ml jingyu-ml commented Jun 18, 2026

Copy link
Copy Markdown

What does this PR do?

Type of change: New feature / bug fix

Enables TensorRT-LLM VisualGen to run Qwen-Image from a pre-quantized
ModelOpt checkpoint
(static NVFP4 / FP8), and adds the missing offline
example + configs. Previously Qwen-Image only supported dynamic quantization
(quantize a BF16 checkpoint at load); pointing --model at a ModelOpt-exported
NVFP4 checkpoint failed during weight loading.

Loader fix (transformer_qwen_image.py):

  1. Honor the checkpoint's ignore list. ModelOpt keeps the embedders
    (img_in/txt_in), proj_out, norm_out, time_text_embed, and the
    first/last transformer blocks in BF16. get_quant_method() selects the
    Linear method purely from module.quant_config, so excluded Linears now
    have it cleared (→ unquantized method) before create_weights(). Without
    this they were built as NVFP4 and their BF16 weights failed to load.
  2. Relax the strict key check for FP8/NVFP4 helper buffers that are derived
    at load time and never serialized (alpha, inv_input_scale, kv_scales,
    inv_kv_scales).

Both changes are backward compatible: the dynamic BF16→NVFP4 path
(dynamic_weight_quant=True, no ignore list) is unaffected.

Examples / docs:

  • examples/visual_gen/models/qwen_image.py — offline text-to-image example
    (mirrors flux1.py).
  • examples/visual_gen/configs/qwen-image-fp4-1gpu.yaml,
    qwen-image-bf16-1gpu.yaml.
  • README + visual-generation.md note that pre-quantized ModelOpt checkpoints
    are now supported for Qwen-Image.

Usage

# NVFP4 — quantization read from the checkpoint's transformer/config.json
python examples/visual_gen/models/qwen_image.py \
    --model <qwen-image-nvfp4> \
    --visual_gen_args examples/visual_gen/configs/qwen-image-fp4-1gpu.yaml

# BF16 baseline
python examples/visual_gen/models/qwen_image.py --model Qwen/Qwen-Image

Output samples

NVFP4 image quality is on par with BF16 (1328×1328, 50 steps, seed 42, same prompt):

BF16 (Qwen/Qwen-Image) NVFP4 (ModelOpt checkpoint, this PR)

Testing

On 1× GB200 (sm_100), TRT-LLM release container:

  • New unit test test_static_quant_excludes_high_precision_layers + existing
    tests/unittest/_torch/visual_gen/test_qwen_image_registry.py7 passed.

  • End-to-end: a ModelOpt static NVFP4 checkpoint loads (729/729 weights) and
    produces a prompt-faithful 1328² image, visually on par with BF16 (same seed).

  • Latency on 1× GB200 (1328², 50 steps, torch.compile on, median per image):

    batch BF16 / img NVFP4 / img
    1 7.49 s 11.13 s
    2 7.39 s 6.63 s
    4 7.23 s 6.36 s

    NVFP4 gives a ~3× smaller transformer (13 GB vs ~40 GB BF16) and is ~12 %
    faster per image in the batched regime (batch ≥ 2)
    at equal quality. The
    batch-1 case is slower only because NVFP4's per-call FP4 activation-quant +
    graph-break overhead isn't amortized there. Among the W4A4 NVFP4 GEMM
    backends, cublaslt is fastest; there is no trtllm-gen cubin for W4A4 (that
    path is W4A8, needing a different checkpoint).

Before your PR is "Ready for review"

  • Backward compatible: ✅ (dynamic-quant and BF16 paths unchanged)
  • New tests: ✅
  • Changelog: ❌ (feature; add entry if required)

…oints

Enable VisualGen to run Qwen-Image from a statically pre-quantized ModelOpt
checkpoint (NVFP4/FP8), and add the offline example + configs. Previously only
dynamic quantization (BF16 -> NVFP4 at load) worked; pointing --model at a
ModelOpt-exported NVFP4 checkpoint failed during weight loading.

transformer_qwen_image.py:
- Honor the checkpoint's quantization `ignore` list: clear quant_config on the
  excluded Linear modules before create_weights() so they build the unquantized
  method (ModelOpt stores those layers -- embedders, proj_out, norm_out,
  time_text_embed, first/last blocks -- in BF16). get_quant_method() selects the
  method purely from module.quant_config.
- Relax the strict weight-key check for FP8/NVFP4 helper buffers that are
  derived at load time and never serialized (alpha, inv_input_scale, kv_scales,
  inv_kv_scales).

Both changes are backward compatible with the dynamic-quant and BF16 paths.

Add examples/visual_gen/models/qwen_image.py and qwen-image-{fp4,bf16}-1gpu.yaml;
document static-checkpoint support in README and visual-generation.md; add a
unit test for the exclusion logic.

Validated on GB200 (sm_100): a static NVFP4 checkpoint loads (729/729 weights)
and renders a 1328x1328 image on par with BF16; qwen registry tests pass (7/7).

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested review from a team as code owners June 18, 2026 00:56
@jingyu-ml jingyu-ml requested review from chang-l and nv-guomingz June 18, 2026 00:56
@coderabbitai

coderabbitai Bot commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Adds support for loading pre-quantized ModelOpt FP8/NVFP4 checkpoints into QwenImageTransformer2DModel by introducing a method to clear quant_config on excluded layers before weight creation and filtering derived parameter suffixes from strict missing-key checks. Includes a new Qwen-Image CLI example script, BF16 and NVFP4 YAML configs, and updated documentation.

Changes

Qwen-Image static quantization loading and examples

Layer / File(s) Summary
Static quantization loading in transformer model
tensorrt_llm/_torch/visual_gen/models/qwen_image/transformer_qwen_image.py
Adds _QUANT_DERIVED_PARAM_SUFFIXES constant, _clear_quant_config_on_excluded_layers() method that nullifies quant_config on excluded Linear modules, and updates load_weights() to call that method before create_weights() and exclude derived-param suffixes from strict missing-key validation.
Registry smoke test for excluded-layer quant clearing
tests/unittest/_torch/visual_gen/test_qwen_image_registry.py
Adds test_static_quant_excludes_high_precision_layers asserting that after _clear_quant_config_on_excluded_layers(), excluded submodules have quant_config is None while non-excluded transformer blocks retain their NVFP4 quant_config.
CLI example script and YAML configs
examples/visual_gen/models/qwen_image.py, examples/visual_gen/configs/qwen-image-bf16-1gpu.yaml, examples/visual_gen/configs/qwen-image-fp4-1gpu.yaml
Adds qwen_image.py CLI with _output_paths helper and main() using VisualGen/VisualGenArgs.from_yaml, plus BF16 and NVFP4 single-GPU YAML configs (VANILLA attention backend, CUDA graphs disabled).
README and model docs updates
examples/visual_gen/README.md, docs/source/models/visual-generation.md
Adds python models/qwen_image.py invocations and NVFP4 guidance to the VisualGen README; expands the [^2] footnote in visual-generation.md to document FP8/NVFP4 dynamic vs. static quantization and the ignore list behavior.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested labels

VisualGen

Suggested reviewers

  • schetlur-nv
  • FrankD412
  • dpitman-nvda
  • yiqingy0
  • QiJune
  • crazydemo
  • Wanli-Jiang
  • juney-nvidia
  • chang-l
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main feature: enabling Qwen-Image to load pre-quantized ModelOpt NVFP4/FP8 checkpoints, which aligns with the primary objective of the PR.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description is comprehensive, well-structured, and covers all required sections with clear explanations of what, why, and how.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/unittest/_torch/visual_gen/test_qwen_image_registry.py (1)

68-109: ⚡ Quick win

Coverage is currently insufficient for the second loader behavior.

This test covers excluded-layer quant_config clearing, but it does not exercise the new strict key-check relaxation for derived suffixes (.alpha, .inv_input_scale, .kv_scales, .inv_kv_scales) added in load_weights(). Please add a focused unit test in tests/unittest/_torch/visual_gen/test_qwen_image_registry.py (or a dedicated test_qwen_image_transformer_load.py) that verifies:

  1. missing only derived keys does not raise in static mode, and
  2. missing a real weight key still does raise.

As per coding guidelines, tests under tests/** should provide actionable coverage assessment for changed behavior.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unittest/_torch/visual_gen/test_qwen_image_registry.py` around lines 68
- 109, Add a new focused unit test (either in the existing
test_qwen_image_registry.py file or in a dedicated
test_qwen_image_transformer_load.py file) that validates the strict key-check
relaxation in the load_weights() method for derived suffix keys. The test should
verify two scenarios: first, loading a checkpoint with missing only derived keys
like .alpha, .inv_input_scale, .kv_scales, and .inv_kv_scales in static
quantization mode should not raise an error, and second, loading a checkpoint
with missing actual weight keys should still raise an error as expected. This
ensures the key-check relaxation only applies to the derived suffixes and not to
real weight keys.

Source: Coding guidelines

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tensorrt_llm/_torch/visual_gen/models/qwen_image/transformer_qwen_image.py`:
- Around line 889-909: The method `_clear_quant_config_on_excluded_layers()`
documents that it must run before any `create_weights()` call, but currently
runs in `load_weights()` which is too late. If
`skip_create_weights_in_init=False` allows excluded Linear modules to create
weights during `__init__`, clearing the quant_config later in `load_weights()`
won't switch them back to unquantized since `create_weights()` becomes a no-op.
Move the call to `_clear_quant_config_on_excluded_layers()` to run earlier in
the initialization sequence, such as right after the parent class `__init__`
completes or before weights can be created, to enforce the documented
precondition and ensure excluded layers use the correct unquantized
configuration.

---

Nitpick comments:
In `@tests/unittest/_torch/visual_gen/test_qwen_image_registry.py`:
- Around line 68-109: Add a new focused unit test (either in the existing
test_qwen_image_registry.py file or in a dedicated
test_qwen_image_transformer_load.py file) that validates the strict key-check
relaxation in the load_weights() method for derived suffix keys. The test should
verify two scenarios: first, loading a checkpoint with missing only derived keys
like .alpha, .inv_input_scale, .kv_scales, and .inv_kv_scales in static
quantization mode should not raise an error, and second, loading a checkpoint
with missing actual weight keys should still raise an error as expected. This
ensures the key-check relaxation only applies to the derived suffixes and not to
real weight keys.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 173a44f7-5ae5-44ca-897f-f445dda939bb

📥 Commits

Reviewing files that changed from the base of the PR and between 79ea125 and 6f7265d.

📒 Files selected for processing (7)
  • docs/source/models/visual-generation.md
  • examples/visual_gen/README.md
  • examples/visual_gen/configs/qwen-image-bf16-1gpu.yaml
  • examples/visual_gen/configs/qwen-image-fp4-1gpu.yaml
  • examples/visual_gen/models/qwen_image.py
  • tensorrt_llm/_torch/visual_gen/models/qwen_image/transformer_qwen_image.py
  • tests/unittest/_torch/visual_gen/test_qwen_image_registry.py

Comment on lines +889 to +909
def _clear_quant_config_on_excluded_layers(self) -> None:
"""Build the checkpoint's excluded layers as unquantized.

ModelOpt keeps some layers in high precision (the ``ignore`` list in
the checkpoint's ``quantization_config`` -- e.g. img_in / txt_in /
proj_out / norm_out / time_text_embed and the first/last transformer
blocks), storing plain BF16 weights with no scales for them.
``get_quant_method()`` selects the Linear method purely from
``module.quant_config``, so clear it on the excluded Linear modules to
fall back to the unquantized method. Must run before any
``create_weights()`` call, since parent modules (e.g. ``MLP``) may
materialize their child Linears eagerly.
"""
quant_config = self.model_config.quant_config
if quant_config is None or quant_config.quant_algo is None:
return
for name, module in self.named_modules():
if (isinstance(module, Linear)
and quant_config.is_module_excluded_from_quantization(name)):
module.quant_config = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Enforce the “before create_weights” precondition to avoid stale quantized layouts.

_clear_quant_config_on_excluded_layers() runs in load_weights(), but if any excluded Linear already created weights during __init__ (e.g., skip_create_weights_in_init=False), clearing quant_config on Line 908 won’t switch it back to unquantized because later create_weights() is a no-op. That can still mis-handle BF16 excluded layers.

Suggested fix
 def _clear_quant_config_on_excluded_layers(self) -> None:
@@
-        for name, module in self.named_modules():
-            if (isinstance(module, Linear)
-                    and quant_config.is_module_excluded_from_quantization(name)):
-                module.quant_config = None
+        already_materialized: list[str] = []
+        for name, module in self.named_modules():
+            if not isinstance(module, Linear):
+                continue
+            if not quant_config.is_module_excluded_from_quantization(name):
+                continue
+            if getattr(module, "_weights_created", False):
+                already_materialized.append(name)
+            module.quant_config = None
+
+        if already_materialized:
+            raise RuntimeError(
+                "Excluded layers already materialized quantized weights before quant_config "
+                f"clear: {already_materialized[:5]}. Construct with "
+                "skip_create_weights_in_init=True for static quantized checkpoints."
+            )

Also applies to: 919-924

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/visual_gen/models/qwen_image/transformer_qwen_image.py`
around lines 889 - 909, The method `_clear_quant_config_on_excluded_layers()`
documents that it must run before any `create_weights()` call, but currently
runs in `load_weights()` which is too late. If
`skip_create_weights_in_init=False` allows excluded Linear modules to create
weights during `__init__`, clearing the quant_config later in `load_weights()`
won't switch them back to unquantized since `create_weights()` becomes a no-op.
Move the call to `_clear_quant_config_on_excluded_layers()` to run earlier in
the initialization sequence, such as right after the parent class `__init__`
completes or before weights can be created, to enforce the documented
precondition and ensure excluded layers use the correct unquantized
configuration.

@jingyu-ml

Copy link
Copy Markdown
Author

Qwen-Image NVFP4 latency sweep — batch × resolution × GEMM backend (GB200 / sm_100)

Follow-up benchmarks for this PR. Per-image latency, median, 50 steps, true_cfg
off, torch.compile on, TRT-LLM release container (1.3.0rc18). NVFP4 = this PR's
pre-quantized ModelOpt checkpoint (13 GB transformer); BF16 = Qwen/Qwen-Image
(~40 GB). Image quality is equivalent (same seed).

TL;DR

  • NVFP4 carries a roughly fixed per-GEMM-call overhead (FP4 activation-quant
    AutoTuner dispatch + torch.compile graph breaks). Once GEMM-compute-per-call is
    large enough to amortize it — via higher resolution and/or larger batch
    NVFP4 beats BF16, up to ~1.28× faster (0.78× latency) at 1024²/batch 16.
  • Crossover: NVFP4 wins at 1536² for all batches (incl. batch 1), 1328²
    for batch ≥ 2
    , 1024² for batch ≥ 4. It only loses in the small-GEMM
    corner (low res + low batch).
  • Best NVFP4 GEMM backend: cublaslt for batch ≥ 2, cutlass at batch 1.
    cutedsl omitted (its CuteDSL JIT build is impractical for this model,

    8 min/shape). W4A4 NVFP4 has no trtllm-gen cubin backend
    fp4_fp8_gemm_trtllmgen is W4A8 (different checkpoint).

  • Plus the standing benefit: ~3× smaller transformer (13 GB vs ~40 GB).

1024×1024

config batch 1 batch 2 batch 4 batch 8 batch 16
BF16 4.93 4.09 3.87 3.73 3.73
NVFP4 (auto) 10.80 6.63 3.34 3.16 3.08
NVFP4 (cublaslt) 11.40 5.82 3.14 3.00 2.92
NVFP4 (cutlass) 9.12 4.57 3.29 3.14 3.07

1328×1328

config batch 1 batch 2 batch 4 batch 8 batch 16
BF16 7.46 7.44 7.17 7.11 7.23
NVFP4 (auto) 13.44 6.61 6.36 6.23 6.16
NVFP4 (cublaslt) 10.33 6.31 6.08 5.95 5.90
NVFP4 (cutlass) 9.92 6.62 6.38 6.24 6.17

1536×1536

config batch 1 batch 2 batch 4 batch 8 batch 16
BF16 10.91 12.02 11.63 11.56 11.74
NVFP4 (auto) 9.35 10.77 10.35 10.16 10.06
NVFP4 (cublaslt) 9.59 10.35 9.94 9.76 9.67
NVFP4 (cutlass) 11.09 10.77 10.35 10.16 10.06

NVFP4 (best backend) vs BF16 — speedup (×, <1 = NVFP4 faster)

res batch 1 batch 2 batch 4 batch 8 batch 16
1024² 1.85 1.12 0.81 0.80 0.78
1328² 1.33 0.85 0.85 0.84 0.82
1536² 0.86 0.86 0.85 0.84 0.82

Note: cutedsl backend omitted — its CuteDSL NVFP4 JIT build is impractical for this model (>8 min/shape). W4A4 NVFP4 has no trtllm-gen cubin backend; fp4_fp8_gemm_trtllmgen is W4A8.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant