Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update #43

Draft
wants to merge 1 commit into
base: rahul-bitmask-additions
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
QuantizationType)
from compressed_tensors.utils import combine_shards


from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import MergedColumnParallelLinear, QKVParallelLinear
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
Expand All @@ -18,7 +19,7 @@
ModelWeightParameter,
PerTensorScaleParameter,
_ColumnvLLMParameter as ColumnvLLMParameter,
BitMaskShapeParameter,)
BitMaskShapeParameter, BitMaskParameter)

__all__ = ["CompressedTensors24"]

Expand Down Expand Up @@ -51,14 +52,14 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
print("Creating weights for:", self.layer_name)
self.output_dtype = params_dtype


# This is needed for tensor scale
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.input_size
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)

# parameter to store uncompressed weight
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
Expand All @@ -67,79 +68,46 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

shape = BitMaskShapeParameter(data=torch.empty(2 * len(output_partition_sizes), 1, dtype=torch.uint64),
weight_loader=weight_loader)

compressed = BitMaskParameter(data=torch.empty(
sum(out * input_size_per_partition for out in output_partition_sizes) // 2, 1,
dtype=self.weights_dtype),
output_dim=0,
input_size_per_partition=input_size_per_partition,
weight_loader=weight_loader)

bitmask = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

if self.do_sparse_decompress:
sparsity_config = self.model_compressor.sparsity_config
if sparsity_config is not None and sparsity_config.format != CompressionFormat.sparse_bitmask.value:
raise ValueError("CompressedTensors24 only supports sparse_bitmask compression format")

# register compression specific parameters

shape = BitMaskShapeParameter(data=torch.empty(2 * len(output_partition_sizes), 1, dtype=torch.uint64),
weight_loader=weight_loader)
compressed = ColumnvLLMParameter(data=torch.empty(
sum(output_partition_size * input_size_per_partition for output_partition_size in output_partition_sizes) // 2, 1,
dtype=self.weights_dtype),
output_dim=0,
weight_loader=weight_loader)
bitmask = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 8,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
row_offsets = ColumnvLLMParameter(data=torch.empty(
sum(output_partition_sizes), 1, dtype=torch.uint64),
output_dim=0,
weight_loader=weight_loader,
)
row_offsets = ColumnvLLMParameter(data=torch.empty(
sum(output_partition_sizes), 1, dtype=torch.uint64),
output_dim=0,
weight_loader=weight_loader,
)

layer.register_parameter("shape", shape)
layer.register_parameter("compressed", compressed)
layer.register_parameter("bitmask", bitmask)
layer.register_parameter("row_offsets", row_offsets)
layer.register_parameter("shape", shape)
layer.register_parameter("compressed", compressed)
layer.register_parameter("bitmask", bitmask)
layer.register_parameter("row_offsets", row_offsets)

# Check if quantized, not just 2:4 Sparse
if self.quantized:
if (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.CHANNEL.value):
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.TENSOR.value)
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader)

layer.register_parameter("weight_scale", weight_scale)

# input quant will be non-none
if self.input_quant and not self.input_quant.dynamic:
# register input quant scale
assert (self.input_quant.strategy ==
QuantizationStrategy.TENSOR.value)
input_scale = BasevLLMParameter(data=torch.empty(
1, dtype=torch.float32),
weight_loader=weight_loader)

layer.register_parameter("input_scale", input_scale)

else:
# for sparse-only, pass in 1 for weight/input scales
weight_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
input_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("input_scale", input_scale)
layer.register_parameter("weight_scale", weight_scale)
# for sparse-only, pass in 1 for weight/input scales
weight_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
input_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("input_scale", input_scale)
layer.register_parameter("weight_scale", weight_scale)

layer.register_parameter("weight", weight)

Expand All @@ -154,27 +122,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
:param layer: The layer with the weights to be processed

"""
print("Processing weights for:", self.layer_name)

if self.do_sparse_decompress:
layer.weight.data = self._decompress_bitmask_compressed_weight(
layer.compressed, layer.shape, layer.bitmask, layer.row_offsets, layer=layer)

# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)

if self.weight_quant:
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise(
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths),
requires_grad=False)
else:
# torch.compile workaround
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False)
layer.weight.data = self._decompress_bitmask_compressed_weight(
layer.compressed, layer.shape, layer.bitmask, layer.row_offsets, layer=layer)

w_compressed, meta = ops.cutlass_compress_entry(layer.weight.data)
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
Expand All @@ -195,27 +144,9 @@ def apply_weights(self,
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
"""
if self.quantized:
scale = None
if hasattr(layer, "input_scale"):
scale = layer.input_scale

