From 4e060dfc6fd4dd7b52a9b10ec93a0797af7b37b6 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 6 Dec 2024 17:05:30 +0000 Subject: [PATCH 1/3] remove compressed support; validate against ct models for tp=1,24 --- .../compressed_tensors/compressed_tensors.py | 114 +++++----- .../schemes/compressed_tensors_24.py | 209 +++++------------- 2 files changed, 121 insertions(+), 202 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 8146e08912f69..b8337fd250a8e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -31,16 +31,20 @@ SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config" QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]] + + class CompressedTensorsConfig(QuantizationConfig): - def __init__(self, - target_scheme_map: Dict[str, Any], - ignore: List[str], - quant_format: str, - kv_cache_scheme: Optional[Dict[str, Any]] = None, - sparsity_scheme_map: Optional[Dict[str, SparsityCompressionConfig]] = None, - config: Optional[Dict[str, Any]] = None, - ): + def __init__( + self, + target_scheme_map: Dict[str, Any], + ignore: List[str], + quant_format: str, + kv_cache_scheme: Optional[Dict[str, Any]] = None, + sparsity_scheme_map: Optional[Dict[str, + SparsityCompressionConfig]] = None, + config: Optional[Dict[str, Any]] = None, + ): self.ignore = ignore self.quant_format = quant_format @@ -92,8 +96,10 @@ def get_quant_method( def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": ignore: List[str] = cast(List[str], config.get("ignore", [])) quant_format = cast(str, config.get("format")) - target_scheme_map = cls._quantization_scheme_map_from_config(config=config) - sparsity_scheme_map = cls._sparsity_scheme_map_from_config(config=config) + target_scheme_map = cls._quantization_scheme_map_from_config( + config=config) + sparsity_scheme_map = cls._sparsity_scheme_map_from_config( + config=config) return cls( target_scheme_map=target_scheme_map, @@ -102,18 +108,21 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": sparsity_scheme_map=sparsity_scheme_map, config=config, ) - + @classmethod - def _sparsity_scheme_map_from_config(cls, config: Dict[str, Any]) -> Dict[str, SparsityCompressionConfig]: + def _sparsity_scheme_map_from_config( + cls, config: Dict[str, + Any]) -> Dict[str, SparsityCompressionConfig]: """ :param config: The `quantization_config` dictionary from config.json :return: A dictionary mapping target layer names to their corresponding sparsity compression configurations """ - if (sparsity_config:=config.get(SPARSITY_CONFIG_NAME)) is None: + if (sparsity_config := config.get(SPARSITY_CONFIG_NAME)) is None: return dict() - - sparsity_config = SparsityCompressionConfig.model_validate(sparsity_config) + + sparsity_config = SparsityCompressionConfig.model_validate( + sparsity_config) sparse_scheme_map: Dict[str, SparsityCompressionConfig] = { target: sparsity_config for target in sparsity_config.targets or list() @@ -121,7 +130,8 @@ def _sparsity_scheme_map_from_config(cls, config: Dict[str, Any]) -> Dict[str, S return sparse_scheme_map @classmethod - def _quantization_scheme_map_from_config(cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: + def _quantization_scheme_map_from_config( + cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: """ :param config: The `quantization_config` dictionary from config.json :return: A dictionary mapping target layer names to their corresponding @@ -144,7 +154,8 @@ def _quantization_scheme_map_from_config(cls, config: Dict[str, Any]) -> QUANTIZ targets = quant_config.get("targets") for target in targets: target_scheme_map[target] = {} - target_scheme_map[target]["weights"] = QuantizationArgs.model_validate( + target_scheme_map[target][ + "weights"] = QuantizationArgs.model_validate( quant_config.get("weights")) target_scheme_map[target]["input_activations"] = None @@ -158,7 +169,8 @@ def _quantization_scheme_map_from_config(cls, config: Dict[str, Any]) -> QUANTIZ assert target_scheme_map[target][ "weights"].type == QuantizationType.FLOAT else: - target_scheme_map[target]["input_activations"] = QuantizationArgs.model_validate( + target_scheme_map[target][ + "input_activations"] = QuantizationArgs.model_validate( quant_config.get("input_activations")) return target_scheme_map @@ -359,7 +371,7 @@ def get_scheme( # TODO (@robertgshaw): add compressed-tensors as dep # so we do not have to re-write these functions # need to make accelerate optional in ct to do this - + matched_target = find_matched_target( layer_name=layer_name, module=layer, @@ -369,42 +381,38 @@ def get_scheme( weight_quant = scheme_dict.get("weights") input_quant = scheme_dict.get("input_activations") - sparsity_scheme: Optional[SparsityCompressionConfig] = self.sparsity_scheme_map.get(matched_target) + sparsity_scheme: Optional[ + SparsityCompressionConfig] = self.sparsity_scheme_map.get( + matched_target) - if self.supports_cutlass_24( - weight_quant=weight_quant, - input_quant=input_quant, - sparsity_scheme=sparsity_scheme - ): - # Have a valid sparsity scheme and the layer is supported by the Cutlass 2:4 Kernel - needs_decompression = sparsity_scheme.format != CompressionFormat.dense.value - is_quantized = weight_quant is not None or input_quant is not None + if self.supports_cutlass_24(weight_quant=weight_quant, + input_quant=input_quant, + sparsity_scheme=sparsity_scheme): + # Have a valid sparsity scheme + # Validate layer is supported by Cutlass 2:4 Kernel scheme = CompressedTensors24( - layer_name=layer_name, - quantized=is_quantized, - do_decompress=needs_decompression, + quantized=weight_quant is not None or input_quant is not None, weight_quant=weight_quant, - input_quant=input_quant, - config=self.config, + input_quant=input_quant ) else: - # Find the quant_scheme + # Find the quant_scheme scheme = self._get_scheme_from_parts( weight_quant=weight_quant, input_quant=input_quant, - ) + ) # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) self._check_scheme_supported(scheme.get_min_capability()) return scheme - + @staticmethod def supports_cutlass_24( - weight_quant: Optional[QuantizationArgs], - input_quant: Optional[QuantizationArgs], - sparsity_scheme: Optional[SparsityCompressionConfig]=None - ) -> bool: + weight_quant: Optional[QuantizationArgs], + input_quant: Optional[QuantizationArgs], + sparsity_scheme: Optional[SparsityCompressionConfig] = None + ) -> bool: """ Check if the layer is supported by the Cutlass 2:4 Kernel Conditions: @@ -418,21 +426,21 @@ def supports_cutlass_24( :return: True if the layer is supported by the Cutlass 2:4 Kernel False otherwise """ - - if ( - sparsity_scheme is None or - sparsity_scheme.sparsity_structure != SparsityStructure.TWO_FOUR.value - ): + is_valid_sparsity = (sparsity_scheme is not None + and sparsity_scheme.sparsity_structure + == SparsityStructure.TWO_FOUR.value + and sparsity_scheme.format == "dense") + if not is_valid_sparsity: return False - + # Unquantized cases are supported if weight_quant is None and input_quant is None: return True - + # Weight only quantization is not-supported if weight_quant is not None and input_quant is None: return False - + supported_weight_quant_strategies = [ QuantizationStrategy.TENSOR.value, QuantizationStrategy.CHANNEL.value @@ -440,17 +448,15 @@ def supports_cutlass_24( if weight_quant.strategy not in supported_weight_quant_strategies: return False - + supported_input_quant_strategies = [ - QuantizationStrategy.TENSOR.value, - QuantizationStrategy.TOKEN.value + QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value ] - + if input_quant.strategy not in supported_input_quant_strategies: return False - - return weight_quant.num_bits == input_quant.num_bits == 8 + return weight_quant.num_bits == input_quant.num_bits == 8 class CompressedTensorsLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 8607464bd9dc0..4b23d9c05582a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -1,128 +1,79 @@ from typing import Any, Dict, List, Callable, Optional import torch -from compressed_tensors.compressors import ModelCompressor from compressed_tensors.quantization import QuantizationType, QuantizationStrategy from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.parameter import ModelWeightParameter, ChannelQuantScaleParameter, PerTensorScaleParameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - ) __all__ = ["CompressedTensors24"] + class CompressedTensors24(CompressedTensorsScheme): + def __init__( - self, - layer_name: Optional[str] = None, - quantized: bool = False, - do_decompress: bool = False, - weight_quant = None, - input_quant = None, - config: Optional[Dict[str, Any]] = None, - ): - self.layer_name = layer_name + self, + quantized: bool = False, + weight_quant=None, + input_quant=None + ): self.quantized = quantized - self.do_decompress = do_decompress self.weight_quant = weight_quant self.input_quant = input_quant - self.model_compressor = ( - ModelCompressor.from_compression_config(compression_config=config) - if self.do_decompress and config is not None - else None - ) - @classmethod def get_min_capability(cls) -> int: return 90 def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: List[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - layer.logical_widths = output_partition_sizes - self.output_dtype=params_dtype + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + self.output_dtype = params_dtype weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) - # parameter to store uncompressed weight or decompressed weight - weight = ModelWeightParameter( - data=torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=weights_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) - - if self.do_decompress: - # store compression specific things to be used - # later during decompression - - # compressed weight for 2:4 sparse (compressed-tensors) - sparse_24_packed_weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // 2, - dtype=weights_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader - ) - - bits_per_weight_element = weight.itemsize * 8 - meta_dtype = torch.int32 if bits_per_weight_element == 8 else torch.int16 - meta_input_size = ( - input_size_per_partition // 32 - if bits_per_weight_element == 8 - else input_size_per_partition // 16 - ) - - # meta tensor for 2:4 decompression - meta = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - meta_input_size, - dtype=meta_dtype), - input_dim=1, + # parameter to store uncompressed weight + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=weights_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + if self.weight_quant.strategy == QuantizationStrategy.CHANNEL.value: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), output_dim=0, weight_loader=weight_loader) + else: + assert self.weight_quant.strategy == QuantizationStrategy.TOKEN.value + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) - layer.register_parameter("sparse_24_packed_weight", sparse_24_packed_weight) - layer.register_parameter("meta", meta) - - if self.quantized: + layer.register_parameter("weight_scale", weight_scale) - if self.weight_quant.strategy == QuantizationStrategy.CHANNEL.value: - weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader) - else: - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - - layer.register_parameter("weight_scale", weight_scale) - - # input quant will be non-none - if not self.input_quant.dynamic: - # register input quant scale - input_scale = PerTensorScaleParameter(data=torch.empty( + # input quant will be non-none + if not self.input_quant.dynamic: + # register input quant scale + assert self.input_quant.strategy == QuantizationStrategy.TENSOR.value + input_scale = PerTensorScaleParameter(data=torch.empty( len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - - layer.register_parameter("input_scale", input_scale) + weight_loader=weight_loader) + + layer.register_parameter("input_scale", input_scale) layer.register_parameter("weight", weight) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: """ - Apply any transformations to the weights after loading - them from disk + Compress weights after loading. Store compressed weight and meta + tensor :post-condition: layer.w_compressed and layer.meta are set to the compressed weight and meta tensor in the @@ -130,12 +81,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: :param layer: The layer with the weights to be processed """ - weight_to_compress = ( - layer.weight.data if not self.do_decompress - else self._decompress_layer_weight(layer) - ) - w_compressed, meta = ops.cutlass_compress_entry(weight_to_compress) - layer.w_compressed = torch.nn.Parameter(w_compressed, requires_grad=False) + w_compressed, meta = ops.cutlass_compress_entry(layer.weight.data) + layer.w_compressed = torch.nn.Parameter(w_compressed, + requires_grad=False) layer.meta = torch.nn.Parameter(meta, requires_grad=False) def apply_weights(self, @@ -153,81 +101,46 @@ def apply_weights(self, :param bias: The bias to be added to the output tensor :return: The output tensor of the layer """ + print("running") if hasattr(layer, "input_scale"): q_input, input_scale = ops.scaled_fp8_quant( x, scale=layer.input_scale) else: q_input, input_scale = ops.scaled_fp8_quant( - x, use_per_token_if_dynamic=True) - - out = ops.cutlass_scaled_sparse_mm( - a=layer.w_compressed, - e=layer.meta, - b=q_input.t(), - scale_a=layer.weight_scale, - scale_b=input_scale, - out_dtype=self.output_dtype, - bias=bias - ) + x, use_per_token_if_dynamic=True) + + out = ops.cutlass_scaled_sparse_mm(a=layer.w_compressed, + e=layer.meta, + b=q_input.t(), + scale_a=layer.weight_scale, + scale_b=input_scale, + out_dtype=self.output_dtype, + bias=bias) assert out.is_contiguous() return out - - def _decompress_layer_weight(self, layer: torch.nn.Module) -> torch.Tensor: - - sparse_24_packed_weight = layer.sparse_24_packed_weight.data - meta = layer.meta.data - - split_weights = None - split_meta = None - - def _process_split(input_weight, input_meta): - weight_data = { - "sparse_24_packed_weight": input_weight, - "meta": input_meta - } - decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data) - return decompress - - if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): - split_weights = torch.split(sparse_24_packed_weight, layer.logical_widths) - split_meta = torch.split(meta, layer.logical_widths) - - if split_weights: - all_compress = [] - for i in range(len(split_weights)): - compress_i = _process_split(split_weights[i], split_meta[i]) - all_compress.append(compress_i) - - decompressed = torch.cat(all_compress) - else: - decompressed = _process_split(sparse_24_packed_weight, meta) - return decompressed - def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: if not self.quantized: return params_dtype - + is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8 if not is_8_bits: raise ValueError("Cutlass only supports 8-bit quantization") - + if (self.weight_quant.type == QuantizationType.FLOAT - and self.input_quant.type == QuantizationType.FLOAT): + and self.input_quant.type == QuantizationType.FLOAT): return torch.float8_e4m3fn - + if (self.weight_quant.type == QuantizationType.INT - and self.input_quant.type == QuantizationType.INT): + and self.input_quant.type == QuantizationType.INT): return torch.int8 - - raise ValueError("Quantization type not supported by Cutlass") + raise ValueError("Quantization type not supported by Cutlass") def check_24(tensor): - new_tensor = tensor.view(-1, 4) + new_tensor = tensor.view(-1, 4) zero_counts = (new_tensor == 0).sum(dim=1) return (zero_counts >= 2).all().item() - From 7a6d0271402cbc1faa21ffedf264c26e66489bbf Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 6 Dec 2024 17:37:10 +0000 Subject: [PATCH 2/3] add testing cases --- tests/quantization/test_compressed_tensors.py | 27 ++++++++++++++++++- tests/weight_loading/models.txt | 1 + .../schemes/compressed_tensors_24.py | 1 - 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 03097569b2b3b..d7deedb2dc49e 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, - CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16, CompressedTensors24) @pytest.mark.parametrize( @@ -178,3 +178,28 @@ def test_compressed_tensors_kv_cache(vllm_runner): with vllm_runner(model_path, kv_cache_dtype="fp8") as llm: output = llm.generate_greedy("Hello world!", max_tokens=20) assert output + +@pytest.mark.parametrize( + "args_2of4", + [("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel", "token")]) +def test_compressed_tensors_2of4(vllm_runner, args_2of4): + model, weight_strategy, input_strategy = args_2of4 + with vllm_runner(model) as llm: + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensors24) + + assert qkv_proj.scheme.weight_quant.strategy == weight_strategy + assert qkv_proj.scheme.input_quant.strategy == input_strategy + assert qkv_proj.scheme.quantized == True + assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map + sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map + assert sparsity_map.get("Linear").format == "dense" + assert sparsity_map.get("Linear").sparsity_structure == "2:4" + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output \ No newline at end of file diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index a4ee9538d646b..ea0fa57f9b242 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -21,6 +21,7 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main +compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 4b23d9c05582a..133e8a236f722 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -101,7 +101,6 @@ def apply_weights(self, :param bias: The bias to be added to the output tensor :return: The output tensor of the layer """ - print("running") if hasattr(layer, "input_scale"): q_input, input_scale = ops.scaled_fp8_quant( x, scale=layer.input_scale) From 0987c9825850c3d617bd4bdf253a2ffbbcca4fe9 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 6 Dec 2024 19:56:02 +0000 Subject: [PATCH 3/3] add support for all cases; update tests --- tests/quantization/test_compressed_tensors.py | 16 ++++++++--- .../compressed_tensors/compressed_tensors.py | 11 ++++---- .../schemes/compressed_tensors_24.py | 28 +++++++++++-------- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index d7deedb2dc49e..ea6d019872654 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -179,9 +179,17 @@ def test_compressed_tensors_kv_cache(vllm_runner): output = llm.generate_greedy("Hello world!", max_tokens=20) assert output -@pytest.mark.parametrize( - "args_2of4", - [("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel", "token")]) + +@pytest.mark.parametrize("args_2of4", [ + ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel", + "token"), + ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing", + "channel", "tensor"), + ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor", + "tensor"), + ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing", + "tensor", "token") +]) def test_compressed_tensors_2of4(vllm_runner, args_2of4): model, weight_strategy, input_strategy = args_2of4 with vllm_runner(model) as llm: @@ -202,4 +210,4 @@ def test_compressed_tensors_2of4(vllm_runner, args_2of4): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) - assert output \ No newline at end of file + assert output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b8337fd250a8e..da34b9b9aa68c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -388,13 +388,12 @@ def get_scheme( if self.supports_cutlass_24(weight_quant=weight_quant, input_quant=input_quant, sparsity_scheme=sparsity_scheme): - # Have a valid sparsity scheme + # Have a valid sparsity scheme # Validate layer is supported by Cutlass 2:4 Kernel - scheme = CompressedTensors24( - quantized=weight_quant is not None or input_quant is not None, - weight_quant=weight_quant, - input_quant=input_quant - ) + scheme = CompressedTensors24(quantized=weight_quant is not None + or input_quant is not None, + weight_quant=weight_quant, + input_quant=input_quant) else: # Find the quant_scheme scheme = self._get_scheme_from_parts( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 133e8a236f722..5fee61d340f7c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -4,20 +4,20 @@ from compressed_tensors.quantization import QuantizationType, QuantizationStrategy from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from vllm.model_executor.parameter import ModelWeightParameter, ChannelQuantScaleParameter, PerTensorScaleParameter +from vllm.model_executor.parameter import ModelWeightParameter, ChannelQuantScaleParameter, PerTensorScaleParameter, BasevLLMParameter from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) __all__ = ["CompressedTensors24"] class CompressedTensors24(CompressedTensorsScheme): - def __init__( - self, - quantized: bool = False, - weight_quant=None, - input_quant=None - ): + def __init__(self, + quantized: bool = False, + weight_quant=None, + input_quant=None): self.quantized = quantized self.weight_quant = weight_quant self.input_quant = input_quant @@ -33,6 +33,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, **kwargs): self.output_dtype = params_dtype + layer.logical_widths = output_partition_sizes weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) # parameter to store uncompressed weight @@ -51,7 +52,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, output_dim=0, weight_loader=weight_loader) else: - assert self.weight_quant.strategy == QuantizationStrategy.TOKEN.value + assert self.weight_quant.strategy == QuantizationStrategy.TENSOR.value weight_scale = PerTensorScaleParameter(data=torch.empty( len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader) @@ -62,9 +63,9 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, if not self.input_quant.dynamic: # register input quant scale assert self.input_quant.strategy == QuantizationStrategy.TENSOR.value - input_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + input_scale = BasevLLMParameter(data=torch.empty( + 1, dtype=torch.float32), + weight_loader=weight_loader) layer.register_parameter("input_scale", input_scale) @@ -81,6 +82,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: :param layer: The layer with the weights to be processed """ + if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value: + layer.weight_scale = torch.nn.Parameter(convert_to_channelwise( + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths), + requires_grad=False) w_compressed, meta = ops.cutlass_compress_entry(layer.weight.data) layer.w_compressed = torch.nn.Parameter(w_compressed, requires_grad=False)