diff --git a/examples/export_int8_llama.py b/examples/export_int8_llama.py new file mode 100644 index 0000000..94274d5 --- /dev/null +++ b/examples/export_int8_llama.py @@ -0,0 +1,57 @@ +import torch +import argparse +import os + +from pathlib import Path + +from transformers import AutoTokenizer +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +from smoothquant.llama import Int8LlamaForCausalLM +from smoothquant.smooth import smooth_lm + +from smoothquant.calibration import get_static_llama_decoder_layer_scales +from torch.nn.functional import pad + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--model-name", type=str, default='fp16_models/llama-13b') + parser.add_argument("--num-samples", type=int, default=512) + parser.add_argument("--seq-len", type=int, default=512) + parser.add_argument("--act-scales", type=str, + default='act_scales/llama-13b.pt') + parser.add_argument("--output-path", type=str, default='int8_models') + parser.add_argument('--dataset-path', type=str, default='dataset/val.jsonl.zst', + help='location of the calibration dataset, we use the validation set of the Pile dataset') + parser.add_argument('--export-FT', default=False, action="store_true") + args = parser.parse_args() + model = LlamaForCausalLM.from_pretrained( + args.model_name, device_map="auto", torch_dtype=torch.float16) + act_scales = torch.load(args.act_scales) + smooth_lm(model, act_scales, 0.5) + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + + if not os.path.exists(args.dataset_path): + print(f'Cannot find the dataset at {args.dataset_path}') + print('Please download the Pile dataset and put the validation set at the path') + print('You can download the validation dataset of the Pile at https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst') + raise FileNotFoundError + + decoder_layer_scales, raw_scales = get_static_llama_decoder_layer_scales(model, + tokenizer, + args.dataset_path, + num_samples=args.num_samples, + seq_len=args.seq_len) + output_path = Path(args.output_path) / ("llama-" + Path(args.model_name).name + "-smoothquant-per-token-opt") + if args.export_FT: + model.save_pretrained(output_path) + print(f"Saved smoothed model at {output_path}") + + output_path = Path(args.output_path) / (Path(args.model_name).name + "-smoothquant-scales.pt") + torch.save(raw_scales, output_path) + print(f"Saved scaling factors at {output_path}") + else: + int8_model = Int8LlamaForCausalLM.from_float(model, decoder_layer_scales) + int8_model.save_pretrained(output_path) + print(f"Saved int8 model at {output_path}") \ No newline at end of file diff --git a/smoothquant/calibration.py b/smoothquant/calibration.py index 8d7906e..b6e8278 100644 --- a/smoothquant/calibration.py +++ b/smoothquant/calibration.py @@ -118,3 +118,82 @@ def stat_io_hook(m, x, y, name): decoder_layer_scales.append(scale_dict) return decoder_layer_scales, act_dict + +#TODO: merge to get_static_decoder_layer_scales method +@torch.no_grad() +def get_static_llama_decoder_layer_scales(model, + tokenizer, + dataset_path, + num_samples=512, + seq_len=512, + ): + model.eval() + device = next(model.parameters()).device + + act_dict = defaultdict(dict) + + def stat_io_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + if name not in act_dict or "input" not in act_dict[name]: + act_dict[name]["input"] = x.detach().abs().max().item() + else: + act_dict[name]["input"] = max( + act_dict[name]["input"], x.detach().abs().max().item()) + if isinstance(y, tuple): + y = y[0] + if name not in act_dict or "output" not in act_dict[name]: + act_dict[name]["output"] = y.detach().abs().max().item() + else: + act_dict[name]["output"] = max( + act_dict[name]["output"], y.detach().abs().max().item()) + + hooks = [] + for name, m in model.named_modules(): + if isinstance(m, torch.nn.Linear): + hooks.append(m.register_forward_hook( + partial(stat_io_hook, name=name))) + + print("Collecting activation scales...") + pbar = tqdm(range(num_samples)) + dataset = load_dataset('json', data_files=dataset_path, split="train") + dataset = dataset.shuffle(seed=42) + for i in pbar: + input_ids = tokenizer(dataset[i]["text"], return_tensors="pt", + max_length=seq_len, truncation=True).input_ids.to(device) + model(input_ids) + mean_scale = np.mean([v["input"] for v in act_dict.values()]) + pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + for hook in hooks: + hook.remove() + + decoder_layer_scales = [] + for idx in range(model.config.num_hidden_layers): + scale_dict = {} + # self attenion scales + scale_dict["attn_input_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.q_proj"]['input'] / 127 + scale_dict["q_output_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.q_proj"]['output'] / 127 + scale_dict["k_output_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.k_proj"]['output'] / 127 + scale_dict["v_output_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.v_proj"]['output'] / 127 + scale_dict["out_input_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.o_proj"]['input'] / 127 + # mlp scales + scale_dict["gate_input_scale"] = act_dict[ + f"model.layers.{idx}.mlp.gate_proj"]['input'] / 127 + scale_dict["up_input_scale"] = act_dict[ + f"model.layers.{idx}.mlp.up_proj"]["input"] / 127 + scale_dict["down_input_scale"] = act_dict[ + f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 + scale_dict["gate_output_scale"] = act_dict[ + f"model.layers.{idx}.mlp.gate_proj"]['output'] / 127 + scale_dict["up_output_scale"] = act_dict[ + f"model.layers.{idx}.mlp.up_proj"]["output"] / 127 + scale_dict["down_output_scale"] = act_dict[ + f"model.layers.{idx}.mlp.down_proj"]["output"] / 127 + decoder_layer_scales.append(scale_dict) + + return decoder_layer_scales, act_dict \ No newline at end of file diff --git a/smoothquant/llama.py b/smoothquant/llama.py new file mode 100644 index 0000000..2f7f241 --- /dev/null +++ b/smoothquant/llama.py @@ -0,0 +1,413 @@ +import torch +import math +from torch import nn +from transformers.models.llama.modeling_llama import ( + LlamaRMSNorm, + LlamaRotaryEmbedding, + LlamaMLP, + LlamaAttention, + LlamaDecoderLayer, + LlamaPreTrainedModel, + LlamaModel, + LlamaForCausalLM, + LlamaForSequenceClassification, + BaseModelOutputWithPast +) +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.activations import SiLUActivation +from typing import Optional, Tuple, List +# must use branch llama-dev in https://github.com/AniZpZ/torch-int +from torch_int.nn.linear import W8A8BFP32OFP32LinearWithSFactor, W8A8BFP32OFP32Linear +from smoothquant.fake_quant import W8A8Linear +from transformers.utils import logging +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "LlamaConfig" +# attention is the same as opt +class Int8LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + def __init__( + self, + config: LlamaConfig + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + self.out_input_scale = 0. + # hidden_size is embed_dim in OptAttetion + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + + self.k_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_heads * self.head_dim) + self.v_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_heads * self.head_dim) + self.q_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_heads * self.head_dim) + # out is fp32 + self.o_proj = W8A8BFP32OFP32LinearWithSFactor(self.num_heads * self.head_dim, self.hidden_size) + + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + + _shape = LlamaAttention._shape + + @staticmethod + @torch.no_grad() + def from_float(module: LlamaAttention, + config: LlamaConfig, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + out_input_scale: float): + int8_module = Int8LlamaAttention(config) + + # we do not impelement attn for now bacuase we want use paged attention + + # FIXME: Fuse the scaling into the q_proj output scale + # linearList = [module.q_proj, module.k_proj, module.v_proj] + + # qkv_list = W8A8BFP32OFP32Linear.from_float_fuse( + # linearList, + # attn_input_scale) + # if len(qkv_list) != 3: + # raise ValueError( + # f"invalid qkv list len, must return 3 linears but get {len(qkv_list)}") + + # int8_module.q_proj = qkv_list[0] + # int8_module.k_proj = qkv_list[1] + # int8_module.v_proj = qkv_list[2] + + int8_module.q_proj = W8A8BFP32OFP32Linear.from_float( + module.q_proj, + attn_input_scale + ) + int8_module.k_proj = W8A8BFP32OFP32Linear.from_float( + module.k_proj, + attn_input_scale + ) + int8_module.v_proj = W8A8BFP32OFP32Linear.from_float( + module.v_proj, + attn_input_scale + ) + + int8_module.o_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( + module.o_proj, out_input_scale) + return int8_module + + @torch.no_grad() + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + # already quant before attention + query_states = self.q_proj(hidden_states).to(torch.float16) + key_states = self.k_proj(hidden_states).to(torch.float16) + value_states = self.v_proj(hidden_states).to(torch.float16) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + past_key_value = (key_states, value_states) if use_cache else None + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + # quant method from torch-int + attn_output = self.o_proj(attn_output) + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + +# we keep scale in LlamaRMSNorm layer for kernel fusion +class Int8LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.register_buffer('weight', torch.ones(hidden_size, dtype=torch.float32, requires_grad=False)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + out = self.weight * hidden_states + int8_out = out.round().clamp(-128, 127).to(torch.int8) + return int8_out + + @staticmethod + def from_float(module: LlamaRMSNorm, + output_scale: float): + int8_norm = Int8LlamaRMSNorm(module.weight.numel(), module.variance_epsilon) + + int8_norm.weight.to(module.weight.dtype) + int8_norm.weight = module.weight / output_scale + + return int8_norm + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Int8LlamaMLP(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.down_input_scale = 0. + # need fp32 out bcause silu + self.gate_proj = W8A8BFP32OFP32Linear(self.hidden_size, + self.intermediate_size) + + self.up_proj = W8A8BFP32OFP32Linear(self.hidden_size, + self.intermediate_size) + self.down_proj = W8A8BFP32OFP32LinearWithSFactor( + self.intermediate_size, self.hidden_size) + # silu_and_mul_kernel in vLLM can be a reference of SwiGLU + self.act_fn = SiLUActivation() + + @staticmethod + @torch.no_grad() + def from_float(module: LlamaMLP, config: LlamaConfig, + gate_input_scale: float, gate_output_scale: float, + up_input_scale: float, up_output_scale: float, + down_input_scale: float, down_output_scale: float): + int8Mlp = Int8LlamaMLP(config) + + # FIXME: Fuse the scaling into the q_proj output scale + # linearList = [module.gate_proj, module.up_proj] + # gateup_list = W8A8BFP32OFP32Linear.from_float_fuse( + # linearList, + # gate_input_scale) + + # if len(gateup_list) != 2: + # raise ValueError( + # f"invalid qkv gateup len, must return 2 linears but get {len(qkv_list)}") + + # int8Mlp.gate_proj = gateup_list[0] + # int8Mlp.up_proj = gateup_list[1] + + int8Mlp.gate_proj = W8A8BFP32OFP32Linear.from_float( + module.gate_proj, + gate_input_scale + ) + int8Mlp.up_proj = W8A8BFP32OFP32Linear.from_float( + module.up_proj, + gate_input_scale + ) + int8Mlp.down_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( + module.down_proj, down_input_scale) + + return int8Mlp + + def forward(self, x): + # TODO: supprot self.config.pretraining_tp > 1 condition, adapt from transformer.modeling_llama + hidden = self.act_fn(self.gate_proj(x).to(torch.float16)) + hidden = hidden * self.up_proj(x) + return self.down_proj(hidden) + +class Int8LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Int8LlamaAttention(config=config) + self.mlp = Int8LlamaMLP(config) + #FIXME: use int8 rmsnorm + self.input_layernorm = Int8LlamaRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Int8LlamaRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + @staticmethod + def from_float(module: LlamaDecoderLayer, + config: LlamaConfig, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + out_input_scale: float, + gate_input_scale: float, + up_input_scale: float, + down_input_scale: float, + gate_output_scale: float, + up_output_scale: float, + down_output_scale: float + ): + int8_module = Int8LlamaDecoderLayer( + config + ) + + int8_module.self_attn = Int8LlamaAttention.from_float( + module.self_attn, + config, + attn_input_scale, + q_output_scale, + k_output_scale, + v_output_scale, + out_input_scale + ) + + int8_module.mlp = Int8LlamaMLP.from_float( + module.mlp, + config, + gate_input_scale, + gate_output_scale, + up_input_scale, + up_output_scale, + down_input_scale, + down_output_scale + ) + int8_module.input_layernorm = Int8LlamaRMSNorm.from_float( + module.input_layernorm, + attn_input_scale + ) + int8_module.post_attention_layernorm = Int8LlamaRMSNorm.from_float( + module.post_attention_layernorm, + gate_input_scale + ) + return int8_module + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + residual.add_(hidden_states.to(residual.dtype)) + + # mlp + hidden_states = self.post_attention_layernorm(residual) + hidden_states = self.mlp(hidden_states) + residual.add_(hidden_states.to(residual.dtype)) + outputs = (residual,) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + return outputs + +class Int8LlamaModel(LlamaPreTrainedModel): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([Int8LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + get_input_embeddings = LlamaModel.get_input_embeddings + set_input_embeddings = LlamaModel.set_input_embeddings + _prepare_decoder_attention_mask = LlamaModel._prepare_decoder_attention_mask + # iter self.layers and calcu forward + forward = LlamaModel.forward + + @staticmethod + def from_float(module, decoder_layer_scales): + int8_module = Int8LlamaModel(module.config) + + int8_module.embed_tokens = module.embed_tokens + int8_module.norm = module.norm + + for i, layer in enumerate(module.layers): + int8_module.layers[i] = Int8LlamaDecoderLayer.from_float( + layer, module.config, **decoder_layer_scales[i]) + return int8_module + +class Int8LlamaForCausalLM(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.model = Int8LlamaModel(config) + # no need to quant + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + + @staticmethod + def from_float(module, decoder_layer_scales): + int8_module = Int8LlamaForCausalLM(module.config) + print("start trans into int8, this might take a while") + int8_module.model = Int8LlamaModel.from_float( + module.model, decoder_layer_scales) + int8_module.lm_head = module.lm_head + return int8_module + + get_input_embeddings = LlamaForCausalLM.get_input_embeddings + set_input_embeddings = LlamaForCausalLM.set_input_embeddings + get_output_embeddings = LlamaForCausalLM.get_output_embeddings + set_output_embeddings = LlamaForCausalLM.set_output_embeddings + set_decoder = LlamaForCausalLM.set_decoder + get_decoder = LlamaForCausalLM.get_decoder + forward = LlamaForCausalLM.forward + prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation + _reorder_cache = LlamaForCausalLM._reorder_cache \ No newline at end of file diff --git a/smoothquant/smooth.py b/smoothquant/smooth.py index 84527f9..053314e 100644 --- a/smoothquant/smooth.py +++ b/smoothquant/smooth.py @@ -3,6 +3,7 @@ from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers.models.bloom.modeling_bloom import BloomBlock +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaAttention, LlamaRMSNorm @torch.no_grad() @@ -30,6 +31,29 @@ def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5): fc.weight.mul_(scales.view(1, -1)) +def smooth_ln_fcs_llama(ln, fcs, act_scales, alpha=0.5): + if not isinstance(fcs, list): + fcs = [fcs] + assert isinstance(ln, LlamaRMSNorm) #llama use rmsnorm + for fc in fcs: + assert isinstance(fc, nn.Linear) + assert ln.weight.numel() == fc.in_features == act_scales.numel() + device, dtype = fcs[0].weight.device, fcs[0].weight.dtype + act_scales = act_scales.to(device=device, dtype=dtype) + weight_scales = torch.cat([fc.weight.abs().max( + dim=0, keepdim=True)[0] for fc in fcs], dim=0) + print(f"weight_scales shape: {weight_scales.shape}") + weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) + scales = (act_scales.pow(alpha) / weight_scales.pow(1-alpha) + ).clamp(min=1e-5).to(device).to(dtype) + + print(f'smoothed scales:{scales}') + # do layer norm smooth + ln.weight.div_(scales) + for fc in fcs: + fc.weight.mul_(scales.view(1, -1)) + + @torch.no_grad() def smooth_lm(model, scales, alpha=0.5): for name, module in model.named_modules(): @@ -54,3 +78,17 @@ def smooth_lm(model, scales, alpha=0.5): fc1 = module.mlp.dense_h_to_4h fc1_input_scales = scales[name + '.mlp.dense_h_to_4h'] smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha) + elif isinstance(module, LlamaDecoderLayer): + print(f"smooth llama decoder: {name}") + attn_ln = module.input_layernorm #attention forward norm + qkv = [module.self_attn.q_proj, + module.self_attn.k_proj, module.self_attn.v_proj] + qkv_input_scales = scales[name + '.self_attn.q_proj'] + smooth_ln_fcs_llama(attn_ln, qkv, qkv_input_scales, alpha) + + ffn_ln = module.post_attention_layernorm #feed forward norm + fcs = [module.mlp.gate_proj, module.mlp.up_proj] + fcs_input_scales = scales[name + '.mlp.gate_proj'] + + smooth_ln_fcs_llama(ffn_ln, fcs, fcs_input_scales, alpha) +