From bc5b14b45581e037af2818c12c5b018263b39ae0 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Thu, 24 Apr 2025 18:14:09 +0800 Subject: [PATCH 1/8] remove unnecessary round in dq simulation --- auto_round/data_type/int.py | 78 ++++++++++++++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 5 deletions(-) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 899abf17..49d6cfee 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -17,6 +17,69 @@ from auto_round.data_type.register import register_dtype + + +def soft_round(x, alpha=10): + return torch.sigmoid(alpha * (x - torch.floor(x) - 0.5)) + torch.floor(x) + + +class SigmoidRoundSTE(torch.autograd.Function): + @staticmethod + def forward(ctx, x, beta=10): + ctx.save_for_backward(x) + ctx.beta = beta + return torch.round(x) # 前向严格整数 + + @staticmethod + def backward(ctx, grad_output): + x, = ctx.saved_tensors + beta = ctx.beta + fractional = (x - torch.floor(x)-0.5)*beta + sigmoid_x = torch.sigmoid(fractional) + # Sigmoid 的导数作为梯度近似 + sigmoid_grad = sigmoid_x * (1.0-sigmoid_x) + return grad_output * (beta *sigmoid_grad), None + +# class NoisyRound(torch.autograd.Function): +# @staticmethod +# def forward(ctx, x): +# ctx.save_for_backward(x) +# return torch.round(x) +# +# @staticmethod +# def backward(ctx, grad_output): +# x, = ctx.saved_tensors +# # 反向时添加与小数部分相关的噪声 +# noise = (x - torch.round(x)).detach() # 仅反向传播噪声 +# return grad_output * (1.0+noise) + +class NoisyRound(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return torch.round(x) + + @staticmethod + def backward(ctx, grad_output): + x, = ctx.saved_tensors + # 反向时添加与小数部分相关的噪声 + noise = ((x - torch.round(x)).detach()) # 仅反向传播噪声 + return grad_output * (1.0+noise) + + +# class RoundSTE(torch.autograd.Function): +# @staticmethod +# def forward(ctx, x): +# ctx.save_for_backward(x) +# return torch.round(x) # 前向严格整数 +# +# @staticmethod +# def backward(ctx, grad_output): +# x, = ctx.saved_tensors +# # 梯度乘以输入与最近整数的距离(|x - round(x)|),避免平坦区梯度消失 +# grad_input = grad_output.clone() +# return grad_input * (1.0 - 2.0 * torch.abs(x - torch.round(x))) + @register_dtype("int_sym") def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, tensor_min=None, @@ -38,7 +101,7 @@ def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scal Returns: Quantized and de-quantized tensor, scale, zero-point """ - + iters = kwargs.get("iters", -1) tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) maxq = 2 ** (bits - 1) if tensor_min is None or tensor_max is None: @@ -56,7 +119,12 @@ def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scal zp = torch.full_like(scale, maxq) # pylint: disable=E1130 scale = scale.unsqueeze(dim=-1) zp = zp.unsqueeze(dim=-1) - int_w = round_ste(tensor / scale + v) + # if iters>=0 and iters<100: + # k = iters/100 *(10-1) + 1 + # int_w = soft_round(tensor / scale + v, alpha=k) + # else: + # int_w = round_ste(tensor / scale + v) + int_w = SigmoidRoundSTE.apply(tensor/scale+v) q = torch.clamp(int_w + zp, 0, 2 ** bits - 1) qdq_result = (scale * (q - zp)).to(tensor.dtype) qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) @@ -123,11 +191,11 @@ def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_ scale = torch.clamp(scale, q_scale_thresh) wmin_m = wmin_m.view(-1, 1) - int_w = round_ste(tensor / scale + v) - q = torch.clamp(int_w + round_ste(wmin_m / scale), 0, maxq) + int_w = round_ste((tensor+ wmin_m) / scale + v ) + q = torch.clamp(int_w, 0, maxq) qdq_result = (scale * q - wmin_m).to(tensor.dtype) qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) - zp = round_ste(wmin_m / scale) # remove this later + # zp = round_ste(wmin_m / scale) # remove this later return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin_m": wmin_m, "d_wmin_m": d_wmin_m} From c951ac76435484fd48f7bfa3014bb7a4fe0188aa Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Thu, 24 Apr 2025 18:15:44 +0800 Subject: [PATCH 2/8] remove unnecessary round in dq simulation --- auto_round/data_type/int.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 899abf17..aedd204b 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -123,14 +123,15 @@ def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_ scale = torch.clamp(scale, q_scale_thresh) wmin_m = wmin_m.view(-1, 1) - int_w = round_ste(tensor / scale + v) - q = torch.clamp(int_w + round_ste(wmin_m / scale), 0, maxq) + int_w = round_ste((tensor+ wmin_m) / scale + v ) + q = torch.clamp(int_w, 0, maxq) qdq_result = (scale * q - wmin_m).to(tensor.dtype) qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) - zp = round_ste(wmin_m / scale) # remove this later + # zp = round_ste(wmin_m / scale) # remove this later return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin_m": wmin_m, "d_wmin_m": d_wmin_m} + @register_dtype("int_asym") def quant_tensor_asym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, tensor_min=None, tensor_max=None, q_scale_thresh=1e-5, **kwargs): From 17e27c54c05ab8db384649bcd01c357df5ee7dc2 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Thu, 24 Apr 2025 18:17:40 +0800 Subject: [PATCH 3/8] revert the change --- auto_round/data_type/int.py | 75 ++----------------------------------- 1 file changed, 3 insertions(+), 72 deletions(-) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 37784b7b..39605364 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -17,69 +17,6 @@ from auto_round.data_type.register import register_dtype - - -def soft_round(x, alpha=10): - return torch.sigmoid(alpha * (x - torch.floor(x) - 0.5)) + torch.floor(x) - - -class SigmoidRoundSTE(torch.autograd.Function): - @staticmethod - def forward(ctx, x, beta=10): - ctx.save_for_backward(x) - ctx.beta = beta - return torch.round(x) # 前向严格整数 - - @staticmethod - def backward(ctx, grad_output): - x, = ctx.saved_tensors - beta = ctx.beta - fractional = (x - torch.floor(x)-0.5)*beta - sigmoid_x = torch.sigmoid(fractional) - # Sigmoid 的导数作为梯度近似 - sigmoid_grad = sigmoid_x * (1.0-sigmoid_x) - return grad_output * (beta *sigmoid_grad), None - -# class NoisyRound(torch.autograd.Function): -# @staticmethod -# def forward(ctx, x): -# ctx.save_for_backward(x) -# return torch.round(x) -# -# @staticmethod -# def backward(ctx, grad_output): -# x, = ctx.saved_tensors -# # 反向时添加与小数部分相关的噪声 -# noise = (x - torch.round(x)).detach() # 仅反向传播噪声 -# return grad_output * (1.0+noise) - -class NoisyRound(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - ctx.save_for_backward(x) - return torch.round(x) - - @staticmethod - def backward(ctx, grad_output): - x, = ctx.saved_tensors - # 反向时添加与小数部分相关的噪声 - noise = ((x - torch.round(x)).detach()) # 仅反向传播噪声 - return grad_output * (1.0+noise) - - -# class RoundSTE(torch.autograd.Function): -# @staticmethod -# def forward(ctx, x): -# ctx.save_for_backward(x) -# return torch.round(x) # 前向严格整数 -# -# @staticmethod -# def backward(ctx, grad_output): -# x, = ctx.saved_tensors -# # 梯度乘以输入与最近整数的距离(|x - round(x)|),避免平坦区梯度消失 -# grad_input = grad_output.clone() -# return grad_input * (1.0 - 2.0 * torch.abs(x - torch.round(x))) - @register_dtype("int_sym") def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, tensor_min=None, @@ -101,7 +38,7 @@ def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scal Returns: Quantized and de-quantized tensor, scale, zero-point """ - iters = kwargs.get("iters", -1) + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) maxq = 2 ** (bits - 1) if tensor_min is None or tensor_max is None: @@ -119,12 +56,7 @@ def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scal zp = torch.full_like(scale, maxq) # pylint: disable=E1130 scale = scale.unsqueeze(dim=-1) zp = zp.unsqueeze(dim=-1) - # if iters>=0 and iters<100: - # k = iters/100 *(10-1) + 1 - # int_w = soft_round(tensor / scale + v, alpha=k) - # else: - # int_w = round_ste(tensor / scale + v) - int_w = SigmoidRoundSTE.apply(tensor/scale+v) + int_w = round_ste(tensor / scale + v) q = torch.clamp(int_w + zp, 0, 2 ** bits - 1) qdq_result = (scale * (q - zp)).to(tensor.dtype) qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) @@ -191,7 +123,7 @@ def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_ scale = torch.clamp(scale, q_scale_thresh) wmin_m = wmin_m.view(-1, 1) - int_w = round_ste((tensor+ wmin_m) / scale + v ) + int_w = round_ste((tensor + wmin_m) / scale + v) q = torch.clamp(int_w, 0, maxq) qdq_result = (scale * q - wmin_m).to(tensor.dtype) qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) @@ -199,7 +131,6 @@ def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_ return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin_m": wmin_m, "d_wmin_m": d_wmin_m} - @register_dtype("int_asym") def quant_tensor_asym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, tensor_min=None, tensor_max=None, q_scale_thresh=1e-5, **kwargs): From 4666fd6dee88925a0b4125628d462cf8d83dc436 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Tue, 29 Apr 2025 10:18:56 +0800 Subject: [PATCH 4/8] tmp change for fp8 --- auto_round/data_type/fp8.py | 23 +- .../export/export_to_autoround/export.py | 210 ++++++++++++++---- auto_round/script/llm.py | 2 +- auto_round/wrapper.py | 2 + 4 files changed, 183 insertions(+), 54 deletions(-) diff --git a/auto_round/data_type/fp8.py b/auto_round/data_type/fp8.py index d92c311a..910b6a6b 100644 --- a/auto_round/data_type/fp8.py +++ b/auto_round/data_type/fp8.py @@ -76,26 +76,27 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, **kwargs): - Placeholder for zp (None). """ orig_shape = tensor.shape - info = torch.finfo(torch.float8_e4m3fn) + info = torch.finfo(torch.float8_e5m2) orig_dtype = tensor.dtype - if tensor_max is None: ##dynamic per-token - tensor = tensor.reshape(-1, orig_shape[-1]) - max_tensor = torch.max(torch.abs(tensor), dim=-1)[ - 0] * max_scale - elif isinstance(tensor_max,torch.Tensor): - max_tensor = tensor_max.clone().detach().to(tensor.device) * max_scale - else: - max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale + # if tensor_max is None: ##dynamic per-token + # tensor = tensor.reshape(-1, orig_shape[-1]) + # max_tensor = torch.max(torch.abs(tensor), dim=-1)[ + # 0] * max_scale + # elif isinstance(tensor_max,torch.Tensor): + # max_tensor = tensor_max.clone().detach().to(tensor.device) * max_scale + # else: + # max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale + max_tensor =torch.max(torch.abs(tensor)) scale = max_tensor.to(torch.float32) / info.max min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm scale = torch.clip(scale, min=min_scaling_factor) if tensor.dtype == torch.float16: ## Avoid NaN gradients with float16 tensor = tensor.to(torch.bfloat16) - scale = scale.unsqueeze(dim=-1) + # scale = scale.unsqueeze(dim=-1) fp8_res = (tensor / scale) fp8_res = torch.clip(fp8_res, info.min, info.max) - fp8_res = float8_e4m3fn_ste(fp8_res) + fp8_res = fp8_res.to(torch.float8_e5m2).to(torch.bfloat16) qdq_res = fp8_res * scale qdq_res = qdq_res.to(orig_dtype).reshape(orig_shape) return qdq_res, scale, None diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index beffe62a..40530c92 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -133,6 +133,123 @@ def pack_qact_layer(name, model): qlayer.to(device) +# def pack_layer(layer_name, model, backend): +# """ +# Packs a model layer for quantization based on its type and configuration. +# +# This function retrieves the specified layer from the model, checks its +# compatibility for quantization, and replaces it with a quantized version +# if applicable. The quantization process depends on the layer's bit-width, +# group size, symmetry, and activation bits. +# +# Args: +# layer_name (str): The name of the layer to be packed. +# model (torch.nn.Module): The model containing the layer. +# backend (str): The backend framework to be used for quantization. +# +# Returns: +# None: The function modifies the model in place. +# """ +# layer = get_module(model, layer_name) +# if hasattr(layer, "orig_layer"): +# layer = layer.orig_layer +# +# if not isinstance(layer, supported_layer_types): ##already packed +# return +# +# if int(layer.act_bits) <= 8: +# return pack_qact_layer(layer_name, model) +# +# if not check_to_quantized(layer): +# return +# +# device = layer.weight.device +# bits = layer.bits +# group_size = layer.group_size +# sym = layer.sym +# act_bits = layer.act_bits +# +# scale = layer.scale +# zp = layer.zp +# QuantLinear = dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym, act_bits) +# +# if isinstance(layer, nn.Linear): +# in_features = layer.in_features +# out_features = layer.out_features +# elif isinstance(layer, nn.Conv2d): +# in_features = layer.in_channels +# out_features = layer.out_channels +# elif isinstance(layer, transformers.pytorch_utils.Conv1D): +# in_features = layer.weight.shape[0] +# out_features = layer.weight.shape[1] +# bias = layer.bias is not None +# +# if "awq" not in backend: +# new_layer = QuantLinear( ##pylint: disable=E1123 +# bits, group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype +# ) +# new_layer.device = device +# set_module(model, layer_name, new_layer) +# qlayer = new_layer +# import auto_round.export.export_to_autoround.qlinear_triton +# if sym and isinstance(QuantLinear, (auto_round.export.export_to_autoround.qlinear_triton.QuantLinear, +# auto_round_extension.cuda.qlinear_tritonv2.QuantLinear)): +# zp = int(zp.flatten()[0]) +# +# qlayer.to("cpu") +# ##force to float32 to be compatible with torch 2.0 +# sig = inspect.signature(qlayer.pack) +# param_count = len(sig.parameters) +# if param_count == 2: +# qlayer.pack(layer, scale) +# else: +# qlayer.pack(layer, scale, zp, None) +# qlayer.to(device) +# else: +# scale, zp = scale.to(torch.float32), zp.to(torch.float32) +# scale = scale.t().contiguous() +# zp = zp.t().contiguous() +# if sym: +# zp = int(zp.flatten()[0]) +# +# if bits != 4: +# logger.error("AutoAWQ format only supports 4-bits quantization.") +# qlayer = QuantLinear.from_linear( +# linear=layer, +# w_bit=bits, +# group_size=group_size, +# init_only=False, +# scales=scale, +# zeros=zp, +# ) +# qlayer.to(device) +# set_module(model, layer_name, qlayer) + +torch.nn.Linear +class MyLinear(torch.nn.Module): + def __init__(self, in_features, out_features, bias=True, device=None, + dtype=None): + factory_kwargs = {"device": device} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter( + torch.empty((out_features, in_features), dtype=torch.float8_e5m2, **factory_kwargs) + ) + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + self.register_buffer('scale', torch.ones((1),dtype=torch.bfloat16)) + + # # 如果你需要forward里用scale的话,可以加在forward里 + # def forward(self, input): + # out = super().forward(input) + # # 例如,简单用scale做点什么(可选) + # # out = out * self.scale.sum() + # return out + + def pack_layer(layer_name, model, backend): """ Packs a model layer for quantization based on its type and configuration. @@ -171,7 +288,10 @@ def pack_layer(layer_name, model, backend): scale = layer.scale zp = layer.zp - QuantLinear = dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym, act_bits) + weight = layer.weight + q_weight = weight / scale + + # QuantLinear = dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym, act_bits) if isinstance(layer, nn.Linear): in_features = layer.in_features @@ -183,47 +303,53 @@ def pack_layer(layer_name, model, backend): in_features = layer.weight.shape[0] out_features = layer.weight.shape[1] bias = layer.bias is not None - - if "awq" not in backend: - new_layer = QuantLinear( ##pylint: disable=E1123 - bits, group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype - ) - new_layer.device = device - set_module(model, layer_name, new_layer) - qlayer = new_layer - import auto_round.export.export_to_autoround.qlinear_triton - if sym and isinstance(QuantLinear, (auto_round.export.export_to_autoround.qlinear_triton.QuantLinear, - auto_round_extension.cuda.qlinear_tritonv2.QuantLinear)): - zp = int(zp.flatten()[0]) - - qlayer.to("cpu") - ##force to float32 to be compatible with torch 2.0 - sig = inspect.signature(qlayer.pack) - param_count = len(sig.parameters) - if param_count == 2: - qlayer.pack(layer, scale) - else: - qlayer.pack(layer, scale, zp, None) - qlayer.to(device) - else: - scale, zp = scale.to(torch.float32), zp.to(torch.float32) - scale = scale.t().contiguous() - zp = zp.t().contiguous() - if sym: - zp = int(zp.flatten()[0]) - - if bits != 4: - logger.error("AutoAWQ format only supports 4-bits quantization.") - qlayer = QuantLinear.from_linear( - linear=layer, - w_bit=bits, - group_size=group_size, - init_only=False, - scales=scale, - zeros=zp, - ) - qlayer.to(device) - set_module(model, layer_name, qlayer) + my_linear = MyLinear(in_features, out_features, bias) + my_linear.scale.data.copy_(scale) + my_linear.weight.data.copy_(q_weight.to(torch.float8_e5m2)) + if bias: + my_linear.bias.data.copy_(layer.bias) + + # + # if "awq" not in backend: + # new_layer = QuantLinear( ##pylint: disable=E1123 + # bits, group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype + # ) + # new_layer.device = device + # set_module(model, layer_name, new_layer) + # qlayer = new_layer + # import auto_round.export.export_to_autoround.qlinear_triton + # if sym and isinstance(QuantLinear, (auto_round.export.export_to_autoround.qlinear_triton.QuantLinear, + # auto_round_extension.cuda.qlinear_tritonv2.QuantLinear)): + # zp = int(zp.flatten()[0]) + # + # qlayer.to("cpu") + # ##force to float32 to be compatible with torch 2.0 + # sig = inspect.signature(qlayer.pack) + # param_count = len(sig.parameters) + # if param_count == 2: + # qlayer.pack(layer, scale) + # else: + # qlayer.pack(layer, scale, zp, None) + # qlayer.to(device) + # else: + # scale, zp = scale.to(torch.float32), zp.to(torch.float32) + # scale = scale.t().contiguous() + # zp = zp.t().contiguous() + # if sym: + # zp = int(zp.flatten()[0]) + # + # if bits != 4: + # logger.error("AutoAWQ format only supports 4-bits quantization.") + # qlayer = QuantLinear.from_linear( + # linear=layer, + # w_bit=bits, + # group_size=group_size, + # init_only=False, + # scales=scale, + # zeros=zp, + # ) + my_linear.to(device) + set_module(model, layer_name, my_linear) def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:exllamav2", **kwargs): diff --git a/auto_round/script/llm.py b/auto_round/script/llm.py index b3685d9f..aeec7808 100644 --- a/auto_round/script/llm.py +++ b/auto_round/script/llm.py @@ -525,7 +525,7 @@ def tune(args): for file in os.listdir(eval_folder): gguf_file = file user_model = AutoModelForCausalLM.from_pretrained( - eval_folder, gguf_file=gguf_file, device_map="auto" if use_auto_mapping else None) + eval_folder, gguf_file=gguf_file, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(eval_folder, gguf_file=gguf_file) else: if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1: diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index a624d9f2..39f0d188 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -262,6 +262,8 @@ def _set_dict_attr(attr_dict, attr_name): if isinstance(scale, dict): _set_dict_attr(scale, "scale") + elif scale.numel()==1: + self.orig_layer.scale = scale.to("cpu") else: self.orig_layer.scale = scale.reshape(shape[0], -1).to("cpu") From cc49284332a55707a7f2de4c9f9b5abedd95d4ed Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Tue, 29 Apr 2025 10:23:43 +0800 Subject: [PATCH 5/8] refine --- auto_round/export/export_to_autoround/export.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 40530c92..9fa1250c 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -225,7 +225,7 @@ def pack_qact_layer(name, model): # qlayer.to(device) # set_module(model, layer_name, qlayer) -torch.nn.Linear + class MyLinear(torch.nn.Module): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): @@ -242,12 +242,6 @@ def __init__(self, in_features, out_features, bias=True, device=None, self.register_parameter("bias", None) self.register_buffer('scale', torch.ones((1),dtype=torch.bfloat16)) - # # 如果你需要forward里用scale的话,可以加在forward里 - # def forward(self, input): - # out = super().forward(input) - # # 例如,简单用scale做点什么(可选) - # # out = out * self.scale.sum() - # return out def pack_layer(layer_name, model, backend): From 1b2aa91cdd6f6f37e708a139fa7f2267a9afe619 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Wed, 30 Apr 2025 09:28:25 +0800 Subject: [PATCH 6/8] change some configs --- auto_round/export/export_to_autoround/export.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 9fa1250c..6624f69d 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -240,7 +240,7 @@ def __init__(self, in_features, out_features, bias=True, device=None, self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs)) else: self.register_parameter("bias", None) - self.register_buffer('scale', torch.ones((1),dtype=torch.bfloat16)) + self.register_buffer('weight_scale', torch.ones((1),dtype=torch.bfloat16)) @@ -298,7 +298,7 @@ def pack_layer(layer_name, model, backend): out_features = layer.weight.shape[1] bias = layer.bias is not None my_linear = MyLinear(in_features, out_features, bias) - my_linear.scale.data.copy_(scale) + my_linear.weight_sclae.data.copy_(scale) my_linear.weight.data.copy_(q_weight.to(torch.float8_e5m2)) if bias: my_linear.bias.data.copy_(layer.bias) @@ -380,7 +380,9 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex layer_config = kwargs["layer_config"] quantization_config = kwargs["serialization_dict"] - quantization_config["quant_method"] = "auto-round" + quantization_config["quant_method"] = "fp8" + quantization_config["fmt"] = "e5m2" + quantization_config["activation_scheme"] = "dynamic" if quantization_config["bits"] == 3: backend = "auto_round:auto_gptq" quantization_config["packing_format"] = backend From 2cb71efd801d90e1acb53e4d7786e3a3b7af7b9d Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Wed, 30 Apr 2025 09:52:33 +0800 Subject: [PATCH 7/8] fix typo --- auto_round/export/export_to_autoround/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 6624f69d..763045f6 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -298,7 +298,7 @@ def pack_layer(layer_name, model, backend): out_features = layer.weight.shape[1] bias = layer.bias is not None my_linear = MyLinear(in_features, out_features, bias) - my_linear.weight_sclae.data.copy_(scale) + my_linear.weight_scale.data.copy_(scale) my_linear.weight.data.copy_(q_weight.to(torch.float8_e5m2)) if bias: my_linear.bias.data.copy_(layer.bias) From 1dcdae770904b75d5f01eba30d0df38392bf67cb Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Tue, 13 May 2025 14:01:56 +0800 Subject: [PATCH 8/8] change to unit scale for now --- auto_round/data_type/fp8.py | 1 + auto_round/export/export_to_autoround/export.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/auto_round/data_type/fp8.py b/auto_round/data_type/fp8.py index 910b6a6b..6541384b 100644 --- a/auto_round/data_type/fp8.py +++ b/auto_round/data_type/fp8.py @@ -94,6 +94,7 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, **kwargs): if tensor.dtype == torch.float16: ## Avoid NaN gradients with float16 tensor = tensor.to(torch.bfloat16) # scale = scale.unsqueeze(dim=-1) + scale = torch.ones((1), device=tensor.device) fp8_res = (tensor / scale) fp8_res = torch.clip(fp8_res, info.min, info.max) fp8_res = fp8_res.to(torch.float8_e5m2).to(torch.bfloat16) diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 763045f6..b3df745b 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -380,7 +380,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex layer_config = kwargs["layer_config"] quantization_config = kwargs["serialization_dict"] - quantization_config["quant_method"] = "fp8" + quantization_config["quant_method"] = "auto-round" quantization_config["fmt"] = "e5m2" quantization_config["activation_scheme"] = "dynamic" if quantization_config["bits"] == 3: