From 4e65cdeb5c377222aacb18fb0ed40c1f67f6ab59 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 17 Dec 2024 05:38:24 +0000 Subject: [PATCH] Add: Support for Sparse24Bitmask Compressed Models --- .../compressed_tensors/compressed_tensors.py | 39 +++++-- .../schemes/compressed_tensors_24.py | 106 ++++++++++++++++-- vllm/model_executor/parameter.py | 26 ++++- 3 files changed, 152 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b2fc2360f47f1..96bfde9655961 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -396,10 +396,13 @@ def get_scheme( sparsity_scheme=sparsity_scheme): # Have a valid sparsity scheme # Validate layer is supported by Cutlass 2:4 Kernel - scheme = CompressedTensors24(quantized=weight_quant is not None - or input_quant is not None, - weight_quant=weight_quant, - input_quant=input_quant) + scheme = CompressedTensors24( + quantized=weight_quant is not None or input_quant is not None, + weight_quant=weight_quant, + input_quant=input_quant, + model_compression_config=self._get_model_compression_config( + sparsity_scheme), + ) else: # Find the quant_scheme scheme = self._get_scheme_from_parts( # type: ignore @@ -447,10 +450,17 @@ def supports_cutlass_24( :return: True if the layer is supported by the Cutlass 2:4 Kernel False otherwise """ - is_valid_sparsity = (sparsity_scheme is not None - and sparsity_scheme.sparsity_structure - == SparsityStructure.TWO_FOUR.value - and sparsity_scheme.format == "dense") + is_valid_sparsity_structure = (sparsity_scheme is not None + and sparsity_scheme.sparsity_structure + == SparsityStructure.TWO_FOUR.value) + valid_compressors = { + CompressionFormat.dense.value, + CompressionFormat.sparse_24_bitmask.value + } + + is_valid_sparsity = (is_valid_sparsity_structure + and sparsity_scheme.format in valid_compressors) + if not is_valid_sparsity: return False @@ -481,6 +491,19 @@ def supports_cutlass_24( return weight_quant.num_bits == input_quant.num_bits == 8 + def _get_model_compression_config( + self, sparsity_scheme: Optional[SparsityCompressionConfig] = None): + """ + Get the model compressor config from the sparsity scheme + + :param sparsity_scheme: The sparsity scheme + :return: The model compressor config + """ + if sparsity_scheme is None or sparsity_scheme.format == "dense": + return None + + return self.config + class CompressedTensorsLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index bc697ef93b34b..e9141d2a8f7f4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -1,16 +1,21 @@ -from typing import Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch +from compressed_tensors import CompressionFormat, ModelCompressor from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) +from compressed_tensors.utils import combine_shards from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, sparse_cutlass_supported) from vllm.model_executor.parameter import (BasevLLMParameter, + BitMaskShapeParameter, ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -20,14 +25,24 @@ class CompressedTensors24(CompressedTensorsScheme): - def __init__(self, - quantized: bool = False, - weight_quant: Optional[QuantizationArgs] = None, - input_quant: Optional[QuantizationArgs] = None): + def __init__( + self, + quantized: bool = False, + weight_quant: Optional[QuantizationArgs] = None, + input_quant: Optional[QuantizationArgs] = None, + model_compression_config: Optional[Dict[str, Any]] = None, + ): self.quantized = quantized self.weight_quant = weight_quant self.input_quant = input_quant + self.model_compressor = ( + ModelCompressor.from_compression_config(model_compression_config) + if model_compression_config is not None else None) + self.do_sparse_decompress = ( + self.model_compressor is not None + and self.model_compressor.sparsity_config.format + == CompressionFormat.sparse_24_bitmask.value) @classmethod def get_min_capability(cls) -> int: @@ -47,6 +62,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, self.output_dtype = params_dtype layer.logical_widths = output_partition_sizes + layer.input_size = input_size + layer.input_size_per_partition = input_size_per_partition self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) # parameter to store uncompressed weight @@ -57,6 +74,34 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, input_dim=1, output_dim=0, weight_loader=weight_loader) + if self.do_sparse_decompress: + assert all( + partition_size % 8 == 0 + for partition_size in output_partition_sizes + ), "All partitions must be divisible by 8 for 2:4 compressed models" + + shape = BitMaskShapeParameter(data=torch.empty( + 2 * len(output_partition_sizes), 1, dtype=torch.uint64), + weight_loader=weight_loader) + compressed = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=self.weights_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + bitmask = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 8, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("shape", shape) + layer.register_parameter("compressed", compressed) + layer.register_parameter("bitmask", bitmask) # Check if quantized, not just 2:4 Sparse if self.quantized: @@ -112,6 +157,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: :param layer: The layer with the weights to be processed """ + if self.do_sparse_decompress: + layer.weight.data = self._decompress_bitmask_compressed_weight( + compressed=layer.compressed, + bitmask=layer.bitmask, + layer=layer, + ) + # torch.compile workaround if hasattr(layer, "input_scale"): layer.input_scale = torch.nn.Parameter(layer.input_scale.data, @@ -201,8 +253,42 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: raise ValueError("Quantization type not supported by Cutlass") - -def check_24(tensor): - new_tensor = tensor.view(-1, 4) - zero_counts = (new_tensor == 0).sum(dim=1) - return (zero_counts >= 2).all().item() + def _decompress_bitmask_compressed_weight( + self, compressed: torch.Tensor, bitmask: torch.Tensor, + layer: torch.nn.Module) -> torch.Tensor: + + sparsity_compressor = self.model_compressor.sparsity_compressor + + def _process_split(bitmask_compressed_weight: torch.Tensor, shape, + bitmask: torch.Tensor) -> torch.Tensor: + weight_data = dict( + compressed=bitmask_compressed_weight, + shape=shape, + bitmask=bitmask, + ) + return sparsity_compressor.decompress_weight(weight_data) + + split_weights = None + split_bitmask = None + split_shape = None + + if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): + split_weights = torch.split(compressed, layer.logical_widths) + split_bitmask = torch.split(bitmask, layer.logical_widths) + split_shape = [(out, layer.input_size_per_partition) + for out in layer.logical_widths] + + if split_weights is not None: + decompressed_shards = [ + _process_split(compressed_weight, shape, bitmask) + for compressed_weight, shape, bitmask in zip( + split_weights, split_shape, split_bitmask) + ] + decompressed = combine_shards(decompressed_shards) + else: + decompressed = sparsity_compressor.decompress_weight( + dict(compressed=compressed, + shape=(layer.logical_widths[0], + layer.input_size_per_partition), + bitmask=bitmask)) + return decompressed diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index a9ce8af15d3bb..89d234d08545c 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -11,7 +11,8 @@ __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", "ModelWeightParameter", "ChannelQuantScaleParameter", - "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter" + "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", + "BitMaskShapeParameter" ] logger = init_logger(__name__) @@ -429,3 +430,26 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, shard_offset=shard_offset, marlin_tile_size=marlin_tile_size) return shard_size, shard_offset + + +class BitMaskShapeParameter(PerTensorScaleParameter): + """ + Parameter class for the shape of the bitmask tensor. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _load_into_shard_id(self, loaded_weight: torch.Tensor, + shard_id: Union[str, int], **kwargs): + """ + Slice the parameter data based on the shard id for + loading. + + Note: Assumes the loaded weight is a 1D tensor + with 2 elements. + """ + param_data = self.data + shard_id = self._shard_id_as_int(shard_id) + start_index = shard_id * 2 + param_data[start_index:start_index + 2].copy_(loaded_weight)