Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 23 additions & 27 deletions modelopt/torch/quantization/plugins/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,50 +346,46 @@ def invoke_fused_moe_quantized(
# First layer of expert
A = self.w13_input_quantizer(A) # noqa: N806
if self.w13_weight_quantizer.is_enabled:
original_weight = self.w13_weight
self.w13_weight = self.w13_weight_quantizer(self.w13_weight)
orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight)
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
self.w13_weight = original_weight
self.w13_weight = orig
Comment on lines +349 to +351
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n modelopt/torch/quantization/plugins/vllm.py | sed -n '340,395p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3292


🏁 Script executed:

cat -n modelopt/torch/quantization/plugins/vllm.py | sed -n '330,345p'

Repository: NVIDIA/Model-Optimizer

Length of output: 718


🏁 Script executed:

cat -n modelopt/torch/quantization/plugins/vllm.py | sed -n '1,50p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2318


Use the swapped quantized weight in the kernel call and always restore state on failure.

At lines 350 and 360, the kernel is invoked with B, which was bound to the original weight before the swap at lines 349 and 359. When quantization returns a new tensor, the kernel receives the unquantized weight, bypassing the quantization. Additionally, weight restoration at lines 351 and 361 is not guarded with finally, so an exception during kernel invocation leaves the module with a quantized weight permanently swapped, corrupting the module state.

Proposed fix
         if B is self.w13_weight:
             # First layer of expert
             A = self.w13_input_quantizer(A)  # noqa: N806
             if self.w13_weight_quantizer.is_enabled:
-                orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight)
-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
-                self.w13_weight = orig
+                orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight)
+                try:
+                    vllm_fused_moe_package._invoke_fused_moe_kernel(
+                        A, self.w13_weight, C, *args, **kwargs
+                    )
+                finally:
+                    self.w13_weight = orig
             else:
                 vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
@@
         elif B is self.w2_weight:
             A = self.w2_input_quantizer(A)  # noqa: N806
             if self.w2_weight_quantizer.is_enabled:
-                orig, self.w2_weight = self.w2_weight, self.w2_weight_quantizer(self.w2_weight)
-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
-                self.w2_weight = orig
+                orig, self.w2_weight = self.w2_weight, self.w2_weight_quantizer(self.w2_weight)
+                try:
+                    vllm_fused_moe_package._invoke_fused_moe_kernel(
+                        A, self.w2_weight, C, *args, **kwargs
+                    )
+                finally:
+                    self.w2_weight = orig
             else:
                 vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)

