diff --git a/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml new file mode 100644 index 0000000000000..2928d75ce4469 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml @@ -0,0 +1,11 @@ +# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2 +model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.6353 + - name: "exact_match,flexible-extract" + value: 0.637 +limit: null +num_fewshot: null diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 0cd86cef0a475..ac9cdd0aa5ae1 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -242,7 +242,10 @@ def test_compressed_tensors_kv_cache(vllm_runner): @pytest.mark.skipif(not sparse_cutlass_supported(), reason="Sparse FP8 is not yet supported on this GPU type.") -def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy): +def _test_2of4_quant_models(qkv_proj, + weight_strategy, + input_strategy, + format="dense"): assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensors24) @@ -251,7 +254,7 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy): assert qkv_proj.scheme.quantized assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 - assert sparsity_map.get("Linear").format == "dense" + assert sparsity_map.get("Linear").format == format assert sparsity_map.get("Linear").sparsity_structure == "2:4" @@ -285,6 +288,72 @@ def check_model(model): assert output +@pytest.mark.skipif(not current_platform.has_device_capability(90), + reason="Sparse FP8 is not yet supported on this GPU type.") +@pytest.mark.parametrize("args_2of4", [ + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM", + "channel", "token"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM", + "channel", "tensor"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM", + "tensor", "token"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM", + "tensor", "tensor"), +]) +def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4): + model, weight_strategy, input_strategy = args_2of4 + with vllm_runner(model) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn + _test_2of4_quant_models(qkv_proj, + weight_strategy, + input_strategy, + format="sparse-24-bitmask") + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output + + +@pytest.mark.skipif(not sparse_cutlass_supported(), + reason="cutlass is not yet supported on this GPU type.") +@pytest.mark.parametrize("args_2of4", [ + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM", + "channel", "token"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM", + "channel", "tensor"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM", + "tensor", "token"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM", + "tensor", "tensor"), +]) +def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4): + model, weight_strategy, input_strategy = args_2of4 + with vllm_runner(model) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert qkv_proj.scheme.weights_dtype == torch.int8 + _test_2of4_quant_models(qkv_proj, + weight_strategy, + input_strategy, + format="sparse-24-bitmask") + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output + + @pytest.mark.skipif(not sparse_cutlass_supported(), reason="Sparse FP8 is not yet supported on this GPU type.") @pytest.mark.parametrize("args_2of4", [ @@ -343,3 +412,35 @@ def check_model(model): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) assert output + + +@pytest.mark.skipif(not sparse_cutlass_supported(), + reason="Cutlass is not yet supported on this GPU type.") +@pytest.mark.parametrize( + "args_2of4", + [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")]) +def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): + model = args_2of4 + with vllm_runner(model) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensors24) + + assert qkv_proj.scheme.weight_quant is None + assert qkv_proj.scheme.input_quant is None + assert not qkv_proj.scheme.quantized + assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map + sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 + assert sparsity_map.get("Linear").format == "sparse-24-bitmask" + assert sparsity_map.get("Linear").sparsity_structure == "2:4" + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b2fc2360f47f1..0298526fea3c5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -396,10 +396,16 @@ def get_scheme( sparsity_scheme=sparsity_scheme): # Have a valid sparsity scheme # Validate layer is supported by Cutlass 2:4 Kernel - scheme = CompressedTensors24(quantized=weight_quant is not None - or input_quant is not None, - weight_quant=weight_quant, - input_quant=input_quant) + model_compression_config = (None if sparsity_scheme is None + or sparsity_scheme.format == "dense" + else self.config) + + scheme = CompressedTensors24( + quantized=weight_quant is not None or input_quant is not None, + weight_quant=weight_quant, + input_quant=input_quant, + model_compression_config=model_compression_config, + ) else: # Find the quant_scheme scheme = self._get_scheme_from_parts( # type: ignore @@ -447,10 +453,21 @@ def supports_cutlass_24( :return: True if the layer is supported by the Cutlass 2:4 Kernel False otherwise """ - is_valid_sparsity = (sparsity_scheme is not None - and sparsity_scheme.sparsity_structure - == SparsityStructure.TWO_FOUR.value - and sparsity_scheme.format == "dense") + if sparsity_scheme is None: + return False + + is_valid_sparsity_structure: bool = ( + sparsity_scheme.sparsity_structure == + SparsityStructure.TWO_FOUR.value) + + valid_compressors = { + CompressionFormat.dense.value, + CompressionFormat.sparse_24_bitmask.value + } + + is_valid_sparsity = (is_valid_sparsity_structure + and sparsity_scheme.format in valid_compressors) + if not is_valid_sparsity: return False diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 21e6fe7a22616..da652e00bde8d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -1,11 +1,15 @@ -from typing import Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch +from compressed_tensors import CompressionFormat, ModelCompressor from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) +from compressed_tensors.utils import combine_shards from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -20,14 +24,24 @@ class CompressedTensors24(CompressedTensorsScheme): - def __init__(self, - quantized: bool = False, - weight_quant: Optional[QuantizationArgs] = None, - input_quant: Optional[QuantizationArgs] = None): + def __init__( + self, + quantized: bool = False, + weight_quant: Optional[QuantizationArgs] = None, + input_quant: Optional[QuantizationArgs] = None, + model_compression_config: Optional[Dict[str, Any]] = None, + ): self.quantized = quantized self.weight_quant = weight_quant self.input_quant = input_quant + self.model_compressor = ( + ModelCompressor.from_compression_config(model_compression_config) + if model_compression_config is not None else None) + self.do_sparse_decompress = ( + self.model_compressor is not None + and self.model_compressor.sparsity_config.format + == CompressionFormat.sparse_24_bitmask.value) @classmethod def get_min_capability(cls) -> int: @@ -47,6 +61,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, self.output_dtype = params_dtype layer.logical_widths = output_partition_sizes + layer.input_size = input_size + layer.input_size_per_partition = input_size_per_partition self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) # parameter to store uncompressed weight @@ -57,6 +73,34 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, input_dim=1, output_dim=0, weight_loader=weight_loader) + if self.do_sparse_decompress: + assert all(partition_size % 8 == 0 + for partition_size in output_partition_sizes + ), "All partitions must be divisible by 8 for " + "2:4 sparse compressed models" + + shape = BasevLLMParameter(data=torch.empty(2, 1, + dtype=torch.int64), + weight_loader=weight_loader) + compressed_weight = ModelWeightParameter( + data=torch.empty(sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=self.weights_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + bitmask = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 8, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("shape", shape) + layer.register_parameter("compressed", compressed_weight) + layer.register_parameter("bitmask", bitmask) # Check if quantized, not just 2:4 Sparse if self.quantized: @@ -112,6 +156,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: :param layer: The layer with the weights to be processed """ + if self.do_sparse_decompress: + layer.weight.data = self._decompress_bitmask_compressed_weight( + compressed=layer.compressed, + bitmask=layer.bitmask, + layer=layer, + ) + # torch.compile workaround if hasattr(layer, "input_scale"): layer.input_scale = torch.nn.Parameter(layer.input_scale.data, @@ -201,8 +252,55 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: raise ValueError("Quantization type not supported by Cutlass") + def _decompress_bitmask_compressed_weight( + self, compressed: torch.Tensor, bitmask: torch.Tensor, + layer: torch.nn.Module) -> torch.Tensor: + """ + Decompress a compressed 2:4 sparse weight tensor + using the bitmask and return the result. + + This function also supports sharded decompression. + + :param compressed: The 2:4 sparse weight tensor + compressed using the sparse-24-bitmask compressor. + :param bitmask: The 2:4 bitmask associated with the compressed weights. + :param layer: The layer whose weights need to be processed + after loading. + :return: The decompressed 2:4 sparse weight tensor. + """ -def check_24(tensor): - new_tensor = tensor.view(-1, 4) - zero_counts = (new_tensor == 0).sum(dim=1) - return (zero_counts >= 2).all().item() + sparsity_compressor = self.model_compressor.sparsity_compressor + + def _process_split(bitmask_compressed_weight: torch.Tensor, shape, + bitmask: torch.Tensor) -> torch.Tensor: + weight_data = dict( + compressed=bitmask_compressed_weight, + shape=shape, + bitmask=bitmask, + ) + return sparsity_compressor.decompress_weight(weight_data) + + split_weights: List[torch.Tensor] = [] + split_bitmask: List[torch.Tensor] = [] + split_shape: List[Tuple[int, int]] = [] + + if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): + split_weights = torch.split(compressed, layer.logical_widths) + split_bitmask = torch.split(bitmask, layer.logical_widths) + split_shape = [(out, layer.input_size_per_partition) + for out in layer.logical_widths] + + if split_weights: + decompressed_shards = [ + _process_split(compressed_weight, shape, bitmask) + for compressed_weight, shape, bitmask in zip( + split_weights, split_shape, split_bitmask) + ] + decompressed = combine_shards(decompressed_shards) + else: + decompressed = sparsity_compressor.decompress_weight( + dict(compressed=compressed, + shape=(layer.logical_widths[0], + layer.input_size_per_partition), + bitmask=bitmask)) + return decompressed