Skip to content

Commit

Permalink
Add: Support for Sparse24Bitmask Compressed Models
Browse files Browse the repository at this point in the history
rahul-tuli committed Jan 15, 2025
1 parent ebd8c66 commit 4e65cde
Showing 3 changed files with 152 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -396,10 +396,13 @@ def get_scheme(
sparsity_scheme=sparsity_scheme):
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme = CompressedTensors24(quantized=weight_quant is not None
or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant)
scheme = CompressedTensors24(
quantized=weight_quant is not None or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant,
model_compression_config=self._get_model_compression_config(
sparsity_scheme),
)
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore
@@ -447,10 +450,17 @@ def supports_cutlass_24(
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""
is_valid_sparsity = (sparsity_scheme is not None
and sparsity_scheme.sparsity_structure
== SparsityStructure.TWO_FOUR.value
and sparsity_scheme.format == "dense")
is_valid_sparsity_structure = (sparsity_scheme is not None
and sparsity_scheme.sparsity_structure
== SparsityStructure.TWO_FOUR.value)
valid_compressors = {
CompressionFormat.dense.value,
CompressionFormat.sparse_24_bitmask.value
}

is_valid_sparsity = (is_valid_sparsity_structure
and sparsity_scheme.format in valid_compressors)

if not is_valid_sparsity:
return False

@@ -481,6 +491,19 @@ def supports_cutlass_24(

return weight_quant.num_bits == input_quant.num_bits == 8

def _get_model_compression_config(
self, sparsity_scheme: Optional[SparsityCompressionConfig] = None):
"""
Get the model compressor config from the sparsity scheme
:param sparsity_scheme: The sparsity scheme
:return: The model compressor config
"""
if sparsity_scheme is None or sparsity_scheme.format == "dense":
return None

return self.config


class CompressedTensorsLinearMethod(LinearMethodBase):

Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from typing import Callable, List, Optional
from typing import Any, Callable, Dict, List, Optional

import torch
from compressed_tensors import CompressionFormat, ModelCompressor
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from compressed_tensors.utils import combine_shards

from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, sparse_cutlass_supported)
from vllm.model_executor.parameter import (BasevLLMParameter,
BitMaskShapeParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
@@ -20,14 +25,24 @@

class CompressedTensors24(CompressedTensorsScheme):

def __init__(self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None):
def __init__(
self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None,
model_compression_config: Optional[Dict[str, Any]] = None,
):

self.quantized = quantized
self.weight_quant = weight_quant
self.input_quant = input_quant
self.model_compressor = (
ModelCompressor.from_compression_config(model_compression_config)
if model_compression_config is not None else None)
self.do_sparse_decompress = (
self.model_compressor is not None
and self.model_compressor.sparsity_config.format
== CompressionFormat.sparse_24_bitmask.value)

@classmethod
def get_min_capability(cls) -> int:
@@ -47,6 +62,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,

self.output_dtype = params_dtype
layer.logical_widths = output_partition_sizes
layer.input_size = input_size
layer.input_size_per_partition = input_size_per_partition
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)

# parameter to store uncompressed weight
@@ -57,6 +74,34 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
if self.do_sparse_decompress:
assert all(
partition_size % 8 == 0
for partition_size in output_partition_sizes
), "All partitions must be divisible by 8 for 2:4 compressed models"

shape = BitMaskShapeParameter(data=torch.empty(
2 * len(output_partition_sizes), 1, dtype=torch.uint64),
weight_loader=weight_loader)
compressed = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=self.weights_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

bitmask = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 8,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

layer.register_parameter("shape", shape)
layer.register_parameter("compressed", compressed)
layer.register_parameter("bitmask", bitmask)

# Check if quantized, not just 2:4 Sparse
if self.quantized:
@@ -112,6 +157,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
:param layer: The layer with the weights to be processed
"""
if self.do_sparse_decompress:
layer.weight.data = self._decompress_bitmask_compressed_weight(
compressed=layer.compressed,
bitmask=layer.bitmask,
layer=layer,
)

# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
@@ -201,8 +253,42 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:

raise ValueError("Quantization type not supported by Cutlass")


def check_24(tensor):
new_tensor = tensor.view(-1, 4)
zero_counts = (new_tensor == 0).sum(dim=1)
return (zero_counts >= 2).all().item()
def _decompress_bitmask_compressed_weight(
self, compressed: torch.Tensor, bitmask: torch.Tensor,
layer: torch.nn.Module) -> torch.Tensor:

sparsity_compressor = self.model_compressor.sparsity_compressor

def _process_split(bitmask_compressed_weight: torch.Tensor, shape,
bitmask: torch.Tensor) -> torch.Tensor:
weight_data = dict(
compressed=bitmask_compressed_weight,
shape=shape,
bitmask=bitmask,
)
return sparsity_compressor.decompress_weight(weight_data)

split_weights = None
split_bitmask = None
split_shape = None

if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, layer.logical_widths)
split_bitmask = torch.split(bitmask, layer.logical_widths)
split_shape = [(out, layer.input_size_per_partition)
for out in layer.logical_widths]

if split_weights is not None:
decompressed_shards = [
_process_split(compressed_weight, shape, bitmask)
for compressed_weight, shape, bitmask in zip(
split_weights, split_shape, split_bitmask)
]
decompressed = combine_shards(decompressed_shards)
else:
decompressed = sparsity_compressor.decompress_weight(
dict(compressed=compressed,
shape=(layer.logical_widths[0],
layer.input_size_per_partition),
bitmask=bitmask))
return decompressed
26 changes: 25 additions & 1 deletion vllm/model_executor/parameter.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,8 @@
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
"ModelWeightParameter", "ChannelQuantScaleParameter",
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter"
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter",
"BitMaskShapeParameter"
]

logger = init_logger(__name__)
@@ -429,3 +430,26 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
shard_offset=shard_offset,
marlin_tile_size=marlin_tile_size)
return shard_size, shard_offset


class BitMaskShapeParameter(PerTensorScaleParameter):
"""
Parameter class for the shape of the bitmask tensor.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

def _load_into_shard_id(self, loaded_weight: torch.Tensor,
shard_id: Union[str, int], **kwargs):
"""
Slice the parameter data based on the shard id for
loading.
Note: Assumes the loaded weight is a 1D tensor
with 2 elements.
"""
param_data = self.data
shard_id = self._shard_id_as_int(shard_id)
start_index = shard_id * 2
param_data[start_index:start_index + 2].copy_(loaded_weight)

0 comments on commit 4e65cde

Please sign in to comment.