-
Notifications
You must be signed in to change notification settings - Fork 331
Added support for MoE for vllm >= 0.14.0rc1 #1162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -346,50 +346,46 @@ def invoke_fused_moe_quantized( | |
| # First layer of expert | ||
| A = self.w13_input_quantizer(A) # noqa: N806 | ||
| if self.w13_weight_quantizer.is_enabled: | ||
| original_weight = self.w13_weight | ||
| self.w13_weight = self.w13_weight_quantizer(self.w13_weight) | ||
| orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight) | ||
| vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) | ||
| self.w13_weight = original_weight | ||
| self.w13_weight = orig | ||
| else: | ||
| vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) | ||
| if self.w13_output_quantizer.is_enabled: | ||
| C[:] = self.w13_output_quantizer(C) | ||
| elif B is self.w2_weight: | ||
| A = self.w2_input_quantizer(A) # noqa: N806 | ||
| if self.w2_weight_quantizer.is_enabled: | ||
| original_weight = self.w2_weight | ||
| self.w2_weight = self.w2_weight_quantizer(self.w2_weight) | ||
| orig, self.w2_weight = self.w2_weight, self.w2_weight_quantizer(self.w2_weight) | ||
| vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) | ||
| self.w2_weight = original_weight | ||
| self.w2_weight = orig | ||
| else: | ||
| vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) | ||
| if self.w2_output_quantizer.is_enabled: | ||
| C[:] = self.w2_output_quantizer(C) | ||
| else: | ||
| raise ValueError("Cannot determine first or second layer of expert") | ||
|
|
||
| @contextmanager | ||
| def _patch_moe_kernel(self): | ||
| """Temporarily replace vLLM fused_moe kernel with quantized version.""" | ||
| # `invoke_fused_moe_kernel` was used through v0.14.0rc0; it was renamed | ||
| # to `invoke_fused_moe_triton_kernel` starting from v0.14.0rc1. | ||
| for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]: | ||
| if hasattr(vllm_fused_moe_package, attr): | ||
| orig = getattr(vllm_fused_moe_package, attr) | ||
| setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig) | ||
| setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized) | ||
| try: | ||
| yield | ||
| finally: | ||
| setattr(vllm_fused_moe_package, attr, orig) | ||
| return | ||
|
Comment on lines
+369
to
+383
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Module-level monkey patching here is not thread-safe and can break under concurrent forwards.
Proposed fix+import threading
@@
+_MOE_KERNEL_PATCH_LOCK = threading.RLock()
+
`@contextmanager`
def _patch_moe_kernel(self):
@@
- for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]:
+ for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]:
if hasattr(vllm_fused_moe_package, attr):
- orig = getattr(vllm_fused_moe_package, attr)
- setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig)
- setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized)
- try:
- yield
- finally:
- setattr(vllm_fused_moe_package, attr, orig)
+ with _MOE_KERNEL_PATCH_LOCK:
+ orig = getattr(vllm_fused_moe_package, attr)
+ had_private = hasattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel")
+ prev_private = getattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", None)
+ setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig)
+ setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized)
+ try:
+ yield
+ finally:
+ setattr(vllm_fused_moe_package, attr, orig)
+ if had_private:
+ setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", prev_private)
+ elif hasattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel"):
+ delattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel")
return🤖 Prompt for AI Agents |
||
| raise ValueError("fused_moe_kernel is not found") | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): | ||
| # This is again due to the bad coding of vLLM | ||
| # fused_moe submodule is overwritten by the fused_moe function | ||
| # so we need to import the fused_moe module explicitly | ||
| assert vllm_fused_moe_package.invoke_fused_moe_kernel is not None | ||
| # This context manager will conflict with torch.compile | ||
| # with replace_function( | ||
| # vllm_fused_moe_package, | ||
| # "invoke_fused_moe_kernel", | ||
| # self.invoke_fused_moe_quantized, | ||
| # ): | ||
| try: | ||
| vllm_fused_moe_package._invoke_fused_moe_kernel = ( # type: ignore[attr-defined] | ||
| vllm_fused_moe_package.invoke_fused_moe_kernel | ||
| ) | ||
| vllm_fused_moe_package.invoke_fused_moe_kernel = self.invoke_fused_moe_quantized # type: ignore[attr-defined] | ||
| output = super().forward(hidden_states, router_logits) | ||
| return output | ||
| finally: | ||
| vllm_fused_moe_package.invoke_fused_moe_kernel = ( # type: ignore[attr-defined] | ||
| vllm_fused_moe_package._invoke_fused_moe_kernel | ||
| ) | ||
| with self._patch_moe_kernel(): | ||
| return super().forward(hidden_states, router_logits) | ||
|
|
||
| @torch.no_grad() | ||
| def fold_weight(self, keep_attrs: bool = False): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: NVIDIA/Model-Optimizer
Length of output: 3292
🏁 Script executed:
Repository: NVIDIA/Model-Optimizer
Length of output: 718
🏁 Script executed:
Repository: NVIDIA/Model-Optimizer
Length of output: 2318
Use the swapped quantized weight in the kernel call and always restore state on failure.
At lines 350 and 360, the kernel is invoked with
B, which was bound to the original weight before the swap at lines 349 and 359. When quantization returns a new tensor, the kernel receives the unquantized weight, bypassing the quantization. Additionally, weight restoration at lines 351 and 361 is not guarded withfinally, so an exception during kernel invocation leaves the module with a quantized weight permanently swapped, corrupting the module state.Proposed fix
if B is self.w13_weight: # First layer of expert A = self.w13_input_quantizer(A) # noqa: N806 if self.w13_weight_quantizer.is_enabled: - orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight) - vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) - self.w13_weight = orig + orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight) + try: + vllm_fused_moe_package._invoke_fused_moe_kernel( + A, self.w13_weight, C, *args, **kwargs + ) + finally: + self.w13_weight = orig else: vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) @@ elif B is self.w2_weight: A = self.w2_input_quantizer(A) # noqa: N806 if self.w2_weight_quantizer.is_enabled: - orig, self.w2_weight = self.w2_weight, self.w2_weight_quantizer(self.w2_weight) - vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) - self.w2_weight = orig + orig, self.w2_weight = self.w2_weight, self.w2_weight_quantizer(self.w2_weight) + try: + vllm_fused_moe_package._invoke_fused_moe_kernel( + A, self.w2_weight, C, *args, **kwargs + ) + finally: + self.w2_weight = orig else: vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)Also applies to: 359-361
🤖 Prompt for AI Agents