if self.weights_dtype == torch.int8:
ops_output = ops.scaled_int8_quant(x, scale=scale)
q_input = ops_output[0]
input_scale = ops_output[1]
else:
assert self.weights_dtype == torch.float8_e4m3fn
if scale is not None:
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
else:
q_input, input_scale = ops.scaled_fp8_quant(
x, use_per_token_if_dynamic=True)

else:
# Not quantized, nothing to do with the input_scales, use as is
input_scale = layer.input_scale
q_input = x
# Not quantized, nothing to do with the input_scales, use as is
input_scale = layer.input_scale
q_input = x

out = ops.cutlass_scaled_sparse_mm(a=layer.weight,
e=layer.meta,
Expand Down Expand Up @@ -250,7 +181,6 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
raise ValueError("Quantization type not supported by Cutlass")

def _decompress_bitmask_compressed_weight(self, compressed, shape, bitmask, row_offsets, layer):
print("Decompressing weights for:", self.layer_name)
split_weights = None
def _process_split(bitmask_compressed_weight, shape, bitmask, row_offsets):
weight_data = {
Expand All @@ -261,24 +191,36 @@ def _process_split(bitmask_compressed_weight, shape, bitmask, row_offsets):
}
decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data)
return decompress

if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, [partition_width * layer.input_size_per_partition // 2 for partition_width in layer.logical_widths])
split_bitmask = torch.split(bitmask, layer.logical_widths)
split_row_offsets = torch.split(row_offsets, layer.logical_widths)
split_shape = torch.split(shape, [2] * len(layer.logical_widths))
if split_weights:
split_shape = [(out, layer.input_size_per_partition) for out in layer.logical_widths]

"""
print(type(layer), layer.input_size, layer.input_size_per_partition, compressed.shape, bitmask.shape, row_offsets.shape)
print([x.shape for x in split_weights])
print([x.shape for x in split_bitmask])
print([x.shape for x in split_row_offsets])
print([x for x in split_shape])
print("\n")
"""

if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
all_compress = []
for i in range(len(split_weights)):
compress_i = _process_split(split_weights[i], split_shape[i], split_bitmask[i], split_row_offsets[i])
all_compress.append(compress_i)
decompressed = combine_shards(all_compress)
else:
print(type(layer), layer.input_size, layer.input_size_per_partition, compressed.shape, bitmask.shape, row_offsets.shape)
decompressed = self.model_compressor.sparsity_compressor.decompress_weight({
"compressed": compressed,
"shape": shape,
"bitmask": bitmask,
"row_offsets": row_offsets
})
"compressed": compressed,
"shape": (layer.logical_widths[0], layer.input_size_per_partition),
"bitmask": bitmask,
"row_offsets": row_offsets
})
return decompressed


Expand Down
41 changes: 38 additions & 3 deletions vllm/model_executor/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
"ModelWeightParameter", "ChannelQuantScaleParameter",
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", "BitMaskShapeParameter"
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", "BitMaskShapeParameter",
"BitMaskParameter"
]

logger = init_logger(__name__)
Expand Down Expand Up @@ -134,6 +135,7 @@ def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
param_data.copy_(loaded_weight)



class RowvLLMParameter(BasevLLMParameter):
"""
Parameter class defining weight_loading functionality
Expand All @@ -153,8 +155,8 @@ def input_dim(self):
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size)
loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size)

if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
Expand Down Expand Up @@ -286,6 +288,39 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)

class BitMaskParameter(_ColumnvLLMParameter):
def __init__(self, output_dim: int, input_size_per_partition: int, **kwargs):
self.compressed_dim = output_dim
self.input_size_per_partition = input_size_per_partition
super().__init__(**kwargs, output_dim=output_dim)

def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.compressed_dim]

loaded_weight = loaded_weight.narrow(self.compressed_dim,
tp_rank * shard_size, shard_size)

assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)

def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
shard_offset = kwargs.pop("shard_offset")
shard_size = kwargs.pop("shard_size")

shard_size = (shard_size * self.input_size_per_partition) // 2
shard_offset = (shard_offset * self.input_size_per_partition) // 2

super().load_qkv_weight(**kwargs, loaded_weight=loaded_weight, shard_size=shard_size, shard_offset=shard_offset)

def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
shard_offset = kwargs.pop("shard_offset")
shard_size = kwargs.pop("shard_size")

shard_size = (shard_size * self.input_size_per_partition) // 2
shard_offset = (shard_offset * self.input_size_per_partition) // 2

super().load_merged_column_weight(**kwargs, loaded_weight=loaded_weight, shard_size=shard_size, shard_offset=shard_offset)

class PackedvLLMParameter(ModelWeightParameter):
"""
Expand Down
Loading