From d483fd0a6523de2fb45b139afd1799afad5a4089 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 18 Dec 2025 23:20:56 +0000 Subject: [PATCH 1/3] Added MLAAttention Signed-off-by: Kinjal Patel --- examples/vllm_serve/fakequant_worker.py | 9 ++++--- examples/vllm_serve/vllm_serve_fakequant.py | 8 +++++- modelopt/torch/quantization/plugins/vllm.py | 28 +++++++++++++++++++-- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 133f31a6c..7a9150afa 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -149,7 +149,7 @@ def disable_compilation(model): quant_config: dict[str, Any] = { "dataset": os.environ.get("QUANT_DATASET", "cnn_dailymail"), "calib_size": int(os.environ.get("QUANT_CALIB_SIZE", 512)), - "quant_cfg": os.environ.get("QUANT_CFG", "NVFP4_DEFAULT_CFG"), + "quant_cfg": os.environ.get("QUANT_CFG", None), "kv_quant_cfg": os.environ.get("KV_QUANT_CFG", None), "amax_file_path": os.environ.get("AMAX_FILE_PATH", None), } @@ -237,9 +237,10 @@ def calibrate_loop(model: Any = None) -> None: self.sample_tokens(None) quant_cfg = getattr(mtq, quant_config["quant_cfg"]) - if quant_config["kv_quant_cfg"] is not None: + quant_kv_cfg = getattr(mtq, quant_config["kv_quant_cfg"]) + if quant_kv_cfg: quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( - quant_cfg, getattr(mtq, quant_config["kv_quant_cfg"])["quant_cfg"] + quant_cfg, quant_kv_cfg["quant_cfg"] ) model = self.model_runner.model @@ -314,6 +315,6 @@ def determine_available_memory(self) -> int: return super().determine_available_memory() def compile_or_warm_up_model(self) -> None: - if quant_config["quant_cfg"]: + if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]: _fakequant_run_prolog_worker(self) super().compile_or_warm_up_model() diff --git a/examples/vllm_serve/vllm_serve_fakequant.py b/examples/vllm_serve/vllm_serve_fakequant.py index b4b230ade..25483f2be 100644 --- a/examples/vllm_serve/vllm_serve_fakequant.py +++ b/examples/vllm_serve/vllm_serve_fakequant.py @@ -70,7 +70,13 @@ # Adding the envs you want to pass to the workers -additional_env_vars = {"QUANT_DATASET", "QUANT_CALIB_SIZE", "QUANT_CFG", "AMAX_FILE_PATH"} +additional_env_vars = { + "QUANT_DATASET", + "QUANT_CALIB_SIZE", + "QUANT_CFG", + "AMAX_FILE_PATH", + "KV_QUANT_CFG", +} RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars) diff --git a/modelopt/torch/quantization/plugins/vllm.py b/modelopt/torch/quantization/plugins/vllm.py index 99bc3d9ee..bd4b998f9 100644 --- a/modelopt/torch/quantization/plugins/vllm.py +++ b/modelopt/torch/quantization/plugins/vllm.py @@ -40,6 +40,11 @@ except ImportError: continue +try: + from vllm.attention.layer import MLAAttention as VllmMLAAttention +except ImportError: + VllmMLAAttention = None + vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe") @@ -262,13 +267,17 @@ class _QuantVLLMAttention(QuantModule): def _setup(self): self.q_bmm_quantizer = TensorQuantizer() self.k_bmm_quantizer = TensorQuantizer() - self.v_bmm_quantizer = TensorQuantizer() + # required for vllm < 0.11.1 + if not self.use_mla: + self.v_bmm_quantizer = TensorQuantizer() self.parallel_state = create_parallel_state() def forward(self, query, key, value, *args, **kwargs): query = self.q_bmm_quantizer(query) key = self.k_bmm_quantizer(key) - value = self.v_bmm_quantizer(value) + # required for vllm < 0.11.1 + if not self.use_mla: + value = self.v_bmm_quantizer(value) return super().forward(query, key, value, *args, **kwargs) @@ -281,3 +290,18 @@ class _QuantVLLMCrossAttention(_QuantVLLMAttention): @QuantModuleRegistry.register({EncoderOnlyAttention: "vllm_EncoderOnlyAttention"}) class _QuantVLLMEncoderOnlyAttention(_QuantVLLMAttention): pass + + +if VllmMLAAttention is not None: + + @QuantModuleRegistry.register({VllmMLAAttention: "vllm_MLAAttention"}) + class _QuantVLLMMLAAttention(QuantModule): + def _setup(self): + self.q_bmm_quantizer = TensorQuantizer() + self.kv_c_bmm_quantizer = TensorQuantizer() + self.parallel_state = create_parallel_state() + + def forward(self, query, kv_c, *args, **kwargs): + query = self.q_bmm_quantizer(query) + kv_c = self.kv_c_bmm_quantizer(kv_c) + return super().forward(query, kv_c, *args, **kwargs) From 4d2f50a4ea8893d36d1639c2ea17d07ffebbb35e Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Sat, 20 Dec 2025 00:13:44 +0000 Subject: [PATCH 2/3] Updated kv config for mla in vllm script Signed-off-by: Kinjal Patel --- examples/vllm_serve/fakequant_worker.py | 35 ++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 7a9150afa..9be41f72e 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -155,6 +155,28 @@ def disable_compilation(model): } +def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: dict[str, Any]) -> dict[str, Any]: + """Update KV cache quantization config for MLA models. + + MLA uses `kv_c_bmm_quantizer` (compressed KV) instead of separate + `k_bmm_quantizer` and `v_bmm_quantizer`. This function copies the + config from `*[kv]_bmm_quantizer` to also cover `*kv_c_bmm_quantizer`. + """ + try: + from vllm.attention.layer import MLAAttention + except ImportError: + return kv_quant_cfg + + if not any(isinstance(m, MLAAttention) for m in model.modules()): + return kv_quant_cfg + + if kv_config := kv_quant_cfg.get("*[kv]_bmm_quantizer"): + kv_quant_cfg["*kv_c_bmm_quantizer"] = kv_config + print("MLA detected: added *kv_c_bmm_quantizer config") + + return kv_quant_cfg + + def _create_new_data_cls(data_cls, **kwargs): """vLLM's low-level API changes frequently. This function creates a class with parameters compatible with the different vLLM versions.""" @@ -238,15 +260,20 @@ def calibrate_loop(model: Any = None) -> None: quant_cfg = getattr(mtq, quant_config["quant_cfg"]) quant_kv_cfg = getattr(mtq, quant_config["kv_quant_cfg"]) - if quant_kv_cfg: - quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( - quant_cfg, quant_kv_cfg["quant_cfg"] - ) model = self.model_runner.model if hasattr(model, "unwrap"): model = model.unwrap() + # Check if model has MLA and update KV config accordingly + if quant_kv_cfg: + quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"]) + + if quant_kv_cfg: + quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( + quant_cfg, quant_kv_cfg["quant_cfg"] + ) + with disable_compilation(model): print("quantizing model...") mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) From fdbfefe82776cb04125a685c078fdd3623ad7852 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 1 Jan 2026 01:47:47 +0000 Subject: [PATCH 3/3] updated to quantize k_pe_bmm for MLA Signed-off-by: Kinjal Patel --- examples/vllm_serve/fakequant_worker.py | 3 ++- modelopt/torch/quantization/plugins/vllm.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 9be41f72e..1008e748a 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -172,7 +172,8 @@ def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: dict[str, Any]) if kv_config := kv_quant_cfg.get("*[kv]_bmm_quantizer"): kv_quant_cfg["*kv_c_bmm_quantizer"] = kv_config - print("MLA detected: added *kv_c_bmm_quantizer config") + kv_quant_cfg["*k_pe_bmm_quantizer"] = kv_config + print("MLA detected: added *kv_c_bmm_quantizer and k_pe_bmm_quantizer config") return kv_quant_cfg diff --git a/modelopt/torch/quantization/plugins/vllm.py b/modelopt/torch/quantization/plugins/vllm.py index bd4b998f9..5ba12ef57 100644 --- a/modelopt/torch/quantization/plugins/vllm.py +++ b/modelopt/torch/quantization/plugins/vllm.py @@ -267,17 +267,13 @@ class _QuantVLLMAttention(QuantModule): def _setup(self): self.q_bmm_quantizer = TensorQuantizer() self.k_bmm_quantizer = TensorQuantizer() - # required for vllm < 0.11.1 - if not self.use_mla: - self.v_bmm_quantizer = TensorQuantizer() + self.v_bmm_quantizer = TensorQuantizer() self.parallel_state = create_parallel_state() def forward(self, query, key, value, *args, **kwargs): query = self.q_bmm_quantizer(query) key = self.k_bmm_quantizer(key) - # required for vllm < 0.11.1 - if not self.use_mla: - value = self.v_bmm_quantizer(value) + value = self.v_bmm_quantizer(value) return super().forward(query, key, value, *args, **kwargs) @@ -299,9 +295,11 @@ class _QuantVLLMMLAAttention(QuantModule): def _setup(self): self.q_bmm_quantizer = TensorQuantizer() self.kv_c_bmm_quantizer = TensorQuantizer() + self.k_pe_bmm_quantizer = TensorQuantizer() self.parallel_state = create_parallel_state() - def forward(self, query, kv_c, *args, **kwargs): + def forward(self, query, kv_c, k_pe, *args, **kwargs): query = self.q_bmm_quantizer(query) kv_c = self.kv_c_bmm_quantizer(kv_c) - return super().forward(query, kv_c, *args, **kwargs) + k_pe = self.k_pe_bmm_quantizer(k_pe) + return super().forward(query, kv_c, k_pe, *args, **kwargs)