Also applies to: 359-361

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/vllm.py` around lines 349 - 351, The
kernel is currently called with the original weight because B was bound before
swapping, and restoration isn't in a finally block so exceptions leave the
quantized tensor in place; fix by assigning orig = self.w13_weight, replacing
self.w13_weight with self.w13_weight_quantizer(self.w13_weight), then invoke
vllm_fused_moe_package._invoke_fused_moe_kernel using the swapped
self.w13_weight (not the previously bound B), and always restore self.w13_weight
= orig in a finally block; apply the identical pattern to the other symmetric
block (the one that swaps/restores the other weight at lines 359-361).

else:
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
if self.w13_output_quantizer.is_enabled:
C[:] = self.w13_output_quantizer(C)
elif B is self.w2_weight:
A = self.w2_input_quantizer(A) # noqa: N806
if self.w2_weight_quantizer.is_enabled:
original_weight = self.w2_weight
self.w2_weight = self.w2_weight_quantizer(self.w2_weight)
orig, self.w2_weight = self.w2_weight, self.w2_weight_quantizer(self.w2_weight)
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
self.w2_weight = original_weight
self.w2_weight = orig
else:
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
if self.w2_output_quantizer.is_enabled:
C[:] = self.w2_output_quantizer(C)
else:
raise ValueError("Cannot determine first or second layer of expert")

@contextmanager
def _patch_moe_kernel(self):
"""Temporarily replace vLLM fused_moe kernel with quantized version."""
# `invoke_fused_moe_kernel` was used through v0.14.0rc0; it was renamed
# to `invoke_fused_moe_triton_kernel` starting from v0.14.0rc1.
for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]:
if hasattr(vllm_fused_moe_package, attr):
orig = getattr(vllm_fused_moe_package, attr)
setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig)
setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized)
try:
yield
finally:
setattr(vllm_fused_moe_package, attr, orig)
return
Comment on lines +369 to +383
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Module-level monkey patching here is not thread-safe and can break under concurrent forwards.

_patch_moe_kernel() mutates global symbols on vllm_fused_moe_package without synchronization. Concurrent entries can overwrite _invoke_fused_moe_kernel with an already patched function, leading to recursion/wrong restore behavior and nondeterministic failures.

Proposed fix
+import threading
@@
+_MOE_KERNEL_PATCH_LOCK = threading.RLock()
+
 `@contextmanager`
 def _patch_moe_kernel(self):
@@
-        for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]:
+        for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]:
             if hasattr(vllm_fused_moe_package, attr):
-                orig = getattr(vllm_fused_moe_package, attr)
-                setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig)
-                setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized)
-                try:
-                    yield
-                finally:
-                    setattr(vllm_fused_moe_package, attr, orig)
+                with _MOE_KERNEL_PATCH_LOCK:
+                    orig = getattr(vllm_fused_moe_package, attr)
+                    had_private = hasattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel")
+                    prev_private = getattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", None)
+                    setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig)
+                    setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized)
+                    try:
+                        yield
+                    finally:
+                        setattr(vllm_fused_moe_package, attr, orig)
+                        if had_private:
+                            setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", prev_private)
+                        elif hasattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel"):
+                            delattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel")
                 return
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/vllm.py` around lines 369 - 383, The
current _patch_moe_kernel contextmanager mutates vllm_fused_moe_package globals
unsafely; make it concurrency-safe by serializing patch/unpatch with a
module-level threading.RLock and per-attribute reference counting (or a
reentrancy counter) so nested/concurrent entries don't clobber originals: on
entry, acquire the lock, for each attr
("invoke_fused_moe_kernel","invoke_fused_moe_triton_kernel") save the original
into a local map only if not already saved and replace the attr with
self.invoke_fused_moe_quantized while incrementing a refcount; yield; in
finally, decrement the refcount and only when it reaches zero restore the
original to vllm_fused_moe_package[attr] and remove the saved original, then
release the lock—use the symbols _patch_moe_kernel, vllm_fused_moe_package,
invoke_fused_moe_kernel, invoke_fused_moe_triton_kernel,
_invoke_fused_moe_kernel and invoke_fused_moe_quantized to locate and implement
this change.

raise ValueError("fused_moe_kernel is not found")

def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
# This is again due to the bad coding of vLLM
# fused_moe submodule is overwritten by the fused_moe function
# so we need to import the fused_moe module explicitly
assert vllm_fused_moe_package.invoke_fused_moe_kernel is not None
# This context manager will conflict with torch.compile
# with replace_function(
# vllm_fused_moe_package,
# "invoke_fused_moe_kernel",
# self.invoke_fused_moe_quantized,
# ):
try:
vllm_fused_moe_package._invoke_fused_moe_kernel = ( # type: ignore[attr-defined]
vllm_fused_moe_package.invoke_fused_moe_kernel
)
vllm_fused_moe_package.invoke_fused_moe_kernel = self.invoke_fused_moe_quantized # type: ignore[attr-defined]
output = super().forward(hidden_states, router_logits)
return output
finally:
vllm_fused_moe_package.invoke_fused_moe_kernel = ( # type: ignore[attr-defined]
vllm_fused_moe_package._invoke_fused_moe_kernel
)
with self._patch_moe_kernel():
return super().forward(hidden_states, router_logits)

@torch.no_grad()
def fold_weight(self, keep_attrs: bool = False):
Expand Down
Loading