diff --git a/modelopt/torch/quantization/plugins/vllm.py b/modelopt/torch/quantization/plugins/vllm.py index 107e225cd0..7ec3ebea2a 100644 --- a/modelopt/torch/quantization/plugins/vllm.py +++ b/modelopt/torch/quantization/plugins/vllm.py @@ -346,10 +346,9 @@ def invoke_fused_moe_quantized( # First layer of expert A = self.w13_input_quantizer(A) # noqa: N806 if self.w13_weight_quantizer.is_enabled: - original_weight = self.w13_weight - self.w13_weight = self.w13_weight_quantizer(self.w13_weight) + orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight) vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) - self.w13_weight = original_weight + self.w13_weight = orig else: vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) if self.w13_output_quantizer.is_enabled: @@ -357,10 +356,9 @@ def invoke_fused_moe_quantized( elif B is self.w2_weight: A = self.w2_input_quantizer(A) # noqa: N806 if self.w2_weight_quantizer.is_enabled: - original_weight = self.w2_weight - self.w2_weight = self.w2_weight_quantizer(self.w2_weight) + orig, self.w2_weight = self.w2_weight, self.w2_weight_quantizer(self.w2_weight) vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) - self.w2_weight = original_weight + self.w2_weight = orig else: vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) if self.w2_output_quantizer.is_enabled: @@ -368,28 +366,26 @@ def invoke_fused_moe_quantized( else: raise ValueError("Cannot determine first or second layer of expert") + @contextmanager + def _patch_moe_kernel(self): + """Temporarily replace vLLM fused_moe kernel with quantized version.""" + # `invoke_fused_moe_kernel` was used through v0.14.0rc0; it was renamed + # to `invoke_fused_moe_triton_kernel` starting from v0.14.0rc1. + for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]: + if hasattr(vllm_fused_moe_package, attr): + orig = getattr(vllm_fused_moe_package, attr) + setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig) + setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized) + try: + yield + finally: + setattr(vllm_fused_moe_package, attr, orig) + return + raise ValueError("fused_moe_kernel is not found") + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - # This is again due to the bad coding of vLLM - # fused_moe submodule is overwritten by the fused_moe function - # so we need to import the fused_moe module explicitly - assert vllm_fused_moe_package.invoke_fused_moe_kernel is not None - # This context manager will conflict with torch.compile - # with replace_function( - # vllm_fused_moe_package, - # "invoke_fused_moe_kernel", - # self.invoke_fused_moe_quantized, - # ): - try: - vllm_fused_moe_package._invoke_fused_moe_kernel = ( # type: ignore[attr-defined] - vllm_fused_moe_package.invoke_fused_moe_kernel - ) - vllm_fused_moe_package.invoke_fused_moe_kernel = self.invoke_fused_moe_quantized # type: ignore[attr-defined] - output = super().forward(hidden_states, router_logits) - return output - finally: - vllm_fused_moe_package.invoke_fused_moe_kernel = ( # type: ignore[attr-defined] - vllm_fused_moe_package._invoke_fused_moe_kernel - ) + with self._patch_moe_kernel(): + return super().forward(hidden_states, router_logits) @torch.no_grad() def fold_weight(self, keep_attrs: bool = False):