@@ -381,6 +381,16 @@ def __init__(
381381
382382 self .attention_mask = []
383383
384+ self .wrapper_block = wrapper_block
385+ if self .enable_alg_ext :
386+ try :
387+ logger .warning_once ("using algorithm extension for quantization." )
388+ from auto_round .alg_ext import wrapper_autoround
389+
390+ wrapper_autoround (self )
391+ except (ImportError , ModuleNotFoundError ):
392+ logger .error ("algorithm extension import error, fallback to default mode" )
393+
384394 def _gen_auto_scheme (
385395 self , model : torch .nn .Module , scheme : AutoScheme , dataset : str , device_map : Union [str , int , dict , torch .device ]
386396 ) -> dict [str , dict ]:
@@ -2516,6 +2526,32 @@ def quantize_block(
25162526 input_ids , input_others = normalize_input (inputs )
25172527 return self ._quantize_block (block , input_ids , input_others , q_input , device , auto_offload )
25182528
2529+ def _get_loss (
2530+ self ,
2531+ output_q : torch .Tensor ,
2532+ current_output : torch .Tensor ,
2533+ indices : torch .Tensor ,
2534+ mse_loss : Callable ,
2535+ device : Union [str , torch .device ] = "cpu" ,
2536+ ):
2537+ if self .attention_mask :
2538+ tmp_attention_mask = [self .attention_mask [i ] for i in indices ]
2539+ tmp_attention_mask = torch .cat (tmp_attention_mask , dim = 0 ).to (device )
2540+ tmp_attention_mask .unsqueeze_ (- 1 )
2541+ else :
2542+ tmp_attention_mask = 1.0
2543+ if self .amp :
2544+ with autocast (device_type = device .split (":" )[0 ], dtype = self .amp_dtype ):
2545+ loss = mse_loss ( # pylint: disable=not-callable
2546+ output_q * tmp_attention_mask , current_output * tmp_attention_mask
2547+ )
2548+ else :
2549+ loss = mse_loss ( # pylint: disable=not-callable
2550+ output_q .to (torch .float32 ) * tmp_attention_mask ,
2551+ current_output .to (torch .float32 ) * tmp_attention_mask ,
2552+ )
2553+ return loss
2554+
25192555 def _quantize_block (
25202556 self ,
25212557 block : torch .nn .Module ,
@@ -2600,7 +2636,7 @@ def _quantize_block(
26002636 clear_memory (device_list = self .device_list )
26012637 input_ids = q_input
26022638
2603- quantized_layer_names , unquantized_layer_names = wrapper_block (
2639+ quantized_layer_names , unquantized_layer_names = self . wrapper_block (
26042640 block ,
26052641 self .enable_minmax_tuning ,
26062642 self .enable_norm_bias_tuning ,
@@ -2675,6 +2711,9 @@ def _quantize_block(
26752711 num_elm = self ._get_current_num_elm (input_ids , whole_indices )
26762712
26772713 for i in range (self .iters ):
2714+ if self .enable_alg_ext and self .data_type .endswith ("dq" ):
2715+ for n , m in block .named_modules ():
2716+ m .cur_iter = i
26782717 total_loss = 0
26792718 if self .sampler == "rand" :
26802719 whole_indices = torch .randperm (nsamples )[:global_batch_size ]
@@ -2688,25 +2727,7 @@ def _quantize_block(
26882727
26892728 output_q = self ._get_current_q_output (block , input_ids , input_others , indices , device , loss_device )
26902729
2691- if self .attention_mask :
2692- tmp_attention_mask = [self .attention_mask [i ] for i in indices ]
2693- tmp_attention_mask = torch .cat (tmp_attention_mask , dim = 0 ).to (loss_device )
2694- tmp_attention_mask .unsqueeze_ (- 1 )
2695- num_elm = torch .sum (tmp_attention_mask ).item ()
2696- if num_elm == 0 :
2697- num_elm = 1
2698- else :
2699- tmp_attention_mask = 1.0
2700- if self .amp :
2701- with autocast (device_type = str (loss_device ).split (":" )[0 ], dtype = self .amp_dtype ):
2702- loss = mse_loss ( # pylint: disable=not-callable
2703- output_q * tmp_attention_mask , current_output * tmp_attention_mask
2704- )
2705- else :
2706- loss = mse_loss ( # pylint: disable=not-callable
2707- output_q .to (torch .float32 ) * tmp_attention_mask ,
2708- current_output .to (torch .float32 ) * tmp_attention_mask ,
2709- )
2730+ loss = self ._get_loss (output_q , current_output , indices , mse_loss , device )
27102731
27112732 total_loss += loss .item () / num_elm
27122733
@@ -2836,44 +2857,6 @@ def _quantize_blocks(
28362857 for i in range (len (input_others [key ])):
28372858 to_dtype (input_others [key ][i ], tmp_dtype )
28382859
2839- if (
2840- self .sym
2841- and self .enable_alg_ext
2842- and self .super_group_size is None
2843- and (
2844- (self .data_type .startswith ("int" ) and self .act_bits >= 8 )
2845- or self .data_type .startswith ("mx" )
2846- or self .data_type .startswith ("nv" )
2847- )
2848- ):
2849- try :
2850- from auto_round .alg_ext import quantize_block_ext
2851-
2852- BaseCompressor .quantize_block_ext = quantize_block_ext
2853- quantize_block = self .quantize_block_ext # must use self.quantize_block_ext
2854- if self .bits > 2 and (not self .data_type .startswith ("mx" ) or not self .data_type .startswith ("nv" )):
2855- logger .warning (
2856- "algorithm extension has only undergone limited validation on "
2857- "INT2,mxfp4 and nvfp4; use with caution."
2858- )
2859- else :
2860- logger .info ("using algorithm extension for quantization." )
2861- except (ImportError , ModuleNotFoundError ):
2862- logger .error ("algorithm extension import error, fallback to default mode" )
2863- quantize_block = self ._quantize_block
2864- elif self .enable_alg_ext and self .data_type .endswith ("dq" ):
2865- try :
2866- from auto_round .alg_ext import dq_quantize_block_ext
2867-
2868- BaseCompressor .dq_quantize_block_ext = dq_quantize_block_ext
2869- quantize_block = self .dq_quantize_block_ext
2870- logger .info ("using algorithm extension for quantization." )
2871- except (ImportError , ModuleNotFoundError ):
2872- logger .error ("algorithm extension import error, fallback to default mode" )
2873- quantize_block = self ._quantize_block
2874- else :
2875- quantize_block = self ._quantize_block
2876-
28772860 if pbar is None :
28782861 pbar = tqdm (range (0 , len (block_names ), nblocks ))
28792862
@@ -2891,7 +2874,7 @@ def _quantize_blocks(
28912874 m = WrapperMultiblock (modules )
28922875
28932876 m .config = model .config if hasattr (model , "config" ) else None
2894- q_input , input_ids = quantize_block (
2877+ q_input , input_ids = self . _quantize_block (
28952878 m ,
28962879 input_ids ,
28972880 input_others ,
0 commit comments