Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hqq support #21

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion benchmarks/kernels/benchmark_machete.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
size_m=a.shape[0],
size_n=w_ref.shape[1],
size_k=w_ref.shape[0],
is_k_full=True))))
is_k_full=True,
is_zp_float=False))))

# machete
timers.append(
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand All @@ -141,7 +141,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down
297 changes: 220 additions & 77 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions csrc/quantization/gptq_marlin/marlin_dtypes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ScalarType<half> {
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>;
using FragZP = Vec<half2, 4>;
using FragZPF = Vec<half2, 1>;

static __device__ float inline num2float(const half x) {
return __half2float(x);
Expand Down Expand Up @@ -53,6 +54,7 @@ class ScalarType<nv_bfloat16> {
using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>;
using FragZPF = Vec<nv_bfloat162, 1>;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) {
Expand Down
2 changes: 1 addition & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor");
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
// conditionally compiled so impl registration is in source file

// gptq_marlin repack from GPTQ.
Expand Down
86 changes: 85 additions & 1 deletion tests/kernels/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
marlin_qqq_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
from vllm.scalar_type import scalar_types

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
Expand Down Expand Up @@ -226,7 +227,7 @@ def test_gptq_marlin_gemm(
torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, False, use_fp32_reduce),
a_input.shape[1], is_k_full, False, use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)

output = ops.gptq_marlin_gemm(
Expand All @@ -244,6 +245,7 @@ def test_gptq_marlin_gemm(
is_k_full=is_k_full,
has_zp=False,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)

Expand Down Expand Up @@ -441,6 +443,7 @@ def test_awq_marlin_gemm(
is_k_full=is_k_full,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)

Expand All @@ -451,6 +454,87 @@ def test_awq_marlin_gemm(
assert max_diff < 0.04


@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("group_size", [64])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_hqq_marlin_gemm(
k_chunk,
n_chunk,
group_size,
mnk_factors,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors

size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor

quant_type = scalar_types.uint4

a_input = rand_data((size_m, size_k))
dev = a_input.device

b_weight = torch.randint(0,
10, (size_n, size_k),
dtype=torch.uint8,
device=dev)
scale = rand_data((size_n, size_k // group_size))
zero = rand_data((size_n, size_k // group_size))

gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)

sort_indices = torch.empty(0, dtype=torch.int, device=dev)
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
4).to(dev)
marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
group_size).to(dev)
marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
group_size).to(dev)

g_idx = marlin_make_empty_g_idx(dev)
g_idx_sort_indices = marlin_make_empty_g_idx(dev)

workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)

output = ops.gptq_marlin_gemm(
a_input,
marlin_w_q,
marlin_s,
marlin_zp,
g_idx,
g_idx_sort_indices,
workspace.scratch,
quant_type,
a_input.shape[0],
b_weight.shape[0],
a_input.shape[1],
is_k_full=True,
has_zp=True,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=True,
)

b_flat = b_weight.reshape(-1, group_size)
zp_flat = zero.reshape(-1, 1)
s_flat = scale.reshape(-1, 1)
dequant = (b_flat - zp_flat) * s_flat

output_ref = torch.matmul(a_input,
dequant.reshape(b_weight.shape).transpose(1, 0))

torch.cuda.synchronize()

max_diff = compute_max_diff(output, output_ref)

assert max_diff < 0.04


@pytest.mark.skipif(not is_quant_method_supported("qqq"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
Expand Down
8 changes: 5 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor,
size_k: torch.SymInt,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

@register_fake("_C::ggml_dequantize")
Expand Down Expand Up @@ -578,11 +579,12 @@ def gptq_marlin_gemm(a: torch.Tensor,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, b_q_type.id,
size_m, size_n, size_k, is_k_full,
has_zp, use_fp32_reduce)
has_zp, use_fp32_reduce, is_zp_float)


# fp8 marlin
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig
from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
Expand Down Expand Up @@ -48,6 +49,7 @@
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
Expand Down
Loading