diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index a3461a9ac1b0d..dc47833e6a6de 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -7,6 +7,7 @@ QuantizationType) from compressed_tensors.utils import combine_shards + from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import MergedColumnParallelLinear, QKVParallelLinear from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( @@ -18,7 +19,7 @@ ModelWeightParameter, PerTensorScaleParameter, _ColumnvLLMParameter as ColumnvLLMParameter, - BitMaskShapeParameter,) + BitMaskShapeParameter, BitMaskParameter) __all__ = ["CompressedTensors24"] @@ -51,14 +52,14 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - print("Creating weights for:", self.layer_name) self.output_dtype = params_dtype + # This is needed for tensor scale layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition + layer.input_size self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) - # parameter to store uncompressed weight weight = ModelWeightParameter(data=torch.empty( sum(output_partition_sizes), @@ -67,79 +68,46 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, input_dim=1, output_dim=0, weight_loader=weight_loader) + + shape = BitMaskShapeParameter(data=torch.empty(2 * len(output_partition_sizes), 1, dtype=torch.uint64), + weight_loader=weight_loader) + + compressed = BitMaskParameter(data=torch.empty( + sum(out * input_size_per_partition for out in output_partition_sizes) // 2, 1, + dtype=self.weights_dtype), + output_dim=0, + input_size_per_partition=input_size_per_partition, + weight_loader=weight_loader) + + bitmask = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) - if self.do_sparse_decompress: - sparsity_config = self.model_compressor.sparsity_config - if sparsity_config is not None and sparsity_config.format != CompressionFormat.sparse_bitmask.value: - raise ValueError("CompressedTensors24 only supports sparse_bitmask compression format") - - # register compression specific parameters - - shape = BitMaskShapeParameter(data=torch.empty(2 * len(output_partition_sizes), 1, dtype=torch.uint64), - weight_loader=weight_loader) - compressed = ColumnvLLMParameter(data=torch.empty( - sum(output_partition_size * input_size_per_partition for output_partition_size in output_partition_sizes) // 2, 1, - dtype=self.weights_dtype), - output_dim=0, - weight_loader=weight_loader) - bitmask = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // 8, - dtype=torch.uint8), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) - row_offsets = ColumnvLLMParameter(data=torch.empty( - sum(output_partition_sizes), 1, dtype=torch.uint64), - output_dim=0, - weight_loader=weight_loader, - ) + row_offsets = ColumnvLLMParameter(data=torch.empty( + sum(output_partition_sizes), 1, dtype=torch.uint64), + output_dim=0, + weight_loader=weight_loader, + ) - layer.register_parameter("shape", shape) - layer.register_parameter("compressed", compressed) - layer.register_parameter("bitmask", bitmask) - layer.register_parameter("row_offsets", row_offsets) + layer.register_parameter("shape", shape) + layer.register_parameter("compressed", compressed) + layer.register_parameter("bitmask", bitmask) + layer.register_parameter("row_offsets", row_offsets) # Check if quantized, not just 2:4 Sparse - if self.quantized: - if (self.weight_quant and self.weight_quant.strategy - == QuantizationStrategy.CHANNEL.value): - weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader) - else: - assert (self.weight_quant and self.weight_quant.strategy - == QuantizationStrategy.TENSOR.value) - weight_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), - dtype=torch.float32), - weight_loader=weight_loader) - - layer.register_parameter("weight_scale", weight_scale) - - # input quant will be non-none - if self.input_quant and not self.input_quant.dynamic: - # register input quant scale - assert (self.input_quant.strategy == - QuantizationStrategy.TENSOR.value) - input_scale = BasevLLMParameter(data=torch.empty( - 1, dtype=torch.float32), - weight_loader=weight_loader) - - layer.register_parameter("input_scale", input_scale) - - else: - # for sparse-only, pass in 1 for weight/input scales - weight_scale = torch.nn.Parameter(data=torch.ones( - 1, dtype=torch.float32), - requires_grad=False) - input_scale = torch.nn.Parameter(data=torch.ones( - 1, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("input_scale", input_scale) - layer.register_parameter("weight_scale", weight_scale) + # for sparse-only, pass in 1 for weight/input scales + weight_scale = torch.nn.Parameter(data=torch.ones( + 1, dtype=torch.float32), + requires_grad=False) + input_scale = torch.nn.Parameter(data=torch.ones( + 1, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("input_scale", input_scale) + layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight", weight) @@ -154,27 +122,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: :param layer: The layer with the weights to be processed """ - print("Processing weights for:", self.layer_name) - - if self.do_sparse_decompress: - layer.weight.data = self._decompress_bitmask_compressed_weight( - layer.compressed, layer.shape, layer.bitmask, layer.row_offsets, layer=layer) - - # torch.compile workaround - if hasattr(layer, "input_scale"): - layer.input_scale = torch.nn.Parameter(layer.input_scale.data, - requires_grad=False) - - if self.weight_quant: - if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value: - layer.weight_scale = torch.nn.Parameter(convert_to_channelwise( - weight_scale=layer.weight_scale, - logical_widths=layer.logical_widths), - requires_grad=False) - else: - # torch.compile workaround - layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.data, requires_grad=False) + layer.weight.data = self._decompress_bitmask_compressed_weight( + layer.compressed, layer.shape, layer.bitmask, layer.row_offsets, layer=layer) w_compressed, meta = ops.cutlass_compress_entry(layer.weight.data) layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False) @@ -195,27 +144,9 @@ def apply_weights(self, :param bias: The bias to be added to the output tensor :return: The output tensor of the layer """ - if self.quantized: - scale = None - if hasattr(layer, "input_scale"): - scale = layer.input_scale - - if self.weights_dtype == torch.int8: - ops_output = ops.scaled_int8_quant(x, scale=scale) - q_input = ops_output[0] - input_scale = ops_output[1] - else: - assert self.weights_dtype == torch.float8_e4m3fn - if scale is not None: - q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale) - else: - q_input, input_scale = ops.scaled_fp8_quant( - x, use_per_token_if_dynamic=True) - - else: - # Not quantized, nothing to do with the input_scales, use as is - input_scale = layer.input_scale - q_input = x + # Not quantized, nothing to do with the input_scales, use as is + input_scale = layer.input_scale + q_input = x out = ops.cutlass_scaled_sparse_mm(a=layer.weight, e=layer.meta, @@ -250,7 +181,6 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: raise ValueError("Quantization type not supported by Cutlass") def _decompress_bitmask_compressed_weight(self, compressed, shape, bitmask, row_offsets, layer): - print("Decompressing weights for:", self.layer_name) split_weights = None def _process_split(bitmask_compressed_weight, shape, bitmask, row_offsets): weight_data = { @@ -261,24 +191,36 @@ def _process_split(bitmask_compressed_weight, shape, bitmask, row_offsets): } decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data) return decompress + if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): split_weights = torch.split(compressed, [partition_width * layer.input_size_per_partition // 2 for partition_width in layer.logical_widths]) split_bitmask = torch.split(bitmask, layer.logical_widths) split_row_offsets = torch.split(row_offsets, layer.logical_widths) - split_shape = torch.split(shape, [2] * len(layer.logical_widths)) - if split_weights: + split_shape = [(out, layer.input_size_per_partition) for out in layer.logical_widths] + + """ + print(type(layer), layer.input_size, layer.input_size_per_partition, compressed.shape, bitmask.shape, row_offsets.shape) + print([x.shape for x in split_weights]) + print([x.shape for x in split_bitmask]) + print([x.shape for x in split_row_offsets]) + print([x for x in split_shape]) + print("\n") + """ + + if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): all_compress = [] for i in range(len(split_weights)): compress_i = _process_split(split_weights[i], split_shape[i], split_bitmask[i], split_row_offsets[i]) all_compress.append(compress_i) decompressed = combine_shards(all_compress) else: + print(type(layer), layer.input_size, layer.input_size_per_partition, compressed.shape, bitmask.shape, row_offsets.shape) decompressed = self.model_compressor.sparsity_compressor.decompress_weight({ - "compressed": compressed, - "shape": shape, - "bitmask": bitmask, - "row_offsets": row_offsets - }) + "compressed": compressed, + "shape": (layer.logical_widths[0], layer.input_size_per_partition), + "bitmask": bitmask, + "row_offsets": row_offsets + }) return decompressed diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 82be12446fb5c..ec5ebb6c858a2 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -10,7 +10,8 @@ __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", "ModelWeightParameter", "ChannelQuantScaleParameter", - "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", "BitMaskShapeParameter" + "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", "BitMaskShapeParameter", + "BitMaskParameter" ] logger = init_logger(__name__) @@ -134,6 +135,7 @@ def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): param_data.copy_(loaded_weight) + class RowvLLMParameter(BasevLLMParameter): """ Parameter class defining weight_loading functionality @@ -153,8 +155,8 @@ def input_dim(self): def load_row_parallel_weight(self, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() shard_size = self.data.shape[self.input_dim] - loaded_weight = loaded_weight.narrow(self.input_dim, - tp_rank * shard_size, shard_size) + loaded_weight = loaded_weight.narrow(self.input_dim, + tp_rank * shard_size, shard_size) if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) @@ -286,6 +288,39 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): packed_factor=self.packed_factor, marlin_tile_size=self.marlin_tile_size) +class BitMaskParameter(_ColumnvLLMParameter): + def __init__(self, output_dim: int, input_size_per_partition: int, **kwargs): + self.compressed_dim = output_dim + self.input_size_per_partition = input_size_per_partition + super().__init__(**kwargs, output_dim=output_dim) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.data.shape[self.compressed_dim] + + loaded_weight = loaded_weight.narrow(self.compressed_dim, + tp_rank * shard_size, shard_size) + + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + shard_offset = kwargs.pop("shard_offset") + shard_size = kwargs.pop("shard_size") + + shard_size = (shard_size * self.input_size_per_partition) // 2 + shard_offset = (shard_offset * self.input_size_per_partition) // 2 + + super().load_qkv_weight(**kwargs, loaded_weight=loaded_weight, shard_size=shard_size, shard_offset=shard_offset) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + shard_offset = kwargs.pop("shard_offset") + shard_size = kwargs.pop("shard_size") + + shard_size = (shard_size * self.input_size_per_partition) // 2 + shard_offset = (shard_offset * self.input_size_per_partition) // 2 + + super().load_merged_column_weight(**kwargs, loaded_weight=loaded_weight, shard_size=shard_size, shard_offset=shard_offset) class PackedvLLMParameter(ModelWeightParameter): """