diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 410b3cb5321cb..e564b18e7d323 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -1,14 +1,22 @@ -from typing import Any, Dict, List, Optional +from typing import Callable, Any, Dict, List, Optional import torch - +from torch.nn import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, replace_tensor, verify_marlin_supported, + verify_marlin_supports_shape) class AWQConfig(QuantizationConfig): """Config class for AWQ. @@ -64,9 +72,11 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": return cls(weight_bits, group_size, zero_point) def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["AWQLinearMethod"]: + prefix: str) -> Optional["QuantizedMethodBase"]: if isinstance(layer, LinearBase): return AWQLinearMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -170,3 +180,176 @@ def apply(self, if bias is not None: out.add_(bias) return out.reshape(out_shape) + +class AWQMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQConfig): + self.quant_config = quant_config + self.num_bits = self.quant_config.weight_bits + self.packed_factor = self.quant_config.pack_factor + self.group_size = self.quant_config.group_size + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + extra_weight_attrs.update({ + "is_transposed": True, + "quant_method": "group", + }) + + w13_qweight = Parameter(torch.empty(num_experts, + hidden_size, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = Parameter(torch.empty(num_experts, + intermediate_size, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = intermediate_size // self.quant_config.group_size + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = Parameter(torch.empty(num_experts, + num_groups_w13, + intermediate_size * 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = Parameter(torch.empty(num_experts, + num_groups_w13, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_qweight.shape[0] + device = layer.w13_qweight.device + + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1], + layer.w13_qweight.shape[2] * self.packed_factor, + self.num_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1], + layer.w2_qweight.shape[2] * self.packed_factor, + self.num_bits, + ) + replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.group_size + ) + + replace_tensor(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] , + size_n=layer.w2_scales.shape[2] * self.packed_factor, + group_size=self.group_size, + ) + replace_tensor(layer, "w2_scales", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + g_idx1=layer.w13_g_idx, + g_idx2=layer.w2_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + num_bits=self.num_bits) \ No newline at end of file diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 2bfe6ea09bd62..995bb253db8a1 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,7 +23,7 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] + mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "awq"] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported