Added support for MoE for vllm >= 0.14.0rc1#1162
Conversation
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
modelopt/torch/quantization/plugins/vllm.py (1)
387-388: Please add regression tests for both symbol paths and patch restore behavior.Given the compatibility branch and runtime patching, add tests that cover: (1)
invoke_fused_moe_kernel, (2)invoke_fused_moe_triton_kernel, and (3) restoration on exceptions duringsuper().forward(...).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/plugins/vllm.py` around lines 387 - 388, Add regression tests that exercise both symbol paths and verify the runtime patching in _patch_moe_kernel is applied and always restored: write tests that (1) trigger the branch where invoke_fused_moe_kernel is used, (2) trigger the branch where invoke_fused_moe_triton_kernel is used, and (3) simulate an exception raised during super().forward(...) to assert the original symbols are restored after the exception. Locate and call the class/method that uses _patch_moe_kernel and forward to run these cases, patch or monkeypatch the target symbols to observable fakes, and assert pre-/post-conditions on the original functions to confirm restore behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/quantization/plugins/vllm.py`:
- Around line 369-383: The current _patch_moe_kernel contextmanager mutates
vllm_fused_moe_package globals unsafely; make it concurrency-safe by serializing
patch/unpatch with a module-level threading.RLock and per-attribute reference
counting (or a reentrancy counter) so nested/concurrent entries don't clobber
originals: on entry, acquire the lock, for each attr
("invoke_fused_moe_kernel","invoke_fused_moe_triton_kernel") save the original
into a local map only if not already saved and replace the attr with
self.invoke_fused_moe_quantized while incrementing a refcount; yield; in
finally, decrement the refcount and only when it reaches zero restore the
original to vllm_fused_moe_package[attr] and remove the saved original, then
release the lock—use the symbols _patch_moe_kernel, vllm_fused_moe_package,
invoke_fused_moe_kernel, invoke_fused_moe_triton_kernel,
_invoke_fused_moe_kernel and invoke_fused_moe_quantized to locate and implement
this change.
- Around line 349-351: The kernel is currently called with the original weight
because B was bound before swapping, and restoration isn't in a finally block so
exceptions leave the quantized tensor in place; fix by assigning orig =
self.w13_weight, replacing self.w13_weight with
self.w13_weight_quantizer(self.w13_weight), then invoke
vllm_fused_moe_package._invoke_fused_moe_kernel using the swapped
self.w13_weight (not the previously bound B), and always restore self.w13_weight
= orig in a finally block; apply the identical pattern to the other symmetric
block (the one that swaps/restores the other weight at lines 359-361).
---
Nitpick comments:
In `@modelopt/torch/quantization/plugins/vllm.py`:
- Around line 387-388: Add regression tests that exercise both symbol paths and
verify the runtime patching in _patch_moe_kernel is applied and always restored:
write tests that (1) trigger the branch where invoke_fused_moe_kernel is used,
(2) trigger the branch where invoke_fused_moe_triton_kernel is used, and (3)
simulate an exception raised during super().forward(...) to assert the original
symbols are restored after the exception. Locate and call the class/method that
uses _patch_moe_kernel and forward to run these cases, patch or monkeypatch the
target symbols to observable fakes, and assert pre-/post-conditions on the
original functions to confirm restore behavior.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 476612da-4375-40a3-b6c6-f1dfd76df7b5
📒 Files selected for processing (1)
modelopt/torch/quantization/plugins/vllm.py
| orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight) | ||
| vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) | ||
| self.w13_weight = original_weight | ||
| self.w13_weight = orig |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n modelopt/torch/quantization/plugins/vllm.py | sed -n '340,395p'Repository: NVIDIA/Model-Optimizer
Length of output: 3292
🏁 Script executed:
cat -n modelopt/torch/quantization/plugins/vllm.py | sed -n '330,345p'Repository: NVIDIA/Model-Optimizer
Length of output: 718
🏁 Script executed:
cat -n modelopt/torch/quantization/plugins/vllm.py | sed -n '1,50p'Repository: NVIDIA/Model-Optimizer
Length of output: 2318
Use the swapped quantized weight in the kernel call and always restore state on failure.
At lines 350 and 360, the kernel is invoked with B, which was bound to the original weight before the swap at lines 349 and 359. When quantization returns a new tensor, the kernel receives the unquantized weight, bypassing the quantization. Additionally, weight restoration at lines 351 and 361 is not guarded with finally, so an exception during kernel invocation leaves the module with a quantized weight permanently swapped, corrupting the module state.
Proposed fix
if B is self.w13_weight:
# First layer of expert
A = self.w13_input_quantizer(A) # noqa: N806
if self.w13_weight_quantizer.is_enabled:
- orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight)
- vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
- self.w13_weight = orig
+ orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight)
+ try:
+ vllm_fused_moe_package._invoke_fused_moe_kernel(
+ A, self.w13_weight, C, *args, **kwargs
+ )
+ finally:
+ self.w13_weight = orig
else:
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
@@
elif B is self.w2_weight:
A = self.w2_input_quantizer(A) # noqa: N806
if self.w2_weight_quantizer.is_enabled:
- orig, self.w2_weight = self.w2_weight, self.w2_weight_quantizer(self.w2_weight)
- vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
- self.w2_weight = orig
+ orig, self.w2_weight = self.w2_weight, self.w2_weight_quantizer(self.w2_weight)
+ try:
+ vllm_fused_moe_package._invoke_fused_moe_kernel(
+ A, self.w2_weight, C, *args, **kwargs
+ )
+ finally:
+ self.w2_weight = orig
else:
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)Also applies to: 359-361
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/plugins/vllm.py` around lines 349 - 351, The
kernel is currently called with the original weight because B was bound before
swapping, and restoration isn't in a finally block so exceptions leave the
quantized tensor in place; fix by assigning orig = self.w13_weight, replacing
self.w13_weight with self.w13_weight_quantizer(self.w13_weight), then invoke
vllm_fused_moe_package._invoke_fused_moe_kernel using the swapped
self.w13_weight (not the previously bound B), and always restore self.w13_weight
= orig in a finally block; apply the identical pattern to the other symmetric
block (the one that swaps/restores the other weight at lines 359-361).
| @contextmanager | ||
| def _patch_moe_kernel(self): | ||
| """Temporarily replace vLLM fused_moe kernel with quantized version.""" | ||
| # `invoke_fused_moe_kernel` was used through v0.14.0rc0; it was renamed | ||
| # to `invoke_fused_moe_triton_kernel` starting from v0.14.0rc1. | ||
| for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]: | ||
| if hasattr(vllm_fused_moe_package, attr): | ||
| orig = getattr(vllm_fused_moe_package, attr) | ||
| setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig) | ||
| setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized) | ||
| try: | ||
| yield | ||
| finally: | ||
| setattr(vllm_fused_moe_package, attr, orig) | ||
| return |
There was a problem hiding this comment.
Module-level monkey patching here is not thread-safe and can break under concurrent forwards.
_patch_moe_kernel() mutates global symbols on vllm_fused_moe_package without synchronization. Concurrent entries can overwrite _invoke_fused_moe_kernel with an already patched function, leading to recursion/wrong restore behavior and nondeterministic failures.
Proposed fix
+import threading
@@
+_MOE_KERNEL_PATCH_LOCK = threading.RLock()
+
`@contextmanager`
def _patch_moe_kernel(self):
@@
- for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]:
+ for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]:
if hasattr(vllm_fused_moe_package, attr):
- orig = getattr(vllm_fused_moe_package, attr)
- setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig)
- setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized)
- try:
- yield
- finally:
- setattr(vllm_fused_moe_package, attr, orig)
+ with _MOE_KERNEL_PATCH_LOCK:
+ orig = getattr(vllm_fused_moe_package, attr)
+ had_private = hasattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel")
+ prev_private = getattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", None)
+ setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig)
+ setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized)
+ try:
+ yield
+ finally:
+ setattr(vllm_fused_moe_package, attr, orig)
+ if had_private:
+ setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", prev_private)
+ elif hasattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel"):
+ delattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel")
return🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/plugins/vllm.py` around lines 369 - 383, The
current _patch_moe_kernel contextmanager mutates vllm_fused_moe_package globals
unsafely; make it concurrency-safe by serializing patch/unpatch with a
module-level threading.RLock and per-attribute reference counting (or a
reentrancy counter) so nested/concurrent entries don't clobber originals: on
entry, acquire the lock, for each attr
("invoke_fused_moe_kernel","invoke_fused_moe_triton_kernel") save the original
into a local map only if not already saved and replace the attr with
self.invoke_fused_moe_quantized while incrementing a refcount; yield; in
finally, decrement the refcount and only when it reaches zero restore the
original to vllm_fused_moe_package[attr] and remove the saved original, then
release the lock—use the symbols _patch_moe_kernel, vllm_fused_moe_package,
invoke_fused_moe_kernel, invoke_fused_moe_triton_kernel,
_invoke_fused_moe_kernel and invoke_fused_moe_quantized to locate and implement
this change.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1162 +/- ##
==========================================
- Coverage 54.55% 54.54% -0.02%
==========================================
Files 348 348
Lines 39755 39770 +15
==========================================
+ Hits 21689 21691 +2
- Misses 18066 18079 +13
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
What does this PR do?
Type of change: Bug fix
_QuantFusedMoEBase.forward()previously replacedvllm_fused_moe_package.invoke_fused_moe_kernel, which was renamed toinvoke_fused_moe_triton_kernelstarting in vLLM v0.14.0rc1. This caused anAttributeError/ assertion failure for any MoE model quantized with vLLM ≥ v0.14.0rc1.The fix refactors the kernel-patching logic into a
_patch_moe_kernel()context manager that probes for both attribute names (the two names are mutually exclusive across vLLM versions — confirmed by inspecting every release from v0.10.0 to v0.18.1).Usage
NA
Testing
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
Summary by CodeRabbit