Skip to content

Commit

Permalink
initiail support of q4_1
Browse files Browse the repository at this point in the history
  • Loading branch information
rnwang04 committed Dec 3, 2024
1 parent 598603b commit 77dabd1
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 8 deletions.
1 change: 1 addition & 0 deletions python/llm/src/ipex_llm/ggml/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"fp6_k": 30,
"sym_int4_rtn": 31,
"sym_int8_rtn": 32,
"asym_int4_rtn": 33,
}

# mixed precison from llama.cpp
Expand Down
16 changes: 11 additions & 5 deletions python/llm/src/ipex_llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@
FP6_K = ggml_tensor_qtype["fp6_k"]
SYM_INT4_RTN = ggml_tensor_qtype["sym_int4_rtn"]
SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"]
ASYM_INT4_RTN = ggml_tensor_qtype["asym_int4_rtn"]
RTN_DTYPE = {
SYM_INT4_RTN: torch.uint8,
ASYM_INT4_RTN: torch.uint8,
SYM_INT8_RTN: torch.int8,
}

Expand Down Expand Up @@ -223,12 +225,16 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
f"Last dim of input tensor must be multiple of {QK}")

dst_size = (n // QK) * block_size_in_bytes
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype],
device=device)
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
scale = torch.empty(n // k, dtype=torch.float32,
device=device)
if qtype == ASYM_INT4_RTN:
scale = torch.empty((n // k) * 2, dtype=torch.float32,
device=device)
else:
scale = torch.empty(n // k, dtype=torch.float32,
device=device)
elif qtype == NF4:
# Deepspeed zero3 requires unified dtype,
# thus here uses bfloat16 consistent to other layers
Expand All @@ -244,7 +250,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
hist = (ctypes.c_int64 * 16)()
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
if imatrix is None:
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
Expand All @@ -269,7 +275,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
ggml.ggml_quantize_tensor_with_weights(src, dst, qtype,
n // in_features, in_features,
hist, imatrix)
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
return dst_tensor, scale.type(torch.float16)
else:
return dst_tensor
Expand Down
1 change: 1 addition & 0 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def from_pretrained(cls, *args, **kwargs):
qtype_map = {
"sym_int4": "sym_int4_rtn",
"sym_int8": "sym_int8_rtn",
"asym_int4": "asym_int4_rtn",
}

invalidInputError(
Expand Down
19 changes: 16 additions & 3 deletions python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,12 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
iqtype, device=device,
enable_scale_search=enable_scale_search,
imatrix=imatrix)
return QuantizedLinear(qweights, scale, layer.bias,
group_size=group_size)
min = None
# split scale to scale & min
if qtype == "asym_int4_rtn":
scale, min = torch.split(scale, scale.shape[0] // 2)
return QuantizedLinear(qweights, scale, min, layer.bias,
group_size=group_size, qtype=qtype)


@module_optimization
Expand All @@ -110,12 +114,21 @@ def replace_with_DequantizedLinear(layer, qtype, device, modules_to_not_convert,
from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype]
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
if qtype == "sym_int4_rtn":
# workaround for qwen2-7B & int4
if (layer.in_features == 3584 and layer.out_features == 152064):
qtype = "sym_int8_rtn"
iqtype = ggml_tensor_qtype[qtype]
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
iqtype, device=device,
enable_scale_search=enable_scale_search,
imatrix=imatrix)
return DequantizedLinear(qweights, scale, layer.bias)
min = None
# split scale to scale & min
if qtype == "asym_int4_rtn":
scale, min = torch.split(scale, scale.shape[0] // 2)
return DequantizedLinear(qweights, scale, min, layer.bias, qtype)


@module_optimization
Expand Down
15 changes: 15 additions & 0 deletions python/llm/src/ipex_llm/transformers/npu_models/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,20 @@ def __init__(
self,
weight: torch.Tensor,
scale: torch.Tensor,
min: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
qtype: Optional[str] = "sym_int4_rtn",
group_size: int = 0,
):
"""Initialize the QuantizedLinear class.
Args:
weight (torch.Tensor): Linear operation weight
scale (torch.Tensor): Quantization scale
min (Optional[torch.Tensor], optional): Quantization min for asym_int4_rtn
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
Defaults to None.
qtype (Optional[str], optional): qtype of this Linear
Raises:
RuntimeError: Quantized weight must be in torch.int8 format
Expand All @@ -163,6 +167,8 @@ def __init__(
self.inC *= 2
self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False)
self.bias = bias
self.min = min
self.qtype = qtype
self.op_id = str(uuid.uuid4())

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -197,6 +203,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

out = run_matmul(x, self.weight.data, self.scale.data, self.op_id)

if self.qtype == "asym_int4_rtn" and self.min is not None:
out = out + self.min

if self.bias is None:
return out
return out + self.bias
Expand All @@ -209,14 +218,18 @@ def __init__(
self,
weight: torch.Tensor,
scale: torch.Tensor,
min: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
qtype: Optional[str] = "sym_int4_rtn",
):
"""Initialize the DequantizedLinear class.
Args:
weight (torch.Tensor): Linear operation quantized weight
scale (torch.Tensor): Quantization scale
min (Optional[torch.Tensor], optional): Quantization min for asym_int4_rtn
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
Defaults to None.
qtype (Optional[str], optional): qtype of this Linear
Raises:
RuntimeError: Quantized weight must be in torch.int8 format
"""
Expand All @@ -240,6 +253,8 @@ def __init__(
decompressed_weight = combined_weight.view(combined_weight.size(0), -1)
dequantized_weight = decompressed_weight.to(torch.float32) * \
torch.unsqueeze(scale.to(torch.float32), dim=1)
if qtype == "asym_int4_rtn" and min is not None:
dequantized_weight = dequantized_weight + torch.unsqueeze(min.to(torch.float32), dim=1)
self.weight = Parameter(dequantized_weight, requires_grad=False).contiguous()
else:
dequantized_weight = weight.to(torch.float32) * \
Expand Down

0 comments on commit 77dabd1

Please sign in to comment.