Skip to content

Commit 56bedf4

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 0ae992f commit 56bedf4

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

neural_compressor/adaptor/ox_utils/weight_only.py

+26-26
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,7 @@ def make_matmul_weight_only_node(
103103
elif num_bits == 8:
104104
packed = q_weight
105105
else:
106-
logger.error(
107-
"MatMulNBits does not have kernel support for num_bits = {}.".format(num_bits)
108-
)
106+
logger.error("MatMulNBits does not have kernel support for num_bits = {}.".format(num_bits))
109107

110108
packed = np.reshape(packed, (-1, k_blocks, blob_size))
111109

@@ -272,44 +270,44 @@ def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32):
272270
scale: scale
273271
zero_point: zero point
274272
"""
275-
data = np.reshape(data, (-1, group_size)).astype(np.float32) # nb = data.shape[0], (nb, group_size)
273+
data = np.reshape(data, (-1, group_size)).astype(np.float32) # nb = data.shape[0], (nb, group_size)
276274
maxq = 2**num_bits - 1
277275
minq = 0
278-
sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1)
279-
av_x = np.sqrt(sum_x2 / group_size) # (nb, 1)
280-
weights = np.add(av_x, np.abs(data)) # (nb, group_size)
281-
rmin = np.min(data, axis=1, keepdims=True) # (nb, 1)
282-
rmax = np.max(data, axis=1, keepdims=True) # (nb, 1)
283-
sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1)
284-
sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
285-
iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
276+
sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1)
277+
av_x = np.sqrt(sum_x2 / group_size) # (nb, 1)
278+
weights = np.add(av_x, np.abs(data)) # (nb, group_size)
279+
rmin = np.min(data, axis=1, keepdims=True) # (nb, 1)
280+
rmax = np.max(data, axis=1, keepdims=True) # (nb, 1)
281+
sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1)
282+
sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
283+
iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
286284
mask = rmin != rmax
287285
iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
288286
scale = 1 / iscale
289-
quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
290-
diff = scale * quant_data + rmin - data # (nb, group_size)
291-
best_mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1)
287+
quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
288+
diff = scale * quant_data + rmin - data # (nb, group_size)
289+
best_mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
292290
nstep = 20
293291
rdelta = 0.1
294292
# nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1
295293
rrmin = -1
296294
for is_ in range(nstep):
297-
iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
295+
iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
298296
factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
299297
mask = rmin != rmax
300298
iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
301-
quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
299+
quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
302300
mul_weights_quant_data_new = weights * quant_data_new
303-
sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
304-
sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
305-
sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
306-
D = np.subtract(sum_w * sum_l2, sum_l ** 2) # (nb, 1)
301+
sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
302+
sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
303+
sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
304+
D = np.subtract(sum_w * sum_l2, sum_l**2) # (nb, 1)
307305

308-
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
309-
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
306+
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
307+
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
310308

311-
diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
312-
mad = np.sum(weights * diff ** 2, axis=1, keepdims=True) # (nb, 1)
309+
diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
310+
mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
313311

314312
mad_1 = np.array(mad)
315313
best_mad_1 = np.array(best_mad)
@@ -538,7 +536,9 @@ def rtn_quantize(
538536
weight = pad_tensor(weight, group_size, k_blocks)
539537

540538
enable_MatMulNBits_8bits = True
541-
satisfy_MatMulNBits_condition = (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (enable_MatMulNBits_8bits and num_bits == 8)
539+
satisfy_MatMulNBits_condition = (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (
540+
enable_MatMulNBits_8bits and num_bits == 8
541+
)
542542
satisfy_MatMulFpQ4_condition = (
543543
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
544544
)

0 commit comments

Comments
 (0)