Skip to content

Commit

Permalink
[NPU] initial support of asym_int4_rtn (#12484)
Browse files Browse the repository at this point in the history
* initiail support of q4_1

* fix

* fix

* update

* update min to Z1

* update

* fix

* update

* fix style

* fix

* support qwen2 optimize_model=True mp version

* temp save

* fix

* fix style

* replace min with zero

* support split linear for q4_1

* fix lm_head with mixed_precision=True

* fix style

* revert test code

* add down proj back for q4_0

* remove print
  • Loading branch information
rnwang04 authored Dec 5, 2024
1 parent 60bafab commit 49ab897
Show file tree
Hide file tree
Showing 12 changed files with 264 additions and 81 deletions.
1 change: 1 addition & 0 deletions python/llm/src/ipex_llm/ggml/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"fp6_k": 30,
"sym_int4_rtn": 31,
"sym_int8_rtn": 32,
"asym_int4_rtn": 33,
}

# mixed precison from llama.cpp
Expand Down
16 changes: 11 additions & 5 deletions python/llm/src/ipex_llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@
FP6_K = ggml_tensor_qtype["fp6_k"]
SYM_INT4_RTN = ggml_tensor_qtype["sym_int4_rtn"]
SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"]
ASYM_INT4_RTN = ggml_tensor_qtype["asym_int4_rtn"]
RTN_DTYPE = {
SYM_INT4_RTN: torch.uint8,
ASYM_INT4_RTN: torch.uint8,
SYM_INT8_RTN: torch.int8,
}

Expand Down Expand Up @@ -223,12 +225,16 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
f"Last dim of input tensor must be multiple of {QK}")

dst_size = (n // QK) * block_size_in_bytes
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype],
device=device)
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
scale = torch.empty(n // k, dtype=torch.float32,
device=device)
if qtype == ASYM_INT4_RTN:
scale = torch.empty((n // k) * 2, dtype=torch.float32,
device=device)
else:
scale = torch.empty(n // k, dtype=torch.float32,
device=device)
elif qtype == NF4:
# Deepspeed zero3 requires unified dtype,
# thus here uses bfloat16 consistent to other layers
Expand All @@ -244,7 +250,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
hist = (ctypes.c_int64 * 16)()
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
if imatrix is None:
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
Expand All @@ -269,7 +275,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
ggml.ggml_quantize_tensor_with_weights(src, dst, qtype,
n // in_features, in_features,
hist, imatrix)
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
return dst_tensor, scale.type(torch.float16)
else:
return dst_tensor
Expand Down
8 changes: 5 additions & 3 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def from_pretrained(cls, *args, **kwargs):
qtype_map = {
"sym_int4": "sym_int4_rtn",
"sym_int8": "sym_int8_rtn",
"asym_int4": "asym_int4_rtn",
}

invalidInputError(
Expand Down Expand Up @@ -154,7 +155,7 @@ def from_pretrained(cls, *args, **kwargs):
f"but got {quantization_group_size}"
)
)
_args = copy.deepcopy(args)

_kwargs = copy.deepcopy(kwargs)

try:
Expand Down Expand Up @@ -270,6 +271,7 @@ def optimize_npu_model(cls, *args, **kwargs):
with torch.no_grad():
model.config.update({"mixed_precision": mixed_precision})
model.config.update({"group_size": quantization_group_size})
model.config.update({"asym": qtype == "asym_int4_rtn"})
optimize_llm_pre(model, qtype, mixed_precision,
quantization_group_size=quantization_group_size)
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
Expand Down Expand Up @@ -416,9 +418,9 @@ def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs)
)

invalidInputError(
qtype in ["sym_int8_rtn", "sym_int4_rtn"],
qtype in ["sym_int8_rtn", "sym_int4_rtn", "asym_int4_rtn"],
f"Unknown bigdl_transformers_low_bit value: {qtype},"
f" expected: sym_int8_rtn, sym_int4_rtn. "
f" expected: sym_int8_rtn, sym_int4_rtn, asym_int4_rtn. "
)

