Skip to content

Added support for MoE for vllm >= 0.14.0rc1#1162

Draft
kinjalpatel27 wants to merge 2 commits intomainfrom
kinjal/vllm_super_nano_support
Draft

Added support for MoE for vllm >= 0.14.0rc1#1162
kinjalpatel27 wants to merge 2 commits intomainfrom
kinjal/vllm_super_nano_support

Conversation

@kinjalpatel27
Copy link
Copy Markdown
Contributor

@kinjalpatel27 kinjalpatel27 commented Apr 1, 2026

What does this PR do?

Type of change: Bug fix

_QuantFusedMoEBase.forward() previously replaced vllm_fused_moe_package.invoke_fused_moe_kernel, which was renamed to invoke_fused_moe_triton_kernel starting in vLLM v0.14.0rc1. This caused an AttributeError / assertion failure for any MoE model quantized with vLLM ≥ v0.14.0rc1.

The fix refactors the kernel-patching logic into a _patch_moe_kernel() context manager that probes for both attribute names (the two names are mutually exclusive across vLLM versions — confirmed by inspecting every release from v0.10.0 to v0.18.1).

Usage

NA

Testing

docker run --gpus all -it --shm-size=160GB --network host --rm -v <modelopt path>:/home/modelopt \
vllm/vllm-openai:v0.15.0 bash -c "cd /home/modelopt && pip install . && pip install datasets && \
  QUANT_CFG=NVFP4_DEFAULT_CFG python3 /home/modelopt/examples/vllm_serve/vllm_serve_fakequant.py \
nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 -tp 1 --served-model-name NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ 
  --host 0.0.0.0 --port 8001 --trust-remote-code --enforce-eager --disable-custom-all-reduce \
--gpu-memory-utilization 0.8" 

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: N/A
  • Did you update Changelog?: N/A

Additional Information

Summary by CodeRabbit

  • Bug Fixes
    • Enhanced the stability and overall reliability of quantized model inference by improving internal weight management mechanisms during kernel execution. Better mechanisms for storing and restoring temporary weights provide more consistent and predictable behavior in quantized Mixture of Experts operations, resulting in more robust inference performance.

Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 1, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 1, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7315c1d1-79b8-4d6a-ad63-27a90a7a6aa9

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • ✅ Review completed - (🔄 Check again to review again)
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kinjal/vllm_super_nano_support

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 1, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1162/

Built to branch gh-pages at 2026-04-01 22:45 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
modelopt/torch/quantization/plugins/vllm.py (1)

387-388: Please add regression tests for both symbol paths and patch restore behavior.

Given the compatibility branch and runtime patching, add tests that cover: (1) invoke_fused_moe_kernel, (2) invoke_fused_moe_triton_kernel, and (3) restoration on exceptions during super().forward(...).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/vllm.py` around lines 387 - 388, Add
regression tests that exercise both symbol paths and verify the runtime patching
in _patch_moe_kernel is applied and always restored: write tests that (1)
trigger the branch where invoke_fused_moe_kernel is used, (2) trigger the branch
where invoke_fused_moe_triton_kernel is used, and (3) simulate an exception
raised during super().forward(...) to assert the original symbols are restored
after the exception. Locate and call the class/method that uses
_patch_moe_kernel and forward to run these cases, patch or monkeypatch the
target symbols to observable fakes, and assert pre-/post-conditions on the
original functions to confirm restore behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/plugins/vllm.py`:
- Around line 369-383: The current _patch_moe_kernel contextmanager mutates
vllm_fused_moe_package globals unsafely; make it concurrency-safe by serializing
patch/unpatch with a module-level threading.RLock and per-attribute reference
counting (or a reentrancy counter) so nested/concurrent entries don't clobber
originals: on entry, acquire the lock, for each attr
("invoke_fused_moe_kernel","invoke_fused_moe_triton_kernel") save the original
into a local map only if not already saved and replace the attr with
self.invoke_fused_moe_quantized while incrementing a refcount; yield; in
finally, decrement the refcount and only when it reaches zero restore the
original to vllm_fused_moe_package[attr] and remove the saved original, then
release the lock—use the symbols _patch_moe_kernel, vllm_fused_moe_package,
invoke_fused_moe_kernel, invoke_fused_moe_triton_kernel,
_invoke_fused_moe_kernel and invoke_fused_moe_quantized to locate and implement
this change.
- Around line 349-351: The kernel is currently called with the original weight
because B was bound before swapping, and restoration isn't in a finally block so
exceptions leave the quantized tensor in place; fix by assigning orig =
self.w13_weight, replacing self.w13_weight with
self.w13_weight_quantizer(self.w13_weight), then invoke
vllm_fused_moe_package._invoke_fused_moe_kernel using the swapped
self.w13_weight (not the previously bound B), and always restore self.w13_weight
= orig in a finally block; apply the identical pattern to the other symmetric
block (the one that swaps/restores the other weight at lines 359-361).

