From 6cdd1c7aed8a2ca6760e9b765fcd66f78f5e9fde Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 7 Oct 2024 07:05:36 -0400 Subject: [PATCH 01/14] pre-merge --- vllm/attention/layer.py | 3 + vllm/engine/llm_engine.py | 7 + vllm/model_executor/layers/fused_moe/layer.py | 9 +- vllm/model_executor/layers/linear.py | 2 + .../layers/quantization/__init__.py | 2 + .../layers/quantization/hqq_marlin.py | 180 ++++++++++++++++++ vllm/model_executor/model_loader/loader.py | 5 + .../model_loader/weight_utils.py | 11 ++ vllm/model_executor/models/llama.py | 6 + 9 files changed, 224 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/layers/quantization/hqq_marlin.py diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ecf964fa49d9b..d6c8f7164c435 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -49,6 +49,8 @@ def __init__( if num_kv_heads is None: num_kv_heads = num_heads + print("QUANT CONFIG:", quant_config) + # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we @@ -59,6 +61,7 @@ def __init__( self._v_scale = 1.0 quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None + print("QUANT METHOD:", quant_method) if quant_method is not None: assert isinstance(quant_method, BaseKVCacheMethod) # TODO (mgoin): kv cache dtype should be specified in the FP8 diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e3cd822f648fe..f02acb5766f6c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -314,6 +314,8 @@ def __init__( self.detokenizer = None tokenizer_group = None + print(self.load_config) + # Ensure that the function doesn't contain a reference to self, # to avoid engine GC issues def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: @@ -332,6 +334,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.input_processor = input_registry.create_input_processor( model_config) + # from pprint import pprint + # pprint(vars(model_config)) + # pprint(vars(executor_class)) + print("QUANT:", model_config.quantization) + self.model_executor = executor_class( model_config=model_config, cache_config=cache_config, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index bce740d0db750..cd333a41b8c71 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -11,6 +11,8 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.qqq import ( + QQQConfig, QQQLinearMethod) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -44,9 +46,12 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs): + print("WEIGHT SIZES:", num_experts, intermediate_size, hidden_size, + params_dtype) + # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, + intermediate_size, hidden_size, dtype=params_dtype), requires_grad=False) @@ -323,6 +328,8 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: + # print("LOADING:", weight_name) + # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 568892778abe2..f7af9f328e887 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -980,6 +980,8 @@ def __init__(self, self.input_size_per_partition = divide(input_size, self.tp_size) assert self.quant_method is not None + print("rowpar", self.quant_method) + self.quant_method.create_weights( layer=self, input_size_per_partition=self.input_size_per_partition, diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 3c38f0a006070..0966895f51229 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -27,6 +27,7 @@ NeuronQuantConfig) from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig +from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, @@ -47,6 +48,7 @@ "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, "qqq": QQQConfig, + "hqq_marlin": HQQMarlinConfig, "experts_int8": ExpertsInt8Config, "neuron_quant": NeuronQuantConfig, } diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py new file mode 100644 index 0000000000000..ea7a98e9ff7de --- /dev/null +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -0,0 +1,180 @@ +from typing import Any, Callable, Dict, List, Optional, Set, Union + +import torch + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, marlin_moe_permute_scales, + marlin_repeat_scales_on_all_ranks, verify_marlin_supported) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +class HQQMarlinConfig(QuantizationConfig): + """Config class for HQQ Marlin""" + + # (num_bits, is_sym) -> quant_type + TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + + def __init__( + self, + weight_bits: int, + group_size: int, + ) -> None: + self.pack_factor = 32 // weight_bits # packed into int32 + self.group_size = group_size + self.quant_type = self.TYPE_MAP[(weight_bits)] + + def __repr__(self) -> str: + return (f"HQQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size})") + + @classmethod + def get_name(cls) -> str: + return "hqq_marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + #TODO + return None + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> "HQQMarlinMethod": + return HQQMarlinMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class HQQMarlinMethod(LinearMethodBase): + """Linear method for HQQ Marlin. + """ + + def __init__( + self, + quant_config: HQQMarlinConfig, + ): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + is_row_parallel = input_size != input_size_per_partition + weight_loader = extra_weight_attrs.get("weight_loader") + + scales_and_zp_size = input_size_per_partition // self.quant_config.group_size + + # Quantized weights + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + zeros_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": + weight_loader + } + scales_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + + zeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **zeros_args) + + scales = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **scales_args) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("zeros", zeros) + layer.register_parameter("scales", scales) + + # self.kernel = '.' #TODO + + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + print("TODO") + # self.kernel.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return torch.empty((0), dtype=x.dtype, device=x.device) + # return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index c21b10d661ecc..2ee1e93100a9f 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -272,17 +272,22 @@ def _prepare_weights(self, model_name_or_path: str, # Some quantized models use .pt files for storing the weights. if load_format == LoadFormat.AUTO: allow_patterns = ["*.safetensors", "*.bin"] + print("AUTO") elif load_format == LoadFormat.SAFETENSORS: use_safetensors = True allow_patterns = ["*.safetensors"] + print("SAFETENSORS") elif load_format == LoadFormat.MISTRAL: use_safetensors = True allow_patterns = ["consolidated*.safetensors"] index_file = "consolidated.safetensors.index.json" + print("MISTRAL") elif load_format == LoadFormat.PT: allow_patterns = ["*.pt"] + print("PT") elif load_format == LoadFormat.NPCACHE: allow_patterns = ["*.bin"] + print("NPCACHE") else: raise ValueError(f"Unknown load_format: {load_format}") diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 5051d45dd1154..653d5d671f1a0 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -123,9 +123,20 @@ def get_quant_config(model_config: ModelConfig, quant_cls = get_quantization_config(model_config.quantization) + # print(vars(model_config)) + # print(vars(quant_cls)) + # GGUF doesn't have config file if model_config.quantization == "gguf": return quant_cls.from_config({}) + + if model_config.quantization == "hqq_marlin": + # print("=======================================") + # print(vars(model_config)) + # print(vars(load_config)) + # print("=======================================") + # TODO shouldn't be done like this + return quant_cls.from_config({"bits": 4, "group_size": 64}) # Read the quantization config from the HF model config, if available. hf_quant_config = getattr(model_config.hf_config, "quantization_config", diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5ff31e3833ec9..805b56552cb12 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -67,6 +67,7 @@ def __init__( prefix: str = "", ) -> None: super().__init__() + print("Quant config:", quant_config) self.gate_up_proj = MergedColumnParallelLinear( input_size=hidden_size, output_sizes=[intermediate_size] * 2, @@ -404,6 +405,8 @@ def __init__( ) -> None: super().__init__() + print("===== LLAMA FOR CAUSAL LM =====") + self.config = config self.lora_config = lora_config @@ -490,6 +493,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + # print(*[(n, w['meta'] if 'meta' in w else "") for n, w in weights], sep="\n") + # print(*[(n, w) for n, w in weights], sep="\n") + for name, loaded_weight in weights: name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight) From ee54bcaf77d0e664b4dcd3b0c435485ece39b665 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 14 Oct 2024 07:03:48 -0400 Subject: [PATCH 02/14] try different shapes --- vllm/attention/layer.py | 3 - vllm/model_executor/layers/linear.py | 60 ++++++- .../layers/quantization/hqq_marlin.py | 145 +++++++++++----- vllm/model_executor/model_loader/loader.py | 11 +- vllm/model_executor/models/llama.py | 157 +++++++++++++++++- 5 files changed, 308 insertions(+), 68 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index d6c8f7164c435..ecf964fa49d9b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -49,8 +49,6 @@ def __init__( if num_kv_heads is None: num_kv_heads = num_heads - print("QUANT CONFIG:", quant_config) - # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we @@ -61,7 +59,6 @@ def __init__( self._v_scale = 1.0 quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None - print("QUANT METHOD:", quant_method) if quant_method is not None: assert isinstance(quant_method, BaseKVCacheMethod) # TODO (mgoin): kv cache dtype should be specified in the FP8 diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 77ba6200e7842..05e46e9ae4b67 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -114,6 +114,8 @@ def apply(self, class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization.""" + global_print_ctr = 0 + def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, @@ -132,6 +134,13 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # if UnquantizedLinearMethod.global_print_ctr < 3: + # torch.set_printoptions(edgeitems=128) + # torch.set_printoptions(sci_mode=False) + + # print("apply to weight:", layer.weight, layer.weight.shape) + # # # print("and to bias:", bias) + # UnquantizedLinearMethod.global_print_ctr += 1 return F.linear(x, layer.weight, bias) @@ -371,12 +380,16 @@ def forward(self, input_): else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None + # print("=== ColumnParallelLinear ===") + # print("forward's io:", input_.shape, output.shape) + # print("for input:", input_) + # print("got output:", output) return output, output_bias def extra_repr(self) -> str: s = f"in_features={self.input_size}" s += f", output_features={self.output_size_per_partition}" - s += f", bias={self.bias is not None}" + s += f", bias={hasattr(self, 'bias') and self.bias is not None}" s += f", tp_size={get_tensor_model_parallel_world_size()}" s += f", gather_output={self.gather_output}" return s @@ -431,6 +444,8 @@ def weight_loader(self, loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None): + # print("weight loader", param.shape, loaded_weight.shape, loaded_shard_id) + # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -506,14 +521,19 @@ def weight_loader(self, # Special case for quantization. # If quantized, we need to adjust the offset and size to account # for the packing. + # print("shard_size1:", shard_size) packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + # print(vars(param)) + pack_factor = getattr(param, "packed_factor", None) + if pack_factor is None: + pack_factor = param.pack_factor + shard_size = shard_size // pack_factor + shard_offset = shard_offset // pack_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - + # print("shard_size2:", shard_size) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) if use_bitsandbytes_4bit: @@ -521,8 +541,11 @@ def weight_loader(self, shard_offset = loaded_weight.shape[output_dim] * \ loaded_shard_id + # print("shard_size3:", shard_size) + # print("param data pre:", param_data.shape) param_data = param_data.narrow(output_dim, shard_offset, shard_size) + # print("param data post:", param_data.shape) start_idx = tp_rank * shard_size # bitsandbytes loads the weights of the specific portion # no need to narrow here @@ -549,6 +572,8 @@ def weight_loader(self, "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") + # if param_data.shape != loaded_weight.shape: + # print("FAIL", param_data.shape, loaded_weight.shape) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -767,6 +792,8 @@ def weight_loader(self, loaded_weight: torch.Tensor, loaded_shard_id: Optional[str] = None): + # print("LOAD", param.shape, loaded_weight.shape, loaded_shard_id) + # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -859,8 +886,11 @@ def weight_loader(self, # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + pack_factor = getattr(param, "packed_factor", None) + if pack_factor is None: + pack_factor = param.pack_factor + shard_size = shard_size // pack_factor + shard_offset = shard_offset // pack_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( @@ -916,6 +946,7 @@ def weight_loader(self, "QKVParallelLinear, assume the weight is the same " "for all partitions.") + # print(param_data.shape, loaded_weight.shape) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -959,6 +990,8 @@ def __init__(self, super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix) + # print("RPL", input_size, output_size) + self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -968,8 +1001,6 @@ def __init__(self, self.input_size_per_partition = divide(input_size, self.tp_size) assert self.quant_method is not None - print("rowpar", self.quant_method) - self.quant_method.create_weights( layer=self, input_size_per_partition=self.input_size_per_partition, @@ -1014,11 +1045,15 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data + # print("PRE", param_data.shape, loaded_weight.shape, input_dim) + # bitsandbytes loads the weights of the specific portion # no need to narrow here if input_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[input_dim] start_idx = tp_rank * shard_size + # print("input_dim:", input_dim, "start_idx:", start_idx, + # "shard_size:", shard_size) loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) @@ -1027,6 +1062,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) + # print("POST", param_data.shape, loaded_weight.shape) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -1065,12 +1102,17 @@ def forward(self, input_): output_bias = self.bias if self.skip_bias_add else None + # print("=== RowParallelLinear ===") + # print("forward's io:", input_.shape, output.shape) + # print("for input:", input_) + # print("got output:", output) + return output, output_bias def extra_repr(self) -> str: s = f"input_features={self.input_size_per_partition}" s += f", output_features={self.output_size}" - s += f", bias={self.bias is not None}" + s += f", bias={hasattr(self, 'bias') and self.bias is not None}" s += f", tp_size={self.tp_size}" s += f", reduce_results={self.reduce_results}" return s diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index ea7a98e9ff7de..775371e101776 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -2,6 +2,7 @@ import torch +import torch.nn.functional as F from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import ( @@ -26,6 +27,8 @@ logger = init_logger(__name__) +big_printing_counter = 0 + class HQQMarlinConfig(QuantizationConfig): """Config class for HQQ Marlin""" @@ -79,7 +82,9 @@ def override_quantization_method(cls, hf_quant_cfg, def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "HQQMarlinMethod": - return HQQMarlinMethod(self) + if isinstance(layer, LinearBase): + return HQQMarlinMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return [] @@ -88,6 +93,8 @@ def get_scaled_act_names(self) -> List[str]: class HQQMarlinMethod(LinearMethodBase): """Linear method for HQQ Marlin. """ + + global_print_ctr = 0 def __init__( self, @@ -105,59 +112,68 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: - output_size_per_partition = sum(output_partition_sizes) + self.output_size_per_partition = sum(output_partition_sizes) is_row_parallel = input_size != input_size_per_partition + + self.input_size_per_partition = input_size_per_partition + weight_loader = extra_weight_attrs.get("weight_loader") + # print("WEIGHT LOADER:", weight_loader) + scales_and_zp_size = input_size_per_partition // self.quant_config.group_size + group_in_tensor_size = (self.output_size_per_partition * self.input_size_per_partition) // self.quant_config.group_size + # Quantized weights qweight = PackedvLLMParameter( + # data=torch.empty( + # self.output_size_per_partition // 2, + # input_size_per_partition, + # dtype=torch.uint8, + # ), data=torch.empty( - input_size_per_partition // self.quant_config.pack_factor, - output_size_per_partition, - dtype=torch.int32, + group_in_tensor_size // 2, + self.quant_config.group_size, + dtype=torch.uint8, ), - input_dim=0, - output_dim=1, + input_dim=1, + output_dim=0, packed_dim=0, - packed_factor=self.quant_config.pack_factor, + packed_factor=2,#self.quant_config.pack_factor, weight_loader=weight_loader) - zeros_args = { - "data": - torch.empty( - scales_and_zp_size, - output_size_per_partition // self.quant_config.pack_factor, - dtype=torch.int32, + zeros = GroupQuantScaleParameter( + # data=torch.empty( + # self.output_size_per_partition, + # scales_and_zp_size, + # dtype=params_dtype, + # ), + data=torch.empty( + group_in_tensor_size, + 1, + dtype=params_dtype, ), - "weight_loader": - weight_loader - } - scales_args = { - "data": - torch.empty( - scales_and_zp_size, - output_size_per_partition, + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + scales = GroupQuantScaleParameter( + # data=torch.empty( + # self.output_size_per_partition, + # scales_and_zp_size, + # dtype=params_dtype, + # ), + data=torch.empty( + group_in_tensor_size, + 1, dtype=params_dtype, ), - "weight_loader": - weight_loader - } - - zeros = PackedvLLMParameter( - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=self.quant_config.pack_factor, - **zeros_args) - - scales = PackedvLLMParameter( - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=self.quant_config.pack_factor, - **scales_args) + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + # print("qweight size:", qweight.shape) layer.register_parameter("qweight", qweight) layer.register_parameter("zeros", zeros) @@ -167,8 +183,12 @@ def create_weights( def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - print("TODO") + torch.set_printoptions(edgeitems=128) + print("layer qweight:", layer.qweight.shape) + print(layer.qweight.data) # self.kernel.process_weights_after_loading(layer) + raise ValueError("stop") + return def apply( self, @@ -176,5 +196,46 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return torch.empty((0), dtype=x.dtype, device=x.device) - # return self.kernel.apply_weights(layer, x, bias) + # print("input size:", x.shape) + # (layer.unpack() - meta['zero'])*meta['scale]).reshape(meta['shape']) + + ## this is unpack function copied from hqq repo + def unpack_4bit_u8(W_q: torch.Tensor, dtype=torch.uint8) ->torch.Tensor: # uint8/2 > uint8 + step = W_q.shape[0] + tmp = torch.empty([2 * step, W_q.shape[1]], dtype=dtype, device=W_q.device) + + tmp[:step] = (W_q & 0b11110000) >> 4 + tmp[step:] = W_q & 0b00001111 + + return tmp + ## + + # lowbits = torch.full((layer.qweight.shape), 15, device=x.device) + # shifts = torch.full((layer.qweight.shape), 4, device=x.device) + # unpacked = torch.concat([layer.qweight.bitwise_and(lowbits).to(torch.int8), + # layer.qweight.bitwise_right_shift(shifts).to(torch.int8)], dim=0) + unpacked = unpack_4bit_u8(layer.qweight, dtype=x.dtype) + scales = layer.scales.repeat_interleave(64, dim=1) + zeros = layer.zeros.repeat_interleave(64, dim=1) + # torch.set_printoptions(sci_mode=False) + # print("scales:", scales, scales.shape) + # print("zeros:", zeros, zeros.shape) + # # print(unpacked.shape, zeros.shape, scales.shape) + # print("mydeq:", unpacked) + b = (unpacked - zeros) * scales + # b = b.reshape(self.output_size_per_partition, self.input_size_per_partition) + # b = b.transpose(1, 0) + # print("unpacked:", unpacked, unpacked.shape) + # if HQQMarlinMethod.global_print_ctr < 3: + # torch.set_printoptions(edgeitems=128) + # torch.set_printoptions(sci_mode=False) + # print("mydeq:", b, b.shape) + # HQQMarlinMethod.global_print_ctr += 1 + # # print(x.shape, b.shape) + # print("act wq:", layer.qweight) + # return torch.matmul(x, b) + # print(x.dtype, b.dtype) + return F.linear(x, b, bias) + # return torch.empty((x.shape[0], self.output_size_per_partition), + # dtype=x.dtype, + # device=x.device) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 444bff430bc16..5da3dd6696d92 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -226,6 +226,7 @@ class Source: def __init__(self, load_config: LoadConfig): super().__init__(load_config) + print("========= INIT DEFAULT MODEL LOADER =========") if load_config.model_loader_extra_config: raise ValueError(f"Model loader extra config is not supported for " f"load format {load_config.load_format}") @@ -271,22 +272,22 @@ def _prepare_weights(self, model_name_or_path: str, # Some quantized models use .pt files for storing the weights. if load_format == LoadFormat.AUTO: allow_patterns = ["*.safetensors", "*.bin"] - print("AUTO") + # print("AUTO") elif load_format == LoadFormat.SAFETENSORS: use_safetensors = True allow_patterns = ["*.safetensors"] - print("SAFETENSORS") + # print("SAFETENSORS") elif load_format == LoadFormat.MISTRAL: use_safetensors = True allow_patterns = ["consolidated*.safetensors"] index_file = "consolidated.safetensors.index.json" - print("MISTRAL") + # print("MISTRAL") elif load_format == LoadFormat.PT: allow_patterns = ["*.pt"] - print("PT") + # print("PT") elif load_format == LoadFormat.NPCACHE: allow_patterns = ["*.bin"] - print("NPCACHE") + # print("NPCACHE") else: raise ValueError(f"Unknown load_format: {load_format}") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 74a4afbd36550..67787f218bf35 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -38,6 +38,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope @@ -68,7 +69,9 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - print("Quant config:", quant_config) + # print("gate_proj:", hidden_size, intermediate_size * 2) + # print("up_proj:", hidden_size, intermediate_size * 2) + # print("down_proj:", intermediate_size, hidden_size) self.gate_up_proj = MergedColumnParallelLinear( input_size=hidden_size, output_sizes=[intermediate_size] * 2, @@ -89,14 +92,18 @@ def __init__( self.act_fn = SiluAndMul() def forward(self, x): + # print("start forward mlp:", x) gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) + # print("end forward mlp:", x) return x class LlamaAttention(nn.Module): + global_print_ctr = 0 + def __init__( self, config: LlamaConfig, @@ -183,9 +190,23 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) + if LlamaAttention.global_print_ctr < 1: + torch.set_printoptions(edgeitems=2048) + torch.set_printoptions(sci_mode=False) + print("qkv:", qkv[0]) + LlamaAttention.global_print_ctr += 1 + # print("split params:", self.q_size, self.kv_size, self.kv_size, + # qkv.dtype) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # if LlamaAttention.global_print_ctr < 1: + # torch.set_printoptions(edgeitems=100) + # torch.set_printoptions(sci_mode=False) + # print("q k v 1:", q, k, v) + # LlamaAttention.global_print_ctr += 1 q, k = self.rotary_emb(positions, q, k) + # print("q k v 2:", q, k, v) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + # print("attn out:", attn_output) output, _ = self.o_proj(attn_output) return output @@ -269,6 +290,8 @@ def forward( class LlamaModel(nn.Module): + global_print_ctr = 0 + def __init__( self, config: LlamaConfig, @@ -286,6 +309,7 @@ def __init__( self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): + # print("et VocabParallelEmbedding") self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -293,6 +317,7 @@ def __init__( quant_config=quant_config, ) else: + # print("et PPMissingLayer") self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -307,6 +332,9 @@ def __init__( else: self.norm = PPMissingLayer() + self.is_hqq = (quant_config is not None and + isinstance(quant_config, HQQMarlinConfig)) + self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) @@ -358,8 +386,92 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] + + # print("load weights LlamaModel") params_dict = dict(self.named_parameters()) + # print(*[(k, v.shape) for k, v in params_dict.items()], sep="\n") + + hqq_map = [ + (".qweight", "W_q", False), + (".zeros", "zero", True), + (".scales", "scale", True), + ] + + ### this is unpack function copied from hqq repo + def unpack_4bit_u8(W_q: torch.Tensor, dtype=torch.uint8) ->torch.Tensor: # uint8/2 > uint8 + step = W_q.shape[0] + tmp = torch.empty([2 * step, W_q.shape[1]], dtype=dtype, device=W_q.device) + + tmp[:step] = (W_q & 0b11110000) >> 4 + tmp[step:] = W_q & 0b00001111 + + return tmp + ### + for name, loaded_weight in weights: + + if self.is_hqq: + # print("START WITH NAME", name) + pick_shard_id = None + for param_name, weight_name, shard_id in stacked_params_mapping: + # print("is", weight_name, "in", name, "?") + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + pick_shard_id = shard_id + break + if name.endswith("_proj"): + to_shape = loaded_weight["shape"] + group_size = loaded_weight["group_size"] + for c, k, should_scale in hqq_map: + new_name = name + c + if new_name not in params_dict: + continue + param = params_dict[new_name] + weight_loader = param.weight_loader + if k == "W_q" and LlamaModel.global_print_ctr < 3: + torch.set_printoptions(edgeitems=128) + print("load:", new_name, param.shape, param.dtype, + loaded_weight[k].shape, loaded_weight[k].dtype, + to_shape) + print(loaded_weight[k]) + LlamaModel.global_print_ctr += 1 + # if should_scale: + # loaded = loaded_weight[k].reshape(-1, to_shape[1] // group_size) + # else: + # loaded = loaded_weight[k].reshape(to_shape[0] // 2, to_shape[1]) + + #TODO try this + loaded = loaded_weight[k] + + if pick_shard_id is not None: + weight_loader(param, loaded, pick_shard_id) + else: + weight_loader(param, loaded) + + # unpack: unpack_4bit_u8 + param_wq = loaded_weight["W_q"] + param_zp = loaded_weight["zero"] + param_s = loaded_weight["scale"] + param_w = ((unpack_4bit_u8(param_wq, dtype=torch.bfloat16) - param_zp) * param_s + ).reshape(to_shape) + torch.set_printoptions(sci_mode=False) + # print("load wq orig shape:", param_wq) + # print("load wq:", param_wq.reshape(to_shape[0] // 2, to_shape[1])) + # print("deq:", unpack_4bit_u8(param_wq)) + # print("zps:", param_zp, param_zp.shape) + # print("s:", param_s, param_s.shape) + # print("w:", param_w) + else: + name = name + ".weight" + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + # print("load:", name, param.shape, param.dtype, + # loaded_weight["weight"].shape, loaded_weight["weight"].dtype) + weight_loader(param, loaded_weight["weight"]) + continue + if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name @@ -390,6 +502,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) + # if "proj" in name: + # print("load:", name, weight_loader) + # torch.set_printoptions(sci_mode=False) + # print("unq:", loaded_weight) + break else: # Skip loading extra bias for GPTQ models. @@ -408,6 +525,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) weight_loader(param, loaded_weight) + # if "proj" in name: + # print("load:", name, weight_loader) + # torch.set_printoptions(sci_mode=False) + # print("unq:", loaded_weight) + # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should # make sure to leave KV cache scale factors in a known good (dummy) state @@ -488,7 +610,7 @@ def __init__( ) -> None: super().__init__() - print("===== LLAMA FOR CAUSAL LM =====") + # print("===== LLAMA FOR CAUSAL LM =====") self.config = config self.lora_config = lora_config @@ -554,12 +676,16 @@ def sample(self, logits: torch.Tensor, return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # print("load weights LlamaForCausalLM") + # print(*[(n, w['W_q'] if 'W_q' in w else "") for n, w in weights], sep="\n") + # print(*[(n, w) for n, w in weights], sep="\n") + # weights = self.maybe_remap_hqq(weights) weights = [ self.maybe_remap_mistral(name, loaded_weight) for name, loaded_weight in weights ] - # print(*[(n, w['meta'] if 'meta' in w else "") for n, w in weights], sep="\n") # print(*[(n, w) for n, w in weights], sep="\n") + # raise ValueError(".") weights_group = group_weights_with_prefix(weights) @@ -568,13 +694,26 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if not self.config.tie_word_embeddings: lm_head_dict = dict(self.lm_head.named_parameters()) for name, loaded_weight in weights_group["lm_head"]: - if is_pp_missing_parameter(name, self.lm_head): - continue - param = lm_head_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + if name == '': + lw = loaded_weight + for name, loaded_weight in lw.items(): + if is_pp_missing_parameter(name, self.lm_head): + continue + + param = lm_head_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + else: + if is_pp_missing_parameter(name, self.lm_head): + continue + + param = lm_head_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) def load_kv_cache_scales(self, quantization_param_path: str) -> None: self.model.load_kv_cache_scales(quantization_param_path) From d3b5c1255300d2b53af84dfc40b69e00f59db715 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 15 Oct 2024 11:01:23 -0400 Subject: [PATCH 03/14] works with a hack --- vllm/model_executor/layers/linear.py | 26 ++---- .../layers/quantization/hqq_marlin.py | 81 +++++++++-------- vllm/model_executor/models/llama.py | 89 ++++++++++++------- 3 files changed, 108 insertions(+), 88 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 05e46e9ae4b67..9875547c94032 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -134,12 +134,15 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - # if UnquantizedLinearMethod.global_print_ctr < 3: + # if UnquantizedLinearMethod.global_print_ctr < 1: # torch.set_printoptions(edgeitems=128) # torch.set_printoptions(sci_mode=False) - # print("apply to weight:", layer.weight, layer.weight.shape) - # # # print("and to bias:", bias) + # torch.set_printoptions(profile="full") + # torch.set_printoptions(sci_mode=False) + # print("weight:", layer.weight.transpose(1, 0)[0], layer.weight.shape) + # # raise ValueError("stop") + # # print("and to bias:", bias) # UnquantizedLinearMethod.global_print_ctr += 1 return F.linear(x, layer.weight, bias) @@ -521,10 +524,8 @@ def weight_loader(self, # Special case for quantization. # If quantized, we need to adjust the offset and size to account # for the packing. - # print("shard_size1:", shard_size) packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - # print(vars(param)) pack_factor = getattr(param, "packed_factor", None) if pack_factor is None: pack_factor = param.pack_factor @@ -533,7 +534,6 @@ def weight_loader(self, # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - # print("shard_size2:", shard_size) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) if use_bitsandbytes_4bit: @@ -541,11 +541,8 @@ def weight_loader(self, shard_offset = loaded_weight.shape[output_dim] * \ loaded_shard_id - # print("shard_size3:", shard_size) - # print("param data pre:", param_data.shape) param_data = param_data.narrow(output_dim, shard_offset, shard_size) - # print("param data post:", param_data.shape) start_idx = tp_rank * shard_size # bitsandbytes loads the weights of the specific portion # no need to narrow here @@ -792,8 +789,6 @@ def weight_loader(self, loaded_weight: torch.Tensor, loaded_shard_id: Optional[str] = None): - # print("LOAD", param.shape, loaded_weight.shape, loaded_shard_id) - # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -856,7 +851,6 @@ def weight_loader(self, if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -946,7 +940,6 @@ def weight_loader(self, "QKVParallelLinear, assume the weight is the same " "for all partitions.") - # print(param_data.shape, loaded_weight.shape) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -990,8 +983,6 @@ def __init__(self, super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix) - # print("RPL", input_size, output_size) - self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -1045,15 +1036,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data - # print("PRE", param_data.shape, loaded_weight.shape, input_dim) # bitsandbytes loads the weights of the specific portion # no need to narrow here if input_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[input_dim] start_idx = tp_rank * shard_size - # print("input_dim:", input_dim, "start_idx:", start_idx, - # "shard_size:", shard_size) loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) @@ -1062,8 +1050,6 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) - # print("POST", param_data.shape, loaded_weight.shape) - assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 775371e101776..72f2fcf836165 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -127,48 +127,48 @@ def create_weights( # Quantized weights qweight = PackedvLLMParameter( - # data=torch.empty( - # self.output_size_per_partition // 2, - # input_size_per_partition, - # dtype=torch.uint8, - # ), data=torch.empty( - group_in_tensor_size // 2, - self.quant_config.group_size, + self.output_size_per_partition, + input_size_per_partition, dtype=torch.uint8, ), + # data=torch.empty( + # group_in_tensor_size // 2, + # self.quant_config.group_size, + # dtype=torch.uint8, + # ), input_dim=1, output_dim=0, packed_dim=0, - packed_factor=2,#self.quant_config.pack_factor, + packed_factor=1,#self.quant_config.pack_factor, weight_loader=weight_loader) zeros = GroupQuantScaleParameter( - # data=torch.empty( - # self.output_size_per_partition, - # scales_and_zp_size, - # dtype=params_dtype, - # ), data=torch.empty( - group_in_tensor_size, - 1, + self.output_size_per_partition, + scales_and_zp_size, dtype=params_dtype, ), + # data=torch.empty( + # group_in_tensor_size, + # 1, + # dtype=params_dtype, + # ), input_dim=1, output_dim=0, weight_loader=weight_loader) scales = GroupQuantScaleParameter( - # data=torch.empty( - # self.output_size_per_partition, - # scales_and_zp_size, - # dtype=params_dtype, - # ), data=torch.empty( - group_in_tensor_size, - 1, + self.output_size_per_partition, + scales_and_zp_size, dtype=params_dtype, ), + # data=torch.empty( + # group_in_tensor_size, + # 1, + # dtype=params_dtype, + # ), input_dim=1, output_dim=0, weight_loader=weight_loader) @@ -183,11 +183,11 @@ def create_weights( def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - torch.set_printoptions(edgeitems=128) - print("layer qweight:", layer.qweight.shape) - print(layer.qweight.data) - # self.kernel.process_weights_after_loading(layer) - raise ValueError("stop") + # torch.set_printoptions(profile="full") + # print("layer qweight:", layer.qweight.shape) + # print(layer.qweight.data.transpose(1, 0)[0]) + # # self.kernel.process_weights_after_loading(layer) + # raise ValueError("stop") return def apply( @@ -214,26 +214,31 @@ def unpack_4bit_u8(W_q: torch.Tensor, dtype=torch.uint8) ->torch.Tensor: # uint # shifts = torch.full((layer.qweight.shape), 4, device=x.device) # unpacked = torch.concat([layer.qweight.bitwise_and(lowbits).to(torch.int8), # layer.qweight.bitwise_right_shift(shifts).to(torch.int8)], dim=0) - unpacked = unpack_4bit_u8(layer.qweight, dtype=x.dtype) - scales = layer.scales.repeat_interleave(64, dim=1) - zeros = layer.zeros.repeat_interleave(64, dim=1) + unpacked = layer.qweight.reshape(-1, 64) #unpack_4bit_u8(layer.qweight.reshape(-1, 64), dtype=x.dtype) + scales = layer.scales.reshape(-1, 1)#.repeat_interleave(64, dim=1) + zeros = layer.zeros.reshape(-1, 1)#.repeat_interleave(64, dim=1) # torch.set_printoptions(sci_mode=False) # print("scales:", scales, scales.shape) # print("zeros:", zeros, zeros.shape) # # print(unpacked.shape, zeros.shape, scales.shape) # print("mydeq:", unpacked) b = (unpacked - zeros) * scales - # b = b.reshape(self.output_size_per_partition, self.input_size_per_partition) - # b = b.transpose(1, 0) + b = b.reshape(self.output_size_per_partition, self.input_size_per_partition) # print("unpacked:", unpacked, unpacked.shape) - # if HQQMarlinMethod.global_print_ctr < 3: - # torch.set_printoptions(edgeitems=128) - # torch.set_printoptions(sci_mode=False) - # print("mydeq:", b, b.shape) - # HQQMarlinMethod.global_print_ctr += 1 + if HQQMarlinMethod.global_print_ctr < 1: + torch.set_printoptions(profile="full") + torch.set_printoptions(sci_mode=False) + # print("unpacked size:", layer.qweight.reshape(-1, 64).shape, "->", unpacked.shape) + # print(layer.qweight.reshape(-1, 64).transpose(1, 0)[0]) + # print(unpacked.transpose(1, 0)[0]) + # print("act wq:", layer.qweight[0]) + # print("scales:", layer.scales.reshape(-1, 1).transpose(1, 0)) + # print("zeros:", layer.zeros.reshape(-1, 1).transpose(1, 0)) + # print("mydeq:", b.transpose(1, 0)[0], b.shape) + HQQMarlinMethod.global_print_ctr += 1 + # raise ValueError("stop") # # print(x.shape, b.shape) # print("act wq:", layer.qweight) - # return torch.matmul(x, b) # print(x.dtype, b.dtype) return F.linear(x, b, bias) # return torch.empty((x.shape[0], self.output_size_per_partition), diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 67787f218bf35..8c7f4c5dc8806 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -190,19 +190,21 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - if LlamaAttention.global_print_ctr < 1: - torch.set_printoptions(edgeitems=2048) - torch.set_printoptions(sci_mode=False) - print("qkv:", qkv[0]) - LlamaAttention.global_print_ctr += 1 + # if LlamaAttention.global_print_ctr < 1: + # torch.set_printoptions(profile="full") + # torch.set_printoptions(sci_mode=False) + # print("qkv:", qkv[0]) + # LlamaAttention.global_print_ctr += 1 # print("split params:", self.q_size, self.kv_size, self.kv_size, # qkv.dtype) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # if LlamaAttention.global_print_ctr < 1: - # torch.set_printoptions(edgeitems=100) + # torch.set_printoptions(profile="full") # torch.set_printoptions(sci_mode=False) - # print("q k v 1:", q, k, v) + # print("q k v 1:", q[0], k[0], v[0]) + # print("shapes of all:", qkv.shape, "->", q.shape, k.shape, v.shape) # LlamaAttention.global_print_ctr += 1 + # raise ValueError("stop") q, k = self.rotary_emb(positions, q, k) # print("q k v 2:", q, k, v) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) @@ -429,20 +431,28 @@ def unpack_4bit_u8(W_q: torch.Tensor, dtype=torch.uint8) ->torch.Tensor: # uint continue param = params_dict[new_name] weight_loader = param.weight_loader - if k == "W_q" and LlamaModel.global_print_ctr < 3: - torch.set_printoptions(edgeitems=128) - print("load:", new_name, param.shape, param.dtype, - loaded_weight[k].shape, loaded_weight[k].dtype, - to_shape) - print(loaded_weight[k]) - LlamaModel.global_print_ctr += 1 - # if should_scale: - # loaded = loaded_weight[k].reshape(-1, to_shape[1] // group_size) - # else: - # loaded = loaded_weight[k].reshape(to_shape[0] // 2, to_shape[1]) + if should_scale: + loaded = loaded_weight[k].reshape(-1, to_shape[1] // group_size) + else: + loaded = unpack_4bit_u8(loaded_weight[k], dtype=torch.bfloat16).reshape(to_shape).to(torch.uint8) + # loaded1 = loaded[:to_shape[0]] + # loaded2 = loaded[to_shape[0]:] + # if (pick_shard_id == "q" or pick_shard_id == "k" or + # pick_shard_id == "v"): + # pass + + # if k == "W_q" and LlamaModel.global_print_ctr < 3: + # torch.set_printoptions(profile="full") + # print("load:", new_name, param.shape, param.dtype, + # loaded_weight[k].shape, loaded_weight[k].dtype, + # to_shape) + # print(loaded.transpose(1, 0)[0]) + # LlamaModel.global_print_ctr += 1 #TODO try this - loaded = loaded_weight[k] + # loaded = loaded_weight[k] + + # print(pick_shard_id) if pick_shard_id is not None: weight_loader(param, loaded, pick_shard_id) @@ -455,13 +465,26 @@ def unpack_4bit_u8(W_q: torch.Tensor, dtype=torch.uint8) ->torch.Tensor: # uint param_s = loaded_weight["scale"] param_w = ((unpack_4bit_u8(param_wq, dtype=torch.bfloat16) - param_zp) * param_s ).reshape(to_shape) + torch.set_printoptions(profile="full") torch.set_printoptions(sci_mode=False) - # print("load wq orig shape:", param_wq) - # print("load wq:", param_wq.reshape(to_shape[0] // 2, to_shape[1])) + if LlamaModel.global_print_ctr < 3: + # print("load wq orig shape:", param_wq.shape, + # param_wq.reshape(to_shape[0] // 2, to_shape[1]).shape) + # # print("load wq:", param_wq.reshape(to_shape[0] // 2, to_shape[1])[0]) + # print("param s:", param_s.shape, param_s.reshape(-1, to_shape[1] // group_size).shape) + # print("param zp:", param_zp.shape, param_zp.reshape(-1, to_shape[1] // group_size).shape) + # print("s:", param_s.transpose(1, 0)) + # if LlamaModel.global_print_ctr > 0: + # print("zp:", param_zp.transpose(1, 0)) + # print(name) + # print("w:", param_w.transpose(1, 0)[0]) + # print("wq shape:", param_wq.shape, "->", unpack_4bit_u8(param_wq, dtype=torch.bfloat16).shape) + # print(param_wq.transpose(1, 0)[0]) + # print(unpack_4bit_u8(param_wq, dtype=torch.bfloat16).transpose(1, 0)[0]) + LlamaModel.global_print_ctr += 1 # print("deq:", unpack_4bit_u8(param_wq)) # print("zps:", param_zp, param_zp.shape) # print("s:", param_s, param_s.shape) - # print("w:", param_w) else: name = name + ".weight" param = params_dict[name] @@ -502,10 +525,13 @@ def unpack_4bit_u8(W_q: torch.Tensor, dtype=torch.uint8) ->torch.Tensor: # uint weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) - # if "proj" in name: - # print("load:", name, weight_loader) - # torch.set_printoptions(sci_mode=False) - # print("unq:", loaded_weight) + if LlamaModel.global_print_ctr < 3 and "layers.0.self_attn.qkv_proj" in name: + torch.set_printoptions(profile="full") + torch.set_printoptions(sci_mode=False) + print("load:", name, weight_loader) + torch.set_printoptions(sci_mode=False) + print("unq:", loaded_weight.transpose(1, 0)[0]) + LlamaModel.global_print_ctr += 1 break else: @@ -525,10 +551,13 @@ def unpack_4bit_u8(W_q: torch.Tensor, dtype=torch.uint8) ->torch.Tensor: # uint default_weight_loader) weight_loader(param, loaded_weight) - # if "proj" in name: - # print("load:", name, weight_loader) - # torch.set_printoptions(sci_mode=False) - # print("unq:", loaded_weight) + if LlamaModel.global_print_ctr < 3 and "layers.0.self_attn.qkv_proj" in name: + torch.set_printoptions(profile="full") + torch.set_printoptions(sci_mode=False) + print("load:", name, weight_loader) + torch.set_printoptions(sci_mode=False) + print("unq:", loaded_weight.transpose(1, 0)[0]) + LlamaModel.global_print_ctr += 1 # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should From 45a993b87aad1a7f3ec17f4d87d7cb52fa1c9e2b Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 15 Oct 2024 11:27:33 -0400 Subject: [PATCH 04/14] cleanup --- vllm/model_executor/layers/fused_moe/layer.py | 7 - vllm/model_executor/layers/linear.py | 26 --- .../layers/quantization/__init__.py | 2 +- .../layers/quantization/hqq_marlin.py | 160 ++++-------------- .../model_loader/weight_utils.py | 2 +- vllm/model_executor/models/llama.py | 133 +++------------ 6 files changed, 59 insertions(+), 271 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index cd333a41b8c71..a994c11ddba83 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -11,8 +11,6 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.qqq import ( - QQQConfig, QQQLinearMethod) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -46,9 +44,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - print("WEIGHT SIZES:", num_experts, intermediate_size, hidden_size, - params_dtype) - # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter(torch.empty(num_experts, intermediate_size, @@ -328,8 +323,6 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: - # print("LOADING:", weight_name) - # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 9875547c94032..602d578b78ff4 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -114,8 +114,6 @@ def apply(self, class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization.""" - global_print_ctr = 0 - def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, @@ -133,17 +131,6 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - - # if UnquantizedLinearMethod.global_print_ctr < 1: - # torch.set_printoptions(edgeitems=128) - # torch.set_printoptions(sci_mode=False) - - # torch.set_printoptions(profile="full") - # torch.set_printoptions(sci_mode=False) - # print("weight:", layer.weight.transpose(1, 0)[0], layer.weight.shape) - # # raise ValueError("stop") - # # print("and to bias:", bias) - # UnquantizedLinearMethod.global_print_ctr += 1 return F.linear(x, layer.weight, bias) @@ -383,10 +370,6 @@ def forward(self, input_): else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None - # print("=== ColumnParallelLinear ===") - # print("forward's io:", input_.shape, output.shape) - # print("for input:", input_) - # print("got output:", output) return output, output_bias def extra_repr(self) -> str: @@ -447,8 +430,6 @@ def weight_loader(self, loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None): - # print("weight loader", param.shape, loaded_weight.shape, loaded_shard_id) - # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -569,8 +550,6 @@ def weight_loader(self, "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") - # if param_data.shape != loaded_weight.shape: - # print("FAIL", param_data.shape, loaded_weight.shape) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -1088,11 +1067,6 @@ def forward(self, input_): output_bias = self.bias if self.skip_bias_add else None - # print("=== RowParallelLinear ===") - # print("forward's io:", input_.shape, output.shape) - # print("for input:", input_) - # print("got output:", output) - return output, output_bias def extra_repr(self) -> str: diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 0966895f51229..2692b6efd6bd2 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -21,13 +21,13 @@ GPTQMarlinConfig) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config) +from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config from vllm.model_executor.layers.quantization.neuron_quant import ( NeuronQuantConfig) from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig -from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 72f2fcf836165..df5d9ba15fa56 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -1,38 +1,23 @@ -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional import torch - import torch.nn.functional as F -from vllm import _custom_ops as ops + from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.kernels import ( - MPLinearLayerConfig, choose_mp_linear_kernel) -from vllm.model_executor.layers.quantization.utils import replace_parameter -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, marlin_moe_permute_scales, - marlin_repeat_scales_on_all_ranks, verify_marlin_supported) -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedColumnParameter, - PackedvLLMParameter, - RowvLLMParameter) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) from vllm.scalar_type import scalar_types logger = init_logger(__name__) -big_printing_counter = 0 class HQQMarlinConfig(QuantizationConfig): """Config class for HQQ Marlin""" - # (num_bits, is_sym) -> quant_type + # (num_bits, is_sym) -> quant_type TYPE_MAP = { 4: scalar_types.uint4, 8: scalar_types.uint8, @@ -79,23 +64,20 @@ def override_quantization_method(cls, hf_quant_cfg, #TODO return None - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> "HQQMarlinMethod": + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["HQQMarlinMethod"]: if isinstance(layer, LinearBase): return HQQMarlinMethod(self) return None def get_scaled_act_names(self) -> List[str]: return [] - + class HQQMarlinMethod(LinearMethodBase): """Linear method for HQQ Marlin. """ - global_print_ctr = 0 - def __init__( self, quant_config: HQQMarlinConfig, @@ -113,17 +95,13 @@ def create_weights( **extra_weight_attrs, ) -> None: self.output_size_per_partition = sum(output_partition_sizes) - is_row_parallel = input_size != input_size_per_partition self.input_size_per_partition = input_size_per_partition - - weight_loader = extra_weight_attrs.get("weight_loader") - - # print("WEIGHT LOADER:", weight_loader) - scales_and_zp_size = input_size_per_partition // self.quant_config.group_size + weight_loader = extra_weight_attrs.get("weight_loader") - group_in_tensor_size = (self.output_size_per_partition * self.input_size_per_partition) // self.quant_config.group_size + scales_and_zp_size = (input_size_per_partition // + self.quant_config.group_size) # Quantized weights qweight = PackedvLLMParameter( @@ -132,62 +110,38 @@ def create_weights( input_size_per_partition, dtype=torch.uint8, ), - # data=torch.empty( - # group_in_tensor_size // 2, - # self.quant_config.group_size, - # dtype=torch.uint8, - # ), input_dim=1, output_dim=0, packed_dim=0, - packed_factor=1,#self.quant_config.pack_factor, - weight_loader=weight_loader) - - zeros = GroupQuantScaleParameter( - data=torch.empty( - self.output_size_per_partition, - scales_and_zp_size, - dtype=params_dtype, - ), - # data=torch.empty( - # group_in_tensor_size, - # 1, - # dtype=params_dtype, - # ), - input_dim=1, - output_dim=0, + packed_factor=1, #self.quant_config.pack_factor, weight_loader=weight_loader) - scales = GroupQuantScaleParameter( - data=torch.empty( - self.output_size_per_partition, - scales_and_zp_size, - dtype=params_dtype, - ), - # data=torch.empty( - # group_in_tensor_size, - # 1, - # dtype=params_dtype, - # ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + zeros = GroupQuantScaleParameter(data=torch.empty( + self.output_size_per_partition, + scales_and_zp_size, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + scales = GroupQuantScaleParameter(data=torch.empty( + self.output_size_per_partition, + scales_and_zp_size, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) - # print("qweight size:", qweight.shape) - layer.register_parameter("qweight", qweight) layer.register_parameter("zeros", zeros) layer.register_parameter("scales", scales) # self.kernel = '.' #TODO - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # torch.set_printoptions(profile="full") - # print("layer qweight:", layer.qweight.shape) - # print(layer.qweight.data.transpose(1, 0)[0]) - # # self.kernel.process_weights_after_loading(layer) - # raise ValueError("stop") + # TODO marlin format return def apply( @@ -196,51 +150,11 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # print("input size:", x.shape) - # (layer.unpack() - meta['zero'])*meta['scale]).reshape(meta['shape']) - - ## this is unpack function copied from hqq repo - def unpack_4bit_u8(W_q: torch.Tensor, dtype=torch.uint8) ->torch.Tensor: # uint8/2 > uint8 - step = W_q.shape[0] - tmp = torch.empty([2 * step, W_q.shape[1]], dtype=dtype, device=W_q.device) - - tmp[:step] = (W_q & 0b11110000) >> 4 - tmp[step:] = W_q & 0b00001111 - - return tmp - ## - - # lowbits = torch.full((layer.qweight.shape), 15, device=x.device) - # shifts = torch.full((layer.qweight.shape), 4, device=x.device) - # unpacked = torch.concat([layer.qweight.bitwise_and(lowbits).to(torch.int8), - # layer.qweight.bitwise_right_shift(shifts).to(torch.int8)], dim=0) - unpacked = layer.qweight.reshape(-1, 64) #unpack_4bit_u8(layer.qweight.reshape(-1, 64), dtype=x.dtype) - scales = layer.scales.reshape(-1, 1)#.repeat_interleave(64, dim=1) - zeros = layer.zeros.reshape(-1, 1)#.repeat_interleave(64, dim=1) - # torch.set_printoptions(sci_mode=False) - # print("scales:", scales, scales.shape) - # print("zeros:", zeros, zeros.shape) - # # print(unpacked.shape, zeros.shape, scales.shape) - # print("mydeq:", unpacked) + # TODO marlin kernel + unpacked = layer.qweight.reshape(-1, 64) + scales = layer.scales.reshape(-1, 1) + zeros = layer.zeros.reshape(-1, 1) b = (unpacked - zeros) * scales - b = b.reshape(self.output_size_per_partition, self.input_size_per_partition) - # print("unpacked:", unpacked, unpacked.shape) - if HQQMarlinMethod.global_print_ctr < 1: - torch.set_printoptions(profile="full") - torch.set_printoptions(sci_mode=False) - # print("unpacked size:", layer.qweight.reshape(-1, 64).shape, "->", unpacked.shape) - # print(layer.qweight.reshape(-1, 64).transpose(1, 0)[0]) - # print(unpacked.transpose(1, 0)[0]) - # print("act wq:", layer.qweight[0]) - # print("scales:", layer.scales.reshape(-1, 1).transpose(1, 0)) - # print("zeros:", layer.zeros.reshape(-1, 1).transpose(1, 0)) - # print("mydeq:", b.transpose(1, 0)[0], b.shape) - HQQMarlinMethod.global_print_ctr += 1 - # raise ValueError("stop") - # # print(x.shape, b.shape) - # print("act wq:", layer.qweight) - # print(x.dtype, b.dtype) + b = b.reshape(self.output_size_per_partition, + self.input_size_per_partition) return F.linear(x, b, bias) - # return torch.empty((x.shape[0], self.output_size_per_partition), - # dtype=x.dtype, - # device=x.device) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 653d5d671f1a0..660bee780dc3c 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -129,7 +129,7 @@ def get_quant_config(model_config: ModelConfig, # GGUF doesn't have config file if model_config.quantization == "gguf": return quant_cls.from_config({}) - + if model_config.quantization == "hqq_marlin": # print("=======================================") # print(vars(model_config)) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8c7f4c5dc8806..adddd2d16859f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -38,9 +38,9 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -69,9 +69,6 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - # print("gate_proj:", hidden_size, intermediate_size * 2) - # print("up_proj:", hidden_size, intermediate_size * 2) - # print("down_proj:", intermediate_size, hidden_size) self.gate_up_proj = MergedColumnParallelLinear( input_size=hidden_size, output_sizes=[intermediate_size] * 2, @@ -92,18 +89,14 @@ def __init__( self.act_fn = SiluAndMul() def forward(self, x): - # print("start forward mlp:", x) gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) - # print("end forward mlp:", x) return x class LlamaAttention(nn.Module): - global_print_ctr = 0 - def __init__( self, config: LlamaConfig, @@ -190,25 +183,9 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - # if LlamaAttention.global_print_ctr < 1: - # torch.set_printoptions(profile="full") - # torch.set_printoptions(sci_mode=False) - # print("qkv:", qkv[0]) - # LlamaAttention.global_print_ctr += 1 - # print("split params:", self.q_size, self.kv_size, self.kv_size, - # qkv.dtype) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # if LlamaAttention.global_print_ctr < 1: - # torch.set_printoptions(profile="full") - # torch.set_printoptions(sci_mode=False) - # print("q k v 1:", q[0], k[0], v[0]) - # print("shapes of all:", qkv.shape, "->", q.shape, k.shape, v.shape) - # LlamaAttention.global_print_ctr += 1 - # raise ValueError("stop") q, k = self.rotary_emb(positions, q, k) - # print("q k v 2:", q, k, v) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - # print("attn out:", attn_output) output, _ = self.o_proj(attn_output) return output @@ -292,8 +269,6 @@ def forward( class LlamaModel(nn.Module): - global_print_ctr = 0 - def __init__( self, config: LlamaConfig, @@ -311,7 +286,6 @@ def __init__( self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): - # print("et VocabParallelEmbedding") self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -319,7 +293,6 @@ def __init__( quant_config=quant_config, ) else: - # print("et PPMissingLayer") self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -334,8 +307,8 @@ def __init__( else: self.norm = PPMissingLayer() - self.is_hqq = (quant_config is not None and - isinstance(quant_config, HQQMarlinConfig)) + self.is_hqq = (quant_config is not None + and isinstance(quant_config, HQQMarlinConfig)) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( @@ -389,9 +362,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] - # print("load weights LlamaModel") params_dict = dict(self.named_parameters()) - # print(*[(k, v.shape) for k, v in params_dict.items()], sep="\n") hqq_map = [ (".qweight", "W_q", False), @@ -399,24 +370,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".scales", "scale", True), ] - ### this is unpack function copied from hqq repo - def unpack_4bit_u8(W_q: torch.Tensor, dtype=torch.uint8) ->torch.Tensor: # uint8/2 > uint8 + # unpack function from https://github.com/mobiusml/hqq + def unpack_4bit_u8( + W_q: torch.Tensor, + dtype=torch.uint8) -> torch.Tensor: # uint8/2 > uint8 step = W_q.shape[0] - tmp = torch.empty([2 * step, W_q.shape[1]], dtype=dtype, device=W_q.device) + tmp = torch.empty([2 * step, W_q.shape[1]], + dtype=dtype, + device=W_q.device) tmp[:step] = (W_q & 0b11110000) >> 4 tmp[step:] = W_q & 0b00001111 return tmp - ### for name, loaded_weight in weights: if self.is_hqq: - # print("START WITH NAME", name) pick_shard_id = None for param_name, weight_name, shard_id in stacked_params_mapping: - # print("is", weight_name, "in", name, "?") if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -432,66 +404,25 @@ def unpack_4bit_u8(W_q: torch.Tensor, dtype=torch.uint8) ->torch.Tensor: # uint param = params_dict[new_name] weight_loader = param.weight_loader if should_scale: - loaded = loaded_weight[k].reshape(-1, to_shape[1] // group_size) + loaded = loaded_weight[k].reshape( + -1, to_shape[1] // group_size) else: - loaded = unpack_4bit_u8(loaded_weight[k], dtype=torch.bfloat16).reshape(to_shape).to(torch.uint8) - # loaded1 = loaded[:to_shape[0]] - # loaded2 = loaded[to_shape[0]:] - # if (pick_shard_id == "q" or pick_shard_id == "k" or - # pick_shard_id == "v"): - # pass - - # if k == "W_q" and LlamaModel.global_print_ctr < 3: - # torch.set_printoptions(profile="full") - # print("load:", new_name, param.shape, param.dtype, - # loaded_weight[k].shape, loaded_weight[k].dtype, - # to_shape) - # print(loaded.transpose(1, 0)[0]) - # LlamaModel.global_print_ctr += 1 - - #TODO try this - # loaded = loaded_weight[k] - - # print(pick_shard_id) + # TODO we should unpack inside the quantization + # method / kernel + loaded = unpack_4bit_u8( + loaded_weight[k], + dtype=torch.bfloat16).reshape(to_shape).to( + torch.uint8) if pick_shard_id is not None: weight_loader(param, loaded, pick_shard_id) else: weight_loader(param, loaded) - - # unpack: unpack_4bit_u8 - param_wq = loaded_weight["W_q"] - param_zp = loaded_weight["zero"] - param_s = loaded_weight["scale"] - param_w = ((unpack_4bit_u8(param_wq, dtype=torch.bfloat16) - param_zp) * param_s - ).reshape(to_shape) - torch.set_printoptions(profile="full") - torch.set_printoptions(sci_mode=False) - if LlamaModel.global_print_ctr < 3: - # print("load wq orig shape:", param_wq.shape, - # param_wq.reshape(to_shape[0] // 2, to_shape[1]).shape) - # # print("load wq:", param_wq.reshape(to_shape[0] // 2, to_shape[1])[0]) - # print("param s:", param_s.shape, param_s.reshape(-1, to_shape[1] // group_size).shape) - # print("param zp:", param_zp.shape, param_zp.reshape(-1, to_shape[1] // group_size).shape) - # print("s:", param_s.transpose(1, 0)) - # if LlamaModel.global_print_ctr > 0: - # print("zp:", param_zp.transpose(1, 0)) - # print(name) - # print("w:", param_w.transpose(1, 0)[0]) - # print("wq shape:", param_wq.shape, "->", unpack_4bit_u8(param_wq, dtype=torch.bfloat16).shape) - # print(param_wq.transpose(1, 0)[0]) - # print(unpack_4bit_u8(param_wq, dtype=torch.bfloat16).transpose(1, 0)[0]) - LlamaModel.global_print_ctr += 1 - # print("deq:", unpack_4bit_u8(param_wq)) - # print("zps:", param_zp, param_zp.shape) - # print("s:", param_s, param_s.shape) else: name = name + ".weight" param = params_dict[name] weight_loader = getattr(param, "weight_loader", - default_weight_loader) - # print("load:", name, param.shape, param.dtype, - # loaded_weight["weight"].shape, loaded_weight["weight"].dtype) + default_weight_loader) weight_loader(param, loaded_weight["weight"]) continue @@ -525,14 +456,6 @@ def unpack_4bit_u8(W_q: torch.Tensor, dtype=torch.uint8) ->torch.Tensor: # uint weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) - if LlamaModel.global_print_ctr < 3 and "layers.0.self_attn.qkv_proj" in name: - torch.set_printoptions(profile="full") - torch.set_printoptions(sci_mode=False) - print("load:", name, weight_loader) - torch.set_printoptions(sci_mode=False) - print("unq:", loaded_weight.transpose(1, 0)[0]) - LlamaModel.global_print_ctr += 1 - break else: # Skip loading extra bias for GPTQ models. @@ -551,14 +474,6 @@ def unpack_4bit_u8(W_q: torch.Tensor, dtype=torch.uint8) ->torch.Tensor: # uint default_weight_loader) weight_loader(param, loaded_weight) - if LlamaModel.global_print_ctr < 3 and "layers.0.self_attn.qkv_proj" in name: - torch.set_printoptions(profile="full") - torch.set_printoptions(sci_mode=False) - print("load:", name, weight_loader) - torch.set_printoptions(sci_mode=False) - print("unq:", loaded_weight.transpose(1, 0)[0]) - LlamaModel.global_print_ctr += 1 - # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should # make sure to leave KV cache scale factors in a known good (dummy) state @@ -639,8 +554,6 @@ def __init__( ) -> None: super().__init__() - # print("===== LLAMA FOR CAUSAL LM =====") - self.config = config self.lora_config = lora_config @@ -705,16 +618,10 @@ def sample(self, logits: torch.Tensor, return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # print("load weights LlamaForCausalLM") - # print(*[(n, w['W_q'] if 'W_q' in w else "") for n, w in weights], sep="\n") - # print(*[(n, w) for n, w in weights], sep="\n") - # weights = self.maybe_remap_hqq(weights) weights = [ self.maybe_remap_mistral(name, loaded_weight) for name, loaded_weight in weights ] - # print(*[(n, w) for n, w in weights], sep="\n") - # raise ValueError(".") weights_group = group_weights_with_prefix(weights) @@ -734,7 +641,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - + else: if is_pp_missing_parameter(name, self.lm_head): continue From 3324b2ec9db0f8a0e60da95cdb5b3cc76ab71704 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 15 Oct 2024 11:40:28 -0400 Subject: [PATCH 05/14] more cleanup --- vllm/engine/llm_engine.py | 7 ------- vllm/model_executor/layers/fused_moe/layer.py | 2 +- vllm/model_executor/layers/linear.py | 4 +++- vllm/model_executor/model_loader/loader.py | 6 ------ vllm/model_executor/model_loader/weight_utils.py | 9 +-------- 5 files changed, 5 insertions(+), 23 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fdc8ca6405375..6372d4b5d2117 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -314,8 +314,6 @@ def __init__( self.detokenizer = None tokenizer_group = None - print(self.load_config) - # Ensure that the function doesn't contain a reference to self, # to avoid engine GC issues def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: @@ -334,11 +332,6 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.input_processor = input_registry.create_input_processor( model_config) - # from pprint import pprint - # pprint(vars(model_config)) - # pprint(vars(executor_class)) - print("QUANT:", model_config.quantization) - self.model_executor = executor_class( model_config=model_config, cache_config=cache_config, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a994c11ddba83..bce740d0db750 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -46,7 +46,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter(torch.empty(num_experts, - intermediate_size, + 2 * intermediate_size, hidden_size, dtype=params_dtype), requires_grad=False) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 602d578b78ff4..db323294f4441 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -131,6 +131,7 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return F.linear(x, layer.weight, bias) @@ -515,6 +516,7 @@ def weight_loader(self, # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) if use_bitsandbytes_4bit: @@ -830,6 +832,7 @@ def weight_loader(self, if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -1015,7 +1018,6 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data - # bitsandbytes loads the weights of the specific portion # no need to narrow here if input_dim is not None and not use_bitsandbytes_4bit: diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 5da3dd6696d92..8d4163ec88490 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -226,7 +226,6 @@ class Source: def __init__(self, load_config: LoadConfig): super().__init__(load_config) - print("========= INIT DEFAULT MODEL LOADER =========") if load_config.model_loader_extra_config: raise ValueError(f"Model loader extra config is not supported for " f"load format {load_config.load_format}") @@ -272,22 +271,17 @@ def _prepare_weights(self, model_name_or_path: str, # Some quantized models use .pt files for storing the weights. if load_format == LoadFormat.AUTO: allow_patterns = ["*.safetensors", "*.bin"] - # print("AUTO") elif load_format == LoadFormat.SAFETENSORS: use_safetensors = True allow_patterns = ["*.safetensors"] - # print("SAFETENSORS") elif load_format == LoadFormat.MISTRAL: use_safetensors = True allow_patterns = ["consolidated*.safetensors"] index_file = "consolidated.safetensors.index.json" - # print("MISTRAL") elif load_format == LoadFormat.PT: allow_patterns = ["*.pt"] - # print("PT") elif load_format == LoadFormat.NPCACHE: allow_patterns = ["*.bin"] - # print("NPCACHE") else: raise ValueError(f"Unknown load_format: {load_format}") diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 660bee780dc3c..1fad224fd4b72 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -123,19 +123,12 @@ def get_quant_config(model_config: ModelConfig, quant_cls = get_quantization_config(model_config.quantization) - # print(vars(model_config)) - # print(vars(quant_cls)) - # GGUF doesn't have config file if model_config.quantization == "gguf": return quant_cls.from_config({}) if model_config.quantization == "hqq_marlin": - # print("=======================================") - # print(vars(model_config)) - # print(vars(load_config)) - # print("=======================================") - # TODO shouldn't be done like this + # TODO don't hardcode params return quant_cls.from_config({"bits": 4, "group_size": 64}) # Read the quantization config from the HF model config, if available. From 271ff95666badfac99f6de415f3aef4edfefd75a Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 22 Oct 2024 01:23:19 -0400 Subject: [PATCH 06/14] it works --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 382 ++++++++++++++---- .../gptq_marlin/marlin_dtypes.cuh | 2 + csrc/torch_bindings.cpp | 2 +- vllm/_custom_ops.py | 5 +- .../layers/quantization/hqq_marlin.py | 101 ++++- .../quantization/utils/marlin_utils_test.py | 6 + .../layers/quantization/utils/quant_utils.py | 18 + 7 files changed, 428 insertions(+), 88 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 5efe15d2b2f6b..6e7788006870b 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -54,9 +54,10 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_float_zp // is zero point of float16 type? > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -82,7 +83,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp) { + bool is_k_full, bool has_zp, bool is_float_zp) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -386,6 +387,17 @@ __device__ inline void sub_zp(typename ScalarType::FragB& frag_b, using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 zp = ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = (frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +template +__device__ inline void sub_zpf(typename ScalarType::FragB& frag_b, + typename ScalarType::FragZPF& frag_zpf, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = + ScalarType::num2num2(reinterpret_cast(&frag_zpf)[i]); frag_b[0] = __hsub2(frag_b[0], zp); frag_b[1] = __hsub2(frag_b[1], zp); } @@ -420,6 +432,15 @@ __device__ inline void scale_float(float* c, c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); } +// Given 2 floats subtract by 2 zero points (halves) +template +__device__ inline void sub_zpf_float( + float* c, typename ScalarType::FragZPF& zp) { + scalar_t* zp_ptr = reinterpret_cast(&zp); + c[0] = __fsub_rn(c[0], ScalarType::num2float(zp_ptr[0])); + c[1] = __fsub_rn(c[1], ScalarType::num2float(zp_ptr[1])); +} + // Wait until barrier reaches `count`, then lock for current threadblock. __device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { @@ -516,10 +537,11 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_float_zp // is zero point of float16 type? > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -556,6 +578,7 @@ __global__ void Marlin( using FragC = typename ScalarType::FragC; using FragS = typename ScalarType::FragS; using FragZP = typename ScalarType::FragZP; + using FragZPF = typename ScalarType::FragZPF; static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); @@ -692,8 +715,10 @@ __global__ void Marlin( int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; // Zero-points sizes/strides - int zp_gl_stride = (prob_n / pack_factor) / 4; - constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; + int zp_gl_stride = is_float_zp ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_float_zp + ? 16 * thread_n_blocks / 8 + : ((16 * thread_n_blocks) / pack_factor) / 4; constexpr int zp_tb_groups = s_tb_groups; constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; int zp_gl_rd_delta = zp_gl_stride; @@ -768,9 +793,18 @@ __global__ void Marlin( constexpr int num_ints_per_thread = 8 / pack_factor; int zp_sh_rd; if constexpr (has_zp) { - zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + if constexpr (is_float_zp) { + if constexpr (group_blocks != -1) + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } } // Precompute which thread should not read memory in which iterations; this is @@ -832,6 +866,7 @@ __global__ void Marlin( FragS act_frag_s[2][4][4]; // For act-order int frag_qzp[2][num_ints_per_thread]; // Zero-points FragZP frag_zp; // Zero-points in fp16 + FragZPF frag_zpf[2][4]; // float16 zero-points // Zero accumulators. auto zero_accums = [&]() { @@ -1126,7 +1161,7 @@ __global__ void Marlin( // has_zp implies AWQ, which doesn't have act_order, static_assert(!has_zp || group_blocks != 0); - if constexpr (has_zp) { + if constexpr (has_zp && !is_float_zp) { int pipe = full_pipe % stages; if constexpr (group_blocks == -1) { @@ -1170,11 +1205,40 @@ __global__ void Marlin( } } } + + if constexpr (has_zp && is_float_zp) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = + sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { - if constexpr (has_zp) { + if constexpr (has_zp && !is_float_zp) { FragB frag_zp_0; FragB frag_zp_1; int zp_quant_0, zp_quant_1; @@ -1219,10 +1283,14 @@ __global__ void Marlin( frag_b1 = dequant(b_quant_1); // Apply zero-point to frag_b0 - if constexpr (has_zp) { + if constexpr (has_zp && !is_float_zp) { sub_zp(frag_b0, frag_zp[j], 0); } + if constexpr (has_zp && is_float_zp && group_blocks != -1) { + sub_zpf(frag_b0, frag_zpf[k % 2][j], 0); + } + // Apply scale to frag_b0 if constexpr (has_act_order) { scale4(frag_b0, act_frag_s[k % 2][0][j], @@ -1235,10 +1303,14 @@ __global__ void Marlin( } // Apply zero-point to frag_b1 - if constexpr (has_zp) { + if constexpr (has_zp && !is_float_zp) { sub_zp(frag_b1, frag_zp[j], 1); } + if constexpr (has_zp && is_float_zp && group_blocks != -1) { + sub_zpf(frag_b1, frag_zpf[k % 2][j], 1); + } + // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], @@ -1451,10 +1523,16 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { + auto write = [&](int idx, float c0, float c1, FragS& s, FragZPF& zp) { scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + // apply float zp + if constexpr (has_zp && is_float_zp && group_blocks == -1 && + w_type.size_bits() == 4) { + res = __hsub2(res, zp[0]); + } + // For per-column quantization we finally apply the scale here (only for // 4-bit) if constexpr (!has_act_order && group_blocks == -1 && @@ -1472,13 +1550,17 @@ __global__ void Marlin( for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0], + frag_zpf[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0], + frag_zpf[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1], + frag_zpf[j / 2][2 * (j % 2) + 1]); write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1], + frag_zpf[j / 2][2 * (j % 2) + 1]); } c_sh_wr += 16 * (4 * c_sh_stride); } @@ -1510,7 +1592,7 @@ __global__ void Marlin( fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); } - if constexpr (has_zp && group_blocks == -1) { + if constexpr (has_zp && !is_float_zp && group_blocks == -1) { if (i == 0) { fetch_zp_to_shared(); } @@ -1601,6 +1683,22 @@ __global__ void Marlin( } } + if constexpr (has_zp && is_float_zp && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + cp_async_fence(); + } else { + if (last) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + cp_async_fence(); + } + } + } + thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { if constexpr (w_type.size_bits() == 8) { @@ -1623,6 +1721,29 @@ __global__ void Marlin( } } + if constexpr (has_zp && is_float_zp && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_zpf)[0] = sh_zp[zp_sh_rd + 0]; + reinterpret_cast(&frag_zpf)[1] = sh_zp[zp_sh_rd + 4]; + } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_zpf)[0] = sh_zp[zp_sh_rd + 0]; + reinterpret_cast(&frag_zpf)[1] = sh_zp[zp_sh_rd + 4]; + } + } + } + } + + // TODO also this code block below + // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) @@ -1651,6 +1772,30 @@ __global__ void Marlin( } } + if constexpr (has_zp && is_float_zp && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + sub_zpf_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_zpf[j / 2][2 * (j % 2) + 0]); + sub_zpf_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_zpf[j / 2][2 * (j % 2) + 0]); + sub_zpf_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_zpf[j / 2][2 * (j % 2) + 1]); + sub_zpf_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_zpf[j / 2][2 * (j % 2) + 1]); + } + } + } + } + if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); @@ -1697,20 +1842,22 @@ __global__ void Marlin( } #define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ + HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS, \ + IS_FLOAT_ZP) \ else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_float_zp == IS_FLOAT_ZP) { \ cudaFuncSetAttribute( \ Marlin, \ + HAS_ZP, GROUP_BLOCKS, IS_FLOAT_ZP>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ Marlin \ + HAS_ZP, GROUP_BLOCKS, IS_FLOAT_ZP> \ <<>>( \ A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \ @@ -1905,51 +2052,122 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, } #define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ + false) \ \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ + false) \ \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ + false) \ \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ + false) \ \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ + false) #define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ + false) \ + \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ + false) \ \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ + false) \ \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false) + + #define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ + true) \ \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ + true) \ + \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ + true) \ + \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, true) template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, @@ -1958,7 +2176,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, vllm::ScalarType const& q_type, bool has_act_order, bool is_k_full, bool has_zp, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k, int thread_n, - int sms, int max_par, bool use_fp32_reduce) { + int sms, int max_par, bool use_fp32_reduce, bool is_float_zp) { if (has_zp) { TORCH_CHECK( q_type == vllm::kU4 || q_type == vllm::kU8, @@ -2111,6 +2329,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, AWQ_CALL_IF(vllm::kU8, 8, 8, 256) AWQ_CALL_IF(vllm::kU8, 8, 4, 128) AWQ_CALL_IF(vllm::kU8, 4, 8, 128) + + HQQ_CALL_IF(vllm::kU4, 16, 4, 256) + HQQ_CALL_IF(vllm::kU4, 8, 8, 256) + HQQ_CALL_IF(vllm::kU4, 8, 4, 128) + HQQ_CALL_IF(vllm::kU4, 4, 8, 128) + HQQ_CALL_IF(vllm::kU8, 16, 4, 256) + HQQ_CALL_IF(vllm::kU8, 8, 8, 256) + HQQ_CALL_IF(vllm::kU8, 8, 4, 128) + HQQ_CALL_IF(vllm::kU8, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]", ", has_act_order = ", has_act_order, @@ -2135,7 +2362,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool has_zp, - bool use_fp32_reduce) { + bool use_fp32_reduce, bool is_float_zp) { if (has_zp) { TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", @@ -2256,12 +2483,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, if (has_zp) { int rank = b_zeros.sizes().size(); TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2"); - TORCH_CHECK(b_zeros.size(0) == num_groups, - "b_zeros dim 0 = ", b_zeros.size(0), - " is not num_groups = ", num_groups); - TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, - "b_zeros dim 1 = ", b_zeros.size(1), - " is not size_n / pack_factor = ", size_n / pack_factor); + if (is_float_zp) { + TORCH_CHECK(b_zeros.size(1) == size_n, + "b_zeros dim 1 = ", b_zeros.size(1), + " is not size_n = ", size_n); + TORCH_CHECK(num_groups == b_zeros.size(0), + "b_zeros dim 0 = ", b_zeros.size(0), + " is not num_groups = ", num_groups); + } else { + TORCH_CHECK(b_zeros.size(0) == num_groups, + "b_zeros dim 0 = ", b_zeros.size(0), + " is not num_groups = ", num_groups); + TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, + "b_zeros dim 1 = ", b_zeros.size(1), + " is not size_n / pack_factor = ", size_n / pack_factor); + } } // Verify workspace size @@ -2281,7 +2517,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, a_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); + thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_float_zp); } else if (a.scalar_type() == at::ScalarType::BFloat16) { marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), @@ -2290,7 +2526,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); + thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_float_zp); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh index be06c09bee331..5b0fe94e7971c 100644 --- a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh +++ b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh @@ -24,6 +24,7 @@ class ScalarType { using FragC = Vec; using FragS = Vec; using FragZP = Vec; + using FragZPF = Vec; static __device__ float inline num2float(const half x) { return __half2float(x); @@ -53,6 +54,7 @@ class ScalarType { using FragC = Vec; using FragS = Vec; using FragZP = Vec; + using FragZPF = Vec; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static __device__ float inline num2float(const nv_bfloat16 x) { diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index a0100b4a85edd..a986a277df9ee 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -203,7 +203,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, " "__torch__.torch.classes._core_C.ScalarType b_q_type, " "int size_m, int size_n, int size_k, bool is_k_full, " - "bool has_zp, bool use_fp32_reduce) -> Tensor"); + "bool has_zp, bool use_fp32_reduce, bool is_float_zp) -> Tensor"); // conditionally compiled so impl registration is in source file // gptq_marlin repack from GPTQ. diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 24e008dc38022..cc796c8542e54 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -595,11 +595,12 @@ def gptq_marlin_gemm(a: torch.Tensor, size_k: int, is_k_full: bool, has_zp: bool = False, - use_fp32_reduce: bool = False) -> torch.Tensor: + use_fp32_reduce: bool = False, + is_float_zp: bool = False) -> torch.Tensor: return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, g_idx, perm, workspace, b_q_type, size_m, size_n, size_k, is_k_full, - has_zp, use_fp32_reduce) + has_zp, use_fp32_reduce, is_float_zp) # fp8 marlin diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index df5d9ba15fa56..1edc5c006c04d 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -3,10 +3,17 @@ import torch import torch.nn.functional as F +from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, + marlin_make_empty_g_idx, marlin_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + MarlinWorkspace) +from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) from vllm.scalar_type import scalar_types @@ -138,11 +145,38 @@ def create_weights( layer.register_parameter("zeros", zeros) layer.register_parameter("scales", scales) - # self.kernel = '.' #TODO - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # TODO marlin format - return + dev = layer.qweight.device + qweight_t = layer.qweight.transpose(1, 0) + + gptq_w_q = gptq_pack(qweight_t, 4, self.input_size_per_partition, + self.output_size_per_partition) + + sort_indices = torch.empty(0, dtype=torch.int, device=gptq_w_q.device) + marlin_w_q = ops.gptq_marlin_repack( + gptq_w_q, + sort_indices, + self.input_size_per_partition, + self.output_size_per_partition, + 4, + ).to(dev) + marlin_s = marlin_permute_scales(layer.scales.transpose(1, 0), + self.input_size_per_partition, + self.output_size_per_partition, + 64).to(dev) + marlin_zp = marlin_permute_scales(layer.zeros.transpose(1, 0), + self.input_size_per_partition, + self.output_size_per_partition, + 64).to(dev) + # print(layer.zeros) + # print(marlin_zp) + + layer.g_idx = marlin_make_empty_g_idx(dev) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev) + + layer.marlin_qweight = marlin_w_q + layer.marlin_zeros = marlin_zp + layer.marlin_scales = marlin_s def apply( self, @@ -150,11 +184,54 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # TODO marlin kernel - unpacked = layer.qweight.reshape(-1, 64) - scales = layer.scales.reshape(-1, 1) - zeros = layer.zeros.reshape(-1, 1) - b = (unpacked - zeros) * scales - b = b.reshape(self.output_size_per_partition, - self.input_size_per_partition) - return F.linear(x, b, bias) + # unpacked = layer.qweight.reshape(-1, 64) + # scales = layer.scales.reshape(-1, 1) + # zeros = layer.zeros.reshape(-1, 1) + # b = (unpacked - zeros) * scales + # b = b.reshape(self.output_size_per_partition, + # self.input_size_per_partition) + + workspace = MarlinWorkspace(self.output_size_per_partition, + GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + # print(x) + # print(layer.marlin_qweight) + # print(layer.marlin_scales) + # print(layer.marlin_zeros) + # print(layer.g_idx) + # print(layer.g_idx_sort_indices) + # print(workspace.scratch) + + # print(x.shape, layer.marlin_qweight.shape, layer.marlin_scales.shape, + # layer.marlin_zeros.shape) + + marlin_out = ops.gptq_marlin_gemm( + x, + layer.marlin_qweight, + layer.marlin_scales, + layer.marlin_zeros, + layer.g_idx, + layer.g_idx_sort_indices, + workspace.scratch, + scalar_types.uint4, + x.shape[0], + self.output_size_per_partition, + self.input_size_per_partition, + True, # is_k_full + True, # has_zp + False, # use 32-bit reduce + True, # use float zp + ) + + # deq_out = F.linear(x, b, bias) + + # print("gptq:", gptq_out) + # print("marlin:", marlin_out.shape, marlin_out) + # print("deq:", deq_out.shape, deq_out) + # print("***") + + if bias is not None: + marlin_out.add_(bias) + + return marlin_out diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index 4a06c5d63d52d..b5baf407e114a 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -43,22 +43,28 @@ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # print("before permute:", q_w) # Permute q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + # print("after permute:", q_w) # Pack pack_factor = get_pack_factor(num_bits) orig_device = q_w.device q_w = q_w.cpu().numpy().astype(np.uint32) + # print("astype:", q_w) q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) for i in range(pack_factor): q_packed |= q_w[:, i::pack_factor] << num_bits * i + # print("packed:", q_packed) q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + # print("packed2:", q_packed) + return q_packed diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 833d00073564e..e9ad04006fbcb 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -132,6 +132,8 @@ def quantize_weights(w: torch.Tensor, assert quant_type.is_integer(), \ "Floating point quantization may work but has not been tested" + torch.set_printoptions(sci_mode="False") + orig_device = w.device orig_type = w.dtype size_k, size_n = w.shape @@ -142,12 +144,18 @@ def quantize_weights(w: torch.Tensor, group_size = size_k assert group_size <= size_k + # print("orig w:", w.shape, w) + # Reshape to [groupsize, -1] if group_size < size_k: w = w.reshape((-1, group_size, size_n)) + # print("reshape1 w:", w.shape, w) w = w.permute(1, 0, 2) + # print("permute w:", w.shape, w) w = w.reshape((group_size, -1)) + # print("reshape2 w:", w.shape, w) + # Compute scale for each group max_val = torch.max(w, 0, keepdim=True).values min_val = torch.min(w, 0, keepdim=True).values @@ -168,10 +176,17 @@ def quantize_weights(w: torch.Tensor, abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) maybe_w_zp = None + # print("w_s:", w_s.shape, w_s) + # Quantize w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) w_q = torch.clamp(w_q, min_q_val, max_q_val) + # print("w_q:", w_q.shape, w_q) + + # TODO hqq has shape of w_s and w_q as in up to now in this code (also zp) + # but we need to transpose them first + # Compute ref (dequantized) # For some kernels (namely Machete) the zero-points are applied after the # scales are applied, for this case computing the reference in similar way @@ -198,6 +213,9 @@ def reshape_w(w): w_s = w_s.reshape((-1, size_n)).contiguous() + # print("final w_s:", w_s.shape, w_s) + # print("final w_q:", w_q.shape, w_q) + if zero_points: maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() maybe_w_zp = maybe_w_zp.to(device=orig_device) From de8c5f0d32f17f83fee69b2515daed18b75094b6 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 22 Oct 2024 03:07:12 -0400 Subject: [PATCH 07/14] cleanup --- benchmarks/kernels/benchmark_marlin.py | 4 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 48 ++++-------------- tests/kernels/test_marlin_gemm.py | 4 +- vllm/_custom_ops.py | 3 +- vllm/model_executor/layers/linear.py | 14 ++---- .../layers/quantization/hqq_marlin.py | 49 ++++--------------- .../layers/quantization/utils/marlin_utils.py | 6 ++- .../quantization/utils/marlin_utils_test.py | 6 --- .../layers/quantization/utils/quant_utils.py | 18 ------- vllm/model_executor/models/llama.py | 4 +- 10 files changed, 36 insertions(+), 120 deletions(-) diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 536c133bb3341..8fb44e3a3dbd8 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -131,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501 + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -141,7 +141,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501 + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 6e7788006870b..881e32165caa8 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -387,7 +387,7 @@ __device__ inline void sub_zp(typename ScalarType::FragB& frag_b, using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 zp = ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); - frag_b[0] = (frag_b[0], zp); + frag_b[0] = __hsub2(frag_b[0], zp); frag_b[1] = __hsub2(frag_b[1], zp); } @@ -1742,8 +1742,6 @@ __global__ void Marlin( } } - // TODO also this code block below - // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) @@ -2133,41 +2131,15 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, false) \ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false) - #define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ - true) \ - \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ - true) \ - \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ - true) \ - \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, true) + // We currently have 4-bit models only with group_blocks == 4 + #define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, true) template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index a9bb72156c39e..5faf3b6fdafa0 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -226,7 +226,7 @@ def test_gptq_marlin_gemm( torch.ops._C.gptq_marlin_gemm, (a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1], - a_input.shape[1], is_k_full, False, use_fp32_reduce), + a_input.shape[1], is_k_full, False, use_fp32_reduce, False), test_utils=DEFAULT_OPCHECK_TEST_UTILS) output = ops.gptq_marlin_gemm( @@ -244,6 +244,7 @@ def test_gptq_marlin_gemm( is_k_full=is_k_full, has_zp=False, use_fp32_reduce=use_fp32_reduce, + is_float_zp=False, ) output_ref = torch.matmul(a_input, w_ref) @@ -431,6 +432,7 @@ def test_awq_marlin_gemm( is_k_full=is_k_full, has_zp=has_zp, use_fp32_reduce=use_fp32_reduce, + is_float_zp=False, ) output_ref = torch.matmul(a_input, w_ref) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index cc796c8542e54..eaef5a605320d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -323,7 +323,8 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, size_k: int, is_k_full: bool, has_zp: bool = False, - use_fp32_reduce: bool = False) -> torch.Tensor: + use_fp32_reduce: bool = False, + is_float_zp: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @torch.library.register_fake("_C::ggml_dequantize") diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index db323294f4441..74d25e1b32b01 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -508,11 +508,8 @@ def weight_loader(self, # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - pack_factor = getattr(param, "packed_factor", None) - if pack_factor is None: - pack_factor = param.pack_factor - shard_size = shard_size // pack_factor - shard_offset = shard_offset // pack_factor + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -862,11 +859,8 @@ def weight_loader(self, # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - pack_factor = getattr(param, "packed_factor", None) - if pack_factor is None: - pack_factor = param.pack_factor - shard_size = shard_size // pack_factor - shard_offset = shard_offset // pack_factor + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 1edc5c006c04d..6e5c09896527d 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Optional import torch -import torch.nn.functional as F from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -15,7 +14,7 @@ MarlinWorkspace) from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack from vllm.model_executor.parameter import (GroupQuantScaleParameter, - PackedvLLMParameter) + ModelWeightParameter) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -111,17 +110,14 @@ def create_weights( self.quant_config.group_size) # Quantized weights - qweight = PackedvLLMParameter( - data=torch.empty( - self.output_size_per_partition, - input_size_per_partition, - dtype=torch.uint8, - ), - input_dim=1, - output_dim=0, - packed_dim=0, - packed_factor=1, #self.quant_config.pack_factor, - weight_loader=weight_loader) + qweight = ModelWeightParameter(data=torch.empty( + self.output_size_per_partition, + input_size_per_partition, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) zeros = GroupQuantScaleParameter(data=torch.empty( self.output_size_per_partition, @@ -168,8 +164,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.input_size_per_partition, self.output_size_per_partition, 64).to(dev) - # print(layer.zeros) - # print(marlin_zp) layer.g_idx = marlin_make_empty_g_idx(dev) layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev) @@ -184,28 +178,10 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # unpacked = layer.qweight.reshape(-1, 64) - # scales = layer.scales.reshape(-1, 1) - # zeros = layer.zeros.reshape(-1, 1) - # b = (unpacked - zeros) * scales - # b = b.reshape(self.output_size_per_partition, - # self.input_size_per_partition) - workspace = MarlinWorkspace(self.output_size_per_partition, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL) - # print(x) - # print(layer.marlin_qweight) - # print(layer.marlin_scales) - # print(layer.marlin_zeros) - # print(layer.g_idx) - # print(layer.g_idx_sort_indices) - # print(workspace.scratch) - - # print(x.shape, layer.marlin_qweight.shape, layer.marlin_scales.shape, - # layer.marlin_zeros.shape) - marlin_out = ops.gptq_marlin_gemm( x, layer.marlin_qweight, @@ -224,13 +200,6 @@ def apply( True, # use float zp ) - # deq_out = F.linear(x, b, bias) - - # print("gptq:", gptq_out) - # print("marlin:", marlin_out.shape, marlin_out) - # print("deq:", deq_out.shape, deq_out) - # print("***") - if bias is not None: marlin_out.add_(bias) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 9a1defa409714..cb81bd549c6b6 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -303,7 +303,8 @@ def apply_gptq_marlin_linear( size_k=input_size_per_partition, is_k_full=is_k_full, has_zp=False, - use_fp32_reduce=use_fp32_reduce) + use_fp32_reduce=use_fp32_reduce, + is_float_zp=False) if bias is not None: output.add_(bias) # In-place add @@ -340,7 +341,8 @@ def apply_awq_marlin_linear( size_k=input_size_per_partition, is_k_full=True, has_zp=True, - use_fp32_reduce=use_fp32_reduce) + use_fp32_reduce=use_fp32_reduce, + is_float_zp=False) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index b5baf407e114a..4a06c5d63d52d 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -43,28 +43,22 @@ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): def marlin_weights(q_w, size_k, size_n, num_bits, perm): - # print("before permute:", q_w) # Permute q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - # print("after permute:", q_w) # Pack pack_factor = get_pack_factor(num_bits) orig_device = q_w.device q_w = q_w.cpu().numpy().astype(np.uint32) - # print("astype:", q_w) q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) for i in range(pack_factor): q_packed |= q_w[:, i::pack_factor] << num_bits * i - # print("packed:", q_packed) q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) - # print("packed2:", q_packed) - return q_packed diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index e9ad04006fbcb..833d00073564e 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -132,8 +132,6 @@ def quantize_weights(w: torch.Tensor, assert quant_type.is_integer(), \ "Floating point quantization may work but has not been tested" - torch.set_printoptions(sci_mode="False") - orig_device = w.device orig_type = w.dtype size_k, size_n = w.shape @@ -144,18 +142,12 @@ def quantize_weights(w: torch.Tensor, group_size = size_k assert group_size <= size_k - # print("orig w:", w.shape, w) - # Reshape to [groupsize, -1] if group_size < size_k: w = w.reshape((-1, group_size, size_n)) - # print("reshape1 w:", w.shape, w) w = w.permute(1, 0, 2) - # print("permute w:", w.shape, w) w = w.reshape((group_size, -1)) - # print("reshape2 w:", w.shape, w) - # Compute scale for each group max_val = torch.max(w, 0, keepdim=True).values min_val = torch.min(w, 0, keepdim=True).values @@ -176,17 +168,10 @@ def quantize_weights(w: torch.Tensor, abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) maybe_w_zp = None - # print("w_s:", w_s.shape, w_s) - # Quantize w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) w_q = torch.clamp(w_q, min_q_val, max_q_val) - # print("w_q:", w_q.shape, w_q) - - # TODO hqq has shape of w_s and w_q as in up to now in this code (also zp) - # but we need to transpose them first - # Compute ref (dequantized) # For some kernels (namely Machete) the zero-points are applied after the # scales are applied, for this case computing the reference in similar way @@ -213,9 +198,6 @@ def reshape_w(w): w_s = w_s.reshape((-1, size_n)).contiguous() - # print("final w_s:", w_s.shape, w_s) - # print("final w_q:", w_q.shape, w_q) - if zero_points: maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() maybe_w_zp = maybe_w_zp.to(device=orig_device) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index adddd2d16859f..b7b9220371213 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -407,8 +407,8 @@ def unpack_4bit_u8( loaded = loaded_weight[k].reshape( -1, to_shape[1] // group_size) else: - # TODO we should unpack inside the quantization - # method / kernel + # TODO should we unpack inside the quantization + # method / kernel? loaded = unpack_4bit_u8( loaded_weight[k], dtype=torch.bfloat16).reshape(to_shape).to( From 1521370c815115d6845df750bb26035775168288 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 22 Oct 2024 07:07:08 -0400 Subject: [PATCH 08/14] further cleanup --- benchmarks/kernels/benchmark_machete.py | 3 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 131 ++++-------------- csrc/torch_bindings.cpp | 2 +- tests/kernels/test_marlin_gemm.py | 4 +- vllm/_custom_ops.py | 6 +- .../layers/quantization/utils/marlin_utils.py | 4 +- 6 files changed, 39 insertions(+), 111 deletions(-) diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index b70c4b94c97a1..1388efce88956 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -156,7 +156,8 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: size_m=a.shape[0], size_n=w_ref.shape[1], size_k=w_ref.shape[0], - is_k_full=True)))) + is_k_full=True, + is_zp_float=False)))) # machete timers.append( diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 881e32165caa8..2346fb91b8ca1 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -57,7 +57,7 @@ template __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -83,7 +83,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp, bool is_float_zp) { + bool is_k_full, bool has_zp, bool is_zp_float) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -541,7 +541,7 @@ template __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -715,8 +715,8 @@ __global__ void Marlin( int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; // Zero-points sizes/strides - int zp_gl_stride = is_float_zp ? prob_n / 8 : (prob_n / pack_factor) / 4; - constexpr int zp_sh_stride = is_float_zp + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float ? 16 * thread_n_blocks / 8 : ((16 * thread_n_blocks) / pack_factor) / 4; constexpr int zp_tb_groups = s_tb_groups; @@ -793,13 +793,10 @@ __global__ void Marlin( constexpr int num_ints_per_thread = 8 / pack_factor; int zp_sh_rd; if constexpr (has_zp) { - if constexpr (is_float_zp) { + if constexpr (is_zp_float) { if constexpr (group_blocks != -1) zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - else - zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; } else { zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + @@ -1161,7 +1158,7 @@ __global__ void Marlin( // has_zp implies AWQ, which doesn't have act_order, static_assert(!has_zp || group_blocks != 0); - if constexpr (has_zp && !is_float_zp) { + if constexpr (has_zp && !is_zp_float) { int pipe = full_pipe % stages; if constexpr (group_blocks == -1) { @@ -1206,7 +1203,7 @@ __global__ void Marlin( } } - if constexpr (has_zp && is_float_zp) { + if constexpr (has_zp && is_zp_float) { int pipe = full_pipe % stages; if constexpr (group_blocks != -1) { @@ -1238,7 +1235,7 @@ __global__ void Marlin( // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { - if constexpr (has_zp && !is_float_zp) { + if constexpr (has_zp && !is_zp_float) { FragB frag_zp_0; FragB frag_zp_1; int zp_quant_0, zp_quant_1; @@ -1283,11 +1280,11 @@ __global__ void Marlin( frag_b1 = dequant(b_quant_1); // Apply zero-point to frag_b0 - if constexpr (has_zp && !is_float_zp) { + if constexpr (has_zp && !is_zp_float) { sub_zp(frag_b0, frag_zp[j], 0); } - if constexpr (has_zp && is_float_zp && group_blocks != -1) { + if constexpr (has_zp && is_zp_float && group_blocks != -1) { sub_zpf(frag_b0, frag_zpf[k % 2][j], 0); } @@ -1303,11 +1300,11 @@ __global__ void Marlin( } // Apply zero-point to frag_b1 - if constexpr (has_zp && !is_float_zp) { + if constexpr (has_zp && !is_zp_float) { sub_zp(frag_b1, frag_zp[j], 1); } - if constexpr (has_zp && is_float_zp && group_blocks != -1) { + if constexpr (has_zp && is_zp_float && group_blocks != -1) { sub_zpf(frag_b1, frag_zpf[k % 2][j], 1); } @@ -1523,16 +1520,10 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s, FragZPF& zp) { + auto write = [&](int idx, float c0, float c1, FragS& s) { scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); - // apply float zp - if constexpr (has_zp && is_float_zp && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hsub2(res, zp[0]); - } - // For per-column quantization we finally apply the scale here (only for // 4-bit) if constexpr (!has_act_order && group_blocks == -1 && @@ -1550,17 +1541,13 @@ __global__ void Marlin( for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0], - frag_zpf[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0], - frag_zpf[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1], - frag_zpf[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1], - frag_zpf[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); } c_sh_wr += 16 * (4 * c_sh_stride); } @@ -1592,7 +1579,7 @@ __global__ void Marlin( fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); } - if constexpr (has_zp && !is_float_zp && group_blocks == -1) { + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { if (i == 0) { fetch_zp_to_shared(); } @@ -1683,22 +1670,6 @@ __global__ void Marlin( } } - if constexpr (has_zp && is_float_zp && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - cp_async_fence(); - } else { - if (last) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - cp_async_fence(); - } - } - } - thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { if constexpr (w_type.size_bits() == 8) { @@ -1721,27 +1692,6 @@ __global__ void Marlin( } } - if constexpr (has_zp && is_float_zp && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_zpf)[0] = sh_zp[zp_sh_rd + 0]; - reinterpret_cast(&frag_zpf)[1] = sh_zp[zp_sh_rd + 4]; - } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_zpf)[0] = sh_zp[zp_sh_rd + 0]; - reinterpret_cast(&frag_zpf)[1] = sh_zp[zp_sh_rd + 4]; - } - } - } - } - // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) @@ -1770,30 +1720,6 @@ __global__ void Marlin( } } - if constexpr (has_zp && is_float_zp && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - sub_zpf_float( - reinterpret_cast(&frag_c[i][j][0][0]), - frag_zpf[j / 2][2 * (j % 2) + 0]); - sub_zpf_float( - reinterpret_cast(&frag_c[i][j][0][2]), - frag_zpf[j / 2][2 * (j % 2) + 0]); - sub_zpf_float( - reinterpret_cast(&frag_c[i][j][1][0]), - frag_zpf[j / 2][2 * (j % 2) + 1]); - sub_zpf_float( - reinterpret_cast(&frag_c[i][j][1][2]), - frag_zpf[j / 2][2 * (j % 2) + 1]); - } - } - } - } - if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); @@ -1841,21 +1767,21 @@ __global__ void Marlin( #define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS, \ - IS_FLOAT_ZP) \ + IS_ZP_FLOAT) \ else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_float_zp == IS_FLOAT_ZP) { \ + is_zp_float == IS_ZP_FLOAT) { \ cudaFuncSetAttribute( \ Marlin, \ + HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ Marlin \ + HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \ <<>>( \ A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \ @@ -2148,7 +2074,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, vllm::ScalarType const& q_type, bool has_act_order, bool is_k_full, bool has_zp, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k, int thread_n, - int sms, int max_par, bool use_fp32_reduce, bool is_float_zp) { + int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) { if (has_zp) { TORCH_CHECK( q_type == vllm::kU4 || q_type == vllm::kU8, @@ -2334,7 +2260,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool has_zp, - bool use_fp32_reduce, bool is_float_zp) { + bool use_fp32_reduce, bool is_zp_float) { if (has_zp) { TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", @@ -2455,13 +2381,14 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, if (has_zp) { int rank = b_zeros.sizes().size(); TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2"); - if (is_float_zp) { + if (is_zp_float) { TORCH_CHECK(b_zeros.size(1) == size_n, "b_zeros dim 1 = ", b_zeros.size(1), " is not size_n = ", size_n); TORCH_CHECK(num_groups == b_zeros.size(0), "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + TORCH_CHECK(num_groups != -1, "num_groups must be != -1"); } else { TORCH_CHECK(b_zeros.size(0) == num_groups, "b_zeros dim 0 = ", b_zeros.size(0), @@ -2489,7 +2416,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, a_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_float_zp); + thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), @@ -2498,7 +2425,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_float_zp); + thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index a986a277df9ee..709244b14a9bb 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -203,7 +203,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, " "__torch__.torch.classes._core_C.ScalarType b_q_type, " "int size_m, int size_n, int size_k, bool is_k_full, " - "bool has_zp, bool use_fp32_reduce, bool is_float_zp) -> Tensor"); + "bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor"); // conditionally compiled so impl registration is in source file // gptq_marlin repack from GPTQ. diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index 5faf3b6fdafa0..afcb24feffc8b 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -244,7 +244,7 @@ def test_gptq_marlin_gemm( is_k_full=is_k_full, has_zp=False, use_fp32_reduce=use_fp32_reduce, - is_float_zp=False, + is_zp_float=False, ) output_ref = torch.matmul(a_input, w_ref) @@ -432,7 +432,7 @@ def test_awq_marlin_gemm( is_k_full=is_k_full, has_zp=has_zp, use_fp32_reduce=use_fp32_reduce, - is_float_zp=False, + is_zp_float=False, ) output_ref = torch.matmul(a_input, w_ref) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index eaef5a605320d..854e758bf5fbf 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -324,7 +324,7 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, is_k_full: bool, has_zp: bool = False, use_fp32_reduce: bool = False, - is_float_zp: bool = False) -> torch.Tensor: + is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @torch.library.register_fake("_C::ggml_dequantize") @@ -597,11 +597,11 @@ def gptq_marlin_gemm(a: torch.Tensor, is_k_full: bool, has_zp: bool = False, use_fp32_reduce: bool = False, - is_float_zp: bool = False) -> torch.Tensor: + is_zp_float: bool = False) -> torch.Tensor: return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, g_idx, perm, workspace, b_q_type, size_m, size_n, size_k, is_k_full, - has_zp, use_fp32_reduce, is_float_zp) + has_zp, use_fp32_reduce, is_zp_float) # fp8 marlin diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index cb81bd549c6b6..c9366ca97d149 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -304,7 +304,7 @@ def apply_gptq_marlin_linear( is_k_full=is_k_full, has_zp=False, use_fp32_reduce=use_fp32_reduce, - is_float_zp=False) + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add @@ -342,7 +342,7 @@ def apply_awq_marlin_linear( is_k_full=True, has_zp=True, use_fp32_reduce=use_fp32_reduce, - is_float_zp=False) + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add From 9972c88e29fa1095dff750f87803c0d643b295f2 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 24 Oct 2024 10:21:24 -0400 Subject: [PATCH 09/14] Adapt to model format prepared with transformers --- vllm/model_executor/layers/linear.py | 4 +- .../layers/quantization/__init__.py | 2 +- .../layers/quantization/hqq_marlin.py | 38 ++++---- .../model_loader/weight_utils.py | 10 ++- vllm/model_executor/models/llama.py | 89 +++++++++---------- vllm/model_executor/models/utils.py | 63 +++---------- 6 files changed, 82 insertions(+), 124 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 4821d256cdee8..94f30412e43b3 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -380,7 +380,7 @@ def forward(self, input_): def extra_repr(self) -> str: s = f"in_features={self.input_size}" s += f", output_features={self.output_size_per_partition}" - s += f", bias={hasattr(self, 'bias') and self.bias is not None}" + s += f", bias={self.bias is not None}" s += f", tp_size={get_tensor_model_parallel_world_size()}" s += f", gather_output={self.gather_output}" return s @@ -1092,7 +1092,7 @@ def forward(self, input_): def extra_repr(self) -> str: s = f"input_features={self.input_size_per_partition}" s += f", output_features={self.output_size}" - s += f", bias={hasattr(self, 'bias') and self.bias is not None}" + s += f", bias={self.bias is not None}" s += f", tp_size={self.tp_size}" s += f", reduce_results={self.reduce_results}" return s diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index fd9492e8059f9..ff342c4f9479e 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -49,7 +49,7 @@ "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, "qqq": QQQConfig, - "hqq_marlin": HQQMarlinConfig, + "hqq": HQQMarlinConfig, "experts_int8": ExpertsInt8Config, "neuron_quant": NeuronQuantConfig, "ipex": IPEXConfig, diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 6e5c09896527d..8ab99aedb123e 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -44,7 +44,7 @@ def __repr__(self) -> str: @classmethod def get_name(cls) -> str: - return "hqq_marlin" + return "hqq" @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: @@ -60,7 +60,7 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig": - weight_bits = cls.get_from_keys(config, ["bits"]) + weight_bits = cls.get_from_keys(config, ["nbits"]) group_size = cls.get_from_keys(config, ["group_size"]) return cls(weight_bits, group_size) @@ -106,8 +106,8 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") - scales_and_zp_size = (input_size_per_partition // - self.quant_config.group_size) + self.scales_and_zp_size = (input_size_per_partition // + self.quant_config.group_size) # Quantized weights qweight = ModelWeightParameter(data=torch.empty( @@ -121,7 +121,7 @@ def create_weights( zeros = GroupQuantScaleParameter(data=torch.empty( self.output_size_per_partition, - scales_and_zp_size, + self.scales_and_zp_size, dtype=params_dtype, ), input_dim=1, @@ -130,20 +130,20 @@ def create_weights( scales = GroupQuantScaleParameter(data=torch.empty( self.output_size_per_partition, - scales_and_zp_size, + self.scales_and_zp_size, dtype=params_dtype, ), input_dim=1, output_dim=0, weight_loader=weight_loader) - layer.register_parameter("qweight", qweight) - layer.register_parameter("zeros", zeros) - layer.register_parameter("scales", scales) + layer.register_parameter("W_q", qweight) + layer.register_parameter("zero", zeros) + layer.register_parameter("scale", scales) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - dev = layer.qweight.device - qweight_t = layer.qweight.transpose(1, 0) + dev = layer.W_q.device + qweight_t = layer.W_q.transpose(1, 0) gptq_w_q = gptq_pack(qweight_t, 4, self.input_size_per_partition, self.output_size_per_partition) @@ -156,14 +156,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.output_size_per_partition, 4, ).to(dev) - marlin_s = marlin_permute_scales(layer.scales.transpose(1, 0), - self.input_size_per_partition, - self.output_size_per_partition, - 64).to(dev) - marlin_zp = marlin_permute_scales(layer.zeros.transpose(1, 0), - self.input_size_per_partition, - self.output_size_per_partition, - 64).to(dev) + marlin_s = marlin_permute_scales( + layer.scale.reshape(-1, self.scales_and_zp_size).transpose(1, 0), + self.input_size_per_partition, self.output_size_per_partition, + self.quant_config.group_size).to(dev) + marlin_zp = marlin_permute_scales( + layer.zero.reshape(-1, self.scales_and_zp_size).transpose(1, 0), + self.input_size_per_partition, self.output_size_per_partition, + self.quant_config.group_size).to(dev) layer.g_idx = marlin_make_empty_g_idx(dev) layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index ebd8ca976f4d3..7bc530a37df09 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -128,9 +128,13 @@ def get_quant_config(model_config: ModelConfig, if model_config.quantization == "gguf": return quant_cls.from_config({}) - if model_config.quantization == "hqq_marlin": - # TODO don't hardcode params - return quant_cls.from_config({"bits": 4, "group_size": 64}) + if model_config.quantization == "hqq": + wq_params = (model_config.hf_config.quantization_config["quant_config"] + ["weight_quant_params"]) + return quant_cls.from_config({ + "nbits": wq_params["nbits"], + "group_size": wq_params["group_size"] + }) # Read the quantization config from the HF model config, if available. hf_quant_config = getattr(model_config.hf_config, "quantization_config", diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 77dd1f4bb0618..431be5802f8ba 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -371,12 +371,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] - hqq_map = [ - (".qweight", "W_q", False), - (".zeros", "zero", True), - (".scales", "scale", True), - ] - # unpack function from https://github.com/mobiusml/hqq def unpack_4bit_u8( W_q: torch.Tensor, @@ -389,48 +383,14 @@ def unpack_4bit_u8( tmp[step:] = W_q & 0b00001111 return tmp + def rescale_hqq_wq(loaded_weight: torch.Tensor, param) -> torch.Tensor: + # TODO don't hardcode type + return unpack_4bit_u8(loaded_weight, dtype=torch.bfloat16).reshape( + (-1, param.shape[1])).to(torch.uint8) + params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if self.is_hqq: - pick_shard_id = None - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - pick_shard_id = shard_id - break - if name.endswith("_proj"): - to_shape = loaded_weight["shape"] - group_size = loaded_weight["group_size"] - for c, k, should_scale in hqq_map: - new_name = name + c - if new_name not in params_dict: - continue - param = params_dict[new_name] - weight_loader = param.weight_loader - if should_scale: - loaded = loaded_weight[k].reshape( - -1, to_shape[1] // group_size) - else: - # TODO should we unpack inside the quantization - # method / kernel? - loaded = unpack_4bit_u8( - loaded_weight[k], - dtype=torch.bfloat16).reshape(to_shape).to( - torch.uint8) - - if pick_shard_id is not None: - weight_loader(param, loaded, pick_shard_id) - else: - weight_loader(param, loaded) - else: - name = name + ".weight" - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight["weight"]) - continue + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -458,9 +418,27 @@ def unpack_4bit_u8( if is_pp_missing_parameter(name, self): continue + # TODO should input/output dim in hqq_marlin.py depend on this? + ignore_hqq = (".axis", ".channel_wise", ".compute_dtype", + ".encoded_state_dict", ".group_size", ".nbits", + ".offload_meta", ".optimize", ".packing", + ".quant_scale", ".quant_zero", ".round_zero", + ".shape", ".stores_quant_config", + ".unpack_view_dtype", ".view_as_float") + if name.endswith(ignore_hqq) and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + if self.is_hqq and name.endswith(".W_q"): + weight_loader(param, rescale_hqq_wq(loaded_weight, param), + shard_id) + elif self.is_hqq and name.endswith((".scale", ".zero")): + weight_loader(param, + loaded_weight.reshape(-1, param.shape[1]), + shard_id) + else: + weight_loader(param, loaded_weight, shard_id) break else: @@ -475,13 +453,26 @@ def unpack_4bit_u8( if is_pp_missing_parameter(name, self): continue - if name not in params_dict: + # TODO should input/output dim in hqq_marlin.py depend on this? + ignore_hqq = (".axis", ".channel_wise", ".compute_dtype", + ".encoded_state_dict", ".group_size", ".nbits", + ".offload_meta", ".optimize", ".packing", + ".quant_scale", ".quant_zero", ".round_zero", + ".shape", ".stores_quant_config", + ".unpack_view_dtype", ".view_as_float") + if name.endswith(ignore_hqq) and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + if self.is_hqq and name.endswith(".W_q"): + weight_loader(param, rescale_hqq_wq(loaded_weight, param)) + elif self.is_hqq and name.endswith((".scale", ".zero")): + weight_loader(param, + loaded_weight.reshape(-1, param.shape[1])) + else: + weight_loader(param, loaded_weight) # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 6bc4cb61cb63d..ec1d76d2117f3 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -103,8 +103,6 @@ def _groupby_prefix( for prefix, group in itertools.groupby(weights_by_parts, key=lambda x: x[0][0]): - # for parts, weights_data in group: - # print("part: ", parts, weights_data) yield ( prefix, # Because maxsplit=1 in weight_name.split(...), @@ -135,52 +133,24 @@ def _load_param( weights: Iterable[Tuple[str, torch.Tensor]], ) -> Iterable[str]: for weight_name, weight_data in weights: + weight_qualname = self._get_qualname(base_prefix, weight_name) - if torch.is_tensor(weight_data): - weight_qualname = self._get_qualname(base_prefix, weight_name) - - if self._can_skip(weight_qualname): - continue - - if weight_name != "": - if not self._can_ignore_unexpected(weight_qualname): - raise ValueError( - f"Attempted to load nested weight " - f"'{weight_qualname}' " - f"into a single parameter '{base_prefix}'") - continue - - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, weight_data) - yield weight_qualname - else: - # TODO remove this when we get a new hqq dataset format - for wn, wd in weight_data.items(): - - weight_qualname = self._get_qualname(base_prefix, wn) - - if self._can_skip(weight_qualname): - continue + if self._can_skip(weight_qualname): + continue - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, wd) + if weight_name != "": + if not self._can_ignore_unexpected(weight_qualname): + raise ValueError( + f"Attempted to load nested weight '{weight_qualname}' " + f"into a single parameter '{base_prefix}'") - yield weight_qualname + continue - def _load_one_param( - self, - base_prefix: str, - param: nn.Parameter, - weight_name: str, - weight_data: torch.Tensor, - ) -> Iterable[str]: - weight_qualname = self._get_qualname(base_prefix, weight_name) - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, weight_data) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight_data) - yield weight_qualname + yield weight_qualname def _load_module( self, @@ -208,13 +178,6 @@ def _load_module( if self._can_skip(prefix): continue - # TODO remove this when we get a new hqq dataset format - if child_prefix == "" and isinstance(child_params, dict): - for _, c_weight in child_params.items(): - yield from self._load_param(prefix, c_weight, - child_weights) - continue - if child_prefix in child_modules: yield from self._load_module(prefix, child_modules[child_prefix], From 7717b552410f7d2c7466b7f2dea2b94abfb27d10 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 24 Oct 2024 10:24:19 -0400 Subject: [PATCH 10/14] small cleanup --- vllm/model_executor/models/llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 431be5802f8ba..d4be3818544d5 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -389,9 +389,7 @@ def rescale_hqq_wq(loaded_weight: torch.Tensor, param) -> torch.Tensor: (-1, param.shape[1])).to(torch.uint8) params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name From b9106bd8644387e27bb2cb645c01cf1cb1e59f39 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 24 Oct 2024 11:24:51 -0400 Subject: [PATCH 11/14] reshape cleanup --- .../layers/quantization/hqq_marlin.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 8ab99aedb123e..ecc88a6e1ea87 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -156,14 +156,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.output_size_per_partition, 4, ).to(dev) - marlin_s = marlin_permute_scales( - layer.scale.reshape(-1, self.scales_and_zp_size).transpose(1, 0), - self.input_size_per_partition, self.output_size_per_partition, - self.quant_config.group_size).to(dev) - marlin_zp = marlin_permute_scales( - layer.zero.reshape(-1, self.scales_and_zp_size).transpose(1, 0), - self.input_size_per_partition, self.output_size_per_partition, - self.quant_config.group_size).to(dev) + marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0), + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.group_size).to(dev) + marlin_zp = marlin_permute_scales(layer.zero.transpose(1, 0), + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.group_size).to(dev) layer.g_idx = marlin_make_empty_g_idx(dev) layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev) From 5ef4b805909e582c62d12ec7e31890a84b46ab1e Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 25 Oct 2024 02:55:36 -0400 Subject: [PATCH 12/14] hqq unit tests --- tests/kernels/test_marlin_gemm.py | 82 +++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index 2ee5d8c7e4e94..7286438dfd76f 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -29,6 +29,7 @@ marlin_qqq_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) +from vllm.scalar_type import scalar_types ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] @@ -453,6 +454,87 @@ def test_awq_marlin_gemm( assert max_diff < 0.04 +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) +@pytest.mark.parametrize("group_size", [64]) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) +def test_hqq_marlin_gemm( + k_chunk, + n_chunk, + group_size, + mnk_factors, + use_fp32_reduce, +): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + quant_type = scalar_types.uint4 + + a_input = rand_data((size_m, size_k)) + dev = a_input.device + + b_weight = torch.randint(0, + 10, (size_n, size_k), + dtype=torch.uint8, + device=dev) + scale = rand_data((size_n, size_k // group_size)) + zero = rand_data((size_n, size_k // group_size)) + + gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n) + + sort_indices = torch.empty(0, dtype=torch.int, device=dev) + marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, + 4).to(dev) + marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n, + group_size).to(dev) + marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n, + group_size).to(dev) + + g_idx = marlin_make_empty_g_idx(dev) + g_idx_sort_indices = marlin_make_empty_g_idx(dev) + + workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + output = ops.gptq_marlin_gemm( + a_input, + marlin_w_q, + marlin_s, + marlin_zp, + g_idx, + g_idx_sort_indices, + workspace.scratch, + quant_type, + a_input.shape[0], + b_weight.shape[0], + a_input.shape[1], + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=True, + ) + + b_flat = b_weight.reshape(-1, group_size) + zp_flat = zero.reshape(-1, 1) + s_flat = scale.reshape(-1, 1) + dequant = (b_flat - zp_flat) * s_flat + + output_ref = torch.matmul(a_input, + dequant.reshape(b_weight.shape).transpose(1, 0)) + + torch.cuda.synchronize() + + max_diff = compute_max_diff(output, output_ref) + + assert max_diff < 0.04 + + @pytest.mark.skipif(not is_quant_method_supported("qqq"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) From 5340ce84afa13bae53a9843fa6cfe44c7f516dd1 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 25 Oct 2024 08:27:10 -0400 Subject: [PATCH 13/14] remove hardcoded type --- vllm/model_executor/models/llama.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d4be3818544d5..3b2fb6d8ad131 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -383,11 +383,6 @@ def unpack_4bit_u8( tmp[step:] = W_q & 0b00001111 return tmp - def rescale_hqq_wq(loaded_weight: torch.Tensor, param) -> torch.Tensor: - # TODO don't hardcode type - return unpack_4bit_u8(loaded_weight, dtype=torch.bfloat16).reshape( - (-1, param.shape[1])).to(torch.uint8) - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: @@ -429,8 +424,10 @@ def rescale_hqq_wq(loaded_weight: torch.Tensor, param) -> torch.Tensor: param = params_dict[name] weight_loader = param.weight_loader if self.is_hqq and name.endswith(".W_q"): - weight_loader(param, rescale_hqq_wq(loaded_weight, param), - shard_id) + weight_loader( + param, + unpack_4bit_u8(loaded_weight).reshape( + -1, param.shape[1]), shard_id) elif self.is_hqq and name.endswith((".scale", ".zero")): weight_loader(param, loaded_weight.reshape(-1, param.shape[1]), @@ -465,7 +462,10 @@ def rescale_hqq_wq(loaded_weight: torch.Tensor, param) -> torch.Tensor: weight_loader = getattr(param, "weight_loader", default_weight_loader) if self.is_hqq and name.endswith(".W_q"): - weight_loader(param, rescale_hqq_wq(loaded_weight, param)) + weight_loader( + param, + unpack_4bit_u8(loaded_weight).reshape( + -1, param.shape[1])) elif self.is_hqq and name.endswith((".scale", ".zero")): weight_loader(param, loaded_weight.reshape(-1, param.shape[1])) From 1da2d973322dd3bd7bd2a4683d59a94f8c04cdab Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 25 Oct 2024 09:55:05 -0400 Subject: [PATCH 14/14] force fp16 type in kernel to reduce wheel size --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 30 ++++++++++++------- .../layers/quantization/hqq_marlin.py | 18 +++++++++-- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 2ff437cfc94b9..7bd246d73b1e4 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1774,17 +1774,19 @@ __global__ void Marlin( has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ is_zp_float == IS_ZP_FLOAT) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ - num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \ + if constexpr (!IS_ZP_FLOAT || std::is_same::value) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin \ + <<>>( \ + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ + num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \ + } \ } typedef struct { @@ -2273,6 +2275,12 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, b_q_type.str()); } + if (has_zp && is_zp_float) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half, + "Computation type must be float16 (half) when using float zero " + "points."); + } + int pack_factor = 32 / b_q_type.size_bits(); // Verify A diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index ecc88a6e1ea87..35c4cb00fb298 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -182,11 +182,20 @@ def apply( GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL) + scales = layer.marlin_scales + zeros = layer.marlin_zeros + orig_type = x.dtype + + if orig_type != torch.float16: + x = x.to(torch.float16) + scales = scales.to(torch.float16) + zeros = zeros.to(torch.float16) + marlin_out = ops.gptq_marlin_gemm( x, layer.marlin_qweight, - layer.marlin_scales, - layer.marlin_zeros, + scales, + zeros, layer.g_idx, layer.g_idx_sort_indices, workspace.scratch, @@ -203,4 +212,7 @@ def apply( if bias is not None: marlin_out.add_(bias) - return marlin_out + if orig_type != torch.float16: + return marlin_out.to(orig_type) + else: + return marlin_out