Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add: Support for Sparse24Bitmask Compressed Models #47

Closed
wants to merge 6 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add: Support for Sparse24Bitmask Compressed Models
Signed-off-by: Rahul Tuli <[email protected]>
rahul-tuli committed Jan 22, 2025
commit 2898d76947f51d67e094370d0fecc8d53636718b
Original file line number Diff line number Diff line change
@@ -396,10 +396,13 @@ def get_scheme(
sparsity_scheme=sparsity_scheme):
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme = CompressedTensors24(quantized=weight_quant is not None
or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant)
scheme = CompressedTensors24(
quantized=weight_quant is not None or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant,
model_compression_config=self._get_model_compression_config(
sparsity_scheme),
)
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore
@@ -447,10 +450,17 @@ def supports_cutlass_24(
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""
is_valid_sparsity = (sparsity_scheme is not None
and sparsity_scheme.sparsity_structure
== SparsityStructure.TWO_FOUR.value
and sparsity_scheme.format == "dense")
is_valid_sparsity_structure = (sparsity_scheme is not None
and sparsity_scheme.sparsity_structure
== SparsityStructure.TWO_FOUR.value)
valid_compressors = {
CompressionFormat.dense.value,
CompressionFormat.sparse_24_bitmask.value
}

is_valid_sparsity = (is_valid_sparsity_structure
and sparsity_scheme.format in valid_compressors)

if not is_valid_sparsity:
return False

@@ -481,6 +491,19 @@ def supports_cutlass_24(

return weight_quant.num_bits == input_quant.num_bits == 8

def _get_model_compression_config(
self, sparsity_scheme: Optional[SparsityCompressionConfig] = None):
"""
Get the model compressor config from the sparsity scheme
:param sparsity_scheme: The sparsity scheme
:return: The model compressor config
"""
if sparsity_scheme is None or sparsity_scheme.format == "dense":
return None

return self.config


class CompressedTensorsLinearMethod(LinearMethodBase):

Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from typing import Callable, List, Optional
from typing import Any, Callable, Dict, List, Optional

import torch
from compressed_tensors import CompressionFormat, ModelCompressor
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from compressed_tensors.utils import combine_shards

from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, sparse_cutlass_supported)
from vllm.model_executor.parameter import (BasevLLMParameter,
BitMaskShapeParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
@@ -20,14 +25,24 @@

class CompressedTensors24(CompressedTensorsScheme):

def __init__(self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None):
def __init__(
self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None,
model_compression_config: Optional[Dict[str, Any]] = None,
):

self.quantized = quantized
self.weight_quant = weight_quant
self.input_quant = input_quant
self.model_compressor = (
ModelCompressor.from_compression_config(model_compression_config)
if model_compression_config is not None else None)
self.do_sparse_decompress = (
self.model_compressor is not None
and self.model_compressor.sparsity_config.format
== CompressionFormat.sparse_24_bitmask.value)

@classmethod
def get_min_capability(cls) -> int:
@@ -47,6 +62,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,

self.output_dtype = params_dtype
layer.logical_widths = output_partition_sizes
layer.input_size = input_size
layer.input_size_per_partition = input_size_per_partition
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)

# parameter to store uncompressed weight
@@ -57,6 +74,34 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
if self.do_sparse_decompress:
assert all(
partition_size % 8 == 0
for partition_size in output_partition_sizes
), "All partitions must be divisible by 8 for 2:4 compressed models"

shape = BitMaskShapeParameter(data=torch.empty(
2 * len(output_partition_sizes), 1, dtype=torch.uint64),
weight_loader=weight_loader)
compressed = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=self.weights_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

bitmask = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 8,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

layer.register_parameter("shape", shape)
layer.register_parameter("compressed", compressed)
layer.register_parameter("bitmask", bitmask)

# Check if quantized, not just 2:4 Sparse
if self.quantized:
@@ -112,6 +157,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
:param layer: The layer with the weights to be processed
"""
if self.do_sparse_decompress:
layer.weight.data = self._decompress_bitmask_compressed_weight(
compressed=layer.compressed,
bitmask=layer.bitmask,
layer=layer,
)

# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
@@ -201,8 +253,42 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:

raise ValueError("Quantization type not supported by Cutlass")


def check_24(tensor):
new_tensor = tensor.view(-1, 4)
zero_counts = (new_tensor == 0).sum(dim=1)
return (zero_counts >= 2).all().item()
def _decompress_bitmask_compressed_weight(
self, compressed: torch.Tensor, bitmask: torch.Tensor,
layer: torch.nn.Module) -> torch.Tensor:

sparsity_compressor = self.model_compressor.sparsity_compressor

def _process_split(bitmask_compressed_weight: torch.Tensor, shape,
bitmask: torch.Tensor) -> torch.Tensor:
weight_data = dict(
compressed=bitmask_compressed_weight,
shape=shape,
bitmask=bitmask,
)
return sparsity_compressor.decompress_weight(weight_data)

split_weights = None
split_bitmask = None
split_shape = None

if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, layer.logical_widths)
split_bitmask = torch.split(bitmask, layer.logical_widths)
split_shape = [(out, layer.input_size_per_partition)
for out in layer.logical_widths]

if split_weights is not None:
decompressed_shards = [
_process_split(compressed_weight, shape, bitmask)
for compressed_weight, shape, bitmask in zip(
split_weights, split_shape, split_bitmask)
]
decompressed = combine_shards(decompressed_shards)
else:
decompressed = sparsity_compressor.decompress_weight(
dict(compressed=compressed,
shape=(layer.logical_widths[0],
layer.input_size_per_partition),
bitmask=bitmask))
return decompressed
26 changes: 25 additions & 1 deletion vllm/model_executor/parameter.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,8 @@
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
"ModelWeightParameter", "ChannelQuantScaleParameter",
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter"
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter",
"BitMaskShapeParameter"
]

logger = init_logger(__name__)
@@ -429,3 +430,26 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
shard_offset=shard_offset,
marlin_tile_size=marlin_tile_size)
return shard_size, shard_offset


class BitMaskShapeParameter(PerTensorScaleParameter):
"""
Parameter class for the shape of the bitmask tensor.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

def _load_into_shard_id(self, loaded_weight: torch.Tensor,
shard_id: Union[str, int], **kwargs):
"""
Slice the parameter data based on the shard id for
loading.
Note: Assumes the loaded weight is a 1D tensor
with 2 elements.
"""
param_data = self.data
shard_id = self._shard_id_as_int(shard_id)
start_index = shard_id * 2
param_data[start_index:start_index + 2].copy_(loaded_weight)