---

Nitpick comments:
In `@modelopt/torch/quantization/plugins/vllm.py`:
- Around line 387-388: Add regression tests that exercise both symbol paths and
verify the runtime patching in _patch_moe_kernel is applied and always restored:
write tests that (1) trigger the branch where invoke_fused_moe_kernel is used,
(2) trigger the branch where invoke_fused_moe_triton_kernel is used, and (3)
simulate an exception raised during super().forward(...) to assert the original
symbols are restored after the exception. Locate and call the class/method that
uses _patch_moe_kernel and forward to run these cases, patch or monkeypatch the
target symbols to observable fakes, and assert pre-/post-conditions on the
original functions to confirm restore behavior.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 476612da-4375-40a3-b6c6-f1dfd76df7b5

📥 Commits

Reviewing files that changed from the base of the PR and between de55e8a and 122b935.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/plugins/vllm.py

Comment on lines +349 to +351
orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight)
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
self.w13_weight = original_weight
self.w13_weight = orig
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n modelopt/torch/quantization/plugins/vllm.py | sed -n '340,395p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3292


🏁 Script executed:

cat -n modelopt/torch/quantization/plugins/vllm.py | sed -n '330,345p'

Repository: NVIDIA/Model-Optimizer

Length of output: 718


🏁 Script executed:

cat -n modelopt/torch/quantization/plugins/vllm.py | sed -n '1,50p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2318


Use the swapped quantized weight in the kernel call and always restore state on failure.

At lines 350 and 360, the kernel is invoked with B, which was bound to the original weight before the swap at lines 349 and 359. When quantization returns a new tensor, the kernel receives the unquantized weight, bypassing the quantization. Additionally, weight restoration at lines 351 and 361 is not guarded with finally, so an exception during kernel invocation leaves the module with a quantized weight permanently swapped, corrupting the module state.

Proposed fix
         if B is self.w13_weight:
             # First layer of expert
             A = self.w13_input_quantizer(A)  # noqa: N806
             if self.w13_weight_quantizer.is_enabled:
-                orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight)
-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
-                self.w13_weight = orig
+                orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight)
+                try:
+                    vllm_fused_moe_package._invoke_fused_moe_kernel(
+                        A, self.w13_weight, C, *args, **kwargs
+                    )
+                finally:
+                    self.w13_weight = orig
             else:
                 vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
@@
         elif B is self.w2_weight:
             A = self.w2_input_quantizer(A)  # noqa: N806
             if self.w2_weight_quantizer.is_enabled:
-                orig, self.w2_weight = self.w2_weight, self.w2_weight_quantizer(self.w2_weight)
-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
-                self.w2_weight = orig
+                orig, self.w2_weight = self.w2_weight, self.w2_weight_quantizer(self.w2_weight)
+                try:
+                    vllm_fused_moe_package._invoke_fused_moe_kernel(
+                        A, self.w2_weight, C, *args, **kwargs
+                    )
+                finally:
+                    self.w2_weight = orig
             else:
                 vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)

