From ae775d9ba713a4d519a433442985343ab041b2d1 Mon Sep 17 00:00:00 2001 From: zhangpeng Date: Thu, 21 Sep 2023 16:00:16 +0800 Subject: [PATCH 1/7] add int8 llama --- examples/export_int8_llama.py | 56 ++++++ smoothquant/calibration.py | 79 ++++++++ smoothquant/llama.py | 368 ++++++++++++++++++++++++++++++++++ smoothquant/smooth.py | 38 ++++ 4 files changed, 541 insertions(+) create mode 100644 examples/export_int8_llama.py create mode 100644 smoothquant/llama.py diff --git a/examples/export_int8_llama.py b/examples/export_int8_llama.py new file mode 100644 index 0000000..b9c53ba --- /dev/null +++ b/examples/export_int8_llama.py @@ -0,0 +1,56 @@ +import torch +import argparse +import os + +from pathlib import Path + +from transformers import AutoTokenizer + +from smoothquant.llama import Int8LlamaForCausalLM +from smoothquant.smooth import smooth_lm + +from smoothquant.calibration import get_static_llama_decoder_layer_scales +from torch.nn.functional import pad + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--model-name", type=str, default='fp16_models/llama-13b') + parser.add_argument("--num-samples", type=int, default=512) + parser.add_argument("--seq-len", type=int, default=512) + parser.add_argument("--act-scales", type=str, + default='act_scales/llama-13b.pt') + parser.add_argument("--output-path", type=str, default='int8_models') + parser.add_argument('--dataset-path', type=str, default='dataset/val.jsonl.zst', + help='location of the calibration dataset, we use the validation set of the Pile dataset') + parser.add_argument('--export-FT', default=False, action="store_true") + args = parser.parse_args() + model = OPTForCausalLM.from_pretrained( + args.model_name, device_map="auto", torch_dtype=torch.float16) + act_scales = torch.load(args.act_scales) + smooth_lm(model, act_scales, 0.5) + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + + if not os.path.exists(args.dataset_path): + print(f'Cannot find the dataset at {args.dataset_path}') + print('Please download the Pile dataset and put the validation set at the path') + print('You can download the validation dataset of the Pile at https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst') + raise FileNotFoundError + + decoder_layer_scales, raw_scales = get_static_llama_decoder_layer_scales(model, + tokenizer, + args.dataset_path, + num_samples=args.num_samples, + seq_len=args.seq_len) + output_path = Path(args.output_path) / (Path(args.model_name).name + "-smoothquant.pt") + if args.export_FT: + model.save_pretrained(output_path) + print(f"Saved smoothed model at {output_path}") + + output_path = Path(args.output_path) / (Path(args.model_name).name + "-smoothquant-scales.pt") + torch.save(raw_scales, output_path) + print(f"Saved scaling factors at {output_path}") + else: + int8_model = Int8LlamaForCausalLM.from_float(model, decoder_layer_scales) + int8_model.save_pretrained(output_path) + print(f"Saved int8 model at {output_path}") \ No newline at end of file diff --git a/smoothquant/calibration.py b/smoothquant/calibration.py index 8d7906e..b6e8278 100644 --- a/smoothquant/calibration.py +++ b/smoothquant/calibration.py @@ -118,3 +118,82 @@ def stat_io_hook(m, x, y, name): decoder_layer_scales.append(scale_dict) return decoder_layer_scales, act_dict + +#TODO: merge to get_static_decoder_layer_scales method +@torch.no_grad() +def get_static_llama_decoder_layer_scales(model, + tokenizer, + dataset_path, + num_samples=512, + seq_len=512, + ): + model.eval() + device = next(model.parameters()).device + + act_dict = defaultdict(dict) + + def stat_io_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + if name not in act_dict or "input" not in act_dict[name]: + act_dict[name]["input"] = x.detach().abs().max().item() + else: + act_dict[name]["input"] = max( + act_dict[name]["input"], x.detach().abs().max().item()) + if isinstance(y, tuple): + y = y[0] + if name not in act_dict or "output" not in act_dict[name]: + act_dict[name]["output"] = y.detach().abs().max().item() + else: + act_dict[name]["output"] = max( + act_dict[name]["output"], y.detach().abs().max().item()) + + hooks = [] + for name, m in model.named_modules(): + if isinstance(m, torch.nn.Linear): + hooks.append(m.register_forward_hook( + partial(stat_io_hook, name=name))) + + print("Collecting activation scales...") + pbar = tqdm(range(num_samples)) + dataset = load_dataset('json', data_files=dataset_path, split="train") + dataset = dataset.shuffle(seed=42) + for i in pbar: + input_ids = tokenizer(dataset[i]["text"], return_tensors="pt", + max_length=seq_len, truncation=True).input_ids.to(device) + model(input_ids) + mean_scale = np.mean([v["input"] for v in act_dict.values()]) + pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + for hook in hooks: + hook.remove() + + decoder_layer_scales = [] + for idx in range(model.config.num_hidden_layers): + scale_dict = {} + # self attenion scales + scale_dict["attn_input_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.q_proj"]['input'] / 127 + scale_dict["q_output_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.q_proj"]['output'] / 127 + scale_dict["k_output_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.k_proj"]['output'] / 127 + scale_dict["v_output_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.v_proj"]['output'] / 127 + scale_dict["out_input_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.o_proj"]['input'] / 127 + # mlp scales + scale_dict["gate_input_scale"] = act_dict[ + f"model.layers.{idx}.mlp.gate_proj"]['input'] / 127 + scale_dict["up_input_scale"] = act_dict[ + f"model.layers.{idx}.mlp.up_proj"]["input"] / 127 + scale_dict["down_input_scale"] = act_dict[ + f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 + scale_dict["gate_output_scale"] = act_dict[ + f"model.layers.{idx}.mlp.gate_proj"]['output'] / 127 + scale_dict["up_output_scale"] = act_dict[ + f"model.layers.{idx}.mlp.up_proj"]["output"] / 127 + scale_dict["down_output_scale"] = act_dict[ + f"model.layers.{idx}.mlp.down_proj"]["output"] / 127 + decoder_layer_scales.append(scale_dict) + + return decoder_layer_scales, act_dict \ No newline at end of file diff --git a/smoothquant/llama.py b/smoothquant/llama.py new file mode 100644 index 0000000..bb315a9 --- /dev/null +++ b/smoothquant/llama.py @@ -0,0 +1,368 @@ +import torch +import math +from torch import nn +from transformers.models.llama.modeling_llama import ( + LlamaRMSNorm, + LlamaRotaryEmbedding, + LlamaMLP, + LlamaAttention, + LlamaDecoderLayer, + LlamaPreTrainedModel, + LlamaModel, + LlamaForCausalLM, + LlamaForSequenceClassification, + BaseModelOutputWithPast +) +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.activations import SiLUActivation +from typing import Optional, Tuple, List +# down, out: W8A8BFP32OFP32Linear, q, k, v: W8A8BFP32OFP32Linear, gate: W8A8BFP32OFP32Linear, up: W8A8B8O8Linear +from torch_int.nn.linear import W8A8B8O8LinearWithSFactor, W8A8BFP32OFP32LinearWithSFactor +from smoothquant.fake_quant import W8A8Linear +from transformers.utils import logging +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "LlamaConfig" +# attention is the same as opt +class Int8LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + def __init__( + self, + config: LlamaConfig + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + self.out_input_scale = 0. + # hidden_size is embed_dim in OptAttetion + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + + self.k_proj = W8A8BFP32OFP32LinearWithSFactor(self.hidden_size, self.num_heads * self.head_dim) + self.v_proj = W8A8BFP32OFP32LinearWithSFactor(self.hidden_size, self.num_heads * self.head_dim) + self.q_proj = W8A8BFP32OFP32LinearWithSFactor(self.hidden_size, self.num_heads * self.head_dim) + # out is fp32 + self.o_proj = W8A8BFP32OFP32LinearWithSFactor(self.num_heads * self.head_dim, self.hidden_size) + + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + _shape = LlamaAttention._shape + @staticmethod + @torch.no_grad() + def from_float(module: LlamaAttention, + config: LlamaConfig, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + out_input_scale: float): + int8_module = Int8LlamaAttention(config) + + print("turning attention into w8a8liner") + logger.info("turning attention into w8a8liner") + # Fuse the scaling into the q_proj output scale + + # we do not impelement attn for now bacuase we want use paged attention + int8_module.q_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( + module.q_proj, attn_input_scale) + int8_module.k_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( + module.k_proj, attn_input_scale) + int8_module.v_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( + module.v_proj, attn_input_scale) + int8_module.o_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( + module.o_proj, out_input_scale) + + return int8_module + @torch.no_grad() + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + # already quant before attention + query_states = self.q_proj(hidden_states).to(torch.float16) + key_states = self.k_proj(hidden_states).to(torch.float16) + value_states = self.v_proj(hidden_states).to(torch.float16) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + past_key_value = (key_states, value_states) if use_cache else None + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + # quant method from torch-int + attn_output = self.o_proj(attn_output) + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + +class Int8LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.register_buffer('weight', torch.ones(hidden_size, dtype=torch.float32, requires_grad=False)) + self.variance_epsilon = eps + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + out = self.weight * hidden_states + return int8_out + @staticmethod + def from_float(module: LlamaRMSNorm, + output_scale: float): + int8_norm = Int8LlamaRMSNorm(module.weight.numel(), module.variance_epsilon) + + int8_norm.weight.to(module.weight.dtype) + int8_norm.weight = module.weight / output_scale + + return int8_norm + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class Int8LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.down_input_scale = 0. + # need fp32 out bcause silu + self.gate_proj = W8A8BFP32OFP32LinearWithSFactor(self.hidden_size, self.intermediate_size) + + self.up_proj = W8A8BFP32OFP32LinearWithSFactor(self.hidden_size, self.intermediate_size) + self.down_proj = W8A8BFP32OFP32LinearWithSFactor(self.intermediate_size, self.hidden_size) + # silu_and_mul_kernel in vLLM can be a reference of SwiGLU + self.act_fn = SiLUActivation() + @staticmethod + @torch.no_grad() + def from_float(module: LlamaMLP, + config: LlamaConfig, + gate_input_scale: float, + gate_output_scale: float, + up_input_scale: float, + up_output_scale: float, + down_input_scale: float, + down_output_scale: float): + int8Mlp = Int8LlamaMLP(config) + # TODO: kernel fusion + int8Mlp.gate_proj = W8A8BFP32OFP32LinearWithSFactor.from_float(module.gate_proj, gate_input_scale) + int8Mlp.up_proj = W8A8BFP32OFP32LinearWithSFactor.from_float(module.up_proj, up_input_scale) + int8Mlp.down_proj = W8A8BFP32OFP32LinearWithSFactor.from_float(module.down_proj, down_input_scale) + + return int8Mlp + + def forward(self, x): + # TODO: supprot self.config.pretraining_tp > 1 condition, adapt from transformer.modeling_llama + hidden = self.act_fn(self.gate_proj(x).to(torch.float16)) + hidden = hidden * self.up_proj(x) + return self.down_proj(hidden) + +class Int8LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Int8LlamaAttention(config=config) + self.mlp = Int8LlamaMLP(config) + #FIXME: use int8 rmsnorm + self.input_layernorm = Int8LlamaRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Int8LlamaRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + @staticmethod + def from_float(module: LlamaDecoderLayer, + config: LlamaConfig, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + out_input_scale: float, + gate_input_scale: float, + up_input_scale: float, + down_input_scale: float, + gate_output_scale: float, + up_output_scale: float, + down_output_scale: float + ): + int8_module = Int8LlamaDecoderLayer( + config + ) + print("turn each layer mlp and attention to int8") + logger.info("turn each layer mlp and attention to int8") + #FIXME: use int8 rmsnorm, torch_int LayerNormQ can be a reference + int8_module.self_attn = Int8LlamaAttention.from_float( + module.self_attn, + config, + attn_input_scale, + q_output_scale, + k_output_scale, + v_output_scale, + out_input_scale + ) + + int8_module.mlp = Int8LlamaMLP.from_float( + module.mlp, + config, + gate_input_scale, + gate_output_scale, + up_input_scale, + up_output_scale, + down_input_scale, + down_output_scale + ) + int8_module.input_layernorm = Int8LlamaRMSNorm.from_float( + module.input_layernorm, + attn_input_scale + ) + int8_module.post_attention_layernorm = Int8LlamaRMSNorm.from_float( + module.post_attention_layernorm, + gate_input_scale + ) + return int8_module + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + residual.add_(hidden_states.to(residual.dtype)) + + # mlp + hidden_states = self.post_attention_layernorm(residual) + hidden_states = self.mlp(hidden_states) + residual.add_(hidden_states.to(residual.dtype)) + outputs = (residual,) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + return outputs +class Int8LlamaModel(LlamaPreTrainedModel): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([Int8LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + get_input_embeddings = LlamaModel.get_input_embeddings + set_input_embeddings = LlamaModel.set_input_embeddings + _prepare_decoder_attention_mask = LlamaModel._prepare_decoder_attention_mask + # iter self.layers and calcu forward + forward = LlamaModel.forward + + @staticmethod + def from_float(module, decoder_layer_scales): + int8_module = Int8LlamaModel(module.config) + + int8_module.embed_tokens = module.embed_tokens + int8_module.norm = module.norm + + print("turn layers from float to int8") + logger.info("turn layers from float to int8") + for i, layer in enumerate(module.layers): + int8_module.layers[i] = Int8LlamaDecoderLayer.from_float( + layer, module.config, **decoder_layer_scales[i]) + return int8_module +class Int8LlamaForCausalLM(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.model = Int8LlamaModel(config) + # no need to quant + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + + @staticmethod + def from_float(module, decoder_layer_scales): + print("create int8 model") + int8_module = Int8LlamaForCausalLM(module.config) + print("start turn into int8") + int8_module.model = Int8LlamaModel.from_float( + module.model, decoder_layer_scales) + int8_module.lm_head = module.lm_head + return int8_module + get_input_embeddings = LlamaForCausalLM.get_input_embeddings + set_input_embeddings = LlamaForCausalLM.set_input_embeddings + get_output_embeddings = LlamaForCausalLM.get_output_embeddings + set_output_embeddings = LlamaForCausalLM.set_output_embeddings + set_decoder = LlamaForCausalLM.set_decoder + get_decoder = LlamaForCausalLM.get_decoder + forward = LlamaForCausalLM.forward + prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation + _reorder_cache = LlamaForCausalLM._reorder_cache \ No newline at end of file diff --git a/smoothquant/smooth.py b/smoothquant/smooth.py index 84527f9..053314e 100644 --- a/smoothquant/smooth.py +++ b/smoothquant/smooth.py @@ -3,6 +3,7 @@ from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers.models.bloom.modeling_bloom import BloomBlock +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaAttention, LlamaRMSNorm @torch.no_grad() @@ -30,6 +31,29 @@ def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5): fc.weight.mul_(scales.view(1, -1)) +def smooth_ln_fcs_llama(ln, fcs, act_scales, alpha=0.5): + if not isinstance(fcs, list): + fcs = [fcs] + assert isinstance(ln, LlamaRMSNorm) #llama use rmsnorm + for fc in fcs: + assert isinstance(fc, nn.Linear) + assert ln.weight.numel() == fc.in_features == act_scales.numel() + device, dtype = fcs[0].weight.device, fcs[0].weight.dtype + act_scales = act_scales.to(device=device, dtype=dtype) + weight_scales = torch.cat([fc.weight.abs().max( + dim=0, keepdim=True)[0] for fc in fcs], dim=0) + print(f"weight_scales shape: {weight_scales.shape}") + weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) + scales = (act_scales.pow(alpha) / weight_scales.pow(1-alpha) + ).clamp(min=1e-5).to(device).to(dtype) + + print(f'smoothed scales:{scales}') + # do layer norm smooth + ln.weight.div_(scales) + for fc in fcs: + fc.weight.mul_(scales.view(1, -1)) + + @torch.no_grad() def smooth_lm(model, scales, alpha=0.5): for name, module in model.named_modules(): @@ -54,3 +78,17 @@ def smooth_lm(model, scales, alpha=0.5): fc1 = module.mlp.dense_h_to_4h fc1_input_scales = scales[name + '.mlp.dense_h_to_4h'] smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha) + elif isinstance(module, LlamaDecoderLayer): + print(f"smooth llama decoder: {name}") + attn_ln = module.input_layernorm #attention forward norm + qkv = [module.self_attn.q_proj, + module.self_attn.k_proj, module.self_attn.v_proj] + qkv_input_scales = scales[name + '.self_attn.q_proj'] + smooth_ln_fcs_llama(attn_ln, qkv, qkv_input_scales, alpha) + + ffn_ln = module.post_attention_layernorm #feed forward norm + fcs = [module.mlp.gate_proj, module.mlp.up_proj] + fcs_input_scales = scales[name + '.mlp.gate_proj'] + + smooth_ln_fcs_llama(ffn_ln, fcs, fcs_input_scales, alpha) + From 3f8a0d0aa31aebe96ef117cff9ce6f5d5ef8d034 Mon Sep 17 00:00:00 2001 From: zhangpeng Date: Thu, 21 Sep 2023 16:15:43 +0800 Subject: [PATCH 2/7] remove debug info --- smoothquant/llama.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/smoothquant/llama.py b/smoothquant/llama.py index bb315a9..8b5f6ce 100644 --- a/smoothquant/llama.py +++ b/smoothquant/llama.py @@ -16,8 +16,8 @@ from transformers.models.llama.configuration_llama import LlamaConfig from transformers.activations import SiLUActivation from typing import Optional, Tuple, List -# down, out: W8A8BFP32OFP32Linear, q, k, v: W8A8BFP32OFP32Linear, gate: W8A8BFP32OFP32Linear, up: W8A8B8O8Linear -from torch_int.nn.linear import W8A8B8O8LinearWithSFactor, W8A8BFP32OFP32LinearWithSFactor +# must use branch llama-dev in https://github.com/AniZpZ/torch-int +from torch_int.nn.linear import W8A8BFP32OFP32LinearWithSFactor from smoothquant.fake_quant import W8A8Linear from transformers.utils import logging logger = logging.get_logger(__name__) @@ -62,10 +62,7 @@ def from_float(module: LlamaAttention, out_input_scale: float): int8_module = Int8LlamaAttention(config) - print("turning attention into w8a8liner") - logger.info("turning attention into w8a8liner") # Fuse the scaling into the q_proj output scale - # we do not impelement attn for now bacuase we want use paged attention int8_module.q_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( module.q_proj, attn_input_scale) @@ -241,9 +238,7 @@ def from_float(module: LlamaDecoderLayer, int8_module = Int8LlamaDecoderLayer( config ) - print("turn each layer mlp and attention to int8") - logger.info("turn each layer mlp and attention to int8") - #FIXME: use int8 rmsnorm, torch_int LayerNormQ can be a reference + int8_module.self_attn = Int8LlamaAttention.from_float( module.self_attn, config, @@ -350,9 +345,9 @@ def __init__(self, config): @staticmethod def from_float(module, decoder_layer_scales): - print("create int8 model") + # print("create int8 model") int8_module = Int8LlamaForCausalLM(module.config) - print("start turn into int8") + # print("start turn into int8") int8_module.model = Int8LlamaModel.from_float( module.model, decoder_layer_scales) int8_module.lm_head = module.lm_head From 8d19ff0458939791907c588dfdf8280fc482824e Mon Sep 17 00:00:00 2001 From: zhangpeng Date: Thu, 21 Sep 2023 17:36:16 +0800 Subject: [PATCH 3/7] change linears for further fusion --- smoothquant/llama.py | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/smoothquant/llama.py b/smoothquant/llama.py index 8b5f6ce..a59c101 100644 --- a/smoothquant/llama.py +++ b/smoothquant/llama.py @@ -17,7 +17,7 @@ from transformers.activations import SiLUActivation from typing import Optional, Tuple, List # must use branch llama-dev in https://github.com/AniZpZ/torch-int -from torch_int.nn.linear import W8A8BFP32OFP32LinearWithSFactor +from torch_int.nn.linear import W8A8BFP32OFP32LinearWithSFactor, W8A8BFP32OFP32Linear from smoothquant.fake_quant import W8A8Linear from transformers.utils import logging logger = logging.get_logger(__name__) @@ -43,14 +43,16 @@ def __init__( f" and `num_heads`: {self.num_heads})." ) - self.k_proj = W8A8BFP32OFP32LinearWithSFactor(self.hidden_size, self.num_heads * self.head_dim) - self.v_proj = W8A8BFP32OFP32LinearWithSFactor(self.hidden_size, self.num_heads * self.head_dim) - self.q_proj = W8A8BFP32OFP32LinearWithSFactor(self.hidden_size, self.num_heads * self.head_dim) + self.k_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_heads * self.head_dim) + self.v_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_heads * self.head_dim) + self.q_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_heads * self.head_dim) # out is fp32 self.o_proj = W8A8BFP32OFP32LinearWithSFactor(self.num_heads * self.head_dim, self.hidden_size) self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + _shape = LlamaAttention._shape + @staticmethod @torch.no_grad() def from_float(module: LlamaAttention, @@ -64,16 +66,16 @@ def from_float(module: LlamaAttention, # Fuse the scaling into the q_proj output scale # we do not impelement attn for now bacuase we want use paged attention - int8_module.q_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( + int8_module.q_proj = W8A8BFP32OFP32Linear.from_float( module.q_proj, attn_input_scale) - int8_module.k_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( + int8_module.k_proj = W8A8BFP32OFP32Linear.from_float( module.k_proj, attn_input_scale) - int8_module.v_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( + int8_module.v_proj = W8A8BFP32OFP32Linear.from_float( module.v_proj, attn_input_scale) int8_module.o_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( module.o_proj, out_input_scale) - return int8_module + @torch.no_grad() def forward( self, @@ -131,6 +133,7 @@ def forward( attn_weights = None return attn_output, attn_weights, past_key_value +# we keep scale in LlamaRMSNorm layer for kernel fusion class Int8LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -139,6 +142,7 @@ def __init__(self, hidden_size, eps=1e-6): super().__init__() self.register_buffer('weight', torch.ones(hidden_size, dtype=torch.float32, requires_grad=False)) self.variance_epsilon = eps + def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) @@ -146,7 +150,9 @@ def forward(self, hidden_states): if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) out = self.weight * hidden_states + int8_out = out.round().clamp(-128, 127).to(torch.int8) return int8_out + @staticmethod def from_float(module: LlamaRMSNorm, output_scale: float): @@ -181,12 +187,13 @@ def __init__(self, config): self.intermediate_size = config.intermediate_size self.down_input_scale = 0. # need fp32 out bcause silu - self.gate_proj = W8A8BFP32OFP32LinearWithSFactor(self.hidden_size, self.intermediate_size) + self.gate_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.intermediate_size) - self.up_proj = W8A8BFP32OFP32LinearWithSFactor(self.hidden_size, self.intermediate_size) + self.up_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.intermediate_size) self.down_proj = W8A8BFP32OFP32LinearWithSFactor(self.intermediate_size, self.hidden_size) # silu_and_mul_kernel in vLLM can be a reference of SwiGLU self.act_fn = SiLUActivation() + @staticmethod @torch.no_grad() def from_float(module: LlamaMLP, @@ -199,8 +206,8 @@ def from_float(module: LlamaMLP, down_output_scale: float): int8Mlp = Int8LlamaMLP(config) # TODO: kernel fusion - int8Mlp.gate_proj = W8A8BFP32OFP32LinearWithSFactor.from_float(module.gate_proj, gate_input_scale) - int8Mlp.up_proj = W8A8BFP32OFP32LinearWithSFactor.from_float(module.up_proj, up_input_scale) + int8Mlp.gate_proj = W8A8BFP32OFP32Linear.from_float(module.gate_proj, gate_input_scale) + int8Mlp.up_proj = W8A8BFP32OFP32Linear.from_float(module.up_proj, up_input_scale) int8Mlp.down_proj = W8A8BFP32OFP32LinearWithSFactor.from_float(module.down_proj, down_input_scale) return int8Mlp @@ -268,6 +275,7 @@ def from_float(module: LlamaDecoderLayer, gate_input_scale ) return int8_module + def forward( self, hidden_states: torch.Tensor, @@ -301,6 +309,7 @@ def forward( if use_cache: outputs += (present_key_value,) return outputs + class Int8LlamaModel(LlamaPreTrainedModel): def __init__(self, config: LlamaConfig): super().__init__(config) @@ -327,12 +336,11 @@ def from_float(module, decoder_layer_scales): int8_module.embed_tokens = module.embed_tokens int8_module.norm = module.norm - print("turn layers from float to int8") - logger.info("turn layers from float to int8") for i, layer in enumerate(module.layers): int8_module.layers[i] = Int8LlamaDecoderLayer.from_float( layer, module.config, **decoder_layer_scales[i]) return int8_module + class Int8LlamaForCausalLM(LlamaPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -345,13 +353,13 @@ def __init__(self, config): @staticmethod def from_float(module, decoder_layer_scales): - # print("create int8 model") int8_module = Int8LlamaForCausalLM(module.config) - # print("start turn into int8") + print("start trans into int8, this might take a while") int8_module.model = Int8LlamaModel.from_float( module.model, decoder_layer_scales) int8_module.lm_head = module.lm_head return int8_module + get_input_embeddings = LlamaForCausalLM.get_input_embeddings set_input_embeddings = LlamaForCausalLM.set_input_embeddings get_output_embeddings = LlamaForCausalLM.get_output_embeddings From 0a26dab348c6a276f6a4fec607a9b8706d65b7f0 Mon Sep 17 00:00:00 2001 From: zhangpeng Date: Tue, 26 Sep 2023 20:18:17 +0800 Subject: [PATCH 4/7] quant fusion --- smoothquant/llama.py | 43 ++++++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/smoothquant/llama.py b/smoothquant/llama.py index a59c101..db3a9bf 100644 --- a/smoothquant/llama.py +++ b/smoothquant/llama.py @@ -64,14 +64,22 @@ def from_float(module: LlamaAttention, out_input_scale: float): int8_module = Int8LlamaAttention(config) - # Fuse the scaling into the q_proj output scale # we do not impelement attn for now bacuase we want use paged attention - int8_module.q_proj = W8A8BFP32OFP32Linear.from_float( - module.q_proj, attn_input_scale) - int8_module.k_proj = W8A8BFP32OFP32Linear.from_float( - module.k_proj, attn_input_scale) - int8_module.v_proj = W8A8BFP32OFP32Linear.from_float( - module.v_proj, attn_input_scale) + + # FIXME: Fuse the scaling into the q_proj output scale + linearList = [module.q_proj, module.k_proj, module.v_proj] + + qkv_list = W8A8BFP32OFP32Linear.from_float_fuse( + linearList, + attn_input_scale) + if len(qkv_list) != 3: + raise ValueError( + f"invalid qkv list len, must return 3 linears but get {len(qkv_list)}") + + int8_module.q_proj = qkv_list[0] + int8_module.k_proj = qkv_list[1] + int8_module.v_proj = qkv_list[2] + int8_module.o_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( module.o_proj, out_input_scale) return int8_module @@ -205,10 +213,23 @@ def from_float(module: LlamaMLP, down_input_scale: float, down_output_scale: float): int8Mlp = Int8LlamaMLP(config) - # TODO: kernel fusion - int8Mlp.gate_proj = W8A8BFP32OFP32Linear.from_float(module.gate_proj, gate_input_scale) - int8Mlp.up_proj = W8A8BFP32OFP32Linear.from_float(module.up_proj, up_input_scale) - int8Mlp.down_proj = W8A8BFP32OFP32LinearWithSFactor.from_float(module.down_proj, down_input_scale) + + # FIXME: Fuse the scaling into the q_proj output scale + print(f"gate in {gate_input_scale}, up in {up_input_scale}") + linearList = [module.gate_proj, module.up_proj] + gateup_list = W8A8BFP32OFP32Linear.from_float_fuse( + linearList, + gate_input_scale) + + if len(gateup_list) != 2: + raise ValueError( + f"invalid qkv gateup len, must return 2 linears but get {len(qkv_list)}") + + int8Mlp.gate_proj = gateup_list[0] + int8Mlp.up_proj = gateup_list[1] + int8Mlp.down_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( + module.down_proj, + down_input_scale) return int8Mlp From ab15d591b151deb0457eb71dbf0045b8ee1058f7 Mon Sep 17 00:00:00 2001 From: zhangpeng Date: Wed, 27 Sep 2023 11:05:58 +0800 Subject: [PATCH 5/7] rm debug code --- smoothquant/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/smoothquant/llama.py b/smoothquant/llama.py index db3a9bf..ee175e4 100644 --- a/smoothquant/llama.py +++ b/smoothquant/llama.py @@ -215,7 +215,6 @@ def from_float(module: LlamaMLP, int8Mlp = Int8LlamaMLP(config) # FIXME: Fuse the scaling into the q_proj output scale - print(f"gate in {gate_input_scale}, up in {up_input_scale}") linearList = [module.gate_proj, module.up_proj] gateup_list = W8A8BFP32OFP32Linear.from_float_fuse( linearList, From 3f063c4f830af769fc4c8c9d4c3efa722feba39a Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Fri, 10 Nov 2023 15:02:10 +0800 Subject: [PATCH 6/7] fix bugs --- examples/export_int8_llama.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/export_int8_llama.py b/examples/export_int8_llama.py index b9c53ba..7f48355 100644 --- a/examples/export_int8_llama.py +++ b/examples/export_int8_llama.py @@ -5,6 +5,7 @@ from pathlib import Path from transformers import AutoTokenizer +from transformers.models.llama.modeling_llama import LlamaForCausalLM from smoothquant.llama import Int8LlamaForCausalLM from smoothquant.smooth import smooth_lm @@ -25,7 +26,7 @@ help='location of the calibration dataset, we use the validation set of the Pile dataset') parser.add_argument('--export-FT', default=False, action="store_true") args = parser.parse_args() - model = OPTForCausalLM.from_pretrained( + model = LlamaForCausalLM.from_pretrained( args.model_name, device_map="auto", torch_dtype=torch.float16) act_scales = torch.load(args.act_scales) smooth_lm(model, act_scales, 0.5) @@ -42,7 +43,7 @@ args.dataset_path, num_samples=args.num_samples, seq_len=args.seq_len) - output_path = Path(args.output_path) / (Path(args.model_name).name + "-smoothquant.pt") + output_path = Path(args.output_path) / (Path(args.model_name).name + "-smoothquant") if args.export_FT: model.save_pretrained(output_path) print(f"Saved smoothed model at {output_path}") From bf6c4c0425679ce0345574bc697af088a394d3fe Mon Sep 17 00:00:00 2001 From: zhangying169 Date: Wed, 13 Dec 2023 11:38:53 +0800 Subject: [PATCH 7/7] optimize dequant scale --- examples/export_int8_llama.py | 2 +- smoothquant/llama.py | 128 ++++++++++++++++++++-------------- 2 files changed, 76 insertions(+), 54 deletions(-) diff --git a/examples/export_int8_llama.py b/examples/export_int8_llama.py index 7f48355..94274d5 100644 --- a/examples/export_int8_llama.py +++ b/examples/export_int8_llama.py @@ -43,7 +43,7 @@ args.dataset_path, num_samples=args.num_samples, seq_len=args.seq_len) - output_path = Path(args.output_path) / (Path(args.model_name).name + "-smoothquant") + output_path = Path(args.output_path) / ("llama-" + Path(args.model_name).name + "-smoothquant-per-token-opt") if args.export_FT: model.save_pretrained(output_path) print(f"Saved smoothed model at {output_path}") diff --git a/smoothquant/llama.py b/smoothquant/llama.py index ee175e4..2f7f241 100644 --- a/smoothquant/llama.py +++ b/smoothquant/llama.py @@ -50,9 +50,9 @@ def __init__( self.o_proj = W8A8BFP32OFP32LinearWithSFactor(self.num_heads * self.head_dim, self.hidden_size) self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - + _shape = LlamaAttention._shape - + @staticmethod @torch.no_grad() def from_float(module: LlamaAttention, @@ -63,27 +63,40 @@ def from_float(module: LlamaAttention, v_output_scale: float, out_input_scale: float): int8_module = Int8LlamaAttention(config) - + # we do not impelement attn for now bacuase we want use paged attention - + # FIXME: Fuse the scaling into the q_proj output scale - linearList = [module.q_proj, module.k_proj, module.v_proj] - - qkv_list = W8A8BFP32OFP32Linear.from_float_fuse( - linearList, - attn_input_scale) - if len(qkv_list) != 3: - raise ValueError( - f"invalid qkv list len, must return 3 linears but get {len(qkv_list)}") + # linearList = [module.q_proj, module.k_proj, module.v_proj] + + # qkv_list = W8A8BFP32OFP32Linear.from_float_fuse( + # linearList, + # attn_input_scale) + # if len(qkv_list) != 3: + # raise ValueError( + # f"invalid qkv list len, must return 3 linears but get {len(qkv_list)}") - int8_module.q_proj = qkv_list[0] - int8_module.k_proj = qkv_list[1] - int8_module.v_proj = qkv_list[2] + # int8_module.q_proj = qkv_list[0] + # int8_module.k_proj = qkv_list[1] + # int8_module.v_proj = qkv_list[2] + + int8_module.q_proj = W8A8BFP32OFP32Linear.from_float( + module.q_proj, + attn_input_scale + ) + int8_module.k_proj = W8A8BFP32OFP32Linear.from_float( + module.k_proj, + attn_input_scale + ) + int8_module.v_proj = W8A8BFP32OFP32Linear.from_float( + module.v_proj, + attn_input_scale + ) int8_module.o_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( module.o_proj, out_input_scale) return int8_module - + @torch.no_grad() def forward( self, @@ -150,7 +163,7 @@ def __init__(self, hidden_size, eps=1e-6): super().__init__() self.register_buffer('weight', torch.ones(hidden_size, dtype=torch.float32, requires_grad=False)) self.variance_epsilon = eps - + def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) @@ -160,7 +173,7 @@ def forward(self, hidden_states): out = self.weight * hidden_states int8_out = out.round().clamp(-128, 127).to(torch.int8) return int8_out - + @staticmethod def from_float(module: LlamaRMSNorm, output_scale: float): @@ -187,7 +200,9 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed + class Int8LlamaMLP(nn.Module): + def __init__(self, config): super().__init__() self.config = config @@ -195,43 +210,50 @@ def __init__(self, config): self.intermediate_size = config.intermediate_size self.down_input_scale = 0. # need fp32 out bcause silu - self.gate_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.intermediate_size) + self.gate_proj = W8A8BFP32OFP32Linear(self.hidden_size, + self.intermediate_size) - self.up_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.intermediate_size) - self.down_proj = W8A8BFP32OFP32LinearWithSFactor(self.intermediate_size, self.hidden_size) + self.up_proj = W8A8BFP32OFP32Linear(self.hidden_size, + self.intermediate_size) + self.down_proj = W8A8BFP32OFP32LinearWithSFactor( + self.intermediate_size, self.hidden_size) # silu_and_mul_kernel in vLLM can be a reference of SwiGLU self.act_fn = SiLUActivation() - + @staticmethod @torch.no_grad() - def from_float(module: LlamaMLP, - config: LlamaConfig, - gate_input_scale: float, - gate_output_scale: float, - up_input_scale: float, - up_output_scale: float, - down_input_scale: float, - down_output_scale: float): + def from_float(module: LlamaMLP, config: LlamaConfig, + gate_input_scale: float, gate_output_scale: float, + up_input_scale: float, up_output_scale: float, + down_input_scale: float, down_output_scale: float): int8Mlp = Int8LlamaMLP(config) # FIXME: Fuse the scaling into the q_proj output scale - linearList = [module.gate_proj, module.up_proj] - gateup_list = W8A8BFP32OFP32Linear.from_float_fuse( - linearList, - gate_input_scale) + # linearList = [module.gate_proj, module.up_proj] + # gateup_list = W8A8BFP32OFP32Linear.from_float_fuse( + # linearList, + # gate_input_scale) - if len(gateup_list) != 2: - raise ValueError( - f"invalid qkv gateup len, must return 2 linears but get {len(qkv_list)}") + # if len(gateup_list) != 2: + # raise ValueError( + # f"invalid qkv gateup len, must return 2 linears but get {len(qkv_list)}") - int8Mlp.gate_proj = gateup_list[0] - int8Mlp.up_proj = gateup_list[1] + # int8Mlp.gate_proj = gateup_list[0] + # int8Mlp.up_proj = gateup_list[1] + + int8Mlp.gate_proj = W8A8BFP32OFP32Linear.from_float( + module.gate_proj, + gate_input_scale + ) + int8Mlp.up_proj = W8A8BFP32OFP32Linear.from_float( + module.up_proj, + gate_input_scale + ) int8Mlp.down_proj = W8A8BFP32OFP32LinearWithSFactor.from_float( - module.down_proj, - down_input_scale) + module.down_proj, down_input_scale) return int8Mlp - + def forward(self, x): # TODO: supprot self.config.pretraining_tp > 1 condition, adapt from transformer.modeling_llama hidden = self.act_fn(self.gate_proj(x).to(torch.float16)) @@ -267,7 +289,7 @@ def from_float(module: LlamaDecoderLayer, ) int8_module.self_attn = Int8LlamaAttention.from_float( - module.self_attn, + module.self_attn, config, attn_input_scale, q_output_scale, @@ -275,9 +297,9 @@ def from_float(module: LlamaDecoderLayer, v_output_scale, out_input_scale ) - + int8_module.mlp = Int8LlamaMLP.from_float( - module.mlp, + module.mlp, config, gate_input_scale, gate_output_scale, @@ -295,7 +317,7 @@ def from_float(module: LlamaDecoderLayer, gate_input_scale ) return int8_module - + def forward( self, hidden_states: torch.Tensor, @@ -318,7 +340,7 @@ def forward( use_cache=use_cache, ) residual.add_(hidden_states.to(residual.dtype)) - + # mlp hidden_states = self.post_attention_layernorm(residual) hidden_states = self.mlp(hidden_states) @@ -342,20 +364,20 @@ def __init__(self, config: LlamaConfig): self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() - + get_input_embeddings = LlamaModel.get_input_embeddings set_input_embeddings = LlamaModel.set_input_embeddings _prepare_decoder_attention_mask = LlamaModel._prepare_decoder_attention_mask # iter self.layers and calcu forward forward = LlamaModel.forward - + @staticmethod def from_float(module, decoder_layer_scales): int8_module = Int8LlamaModel(module.config) - + int8_module.embed_tokens = module.embed_tokens int8_module.norm = module.norm - + for i, layer in enumerate(module.layers): int8_module.layers[i] = Int8LlamaDecoderLayer.from_float( layer, module.config, **decoder_layer_scales[i]) @@ -370,7 +392,7 @@ def __init__(self, config): self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() - + @staticmethod def from_float(module, decoder_layer_scales): int8_module = Int8LlamaForCausalLM(module.config) @@ -379,7 +401,7 @@ def from_float(module, decoder_layer_scales): module.model, decoder_layer_scales) int8_module.lm_head = module.lm_head return int8_module - + get_input_embeddings = LlamaForCausalLM.get_input_embeddings set_input_embeddings = LlamaForCausalLM.set_input_embeddings get_output_embeddings = LlamaForCausalLM.get_output_embeddings