@@ -301,7 +301,7 @@ def load_weights(weights, model_runner):
301301 )
302302 param_scale = torch .squeeze (param_scale , dim = - 1 )
303303 weights_quantized .append ([k , param_lp ])
304- weights_quantized .append ([k + "_scale " , param_scale ])
304+ weights_quantized .append ([k + "_scale_inv " , param_scale ])
305305 # Monkey patch the param class to their subclass, as certain models
306306 # will check the param type to call the proper weightloader
307307 for name , param in model .named_parameters ():
@@ -391,10 +391,6 @@ def cast_tensor_to_fp8_blockwise(
391391
392392def process_weights_after_loading (self , layer ) -> None :
393393 from torch .nn import Parameter
394- from vllm .model_executor .layers .quantization .utils .fp8_utils import (
395- maybe_post_process_fp8_weight_block ,
396- process_fp8_weight_block_strategy ,
397- )
398394 from vllm .model_executor .parameter import (
399395 BlockQuantScaleParameter ,
400396 ModelWeightParameter ,
@@ -420,34 +416,27 @@ def _create_param_from_subclass_attributes(custom_param):
420416 param .subclass_type = type (custom_param )
421417 return param
422418
423- weight_scale = (
424- layer .weight_scale_inv
425- if hasattr (layer , "weight_scale_inv" )
426- else layer .weight_scale
427- )
428- weight , weight_scale = process_fp8_weight_block_strategy (layer .weight , weight_scale )
419+ weight = layer .weight .data
420+ weight_scale_inv = layer .weight_scale_inv .data
421+ weight = self ._maybe_pad_weight (weight )
429422
430423 layer .weight = _create_param_from_subclass_attributes (
431424 ModelWeightParameter (
432- data = weight . data ,
425+ data = weight ,
433426 output_dim = 0 ,
434427 input_dim = 1 ,
435428 weight_loader = layer .weight .weight_loader ,
436429 )
437430 )
438- layer .weight_scale = _create_param_from_subclass_attributes (
431+ layer .weight_scale_inv = _create_param_from_subclass_attributes (
439432 BlockQuantScaleParameter (
440- data = weight_scale . data ,
433+ data = weight_scale_inv ,
441434 output_dim = 0 ,
442435 input_dim = 1 ,
443436 weight_loader = layer .weight_scale_inv .weight_loader ,
444437 )
445438 )
446439
447- del layer .weight_scale_inv
448-
449- maybe_post_process_fp8_weight_block (layer , self .cutlass_block_fp8_supported )
450-
451440
452441@triton .jit
453442def _per_token_group_quant_fp8 (
0 commit comments