Also applies to: 359-361

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/vllm.py` around lines 349 - 351, The
kernel is currently called with the original weight because B was bound before
swapping, and restoration isn't in a finally block so exceptions leave the
quantized tensor in place; fix by assigning orig = self.w13_weight, replacing
self.w13_weight with self.w13_weight_quantizer(self.w13_weight), then invoke
vllm_fused_moe_package._invoke_fused_moe_kernel using the swapped
self.w13_weight (not the previously bound B), and always restore self.w13_weight
= orig in a finally block; apply the identical pattern to the other symmetric
block (the one that swaps/restores the other weight at lines 359-361).

Comment on lines +369 to +383
@contextmanager
def _patch_moe_kernel(self):
"""Temporarily replace vLLM fused_moe kernel with quantized version."""
# `invoke_fused_moe_kernel` was used through v0.14.0rc0; it was renamed
# to `invoke_fused_moe_triton_kernel` starting from v0.14.0rc1.
for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]:
if hasattr(vllm_fused_moe_package, attr):
orig = getattr(vllm_fused_moe_package, attr)
setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig)
setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized)
try:
yield
finally:
setattr(vllm_fused_moe_package, attr, orig)
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Module-level monkey patching here is not thread-safe and can break under concurrent forwards.

_patch_moe_kernel() mutates global symbols on vllm_fused_moe_package without synchronization. Concurrent entries can overwrite _invoke_fused_moe_kernel with an already patched function, leading to recursion/wrong restore behavior and nondeterministic failures.

Proposed fix
+import threading
@@
+_MOE_KERNEL_PATCH_LOCK = threading.RLock()
+
 `@contextmanager`
 def _patch_moe_kernel(self):
@@
-        for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]:
+        for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]:
             if hasattr(vllm_fused_moe_package, attr):
-                orig = getattr(vllm_fused_moe_package, attr)
-                setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig)
-                setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized)
-                try:
-                    yield
-                finally:
-                    setattr(vllm_fused_moe_package, attr, orig)
+                with _MOE_KERNEL_PATCH_LOCK:
+                    orig = getattr(vllm_fused_moe_package, attr)
+                    had_private = hasattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel")
+                    prev_private = getattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", None)
+                    setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig)
+                    setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized)
+                    try:
+                        yield
+                    finally:
+                        setattr(vllm_fused_moe_package, attr, orig)
+                        if had_private:
+                            setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", prev_private)
+                        elif hasattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel"):
+                            delattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel")
                 return
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/vllm.py` around lines 369 - 383, The
current _patch_moe_kernel contextmanager mutates vllm_fused_moe_package globals
unsafely; make it concurrency-safe by serializing patch/unpatch with a
module-level threading.RLock and per-attribute reference counting (or a
reentrancy counter) so nested/concurrent entries don't clobber originals: on
entry, acquire the lock, for each attr
("invoke_fused_moe_kernel","invoke_fused_moe_triton_kernel") save the original
into a local map only if not already saved and replace the attr with
self.invoke_fused_moe_quantized while incrementing a refcount; yield; in
finally, decrement the refcount and only when it reaches zero restore the
original to vllm_fused_moe_package[attr] and remove the saved original, then
release the lock—use the symbols _patch_moe_kernel, vllm_fused_moe_package,
invoke_fused_moe_kernel, invoke_fused_moe_triton_kernel,
_invoke_fused_moe_kernel and invoke_fused_moe_quantized to locate and implement
this change.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 1, 2026

Codecov Report

❌ Patch coverage is 0% with 17 lines in your changes missing coverage. Please review.
✅ Project coverage is 54.54%. Comparing base (2ae407c) to head (122b935).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/plugins/vllm.py 0.00% 17 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1162      +/-   ##
==========================================
- Coverage   54.55%   54.54%   -0.02%     
==========================================
  Files         348      348              
  Lines       39755    39770      +15     
==========================================
+ Hits        21689    21691       +2     
- Misses      18066    18079      +13     
Flag Coverage Δ
unit 54.54% <0.00%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant