Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Dec 12, 2024
1 parent d0475c4 commit 153a8d6
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
QuantizationType)
from compressed_tensors.utils import combine_shards


from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import MergedColumnParallelLinear, QKVParallelLinear
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
Expand All @@ -18,7 +19,7 @@
ModelWeightParameter,
PerTensorScaleParameter,
_ColumnvLLMParameter as ColumnvLLMParameter,
BitMaskShapeParameter,)
BitMaskShapeParameter, BitMaskParameter)

__all__ = ["CompressedTensors24"]

Expand Down Expand Up @@ -51,14 +52,14 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
print("Creating weights for:", self.layer_name)
self.output_dtype = params_dtype


# This is needed for tensor scale
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.input_size
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)

# parameter to store uncompressed weight
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
Expand All @@ -67,79 +68,46 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

shape = BitMaskShapeParameter(data=torch.empty(2 * len(output_partition_sizes), 1, dtype=torch.uint64),
weight_loader=weight_loader)

compressed = BitMaskParameter(data=torch.empty(
sum(out * input_size_per_partition for out in output_partition_sizes) // 2, 1,
dtype=self.weights_dtype),
output_dim=0,
input_size_per_partition=input_size_per_partition,
weight_loader=weight_loader)

bitmask = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

if self.do_sparse_decompress:
sparsity_config = self.model_compressor.sparsity_config
if sparsity_config is not None and sparsity_config.format != CompressionFormat.sparse_bitmask.value:
raise ValueError("CompressedTensors24 only supports sparse_bitmask compression format")

# register compression specific parameters

shape = BitMaskShapeParameter(data=torch.empty(2 * len(output_partition_sizes), 1, dtype=torch.uint64),
weight_loader=weight_loader)
compressed = ColumnvLLMParameter(data=torch.empty(
sum(output_partition_size * input_size_per_partition for output_partition_size in output_partition_sizes) // 2, 1,
dtype=self.weights_dtype),
output_dim=0,
weight_loader=weight_loader)
bitmask = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 8,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
row_offsets = ColumnvLLMParameter(data=torch.empty(
sum(output_partition_sizes), 1, dtype=torch.uint64),
output_dim=0,
weight_loader=weight_loader,
)
row_offsets = ColumnvLLMParameter(data=torch.empty(
sum(output_partition_sizes), 1, dtype=torch.uint64),
output_dim=0,
weight_loader=weight_loader,
)

layer.register_parameter("shape", shape)
layer.register_parameter("compressed", compressed)
layer.register_parameter("bitmask", bitmask)
layer.register_parameter("row_offsets", row_offsets)
layer.register_parameter("shape", shape)
layer.register_parameter("compressed", compressed)
layer.register_parameter("bitmask", bitmask)
layer.register_parameter("row_offsets", row_offsets)

# Check if quantized, not just 2:4 Sparse
if self.quantized:
if (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.CHANNEL.value):
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.TENSOR.value)
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader)

layer.register_parameter("weight_scale", weight_scale)

# input quant will be non-none
if self.input_quant and not self.input_quant.dynamic:
# register input quant scale
assert (self.input_quant.strategy ==
QuantizationStrategy.TENSOR.value)
input_scale = BasevLLMParameter(data=torch.empty(
1, dtype=torch.float32),
weight_loader=weight_loader)

layer.register_parameter("input_scale", input_scale)

else:
# for sparse-only, pass in 1 for weight/input scales
weight_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
input_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("input_scale", input_scale)
layer.register_parameter("weight_scale", weight_scale)
# for sparse-only, pass in 1 for weight/input scales
weight_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
input_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("input_scale", input_scale)
layer.register_parameter("weight_scale", weight_scale)

layer.register_parameter("weight", weight)

Expand All @@ -154,27 +122,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
:param layer: The layer with the weights to be processed
"""
print("Processing weights for:", self.layer_name)

if self.do_sparse_decompress:
layer.weight.data = self._decompress_bitmask_compressed_weight(
layer.compressed, layer.shape, layer.bitmask, layer.row_offsets, layer=layer)

# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)

if self.weight_quant:
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise(
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths),
requires_grad=False)
else:
# torch.compile workaround
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False)
layer.weight.data = self._decompress_bitmask_compressed_weight(
layer.compressed, layer.shape, layer.bitmask, layer.row_offsets, layer=layer)

w_compressed, meta = ops.cutlass_compress_entry(layer.weight.data)
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
Expand All @@ -195,27 +144,9 @@ def apply_weights(self,
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
"""
if self.quantized:
scale = None
if hasattr(layer, "input_scale"):
scale = layer.input_scale

