From d0475c4682af7f2f36b9ee2ba896764945c5203c Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 10 Dec 2024 22:55:28 +0000 Subject: [PATCH] WIP --- .../compressed_tensors/compressed_tensors.py | 1 + .../schemes/compressed_tensors_24.py | 91 ++++++++++++++++++- vllm/model_executor/parameter.py | 26 +++++- 3 files changed, 112 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 6c1c465bc9bad..5302e610d5bbc 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -402,6 +402,7 @@ def get_scheme( input_quant=input_quant, model_compression_config=self._get_model_compression_config( sparsity_scheme), + layer_name=layer_name, ) else: # Find the quant_scheme diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 806b899ecd8cc..a3461a9ac1b0d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -1,12 +1,14 @@ from typing import Any, Callable, Dict, List, Optional import torch -from compressed_tensors import ModelCompressor +from compressed_tensors import ModelCompressor, CompressionFormat from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) +from compressed_tensors.utils import combine_shards from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import MergedColumnParallelLinear, QKVParallelLinear from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -14,7 +16,9 @@ from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, ModelWeightParameter, - PerTensorScaleParameter) + PerTensorScaleParameter, + _ColumnvLLMParameter as ColumnvLLMParameter, + BitMaskShapeParameter,) __all__ = ["CompressedTensors24"] @@ -25,7 +29,9 @@ def __init__(self, quantized: bool = False, weight_quant: Optional[QuantizationArgs] = None, input_quant: Optional[QuantizationArgs] = None, - model_compression_config: Optional[Dict[str, Any]] = None): + model_compression_config: Optional[Dict[str, Any]] = None, + layer_name: Optional[str] = None # TODO: Remove + ): self.quantized = quantized self.weight_quant = weight_quant @@ -33,6 +39,8 @@ def __init__(self, self.model_compressor = ( ModelCompressor.from_compression_config(model_compression_config) if model_compression_config is not None else None) + self.layer_name = layer_name # TODO: Remove + self.do_sparse_decompress = self.model_compressor is not None @classmethod def get_min_capability(cls) -> int: @@ -43,9 +51,12 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - + print("Creating weights for:", self.layer_name) self.output_dtype = params_dtype + + # This is needed for tensor scale layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) # parameter to store uncompressed weight @@ -56,6 +67,38 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, input_dim=1, output_dim=0, weight_loader=weight_loader) + + if self.do_sparse_decompress: + sparsity_config = self.model_compressor.sparsity_config + if sparsity_config is not None and sparsity_config.format != CompressionFormat.sparse_bitmask.value: + raise ValueError("CompressedTensors24 only supports sparse_bitmask compression format") + + # register compression specific parameters + + shape = BitMaskShapeParameter(data=torch.empty(2 * len(output_partition_sizes), 1, dtype=torch.uint64), + weight_loader=weight_loader) + compressed = ColumnvLLMParameter(data=torch.empty( + sum(output_partition_size * input_size_per_partition for output_partition_size in output_partition_sizes) // 2, 1, + dtype=self.weights_dtype), + output_dim=0, + weight_loader=weight_loader) + bitmask = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 8, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + row_offsets = ColumnvLLMParameter(data=torch.empty( + sum(output_partition_sizes), 1, dtype=torch.uint64), + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("shape", shape) + layer.register_parameter("compressed", compressed) + layer.register_parameter("bitmask", bitmask) + layer.register_parameter("row_offsets", row_offsets) # Check if quantized, not just 2:4 Sparse if self.quantized: @@ -111,6 +154,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: :param layer: The layer with the weights to be processed """ + print("Processing weights for:", self.layer_name) + + if self.do_sparse_decompress: + layer.weight.data = self._decompress_bitmask_compressed_weight( + layer.compressed, layer.shape, layer.bitmask, layer.row_offsets, layer=layer) + # torch.compile workaround if hasattr(layer, "input_scale"): layer.input_scale = torch.nn.Parameter(layer.input_scale.data, @@ -199,6 +248,40 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: return torch.int8 raise ValueError("Quantization type not supported by Cutlass") + + def _decompress_bitmask_compressed_weight(self, compressed, shape, bitmask, row_offsets, layer): + print("Decompressing weights for:", self.layer_name) + split_weights = None + def _process_split(bitmask_compressed_weight, shape, bitmask, row_offsets): + weight_data = { + "compressed": bitmask_compressed_weight, + "shape": shape, + "bitmask": bitmask, + "row_offsets": row_offsets + } + decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data) + return decompress + if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): + split_weights = torch.split(compressed, [partition_width * layer.input_size_per_partition // 2 for partition_width in layer.logical_widths]) + split_bitmask = torch.split(bitmask, layer.logical_widths) + split_row_offsets = torch.split(row_offsets, layer.logical_widths) + split_shape = torch.split(shape, [2] * len(layer.logical_widths)) + if split_weights: + all_compress = [] + for i in range(len(split_weights)): + compress_i = _process_split(split_weights[i], split_shape[i], split_bitmask[i], split_row_offsets[i]) + all_compress.append(compress_i) + decompressed = combine_shards(all_compress) + else: + decompressed = self.model_compressor.sparsity_compressor.decompress_weight({ + "compressed": compressed, + "shape": shape, + "bitmask": bitmask, + "row_offsets": row_offsets + }) + return decompressed + + def check_24(tensor): diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 7a6d7c90f34d5..82be12446fb5c 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -10,7 +10,7 @@ __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", "ModelWeightParameter", "ChannelQuantScaleParameter", - "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter" + "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", "BitMaskShapeParameter" ] logger = init_logger(__name__) @@ -238,7 +238,7 @@ def _load_into_shard_id(self, loaded_weight: torch.Tensor, param_data = self.data shard_id = self._shard_id_as_int(shard_id) - + # AutoFP8 scales do not have a shape # compressed-tensors scales do have a shape if len(loaded_weight.shape) != 0: @@ -401,3 +401,25 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, shard_offset=shard_offset, marlin_tile_size=marlin_tile_size) return shard_size, shard_offset + + +class BitMaskShapeParameter(PerTensorScaleParameter): + """ + Parameter class for the shape of the bitmask tensor. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _load_into_shard_id(self, loaded_weight: torch.Tensor, + shard_id: Union[str, int], **kwargs): + """ + Slice the parameter data based on the shard id for + loading. + """ + + param_data = self.data + shard_id = self._shard_id_as_int(shard_id) + + start_index = shard_id * 2 + param_data[start_index: start_index+2].copy_(loaded_weight) \ No newline at end of file