Skip to content

Commit

Permalink
remove compressed support; validate against ct models for tp=1,24
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Dec 6, 2024
1 parent a8a1b57 commit 4e060df
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 202 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,20 @@

SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]]


class CompressedTensorsConfig(QuantizationConfig):

def __init__(self,
target_scheme_map: Dict[str, Any],
ignore: List[str],
quant_format: str,
kv_cache_scheme: Optional[Dict[str, Any]] = None,
sparsity_scheme_map: Optional[Dict[str, SparsityCompressionConfig]] = None,
config: Optional[Dict[str, Any]] = None,
):
def __init__(
self,
target_scheme_map: Dict[str, Any],
ignore: List[str],
quant_format: str,
kv_cache_scheme: Optional[Dict[str, Any]] = None,
sparsity_scheme_map: Optional[Dict[str,
SparsityCompressionConfig]] = None,
config: Optional[Dict[str, Any]] = None,
):

self.ignore = ignore
self.quant_format = quant_format
Expand Down Expand Up @@ -92,8 +96,10 @@ def get_quant_method(
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
ignore: List[str] = cast(List[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
sparsity_scheme_map = cls._sparsity_scheme_map_from_config(config=config)
target_scheme_map = cls._quantization_scheme_map_from_config(
config=config)
sparsity_scheme_map = cls._sparsity_scheme_map_from_config(
config=config)

return cls(
target_scheme_map=target_scheme_map,
Expand All @@ -102,26 +108,30 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
sparsity_scheme_map=sparsity_scheme_map,
config=config,
)

@classmethod
def _sparsity_scheme_map_from_config(cls, config: Dict[str, Any]) -> Dict[str, SparsityCompressionConfig]:
def _sparsity_scheme_map_from_config(
cls, config: Dict[str,
Any]) -> Dict[str, SparsityCompressionConfig]:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
sparsity compression configurations
"""
if (sparsity_config:=config.get(SPARSITY_CONFIG_NAME)) is None:
if (sparsity_config := config.get(SPARSITY_CONFIG_NAME)) is None:
return dict()

sparsity_config = SparsityCompressionConfig.model_validate(sparsity_config)

sparsity_config = SparsityCompressionConfig.model_validate(
sparsity_config)
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
target: sparsity_config
for target in sparsity_config.targets or list()
}
return sparse_scheme_map

@classmethod
def _quantization_scheme_map_from_config(cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
def _quantization_scheme_map_from_config(
cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
Expand All @@ -144,7 +154,8 @@ def _quantization_scheme_map_from_config(cls, config: Dict[str, Any]) -> QUANTIZ
targets = quant_config.get("targets")
for target in targets:
target_scheme_map[target] = {}
target_scheme_map[target]["weights"] = QuantizationArgs.model_validate(
target_scheme_map[target][
"weights"] = QuantizationArgs.model_validate(
quant_config.get("weights"))

target_scheme_map[target]["input_activations"] = None
Expand All @@ -158,7 +169,8 @@ def _quantization_scheme_map_from_config(cls, config: Dict[str, Any]) -> QUANTIZ
assert target_scheme_map[target][
"weights"].type == QuantizationType.FLOAT
else:
target_scheme_map[target]["input_activations"] = QuantizationArgs.model_validate(
target_scheme_map[target][
"input_activations"] = QuantizationArgs.model_validate(
quant_config.get("input_activations"))
return target_scheme_map

Expand Down Expand Up @@ -359,7 +371,7 @@ def get_scheme(
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this

matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
Expand All @@ -369,42 +381,38 @@ def get_scheme(
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")

sparsity_scheme: Optional[SparsityCompressionConfig] = self.sparsity_scheme_map.get(matched_target)
sparsity_scheme: Optional[
SparsityCompressionConfig] = self.sparsity_scheme_map.get(
matched_target)

if self.supports_cutlass_24(
weight_quant=weight_quant,
input_quant=input_quant,
sparsity_scheme=sparsity_scheme
):
# Have a valid sparsity scheme and the layer is supported by the Cutlass 2:4 Kernel
needs_decompression = sparsity_scheme.format != CompressionFormat.dense.value
is_quantized = weight_quant is not None or input_quant is not None
if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant,
sparsity_scheme=sparsity_scheme):
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme = CompressedTensors24(
layer_name=layer_name,
quantized=is_quantized,
do_decompress=needs_decompression,
quantized=weight_quant is not None or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant,
config=self.config,
input_quant=input_quant
)
else:
# Find the quant_scheme
# Find the quant_scheme
scheme = self._get_scheme_from_parts(
weight_quant=weight_quant,
input_quant=input_quant,
)
)

# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
return scheme

@staticmethod
def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs],
input_quant: Optional[QuantizationArgs],
sparsity_scheme: Optional[SparsityCompressionConfig]=None
) -> bool:
weight_quant: Optional[QuantizationArgs],
input_quant: Optional[QuantizationArgs],
sparsity_scheme: Optional[SparsityCompressionConfig] = None
) -> bool:
"""
Check if the layer is supported by the Cutlass 2:4 Kernel
Conditions:
Expand All @@ -418,39 +426,37 @@ def supports_cutlass_24(
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""

if (
sparsity_scheme is None or
sparsity_scheme.sparsity_structure != SparsityStructure.TWO_FOUR.value
):
is_valid_sparsity = (sparsity_scheme is not None
and sparsity_scheme.sparsity_structure
== SparsityStructure.TWO_FOUR.value
and sparsity_scheme.format == "dense")
if not is_valid_sparsity:
return False

# Unquantized cases are supported
if weight_quant is None and input_quant is None:
return True

# Weight only quantization is not-supported
if weight_quant is not None and input_quant is None:
return False

supported_weight_quant_strategies = [
QuantizationStrategy.TENSOR.value,
QuantizationStrategy.CHANNEL.value
]

if weight_quant.strategy not in supported_weight_quant_strategies:
return False

supported_input_quant_strategies = [
QuantizationStrategy.TENSOR.value,
QuantizationStrategy.TOKEN.value
QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value
]

if input_quant.strategy not in supported_input_quant_strategies:
return False

return weight_quant.num_bits == input_quant.num_bits == 8

return weight_quant.num_bits == input_quant.num_bits == 8


class CompressedTensorsLinearMethod(LinearMethodBase):
Expand Down
Loading

0 comments on commit 4e060df

Please sign in to comment.