Skip to content

Commit

Permalink
Merge pull request #37 from neuralmagic/dipika/sem-struc-uncomp
Browse files Browse the repository at this point in the history
Update 2:4 Support
  • Loading branch information
dsikka authored Dec 6, 2024
2 parents a8a1b57 + 0987c98 commit c7d1cc3
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 210 deletions.
35 changes: 34 additions & 1 deletion tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16, CompressedTensors24)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -178,3 +178,36 @@ def test_compressed_tensors_kv_cache(vllm_runner):
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
output = llm.generate_greedy("Hello world!", max_tokens=20)
assert output


@pytest.mark.parametrize("args_2of4", [
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel",
"token"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
"channel", "tensor"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor",
"tensor"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor", "token")
])
def test_compressed_tensors_2of4(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)

assert qkv_proj.scheme.weight_quant.strategy == weight_strategy
assert qkv_proj.scheme.input_quant.strategy == input_strategy
assert qkv_proj.scheme.quantized == True
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map
assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"

output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
1 change: 1 addition & 0 deletions tests/weight_loading/models.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,20 @@

SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]]


class CompressedTensorsConfig(QuantizationConfig):

def __init__(self,
target_scheme_map: Dict[str, Any],
ignore: List[str],
quant_format: str,
kv_cache_scheme: Optional[Dict[str, Any]] = None,
sparsity_scheme_map: Optional[Dict[str, SparsityCompressionConfig]] = None,
config: Optional[Dict[str, Any]] = None,
):
def __init__(
self,
target_scheme_map: Dict[str, Any],
ignore: List[str],
quant_format: str,
kv_cache_scheme: Optional[Dict[str, Any]] = None,
sparsity_scheme_map: Optional[Dict[str,
SparsityCompressionConfig]] = None,
config: Optional[Dict[str, Any]] = None,
):

self.ignore = ignore
self.quant_format = quant_format
Expand Down Expand Up @@ -92,8 +96,10 @@ def get_quant_method(
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
ignore: List[str] = cast(List[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
sparsity_scheme_map = cls._sparsity_scheme_map_from_config(config=config)
target_scheme_map = cls._quantization_scheme_map_from_config(
config=config)
sparsity_scheme_map = cls._sparsity_scheme_map_from_config(
config=config)

return cls(
target_scheme_map=target_scheme_map,
Expand All @@ -102,26 +108,30 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
sparsity_scheme_map=sparsity_scheme_map,
config=config,
)

@classmethod
def _sparsity_scheme_map_from_config(cls, config: Dict[str, Any]) -> Dict[str, SparsityCompressionConfig]:
def _sparsity_scheme_map_from_config(
cls, config: Dict[str,
Any]) -> Dict[str, SparsityCompressionConfig]:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
sparsity compression configurations
"""
if (sparsity_config:=config.get(SPARSITY_CONFIG_NAME)) is None:
if (sparsity_config := config.get(SPARSITY_CONFIG_NAME)) is None:
return dict()

sparsity_config = SparsityCompressionConfig.model_validate(sparsity_config)

sparsity_config = SparsityCompressionConfig.model_validate(
sparsity_config)
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
target: sparsity_config
for target in sparsity_config.targets or list()
}
return sparse_scheme_map

@classmethod
def _quantization_scheme_map_from_config(cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
def _quantization_scheme_map_from_config(
cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
Expand All @@ -144,7 +154,8 @@ def _quantization_scheme_map_from_config(cls, config: Dict[str, Any]) -> QUANTIZ
targets = quant_config.get("targets")
for target in targets:
target_scheme_map[target] = {}
target_scheme_map[target]["weights"] = QuantizationArgs.model_validate(
target_scheme_map[target][
"weights"] = QuantizationArgs.model_validate(
quant_config.get("weights"))

target_scheme_map[target]["input_activations"] = None
Expand All @@ -158,7 +169,8 @@ def _quantization_scheme_map_from_config(cls, config: Dict[str, Any]) -> QUANTIZ
assert target_scheme_map[target][
"weights"].type == QuantizationType.FLOAT
else:
target_scheme_map[target]["input_activations"] = QuantizationArgs.model_validate(
target_scheme_map[target][
"input_activations"] = QuantizationArgs.model_validate(
quant_config.get("input_activations"))
return target_scheme_map

Expand Down Expand Up @@ -359,7 +371,7 @@ def get_scheme(
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this

matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
Expand All @@ -369,42 +381,37 @@ def get_scheme(
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")

sparsity_scheme: Optional[SparsityCompressionConfig] = self.sparsity_scheme_map.get(matched_target)

if self.supports_cutlass_24(
weight_quant=weight_quant,
input_quant=input_quant,
sparsity_scheme=sparsity_scheme
):
# Have a valid sparsity scheme and the layer is supported by the Cutlass 2:4 Kernel
needs_decompression = sparsity_scheme.format != CompressionFormat.dense.value
is_quantized = weight_quant is not None or input_quant is not None
scheme = CompressedTensors24(
layer_name=layer_name,
quantized=is_quantized,
do_decompress=needs_decompression,
weight_quant=weight_quant,
input_quant=input_quant,
config=self.config,
)
sparsity_scheme: Optional[
SparsityCompressionConfig] = self.sparsity_scheme_map.get(
matched_target)

if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant,
sparsity_scheme=sparsity_scheme):
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme = CompressedTensors24(quantized=weight_quant is not None
or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant)
else:
# Find the quant_scheme
# Find the quant_scheme
scheme = self._get_scheme_from_parts(
weight_quant=weight_quant,
input_quant=input_quant,
)
)

# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
return scheme

@staticmethod
def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs],
input_quant: Optional[QuantizationArgs],
sparsity_scheme: Optional[SparsityCompressionConfig]=None
) -> bool:
weight_quant: Optional[QuantizationArgs],
input_quant: Optional[QuantizationArgs],
sparsity_scheme: Optional[SparsityCompressionConfig] = None
) -> bool:
"""
Check if the layer is supported by the Cutlass 2:4 Kernel
Conditions:
Expand All @@ -418,39 +425,37 @@ def supports_cutlass_24(
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""

if (
sparsity_scheme is None or
sparsity_scheme.sparsity_structure != SparsityStructure.TWO_FOUR.value
):
is_valid_sparsity = (sparsity_scheme is not None
and sparsity_scheme.sparsity_structure
== SparsityStructure.TWO_FOUR.value
and sparsity_scheme.format == "dense")
if not is_valid_sparsity:
return False

# Unquantized cases are supported
if weight_quant is None and input_quant is None:
return True

# Weight only quantization is not-supported
if weight_quant is not None and input_quant is None:
return False

supported_weight_quant_strategies = [
QuantizationStrategy.TENSOR.value,
QuantizationStrategy.CHANNEL.value
]

if weight_quant.strategy not in supported_weight_quant_strategies:
return False

supported_input_quant_strategies = [
QuantizationStrategy.TENSOR.value,
QuantizationStrategy.TOKEN.value
QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value
]

if input_quant.strategy not in supported_input_quant_strategies:
return False

return weight_quant.num_bits == input_quant.num_bits == 8

return weight_quant.num_bits == input_quant.num_bits == 8


class CompressedTensorsLinearMethod(LinearMethodBase):
Expand Down
Loading

0 comments on commit c7d1cc3

Please sign in to comment.