Skip to content

Commit

Permalink
Emulated dynamic MX quantization
Browse files Browse the repository at this point in the history
Signed-off-by: mgoin <[email protected]>
  • Loading branch information
mgoin committed Nov 6, 2024
1 parent 098f94d commit bdd69c6
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 0 deletions.
14 changes: 14 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,20 @@ def _verify_quantization(self) -> None:
f"method specified in the `quantization` argument "
f"({self.quantization}).")

from vllm.model_executor.layers.quantization.mx_quant import (
SUPPORTED_MX_CONFIGS)
if self.quantization is not None and self.quantization in \
SUPPORTED_MX_CONFIGS:
dtypes = self.quantization.split("_")
# remove 'w' and 'a' prefix
weight_dtype = dtypes[0][1:]
act_dtype = None if len(dtypes) == 1 else dtypes[1][1:]
self.hf_config.quantization_config = {
"weight_dtype": weight_dtype,
"act_dtype": act_dtype,
"quant_method": self.quantization,
}

if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
from vllm.model_executor.layers.quantization.mx_quant import (
SUPPORTED_MX_CONFIGS, MXConfig)
from vllm.model_executor.layers.quantization.neuron_quant import (
NeuronQuantConfig)
from vllm.model_executor.layers.quantization.qqq import QQQConfig
Expand Down Expand Up @@ -51,6 +53,9 @@
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
# Aliases for the different MX quant configs
**{config: MXConfig
for config in SUPPORTED_MX_CONFIGS},
}


Expand Down
141 changes: 141 additions & 0 deletions vllm/model_executor/layers/quantization/mx_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from typing import Any, Dict, List, Optional

import torch
from torch.nn import Module

from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import ModelWeightParameter

MXFP8_E5M2 = "fp8e5"
MXFP8_E4M3 = "fp8e4"
MXFP6_E3M2 = "fp6e3"
MXFP6_E2M3 = "fp6e2"
MXFP4_E2M1 = "fp4e2"

# Supported element dtypes
# TODO: add support for MXINT8
SUPPORTED_DTYPES = [
MXFP8_E5M2,
MXFP8_E4M3,
MXFP6_E3M2,
MXFP6_E2M3,
MXFP4_E2M1,
]

SUPPORTED_MX_CONFIGS = [
*[
f"w{wdtype}_a{adtype}"
for wdtype, adtype in zip(SUPPORTED_DTYPES, SUPPORTED_DTYPES)
],
*[f"w{wdtype}" for wdtype in SUPPORTED_DTYPES],
]

# https://github.com/pytorch/ao/blob/71a442ae775e0ea5a541dcce637b128070d1243c/torchao/prototype/mx_formats/constants.py#L3-L18
DTYPE_MAP = {
MXFP8_E5M2: torch.float8_e5m2,
MXFP8_E4M3: torch.float8_e4m3fn,
MXFP6_E3M2: "fp6_e3m2",
MXFP6_E2M3: "fp6_e2m3",
MXFP4_E2M1: "fp4_e2m1",
}


class MXConfig(QuantizationConfig):
"""MX Quantization Configuration."""

def __init__(
self,
weight_dtype: str,
act_dtype: Optional[str],
) -> None:
if weight_dtype not in SUPPORTED_DTYPES:
raise ValueError(f"Unsupported weight scheme {weight_dtype}")
if act_dtype and act_dtype not in SUPPORTED_DTYPES:
raise ValueError(f"Unsupported activation scheme {act_dtype}")

self.weight_dtype = DTYPE_MAP[weight_dtype]
self.act_dtype = DTYPE_MAP[act_dtype] if act_dtype else None
# Hardcoded for the MX spec
self.block_size = 32

def get_name(self) -> str:
return "mx_quant"

def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]

@classmethod
def get_min_capability(cls) -> int:
return 70

@staticmethod
def get_config_filenames() -> List[str]:
return []

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MXConfig":
weight_dtype = cls.get_from_keys(config, ["weight_dtype"])
act_dtype = cls.get_from_keys(config, ["act_dtype"])
return cls(weight_dtype=weight_dtype, act_dtype=act_dtype)

def get_quant_method(self, layer: Module,
prefix: str) -> Optional["MXLinearMethod"]:
if isinstance(layer, LinearBase):
return MXLinearMethod(self)
return None


class MXLinearMethod(LinearMethodBase):
"""Linear method for MX quant. """

def __init__(self, quant_config: MXConfig):
try:
import torchao # noqa: F401
except ImportError as err:
raise ImportError("Please install torchao==0.6.1 via "
"`pip install torchao==0.6.1` to use "
"mx quantization.") from err
self.quant_config = quant_config

def create_weights(self, layer: Module, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
del input_size, output_size # Unused.
weight_loader = extra_weight_attrs.get("weight_loader")
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)

def process_weights_after_loading(self, layer: Module) -> None:
from torchao.prototype.mx_formats.mx_tensor import MXTensor
layer.weight_mx = MXTensor.to_mx(
layer.weight.data.t().contiguous().to(torch.float32),
self.quant_config.weight_dtype, self.quant_config.block_size)
layer.weight = None

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
orig_dtype = x.dtype
from torchao.prototype.mx_formats.mx_tensor import MXTensor
if not self.quant_config.act_dtype:
weight = layer.weight_mx.to_dtype(orig_dtype).t().contiguous()
out = torch.nn.functional.linear(x, weight, bias)
else:
x = MXTensor.to_mx(x.to(torch.float32),
self.quant_config.act_dtype,
self.quant_config.block_size)
out = torch.mm(x, layer.weight_mx)
if bias:
out += bias

return out.to(orig_dtype)

0 comments on commit bdd69c6

Please sign in to comment.