Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Dec 10, 2024
1 parent 8ad1c7d commit d0475c4
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def get_scheme(
input_quant=input_quant,
model_compression_config=self._get_model_compression_config(
sparsity_scheme),
layer_name=layer_name,
)
else:
# Find the quant_scheme
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
from typing import Any, Callable, Dict, List, Optional

import torch
from compressed_tensors import ModelCompressor
from compressed_tensors import ModelCompressor, CompressionFormat
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from compressed_tensors.utils import combine_shards

from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import MergedColumnParallelLinear, QKVParallelLinear
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
PerTensorScaleParameter,
_ColumnvLLMParameter as ColumnvLLMParameter,
BitMaskShapeParameter,)

__all__ = ["CompressedTensors24"]

Expand All @@ -25,14 +29,18 @@ def __init__(self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None,
model_compression_config: Optional[Dict[str, Any]] = None):
model_compression_config: Optional[Dict[str, Any]] = None,
layer_name: Optional[str] = None # TODO: Remove
):

self.quantized = quantized
self.weight_quant = weight_quant
self.input_quant = input_quant
self.model_compressor = (
ModelCompressor.from_compression_config(model_compression_config)
if model_compression_config is not None else None)
self.layer_name = layer_name # TODO: Remove
self.do_sparse_decompress = self.model_compressor is not None

@classmethod
def get_min_capability(cls) -> int:
Expand All @@ -43,9 +51,12 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

print("Creating weights for:", self.layer_name)
self.output_dtype = params_dtype

# This is needed for tensor scale
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)

# parameter to store uncompressed weight
Expand All @@ -56,6 +67,38 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

if self.do_sparse_decompress:
sparsity_config = self.model_compressor.sparsity_config
if sparsity_config is not None and sparsity_config.format != CompressionFormat.sparse_bitmask.value:
raise ValueError("CompressedTensors24 only supports sparse_bitmask compression format")

# register compression specific parameters

shape = BitMaskShapeParameter(data=torch.empty(2 * len(output_partition_sizes), 1, dtype=torch.uint64),
weight_loader=weight_loader)
compressed = ColumnvLLMParameter(data=torch.empty(
sum(output_partition_size * input_size_per_partition for output_partition_size in output_partition_sizes) // 2, 1,
dtype=self.weights_dtype),
output_dim=0,
weight_loader=weight_loader)
bitmask = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 8,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
row_offsets = ColumnvLLMParameter(data=torch.empty(
sum(output_partition_sizes), 1, dtype=torch.uint64),
output_dim=0,
weight_loader=weight_loader,
)

layer.register_parameter("shape", shape)
layer.register_parameter("compressed", compressed)
layer.register_parameter("bitmask", bitmask)
layer.register_parameter("row_offsets", row_offsets)

# Check if quantized, not just 2:4 Sparse
if self.quantized:
Expand Down Expand Up @@ -111,6 +154,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
:param layer: The layer with the weights to be processed
"""
print("Processing weights for:", self.layer_name)

if self.do_sparse_decompress:
layer.weight.data = self._decompress_bitmask_compressed_weight(
layer.compressed, layer.shape, layer.bitmask, layer.row_offsets, layer=layer)

# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
Expand Down Expand Up @@ -199,6 +248,40 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
return torch.int8

raise ValueError("Quantization type not supported by Cutlass")

def _decompress_bitmask_compressed_weight(self, compressed, shape, bitmask, row_offsets, layer):
print("Decompressing weights for:", self.layer_name)
split_weights = None
def _process_split(bitmask_compressed_weight, shape, bitmask, row_offsets):
weight_data = {
"compressed": bitmask_compressed_weight,
"shape": shape,
"bitmask": bitmask,
"row_offsets": row_offsets
}
decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data)
return decompress
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, [partition_width * layer.input_size_per_partition // 2 for partition_width in layer.logical_widths])
split_bitmask = torch.split(bitmask, layer.logical_widths)
split_row_offsets = torch.split(row_offsets, layer.logical_widths)
split_shape = torch.split(shape, [2] * len(layer.logical_widths))
if split_weights:
all_compress = []
for i in range(len(split_weights)):
compress_i = _process_split(split_weights[i], split_shape[i], split_bitmask[i], split_row_offsets[i])
all_compress.append(compress_i)
decompressed = combine_shards(all_compress)
else:
decompressed = self.model_compressor.sparsity_compressor.decompress_weight({
"compressed": compressed,
"shape": shape,
"bitmask": bitmask,
"row_offsets": row_offsets
})
return decompressed




def check_24(tensor):
Expand Down
26 changes: 24 additions & 2 deletions vllm/model_executor/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
"ModelWeightParameter", "ChannelQuantScaleParameter",
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter"
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", "BitMaskShapeParameter"
]

logger = init_logger(__name__)
Expand Down Expand Up @@ -238,7 +238,7 @@ def _load_into_shard_id(self, loaded_weight: torch.Tensor,

param_data = self.data
shard_id = self._shard_id_as_int(shard_id)

# AutoFP8 scales do not have a shape
# compressed-tensors scales do have a shape
if len(loaded_weight.shape) != 0:
Expand Down Expand Up @@ -401,3 +401,25 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
shard_offset=shard_offset,
marlin_tile_size=marlin_tile_size)
return shard_size, shard_offset


class BitMaskShapeParameter(PerTensorScaleParameter):
"""
Parameter class for the shape of the bitmask tensor.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

def _load_into_shard_id(self, loaded_weight: torch.Tensor,
shard_id: Union[str, int], **kwargs):
"""
Slice the parameter data based on the shard id for
loading.
"""

param_data = self.data
shard_id = self._shard_id_as_int(shard_id)

start_index = shard_id * 2
param_data[start_index: start_index+2].copy_(loaded_weight)

0 comments on commit d0475c4

Please sign in to comment.