if enable_cpp_backend:
Expand Down
28 changes: 22 additions & 6 deletions python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,26 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype]
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
if qtype == "sym_int4_rtn":
if qtype in ["sym_int4_rtn", "asym_int4_rtn"]:
# workaround for qwen2-7B & int4
if (layer.in_features == 3584 and layer.out_features == 152064) or \
(layer.in_features == 18944 and layer.out_features == 3584):
if (layer.in_features == 3584 and layer.out_features == 152064):
qtype = "sym_int8_rtn"
iqtype = ggml_tensor_qtype[qtype]
if qtype == "sym_int4_rtn":
if (layer.in_features == 18944 and layer.out_features == 3584):
qtype = "sym_int8_rtn"
iqtype = ggml_tensor_qtype[qtype]
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
iqtype, device=device,
enable_scale_search=enable_scale_search,
imatrix=imatrix)
return QuantizedLinear(qweights, scale, layer.bias,
group_size=group_size)
zero = None
# split scale to scale & zero
if qtype == "asym_int4_rtn":
scale, zero = torch.split(scale, scale.shape[0] // 2)
return QuantizedLinear(qweights, scale, zero, layer.bias,
group_size=group_size, qtype=qtype)


@module_optimization
Expand All @@ -111,12 +118,21 @@ def replace_with_DequantizedLinear(layer, qtype, device, modules_to_not_convert,
from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype]
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
if qtype in ["sym_int4_rtn", "asym_int4_rtn"]:
# workaround for qwen2-7B & int4
if (layer.in_features == 3584 and layer.out_features == 152064):
qtype = "sym_int8_rtn"
iqtype = ggml_tensor_qtype[qtype]
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
iqtype, device=device,
enable_scale_search=enable_scale_search,
imatrix=imatrix)
return DequantizedLinear(qweights, scale, layer.bias)
zero = None
# split scale to scale & zero
if qtype == "asym_int4_rtn":
scale, zero = torch.split(scale, scale.shape[0] // 2)
return DequantizedLinear(qweights, scale, zero, layer.bias, qtype)


@module_optimization
Expand Down
17 changes: 11 additions & 6 deletions python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
from ipex_llm.transformers.npu_models.common import split_linears
if quantization_group_size == 0:
n_splits_linear = 1
if qtype == "sym_int8_rtn":
if qtype in ["sym_int8_rtn", "asym_int4_rtn"]:
# do not split mlp down_proj for Qwen2-7B & sym_int8
n_splits_down_proj = 1
else:
Expand All @@ -154,18 +154,21 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
# workaround for MiniCPM-2B
new_lm_head_0 = SlicedLMHead(model.lm_head_0.weight, split_num=split_num,
bias=model.lm_head_0.bias, use_split=True,
group_size=quantization_group_size)
group_size=quantization_group_size,
asym=(qtype == "asym_int4_rtn"))
del model.lm_head_0
model.lm_head_0 = new_lm_head_0
new_lm_head_1 = SlicedLMHead(model.lm_head_1.weight, split_num=split_num,
bias=model.lm_head_1.bias, use_split=True,
group_size=quantization_group_size)
group_size=quantization_group_size,
asym=(qtype == "asym_int4_rtn"))
del model.lm_head_1
model.lm_head_1 = new_lm_head_1
else:
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
bias=model.lm_head.bias, use_split=True,
group_size=quantization_group_size)
group_size=quantization_group_size,
asym=(qtype == "asym_int4_rtn"))
del model.lm_head
model.lm_head = new_lm_head

Expand All @@ -176,11 +179,13 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
if quantization_group_size == 0:
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
is_split = (not mixed_precision) and qtype == "sym_int4_rtn"
is_split = (not mixed_precision) and qtype in ["sym_int4_rtn", "asym_int4_rtn"]
split_num = 14 if is_split else 1
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
bias=model.lm_head.bias, use_split=True,
group_size=quantization_group_size)
group_size=quantization_group_size,
asym=((qtype == "asym_int4_rtn") and
(not mixed_precision)))
del model.lm_head
model.lm_head = new_lm_head

