Skip to content

Commit

Permalink
* Pipe through model_compression_config
Browse files Browse the repository at this point in the history
* Instantiate ModelCompressor in vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
  • Loading branch information
rahul-tuli committed Dec 9, 2024
1 parent a27ca81 commit 8ad1c7d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,13 @@ def get_scheme(
sparsity_scheme=sparsity_scheme):
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme = CompressedTensors24(quantized=weight_quant is not None
or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant)
scheme = CompressedTensors24(
quantized=weight_quant is not None or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant,
model_compression_config=self._get_model_compression_config(
sparsity_scheme),
)
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore
Expand Down Expand Up @@ -433,8 +436,7 @@ def supports_cutlass_24(
"""
is_valid_sparsity = (sparsity_scheme is not None
and sparsity_scheme.sparsity_structure
== SparsityStructure.TWO_FOUR.value
and sparsity_scheme.format == "dense")
== SparsityStructure.TWO_FOUR.value)
if not is_valid_sparsity:
return False

Expand Down Expand Up @@ -465,6 +467,19 @@ def supports_cutlass_24(

return weight_quant.num_bits == input_quant.num_bits == 8

def _get_model_compression_config(
self, sparsity_scheme: Optional[SparsityCompressionConfig] = None):
"""
Get the model compressor config from the sparsity scheme
:param sparsity_scheme: The sparsity scheme
:return: The model compressor config
"""
if sparsity_scheme is None or sparsity_scheme.format == "dense":
return None

return self.config


class CompressedTensorsLinearMethod(LinearMethodBase):

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable, List, Optional
from typing import Any, Callable, Dict, List, Optional

import torch
from compressed_tensors import ModelCompressor
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
Expand All @@ -23,11 +24,15 @@ class CompressedTensors24(CompressedTensorsScheme):
def __init__(self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None):
input_quant: Optional[QuantizationArgs] = None,
model_compression_config: Optional[Dict[str, Any]] = None):

self.quantized = quantized
self.weight_quant = weight_quant
self.input_quant = input_quant
self.model_compressor = (
ModelCompressor.from_compression_config(model_compression_config)
if model_compression_config is not None else None)

@classmethod
def get_min_capability(cls) -> int:
Expand Down

0 comments on commit 8ad1c7d

Please sign in to comment.