if self.weights_dtype == torch.int8:
ops_output = ops.scaled_int8_quant(x, scale=scale)
q_input = ops_output[0]
input_scale = ops_output[1]
else:
assert self.weights_dtype == torch.float8_e4m3fn
if scale is not None:
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
else:
q_input, input_scale = ops.scaled_fp8_quant(
x, use_per_token_if_dynamic=True)

else:
# Not quantized, nothing to do with the input_scales, use as is
input_scale = layer.input_scale
q_input = x
# Not quantized, nothing to do with the input_scales, use as is
input_scale = layer.input_scale
q_input = x

out = ops.cutlass_scaled_sparse_mm(a=layer.weight,
e=layer.meta,
Expand Down Expand Up @@ -250,7 +181,6 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
raise ValueError("Quantization type not supported by Cutlass")

def _decompress_bitmask_compressed_weight(self, compressed, shape, bitmask, row_offsets, layer):
print("Decompressing weights for:", self.layer_name)
split_weights = None
def _process_split(bitmask_compressed_weight, shape, bitmask, row_offsets):
weight_data = {
Expand All @@ -261,24 +191,36 @@ def _process_split(bitmask_compressed_weight, shape, bitmask, row_offsets):
}
decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data)
return decompress

if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, [partition_width * layer.input_size_per_partition // 2 for partition_width in layer.logical_widths])
split_bitmask = torch.split(bitmask, layer.logical_widths)
split_row_offsets = torch.split(row_offsets, layer.logical_widths)
split_shape = torch.split(shape, [2] * len(layer.logical_widths))
if split_weights:
split_shape = [(out, layer.input_size_per_partition) for out in layer.logical_widths]

"""
print(type(layer), layer.input_size, layer.input_size_per_partition, compressed.shape, bitmask.shape, row_offsets.shape)
print([x.shape for x in split_weights])
print([x.shape for x in split_bitmask])
print([x.shape for x in split_row_offsets])
print([x for x in split_shape])
print("\n")
"""

if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
all_compress = []
for i in range(len(split_weights)):
compress_i = _process_split(split_weights[i], split_shape[i], split_bitmask[i], split_row_offsets[i])
all_compress.append(compress_i)
decompressed = combine_shards(all_compress)
else:
print(type(layer), layer.input_size, layer.input_size_per_partition, compressed.shape, bitmask.shape, row_offsets.shape)
decompressed = self.model_compressor.sparsity_compressor.decompress_weight({
"compressed": compressed,
"shape": shape,
"bitmask": bitmask,
"row_offsets": row_offsets
})
"compressed": compressed,
"shape": (layer.logical_widths[0], layer.input_size_per_partition),
"bitmask": bitmask,
"row_offsets": row_offsets
})
return decompressed


Expand Down
41 changes: 38 additions & 3 deletions vllm/model_executor/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
"ModelWeightParameter", "ChannelQuantScaleParameter",
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", "BitMaskShapeParameter"
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", "BitMaskShapeParameter",
"BitMaskParameter"
]

logger = init_logger(__name__)
Expand Down Expand Up @@ -134,6 +135,7 @@ def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
param_data.copy_(loaded_weight)



class RowvLLMParameter(BasevLLMParameter):
"""
Parameter class defining weight_loading functionality
Expand All @@ -153,8 +155,8 @@ def input_dim(self):
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size)
loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size)

if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
Expand Down Expand Up @@ -286,6 +288,39 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)

class BitMaskParameter(_ColumnvLLMParameter):
def __init__(self, output_dim: int, input_size_per_partition: int, **kwargs):
self.compressed_dim = output_dim
self.input_size_per_partition = input_size_per_partition
super().__init__(**kwargs, output_dim=output_dim)

def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.compressed_dim]

loaded_weight = loaded_weight.narrow(self.compressed_dim,
tp_rank * shard_size, shard_size)

assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)

def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
shard_offset = kwargs.pop("shard_offset")
shard_size = kwargs.pop("shard_size")

shard_size = (shard_size * self.input_size_per_partition) // 2
shard_offset = (shard_offset * self.input_size_per_partition) // 2

super().load_qkv_weight(**kwargs, loaded_weight=loaded_weight, shard_size=shard_size, shard_offset=shard_offset)

def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
shard_offset = kwargs.pop("shard_offset")
shard_size = kwargs.pop("shard_size")

shard_size = (shard_size * self.input_size_per_partition) // 2
shard_offset = (shard_offset * self.input_size_per_partition) // 2

super().load_merged_column_weight(**kwargs, loaded_weight=loaded_weight, shard_size=shard_size, shard_offset=shard_offset)

class PackedvLLMParameter(ModelWeightParameter):
"""
Expand Down

0 comments on commit 153a8d6

Please sign in to comment.