@@ -103,9 +103,7 @@ def make_matmul_weight_only_node(
103
103
elif num_bits == 8 :
104
104
packed = q_weight
105
105
else :
106
- logger .error (
107
- "MatMulNBits does not have kernel support for num_bits = {}." .format (num_bits )
108
- )
106
+ logger .error ("MatMulNBits does not have kernel support for num_bits = {}." .format (num_bits ))
109
107
110
108
packed = np .reshape (packed , (- 1 , k_blocks , blob_size ))
111
109
@@ -272,44 +270,44 @@ def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32):
272
270
scale: scale
273
271
zero_point: zero point
274
272
"""
275
- data = np .reshape (data , (- 1 , group_size )).astype (np .float32 ) # nb = data.shape[0], (nb, group_size)
273
+ data = np .reshape (data , (- 1 , group_size )).astype (np .float32 ) # nb = data.shape[0], (nb, group_size)
276
274
maxq = 2 ** num_bits - 1
277
275
minq = 0
278
- sum_x2 = np .sum (data ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
279
- av_x = np .sqrt (sum_x2 / group_size ) # (nb, 1)
280
- weights = np .add (av_x , np .abs (data )) # (nb, group_size)
281
- rmin = np .min (data , axis = 1 , keepdims = True ) # (nb, 1)
282
- rmax = np .max (data , axis = 1 , keepdims = True ) # (nb, 1)
283
- sum_w = np .sum (weights , axis = 1 , keepdims = True ) # (nb, 1)
284
- sum_x = np .sum (weights * data , axis = 1 , keepdims = True ) # (nb, group_size)
285
- iscale = np .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
276
+ sum_x2 = np .sum (data ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
277
+ av_x = np .sqrt (sum_x2 / group_size ) # (nb, 1)
278
+ weights = np .add (av_x , np .abs (data )) # (nb, group_size)
279
+ rmin = np .min (data , axis = 1 , keepdims = True ) # (nb, 1)
280
+ rmax = np .max (data , axis = 1 , keepdims = True ) # (nb, 1)
281
+ sum_w = np .sum (weights , axis = 1 , keepdims = True ) # (nb, 1)
282
+ sum_x = np .sum (weights * data , axis = 1 , keepdims = True ) # (nb, group_size)
283
+ iscale = np .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
286
284
mask = rmin != rmax
287
285
iscale [mask ] = (maxq - minq ) / (rmax [mask ] - rmin [mask ])
288
286
scale = 1 / iscale
289
- quant_data = np .clip (np .round (iscale * (data - rmin )), minq , maxq ) # (nb, group_size)
290
- diff = scale * quant_data + rmin - data # (nb, group_size)
291
- best_mad = np .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
287
+ quant_data = np .clip (np .round (iscale * (data - rmin )), minq , maxq ) # (nb, group_size)
288
+ diff = scale * quant_data + rmin - data # (nb, group_size)
289
+ best_mad = np .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
292
290
nstep = 20
293
291
rdelta = 0.1
294
292
# nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1
295
293
rrmin = - 1
296
294
for is_ in range (nstep ):
297
- iscale_new = np .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
295
+ iscale_new = np .ones (rmax .shape , dtype = data .dtype ) # (nb, 1)
298
296
factor = np .array ([rrmin + rdelta * is_ + maxq - minq ]).astype (data .dtype )[0 ]
299
297
mask = rmin != rmax
300
298
iscale_new [mask ] = factor / (rmax [mask ] - rmin [mask ])
301
- quant_data_new = np .clip (np .round (iscale_new * (data - rmin )), minq , maxq ) # (nb, group_size)
299
+ quant_data_new = np .clip (np .round (iscale_new * (data - rmin )), minq , maxq ) # (nb, group_size)
302
300
mul_weights_quant_data_new = weights * quant_data_new
303
- sum_l = np .sum (mul_weights_quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
304
- sum_l2 = np .sum (mul_weights_quant_data_new * quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
305
- sum_xl = np .sum (mul_weights_quant_data_new * data , axis = 1 , keepdims = True ) # (nb, 1)
306
- D = np .subtract (sum_w * sum_l2 , sum_l ** 2 ) # (nb, 1)
301
+ sum_l = np .sum (mul_weights_quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
302
+ sum_l2 = np .sum (mul_weights_quant_data_new * quant_data_new , axis = 1 , keepdims = True ) # (nb, 1)
303
+ sum_xl = np .sum (mul_weights_quant_data_new * data , axis = 1 , keepdims = True ) # (nb, 1)
304
+ D = np .subtract (sum_w * sum_l2 , sum_l ** 2 ) # (nb, 1)
307
305
308
- this_scale = (sum_w * sum_xl - sum_x * sum_l ) / D # (nb, 1)
309
- this_min = (sum_l2 * sum_x - sum_l * sum_xl ) / D # (nb, 1)
306
+ this_scale = (sum_w * sum_xl - sum_x * sum_l ) / D # (nb, 1)
307
+ this_min = (sum_l2 * sum_x - sum_l * sum_xl ) / D # (nb, 1)
310
308
311
- diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
312
- mad = np .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
309
+ diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
310
+ mad = np .sum (weights * diff ** 2 , axis = 1 , keepdims = True ) # (nb, 1)
313
311
314
312
mad_1 = np .array (mad )
315
313
best_mad_1 = np .array (best_mad )
@@ -538,7 +536,9 @@ def rtn_quantize(
538
536
weight = pad_tensor (weight , group_size , k_blocks )
539
537
540
538
enable_MatMulNBits_8bits = True
541
- satisfy_MatMulNBits_condition = (Version (ort .__version__ ) > ONNXRT1161_VERSION and num_bits == 4 ) or (enable_MatMulNBits_8bits and num_bits == 8 )
539
+ satisfy_MatMulNBits_condition = (Version (ort .__version__ ) > ONNXRT1161_VERSION and num_bits == 4 ) or (
540
+ enable_MatMulNBits_8bits and num_bits == 8
541
+ )
542
542
satisfy_MatMulFpQ4_condition = (
543
543
Version (ort .__version__ ) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
544
544
)
0 commit comments