Skip to content

Commit

Permalink
add support for all cases; update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Dec 6, 2024
1 parent 7a6d027 commit 0987c98
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 21 deletions.
16 changes: 12 additions & 4 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,17 @@ def test_compressed_tensors_kv_cache(vllm_runner):
output = llm.generate_greedy("Hello world!", max_tokens=20)
assert output

@pytest.mark.parametrize(
"args_2of4",
[("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel", "token")])

@pytest.mark.parametrize("args_2of4", [
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel",
"token"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
"channel", "tensor"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor",
"tensor"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor", "token")
])
def test_compressed_tensors_2of4(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
Expand All @@ -202,4 +210,4 @@ def test_compressed_tensors_2of4(vllm_runner, args_2of4):

output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
assert output
Original file line number Diff line number Diff line change
Expand Up @@ -388,13 +388,12 @@ def get_scheme(
if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant,
sparsity_scheme=sparsity_scheme):
# Have a valid sparsity scheme
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme = CompressedTensors24(
quantized=weight_quant is not None or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant
)
scheme = CompressedTensors24(quantized=weight_quant is not None
or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant)
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
from compressed_tensors.quantization import QuantizationType, QuantizationStrategy
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.parameter import ModelWeightParameter, ChannelQuantScaleParameter, PerTensorScaleParameter
from vllm.model_executor.parameter import ModelWeightParameter, ChannelQuantScaleParameter, PerTensorScaleParameter, BasevLLMParameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise)

__all__ = ["CompressedTensors24"]


class CompressedTensors24(CompressedTensorsScheme):

def __init__(
self,
quantized: bool = False,
weight_quant=None,
input_quant=None
):
def __init__(self,
quantized: bool = False,
weight_quant=None,
input_quant=None):
self.quantized = quantized
self.weight_quant = weight_quant
self.input_quant = input_quant
Expand All @@ -33,6 +33,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
**kwargs):

self.output_dtype = params_dtype
layer.logical_widths = output_partition_sizes
weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)

# parameter to store uncompressed weight
Expand All @@ -51,7 +52,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
output_dim=0,
weight_loader=weight_loader)
else:
assert self.weight_quant.strategy == QuantizationStrategy.TOKEN.value
assert self.weight_quant.strategy == QuantizationStrategy.TENSOR.value
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
Expand All @@ -62,9 +63,9 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
if not self.input_quant.dynamic:
# register input quant scale
assert self.input_quant.strategy == QuantizationStrategy.TENSOR.value
input_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
input_scale = BasevLLMParameter(data=torch.empty(
1, dtype=torch.float32),
weight_loader=weight_loader)

layer.register_parameter("input_scale", input_scale)

Expand All @@ -81,6 +82,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
:param layer: The layer with the weights to be processed
"""
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise(
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths),
requires_grad=False)
w_compressed, meta = ops.cutlass_compress_entry(layer.weight.data)
layer.w_compressed = torch.nn.Parameter(w_compressed,
requires_grad=False)
Expand Down

0 comments on commit 0987c98

Please sign in to comment.