Expand Down
19 changes: 18 additions & 1 deletion python/llm/src/ipex_llm/transformers/npu_models/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,20 @@ def __init__(
self,
weight: torch.Tensor,
scale: torch.Tensor,
zero: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
qtype: Optional[str] = "sym_int4_rtn",
group_size: int = 0,
):
"""Initialize the QuantizedLinear class.
Args:
weight (torch.Tensor): Linear operation weight
scale (torch.Tensor): Quantization scale
zero (Optional[torch.Tensor], optional): Quantization zero for asym_int4_rtn
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
Defaults to None.
qtype (Optional[str], optional): qtype of this Linear
Raises:
RuntimeError: Quantized weight must be in torch.int8 format
Expand All @@ -155,14 +159,19 @@ def __init__(
)
)
self.outC, self.inC = self.weight.shape
self.zero = None
if group_size != 0:
self.scale = Parameter(scale, requires_grad=False)
self.zero = Parameter(zero, requires_grad=False)
else:
if self.weight.dtype == torch.uint8:
# Int4 we need to double the input channels because weights are compressed
self.inC *= 2
self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False)
if zero is not None:
self.zero = Parameter(zero * math.sqrt(self.inC), requires_grad=False)
self.bias = bias
self.qtype = qtype
self.op_id = str(uuid.uuid4())

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -195,7 +204,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
)

out = run_matmul(x, self.weight.data, self.scale.data, self.op_id)
zero_data = self.zero.data if self.zero is not None else None
out = run_matmul(x, self.weight.data, self.scale.data, zero_data, self.op_id)

if self.bias is None:
return out
Expand All @@ -209,14 +219,18 @@ def __init__(
self,
weight: torch.Tensor,
scale: torch.Tensor,
zero: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
qtype: Optional[str] = "sym_int4_rtn",
):
"""Initialize the DequantizedLinear class.
Args:
weight (torch.Tensor): Linear operation quantized weight
scale (torch.Tensor): Quantization scale
zero (Optional[torch.Tensor], optional): Quantization zero for asym_int4_rtn
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
Defaults to None.
qtype (Optional[str], optional): qtype of this Linear
Raises:
RuntimeError: Quantized weight must be in torch.int8 format
"""
Expand All @@ -240,6 +254,9 @@ def __init__(
decompressed_weight = combined_weight.view(combined_weight.size(0), -1)
dequantized_weight = decompressed_weight.to(torch.float32) * \
torch.unsqueeze(scale.to(torch.float32), dim=1)
if qtype == "asym_int4_rtn" and zero is not None:
dequantized_weight = dequantized_weight + torch.unsqueeze(zero.to(torch.float32),
dim=1)
self.weight = Parameter(dequantized_weight, requires_grad=False).contiguous()
else:
dequantized_weight = weight.to(torch.float32) * \
Expand Down
35 changes: 25 additions & 10 deletions python/llm/src/ipex_llm/transformers/npu_models/lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
dtype: np.dtype = np.int8,
use_split: bool = False,
group_size: int = 0,
asym: bool = False,
):
"""Initialize the LMHeadLinear class.
Expand All @@ -54,11 +55,10 @@ def __init__(
self.batch = batch

self.split_num = split_num

if use_split:
input = self.parameter((1, self.batch, self.inC))
res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype,
scale_factor=(group_size == 0))
scale_factor=(group_size == 0), asym=asym)
else:
input = self.parameter((self.batch, self.inC))
split_size = self.inC // split_num // 2 * 2
Expand All @@ -69,7 +69,7 @@ def __init__(
input_slice = self.slice(input, begin=[0, start_idx],
end=[self.batch, end_idx])
linear_slice = self.linear(input_slice, outC, split_size, bias=False,
wt_dtype=dtype)
wt_dtype=dtype, asym=asym)
if i == 0:
res = linear_slice
else:
Expand Down Expand Up @@ -109,7 +109,7 @@ def run(


class SlicedLMHead(nn.Module):
def __init__(self, weight, bias, split_num, use_split=False, group_size=0):
def __init__(self, weight, bias, split_num, use_split=False, group_size=0, asym=False):
super().__init__()
self.split_num = split_num
self.outC, self.inC = weight.shape
Expand All @@ -128,6 +128,7 @@ def __init__(self, weight, bias, split_num, use_split=False, group_size=0):
self.lm_heads.append(new_linear)
self.bias = bias
self.use_split = use_split
self.asym = asym

def forward(self, hidden_states):
if hidden_states.size(0) * hidden_states.size(1) == 1:
Expand Down Expand Up @@ -162,19 +163,33 @@ def get_fused_lm_head(self):
np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
False, "NPU", dtype=np_dtype, use_split=self.use_split,
group_size=self.group_size)
group_size=self.group_size, asym=self.asym)
if self.use_split:
weights = []
scales = []
zeros = []
for i in range(self.split_num):
weights.append(self.lm_heads[i].weight)
scales.append(self.lm_heads[i].scale)
fused_lm_head_weights = (torch.stack(weights, axis=0).numpy(),
torch.stack(scales, axis=0).numpy())
if self.lm_heads[i].zero is not None:
zeros.append(self.lm_heads[i].zero)
if len(zeros):
fused_lm_head_weights = [(torch.stack(weights, axis=0).numpy(),
torch.stack(scales, axis=0).numpy(),
torch.stack(zeros, axis=0).numpy())]
else:
fused_lm_head_weights = [(torch.stack(weights, axis=0).numpy(),
torch.stack(scales, axis=0).numpy())]
else:
fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
self.lm_heads[i].scale.data.numpy())
for i in range(self.split_num)]
if self.asym:
fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
self.lm_heads[i].scale.data.numpy(),
self.lm_heads[i].zero.data.numpy())
for i in range(self.split_num)]
else:
fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
self.lm_heads[i].scale.data.numpy())
for i in range(self.split_num)]

self.fused_lm_head.set_weights(self.lm_heads[0].op_id,
fused_lm_head_weights)
Loading

0 comments on commit 49ab897

Please sign in to comment.