diff --git a/.github/workflows/script/unitTest/coverage/.optimize-coveragerc b/.github/workflows/script/unitTest/coverage/.optimize-coveragerc index 7503552f189..469ea6c0e38 100644 --- a/.github/workflows/script/unitTest/coverage/.optimize-coveragerc +++ b/.github/workflows/script/unitTest/coverage/.optimize-coveragerc @@ -18,6 +18,7 @@ omit = */intel_extension_for_transformers/langchain/** */intel_extension_for_transformers/llama_index/** */intel_extension_for_transformers/transformers/utils/get_throughput.py + */intel_extension_for_transformers/transformers/kv_cache_compression/** exclude_lines = pragma: no cover raise NotImplementedError diff --git a/docs/h2o.md b/docs/h2o.md new file mode 100644 index 00000000000..d7b80f0108f --- /dev/null +++ b/docs/h2o.md @@ -0,0 +1,49 @@ +# H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models +1. [Introduction](#introduction) +2. [Usage](#usage) + +## Introduction +**Heavy-Hitter Oracal (H2O)** is a novel approach for implementing the KV cache which significantly reduces memory footprint. + +This methods base on the fact that the accumulated attention scores of all tokens in attention blocks adhere to a power-law distribution. It suggests that there exists a small set of influential tokens that are critical during generation, named heavy-hitters (H2). H2 provides an opportunity to step away from the combinatorial search problem and identify an eviction policy that maintains accuracy. + +H2O can dynamically retains the balance of recent and H2 tokens. Significantly increase model throughput while ensuring accuracy. + + +For more info, please refer to the paper [H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models](https://arxiv.org/pdf/2306.14048). + + +![](./imgs/h2o.png) + + +## Usage +Using simulation mode +```python +from intel_extension_for_transformers.transformers.kv_cache_compression import H2OConfig, LlamaForCausalLM +h2o_config = H2OConfig( + heavy_ratio=heavy_ratio, + recent_ratio=recent_ratio, + h2o_min_seqlen=h2o_min_seqlen, + real_drop=False, +) +user_model = LlamaForCausalLM.from_pretrained( + args.model, + prune_config=h2o_config, + trust_remote_code=args.trust_remote_code) +``` +To run the real_drop mode +```python +from intel_extension_for_transformers.transformers.kv_cache_compression import H2OConfig, LlamaForCausalLM +h2o_config = H2OConfig( + heavy_ratio=heavy_ratio, + recent_ratio=recent_ratio, + h2o_min_seqlen=h2o_min_seqlen, + real_drop=True, +) +user_model = LlamaForCausalLM.from_pretrained( + args.model, + prune_config=h2o_config, + trust_remote_code=args.trust_remote_code) +``` + +Please refer to [h2o example](../examples/huggingface/pytorch/text-generation/h2o/run_generation.py) for the details. diff --git a/docs/imgs/h2o.png b/docs/imgs/h2o.png new file mode 100644 index 00000000000..3cd5c8ff156 Binary files /dev/null and b/docs/imgs/h2o.png differ diff --git a/examples/huggingface/pytorch/text-generation/h2o/README.md b/examples/huggingface/pytorch/text-generation/h2o/README.md new file mode 100644 index 00000000000..22064c00b29 --- /dev/null +++ b/examples/huggingface/pytorch/text-generation/h2o/README.md @@ -0,0 +1,47 @@ +# H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models + +**Heavy-Hitter Oracal (H2O)** is a novel approach for implementing the KV cache which significantly reduces memory footprint. + +This methods base on the fact that the accumulated attention scores of all tokens in attention blocks adhere to a power-law distribution. It suggests that there exists a small set of influential tokens that are critical during generation, named heavy-hitters (H2). H2 provides an opportunity to step away from the combinatorial search problem and identify an eviction policy that maintains accuracy. + +H2O can dynamically retains the balance of recent and H2 tokens. Significantly increase model throughput while ensuring accuracy. + + +For more info, please refer to the paper [H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models](https://arxiv.org/pdf/2306.14048). + + +![](./imgs/1.png) + + +## Usage and Examples +### Evaluation on tasks from [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) framework +Using simulation mode +```bash +python run_generation.py \ + --model meta-llama/Meta-Llama-3-8B \ + --accuracy \ + --batch_size 16 \ + --h2o \ + --heavy_ratio 0.1 \ + --recent_ratio 0.1 \ + --device 0 +``` +To run the real_drop mode +```bash +python run_generation.py \ + --model meta-llama/Meta-Llama-3-8B \ + --accuracy \ + --batch_size 16 \ + --h2o \ + --heavy_ratio 0.1 \ + --recent_ratio 0.1 \ + --device 0 + --real_drop +``` +Get the accuracy of dense model +```bash +python run_generation.py \ + --model meta-llama/Meta-Llama-3-8B \ + --accuracy \ + --batch_size 16 +``` \ No newline at end of file diff --git a/examples/huggingface/pytorch/text-generation/h2o/imgs/1.png b/examples/huggingface/pytorch/text-generation/h2o/imgs/1.png new file mode 100644 index 00000000000..3cd5c8ff156 Binary files /dev/null and b/examples/huggingface/pytorch/text-generation/h2o/imgs/1.png differ diff --git a/examples/huggingface/pytorch/text-generation/h2o/run_generation.py b/examples/huggingface/pytorch/text-generation/h2o/run_generation.py new file mode 100644 index 00000000000..225275d6323 --- /dev/null +++ b/examples/huggingface/pytorch/text-generation/h2o/run_generation.py @@ -0,0 +1,238 @@ +import argparse +import sys +import time +import json +import torch +from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM +from transformers.utils import check_min_version + +parser = argparse.ArgumentParser() +parser.add_argument("--model", default=None) +parser.add_argument( + "--dataset", nargs="?", default="NeelNanda/pile-10k", const="NeelNanda/pile-10k" +) +parser.add_argument( + "--max_new_tokens", default=32, type=int, help="output max new tokens" +) +parser.add_argument("--output_dir", nargs="?", default="./saved_results") +parser.add_argument("--int8", action="store_true") +parser.add_argument( + "--int8_bf16_mixed", + action="store_true", + help="by default it is int8-fp32 mixed, to enable int8 mixed amp bf16 (work on platforms like SPR)", +) +parser.add_argument( + "--restore", + action="store_true", + help="restore ipex quantized model from output_dir/best_configure.json", +) +parser.add_argument( + "--peft_model_id", type=str, default=None, help="model_name_or_path of peft model" +) +parser.add_argument("--_commit_hash", default=None, type=str) +parser.add_argument("--trust_remote_code", action="store_true") +parser.add_argument("--use_neural_speed", action="store_true") +# ============Benchmark configs============== +parser.add_argument("--benchmark", action="store_true") +parser.add_argument("--iters", default=100, type=int, help="num iter") +parser.add_argument("--num_warmup", default=10, type=int, help="num warmup") +# ============Accuracy configs============== +parser.add_argument("--accuracy", action="store_true") +parser.add_argument("--batch_size", default=16, type=int, help="batch size num.") +parser.add_argument( + "--save_accuracy_path", default=None, help="Save accuracy results path." +) +parser.add_argument("--output_excel", default=None, type=str) +parser.add_argument("--eval_bs", default=4, type=int, + help="eval batch size") +parser.add_argument("--tasks", nargs='+', default=["winogrande", "copa", "piqa", "rte", "hellaswag", \ + "openbookqa", "lambada_openai", "lambada_standard", "wikitext"], type=str, \ + help="tasks list for accuracy validation") +parser.add_argument("--num_fewshot", default=0, type=int, help="num few shot.") +# ============MixedPrecision configs============== +parser.add_argument("--mixed_precision", action="store_true") + +# ============h2o configs============== +parser.add_argument('--h2o', action='store_true') +parser.add_argument('--is_gen', action='store_true') +parser.add_argument('--real_drop', action='store_true') +parser.add_argument("--heavy_ratio", type=float, default=0.1) +parser.add_argument("--recent_ratio", type=float, default=0.1) +parser.add_argument("--device", type=str, default='cpu') +parser.add_argument("--h2o_min_seqlen", type=int, default=0) + +args = parser.parse_args() +# transformers version >= 4.32.0 contained the mpt modeling definition. +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py +# 4.31.0 for ipex.optimize_transformers +# get model config +if args.peft_model_id: + from peft import PeftConfig + + peft_config = PeftConfig.from_pretrained(args.peft_model_id) + if args.model is None: + args.model = peft_config.base_model_name_or_path + print("we will use peft base_model_name_or_path to get tokenizer.") + +config = AutoConfig.from_pretrained( + args.model, + torchscript=False, + use_cache=True, # to use kv cache. + trust_remote_code=args.trust_remote_code, + _commit_hash=args._commit_hash, +) + +# chatglm +if config.model_type == "chatglm": + AutoModelForCausalLM = AutoModel +# tokenizer +if config.model_type == "llama": + from transformers import LlamaTokenizer + + # tokenizer = LlamaTokenizer.from_pretrained(args.model) + tokenizer = AutoTokenizer.from_pretrained(args.model) +else: + tokenizer = AutoTokenizer.from_pretrained( + args.model, trust_remote_code=args.trust_remote_code + ) + +# use peft +args.model = args.peft_model_id if args.peft_model_id is not None else args.model + +# Generation +if args.use_neural_speed: + generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=1) +else: + generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4) + +if 'cpu' in args.device: + device = args.device +else: + device = f"cuda:{args.device}" + +# get optimized model +if args.h2o: + print('Enable Small Cache Size') + from intel_extension_for_transformers.transformers.kv_cache_compression import H2OConfig, LlamaForCausalLM + h2o_config = H2OConfig( + heavy_ratio=args.heavy_ratio, + recent_ratio=args.recent_ratio, + h2o_min_seqlen=args.h2o_min_seqlen, + real_drop=args.real_drop, + mean=False, + ) + user_model = LlamaForCausalLM.from_pretrained( + args.model, + prune_config=h2o_config, + trust_remote_code=args.trust_remote_code) + print("converted model: ", user_model) +else: + user_model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) +user_model.to(device) + +# save model +# if args.output_dir is not None: +# tokenizer.save_pretrained(args.output_dir) +# user_model.save_pretrained(args.output_dir) + +if args.benchmark: + user_model = ( + user_model.eval() if (not (args.int8 or args.int8_bf16_mixed) and hasattr(user_model, "eval")) else user_model + ) + prompt = "Once upon a time, there existed a little girl, who liked to have adventures. She wanted to go to places and meet new people, and have fun." + input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1) + print("---- Prompt size:", input_size) + + # start + total_time = 0.0 + num_iter = args.iters + num_warmup = args.num_warmup + total_token_num = 0 + eos_token_id = tokenizer.eos_token_id + with torch.inference_mode(), torch.no_grad(): + for i in range(num_iter): + tic = time.time() + if hasattr(tokenizer, "build_chat_input"): + input_ids = tokenizer.build_chat_input(prompt)["input_ids"] + input_ids = input_ids.repeat(args.batch_size, 1) + eos_token_id = [ + tokenizer.eos_token_id, + tokenizer.get_command("<|user|>"), + tokenizer.get_command("<|observation|>"), + ] + elif hasattr(tokenizer, "build_prompt"): + build_prompt = tokenizer.build_prompt(prompt) + input_ids = tokenizer( + [build_prompt] * args.batch_size, return_tensors="pt" + ).input_ids + else: + input_ids = tokenizer( + [prompt] * args.batch_size, return_tensors="pt" + ).input_ids + gen_ids = user_model.generate( + input_ids, + max_new_tokens=args.max_new_tokens, + **generate_kwargs, + eos_token_id=eos_token_id + ) + gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True) + toc = time.time() + # please check the gen_ids if include input_ids. + input_tokens_num = input_ids.numel() + output_tokens_num = torch.tensor(gen_ids).numel() - input_tokens_num + print(gen_text, flush=True) + if i >= num_warmup: + total_time += toc - tic + total_token_num += output_tokens_num + + print("\n", "-" * 10, "Summary:", "-" * 10) + latency = total_time / total_token_num + print("Inference latency: %.3f sec." % latency) + throughput = total_token_num / total_time + print("Throughput: {} samples/sec".format(throughput)) + +if args.accuracy: + user_model = (user_model.eval() if (not (args.int8 or args.int8_bf16_mixed) and hasattr(user_model, "eval")) \ + else user_model) + # from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser + # model_args="pretrained="+args.model+",trust_remote_code="+str(args.trust_remote_code) + # args.tasks = ",".join(args.tasks) + # tokenizer.pad_token = tokenizer.eos_token + # eval_args = LMEvalParser(model = "hf", + # user_model=user_model, + # tokenizer=tokenizer, + # model_args=model_args, + # tasks = args.tasks, + # device = device, + # num_fewshot=args.num_fewshot, + # output_path=args.save_accuracy_path, + # batch_size = args.batch_size) + # print("using device:", device) + # results = evaluate(eval_args) + + + # original lm_eval + from lm_eval.evaluator import simple_evaluate + from lm_eval.tasks import TaskManager + import lm_eval + + verbosity = 'INFO' + task_manager = TaskManager(verbosity) + limit = None + cache_requests = False + lm = lm_eval.api.registry.get_model("hf")( + pretrained=user_model, + batch_size=args.batch_size, + max_batch_size=None, + ) + model_args="pretrained="+ args.model+ ",tokenizer="+ args.model + ",dtype=float32" + use_cache = None + results = simple_evaluate( + model=lm, + model_args=model_args, + tasks=args.tasks, + num_fewshot=args.num_fewshot, + device=device + ) + import pprint + pprint.pprint(results["results"]) diff --git a/examples/huggingface/pytorch/text-generation/h2o/run_summarization.py b/examples/huggingface/pytorch/text-generation/h2o/run_summarization.py new file mode 100644 index 00000000000..ba89c5a73a6 --- /dev/null +++ b/examples/huggingface/pytorch/text-generation/h2o/run_summarization.py @@ -0,0 +1,169 @@ +import argparse +import json +import os.path +import sys +sys.path.insert(0, '/home/hengguo/code/intel-extension-for-transformers') + +import tqdm +import torch + +from rouge import Rouge +import numpy as np + +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +from transformers.models.llama.configuration_llama import LlamaConfig + +from intel_extension_for_transformers.transformers.modeling.kv_cache_compression.h2o import convert_model + + + +os.environ['CUDA_LAUNCH_BLOCKING'] = "1" + +MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop + +def set_seed(args): + np.random.seed(args.seed) + torch.manual_seed(args.seed) + +def main(): + + parser = argparse.ArgumentParser() + + parser.add_argument("--input_path", type=str, default="") + parser.add_argument("--output_path", type=str, default="") + + parser.add_argument("--model_name", type=str, default="") + parser.add_argument("--cache_dir", type=str, default=None) + + parser.add_argument('--h2o', action='store_true') + parser.add_argument("--heavy_ratio", type=float, default=0.1) + parser.add_argument("--recent_ratio", type=float, default=0.1) + parser.add_argument("--h2o_min_seqlen", type=int, default=0) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument('--mean', action='store_true') + + + parser.add_argument("--sample_num", type=int, default=100) + parser.add_argument("--k", type=int, default=0) + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument( + "--fp16", + action="store_true", + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", + ) + args = parser.parse_args() + + try: + device_str = int(args.device) + device_str = f'cuda:{args.device}' + except: + device_str = args.device + set_seed(args) + + model_name = args.model_name + input_path = args.input_path + output_path = args.output_path + + config = AutoConfig.from_pretrained(model_name, cache_dir=args.cache_dir) + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, cache_dir=args.cache_dir) + model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=args.cache_dir) + + if args.batch_size>1: + tokenizer.pad_token = tokenizer.eos_token + + if args.h2o: + print('Enabling H2O KV cache') + model = convert_model( + model, + heavy_ratio=args.heavy_ratio, + recent_ratio=args.recent_ratio, + h2o_min_seqlen=args.h2o_min_seqlen, + real_drop=True, + is_gen=True, + mean=args.mean) + model.clean_cache() + + model = model.half().eval().to(device_str) + + requests = [] + with open(input_path, 'r') as f: + for line in f: + if line.strip() != '': + requests.append(json.loads(line)) + + print(len(requests)) + if args.sample_num < len(requests): + print('Sample {} Examples from {} samples'.format(args.sample_num, len(requests))) + requests = requests[:args.sample_num] + + results = [] + rouge = Rouge() + rouge1_score_list = [] + rouge2_score_list = [] + rougel_score_list = [] + + with torch.no_grad(): + for request in tqdm.tqdm(requests): + result = {'request': request, 'result': {}} + prompt = request['article'] + label = request['summary_gt'] + temperature = request['temperature'] + stop = request['stop'] + + input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').input_ids.to(model.device) + + output_sequences = model.generate( + input_ids=input_ids, + max_length=request['max_tokens'] + len(input_ids[0]), + temperature=temperature, + top_k=args.k, + top_p=request['top_p'], + do_sample=True, + num_return_sequences=request['n'], + return_dict_in_generate=True, output_scores=True, + pad_token_id=tokenizer.eos_token_id + ) + + if args.h2o: + model.clean_cache() + + tokens = tokenizer.convert_ids_to_tokens(output_sequences['sequences'].squeeze(0))[len(input_ids[0]):] + logprobs = [logits.log_softmax(dim=-1).max().item() for logits in output_sequences['scores']] + top_logprobs = [{i: v for i, v in zip(tokens, logprobs)}] + + generate_text = tokenizer.decode(output_sequences['sequences'].squeeze(0)[len(input_ids[0]):]) + generate_text = generate_text[: generate_text.find(stop[0])] + + scores = rouge.get_scores(generate_text, label)[0] + rouge1_score_list.append(scores['rouge-1']['f']) + rouge2_score_list.append(scores['rouge-2']['f']) + rougel_score_list.append(scores['rouge-l']['f']) + + result['result'] = { + "choices": [ + { + "text": generate_text, + "logprobs": { + "tokens": tokens, + "token_logprobs": logprobs, + "top_logprobs": top_logprobs, + "text_offset": [] + }, + "finish_reason": "length" + } + ], + "request_time": { + "batch_time": 0, + "batch_size": 1} + } + + results.append(result) + print('rouge-1: {:.6f}, rouge-2: {:.6f}, rouge-l: {:.6f}, prompt length: {}, generate text length: {}'.format(np.mean(rouge1_score_list), np.mean(rouge2_score_list), np.mean(rougel_score_list), input_ids.size(-1), output_sequences['sequences'].size(-1))) + + with open(output_path, 'w') as f: + for result in results: + f.write(json.dumps(result) + '\n') + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/intel_extension_for_transformers/transformers/kv_cache_compression/__init__.py b/intel_extension_for_transformers/transformers/kv_cache_compression/__init__.py new file mode 100644 index 00000000000..f73db323734 --- /dev/null +++ b/intel_extension_for_transformers/transformers/kv_cache_compression/__init__.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .prune.h2o import H2OConfig, H2OKVPruner +from .models.modeling_llama import LlamaForCausalLM +from intel_extension_for_transformers.transformers.utils.utility import LazyImport + +GaudiLlamaForCausalLM = LazyImport(".models.modeling_gaudi_llama.GaudiLlamaForCausalLM") diff --git a/intel_extension_for_transformers/transformers/kv_cache_compression/models/__init__.py b/intel_extension_for_transformers/transformers/kv_cache_compression/models/__init__.py new file mode 100644 index 00000000000..28f108cb636 --- /dev/null +++ b/intel_extension_for_transformers/transformers/kv_cache_compression/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_gaudi_llama.py b/intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_gaudi_llama.py new file mode 100644 index 00000000000..bc8fdde5b95 --- /dev/null +++ b/intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_gaudi_llama.py @@ -0,0 +1,1249 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.cache_utils import Cache, DynamicCache, StaticCache # pylint: disable=E0611 +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + repeat_kv, + LlamaMLP, + LlamaModel, + LlamaRMSNorm, + LlamaPreTrainedModel, + apply_rotary_pos_emb, + logger, +) + +from ...modeling.modeling_gaudi.models.modeling_attn_mask_utils import( + _gaudi_prepare_4d_causal_attention_mask, +) + + +try: + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE + + has_fused_rope = True +except ImportError: + has_fused_rope = False + print("Not using HPU fused kernel for apply_rotary_pos_emb") + +try: + from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm + + has_fused_rms_norm = True +except ImportError: + has_fused_rms_norm = False + print("Not using HPU fused kernel for RMSNorm") + +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + +import habana_frameworks.torch.core as htcore # pylint: disable=E0401 + +from ..prune import PruneConfig, H2OConfig + + +def gaudi_llama_rmsnorm_forward(self, hidden_states): + """ + + The only differences are: + - override RMSNorm with Habana fused RMSNorm + """ + if hidden_states.device.type == "hpu" and has_fused_rms_norm: + # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype + if hidden_states.dtype != self.weight.dtype: + orig_dtype = hidden_states.dtype + hidden_states = FusedRMSNorm.apply( + hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon + ) + return hidden_states.to(orig_dtype) + else: + hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) + return hidden_states + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class GaudiLlamaMLP(LlamaMLP): + def pre_mlp_forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) \ + for i in range(self.config.pretraining_tp)], dim=-1) # pylint: disable=E1102 + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) \ + for i in range(self.config.pretraining_tp)], dim=-1) # pylint: disable=E1102 + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) \ + for i in range(self.config.pretraining_tp)] # pylint: disable=E1102 + output = sum(down_proj) + else: + input = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + output = self.down_proj(input) + return output + + def mlp_all_reduce(self, x): + if hasattr(self.down_proj, "all_reduce"): + self.down_proj.all_reduce(x) + + def post_mlp_forward(self, x): + if self.config.pretraining_tp > 1: + return x + if hasattr(self.down_proj, "post_all_reduce"): + return self.down_proj.post_all_reduce(x) + return x + + +def gaudi_llama_repeat_kv( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, +): + """ + + The only differences are: + - Append num_key_value_heads == 1 check as kv states can be broadcasted during + matmuls so need to expand and reshape them. + - Add new args query_states, key_states, value_states and attention_mask and + update the logic for expansion. + The query states go from (batch, num_heads, seqlen, head_dim) to + (batch, num_key_value_heads, n_rep, seqlen, head_dim) + The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to + (batch, num_key_value_heads, 1, seqlen, head_dim) + """ + batch, num_key_value_heads, kv_len, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask + + new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) + key_states = key_states.reshape(new_kv_shape) + value_states = value_states.reshape(new_kv_shape) + + batch, _, q_len, head_dim = query_states.shape + new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) + query_states = query_states.reshape(new_q_shape) + + if attention_mask is not None: + # Add groups dim and set to 1 + attention_mask = attention_mask.unsqueeze(1) + + return query_states, key_states, value_states, attention_mask + + +class Matmul(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + + +class GaudiLlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self._cos_cached[:seq_len].to(dtype=x.dtype), + self._sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class GaudiLlamaLinearScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding): + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) + + +class GaudiLlamaDynamicNTKScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding): + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) + + +class GaudiLlamaAttention(LlamaAttention): + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + + self.matmul_qk = Matmul() + self.matmul_av = Matmul() + self.k_cache = KVCache() + self.v_cache = KVCache() + self.inp_seq_len = -1 + self.norm_factor = 1.0 / math.sqrt(self.head_dim) + + self._init_func = [] + + def register_init_func(self, func): + self._init_func.append(func) + + def post_init(self): + for func in self._init_func: + func(self) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) + + def update_sincos_cache(self, seq_len): + # Call rotary emb forward() to update cos/sin cache when inferring more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: # pylint: disable=E0203 + self.max_position_embeddings = seq_len + _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + + def reorder(self, tensor, beam_idx, dim_a, dim_b): + updated = tensor.index_select(0, beam_idx) + tensor.copy_(updated) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + if self.k_cache.cache is None: + return (None, None) + + head_dim = self.k_cache.cache.size(-1) + seq_length = self.k_cache.cache.size(-2) + self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim) + self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) + return (self.k_cache.cache.shape, self.v_cache.cache.shape) + + def pre_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + cache_prune_num: int = 0, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + + The only differences are: + - add new args token_idx + - optimize KV cache + - add new args attn_softmax_bf16 + - add new args reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask + - add new arg cache_prune_num for attention_sinks + """ + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) \ + for i in range(self.config.pretraining_tp)] # pylint: disable=E1102 + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) \ + for i in range(self.config.pretraining_tp)] # pylint: disable=E1102 + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) \ + for i in range(self.config.pretraining_tp)] # pylint: disable=E1102 + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if token_idx is None: + if hasattr(past_key_value, "get_usable_length"): + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value[0].shape[-2] + else: + if reuse_cache: + kv_seq_len = past_key_value[0][-2] + else: + kv_seq_len = past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + + if use_cache: + # reuse k, v, self_attention + if reuse_cache: + # key_states = self.k_cache(key_states, 2, token_idx) + # value_states = self.v_cache(value_states, 2, token_idx) + # past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) + + # pruning kv cache + if self.pruner.real_drop: + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + if self.layer_idx == 0: + self.pruner.past_length += query_states.size(-2) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + new_key_states, new_value_states = self.pruner.prune( + self, + query_states, + key_states, + value_states, + causal_mask=causal_mask + ) + new_key_states = self.pruner.remove_repeat_kv(new_key_states, self.num_key_value_groups) + new_value_states = self.pruner.remove_repeat_kv(new_value_states, self.num_key_value_groups) + key_states = self.k_cache(new_key_states, 2, token_idx) + value_states = self.v_cache(new_value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if token_idx is None: + past_key_value = (key_states, value_states) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] + else: + past_key_value = None + + if use_flash_attention and FusedSDPA: + import habana_frameworks.torch.hpu as ht # pylint: disable=E0401 + + if q_len == 1: + # next token + with ht.sdp_kernel(enable_recompute=False): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + # first token + if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same length + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + + else: + query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) + + attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor + + if not self.pruner.real_drop: + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + mask = self.pruner.get_mask(self, query_states, key_states, value_states, + causal_mask=causal_mask) + attn_weights[~mask] = torch.finfo(attn_weights.dtype).min + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = self.matmul_av(attn_weights, value_states) + attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def attention_all_reduce(self, attn_output): + if hasattr(self.o_proj, "all_reduce"): + self.o_proj.all_reduce(attn_output) + + def post_attn_forward(self, attn_output): + if hasattr(self.o_proj, "post_all_reduce"): + self.o_proj.post_all_reduce(attn_output) + return attn_output + + +class GaudiLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig, layer_idx: int): + super(GaudiLlamaDecoderLayer, self).__init__(config, layer_idx) + self.hidden_size = config.hidden_size + + self.self_attn = GaudiLlamaAttention(config=config, layer_idx=layer_idx) + + self.mlp = GaudiLlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.self_attn.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.self_attn.update_sincos_cache(seq_len) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + cache_prune_num: int = 0, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + + The only differences are: + - add new args token_idx + - add new args attn_softmax_bf16 + - add new args reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask + - add new arg cache_prune_num for attention_sinks + """ + residual = hidden_states + hidden_states, self_attn_weights, present_key_value = self.pre_attn( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + token_idx, + attn_softmax_bf16, + reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, + cache_prune_num = cache_prune_num, + **kwargs, + ) + self.self_attn.attention_all_reduce(hidden_states) + hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual) + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.post_mlp(hidden_states, residual) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + + return outputs + + def pre_attn( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + cache_prune_num: int = 0, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + hidden_states = self.input_layernorm(hidden_states) + hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + token_idx, + attn_softmax_bf16, + reuse_cache, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, + cache_idx=cache_idx, + cache_prune_num = cache_prune_num, + ) + return hidden_states, attn_weights, present_key_value + + def post_attn_pre_mlp(self, hidden_states, residual): + hidden_states = self.self_attn.post_attn_forward(hidden_states) + + if self.training: + hidden_states = hidden_states + residual + residual = hidden_states + else: + residual.add_(hidden_states) + hidden_states = residual + + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp.pre_mlp_forward(hidden_states) + return hidden_states, residual + + def post_mlp(self, hidden_states, residual): + hidden_states = self.mlp.post_mlp_forward(hidden_states) + + if self.training: + hidden_states = hidden_states + residual + else: + residual.add_(hidden_states) + hidden_states = residual + + return hidden_states + + +class GaudiLlamaModel(LlamaModel): + """""" + + def __init__(self, config: LlamaConfig): + """ + + 1. set fill_value to 1 instead of True + 2. add device=self.device + """ + super(GaudiLlamaModel, self).__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = torch.nn.ModuleList( + [GaudiLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class. + causal_mask = torch.full( + (config.max_position_embeddings, config.max_position_embeddings), + fill_value=1, + dtype=torch.bool, + ) + self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) + # Initialize weights and apply final processing + self.post_init() + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + for layer in self.layers: + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) + + def update_sincos_cache(self, seq_len): + for layer in self.layers: + layer.update_sincos_cache(seq_len) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + lazy_mode: Optional[bool] = True, + cache_prune_num: int = 0, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + + The only differences are: + - add new args token_idx + - add new args attn_softmax_bf16 + - add new args reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask + - add new arg lazy_mode + - add new arg cache_prune_num for attention_sinks + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + ignore_cache_position = True # Ignoring cache position for HPU + use_new_cache = False # Ignoring new Cache path for HPU + past_seen_tokens = 0 + + if past_key_values is not None and use_cache: # kept for BC (cache positions) + if reuse_cache: + past_seen_tokens = past_key_values[0][0][2] + else: + if use_new_cache: + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + else: + past_seen_tokens = past_key_values[0][0].shape[2] + + if ignore_cache_position is False: + if cache_position is None: + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None and cache_position: + position_ids = cache_position.unsqueeze(0) + + else: + if position_ids is None: + position_ids = torch.arange( + past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device + ) + position_ids = position_ids.unsqueeze(0) + cache_position = None + + # HPU specific mask generation + if ignore_cache_position: + # workaround for attention_sinks attention_mask which has fixed seq_len at dim -1 + if hasattr(self, "attention_sink_size") and hasattr(self, "attention_sink_window_size"): + past_seen_tokens = self.attention_sink_size + self.attention_sink_window_size - seq_length + causal_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + input_ids.shape if input_ids is not None else (batch_size, seq_length), + inputs_embeds, + past_seen_tokens, + ) + else: + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if not use_new_cache else None + + if lazy_mode: + htcore.mark_step() + + for layer_idx, decoder_layer in enumerate(self.layers): + if ( + lazy_mode + and not self.training + and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) + ): + htcore.mark_step() + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + None, + attn_softmax_bf16, + False, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, + cache_prune_num = cache_prune_num, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class GaudiLlamaForCausalLM(LlamaPreTrainedModel): + """ + + The only differences are: + - add new args token_idx + - add token_idx into model_inputs + - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx + - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx + - add new args attn_softmax_bf16 + - add new args reuse_cache + - add new arg cache_prune_num for attention_sinks + """ + def __init__( + self, + config: LlamaConfig, + prune_config: PruneConfig, + ): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + if isinstance(prune_config, H2OConfig): + from ..prune import H2OKVPruner + self.pruner = H2OKVPruner(prune_config) + else: + from ..prune import KVPruner + self.pruner = KVPruner(prune_config) + + num_layers = len(self.model.layers) + for layer_idx in range(num_layers): + module = self.model.layers[layer_idx].self_attn + self.model.layers[layer_idx].self_attn = GaudiLlamaAttention( + config, + layer_idx + ) + self.model.layers[layer_idx].self_attn.register_init_func(self.pruner.self_attn_init) + self.model.layers[layer_idx].self_attn.post_init() + + self.model.layers[layer_idx].self_attn.pruner = self.pruner + + # Initialize weights and apply final processing + self.post_init() + + def _generate(*args, **kwargs): + self.pruner.before_generate(self, *args, **kwargs) + result = self.ori_generate(*args, **kwargs) + self.pruner.after_generate(self,*args, **kwargs) + return result + + self.ori_generate = self.generate + self.generate = _generate + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.model.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.model.update_sincos_cache(seq_len) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + trim_logits: Optional[bool] = False, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + lazy_mode: Optional[bool] = True, + cache_prune_num: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.generation_config.use_fused_rope is False: + global has_fused_rope + has_fused_rope = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, + lazy_mode=lazy_mode, + cache_prune_num=cache_prune_num, + ) + hidden_states = outputs[0] + _, seq_len, _ = hidden_states.shape + if seq_len > 1 and trim_logits and not self.training: + if token_idx is not None: + hidden_states = hidden_states.index_select(1, token_idx - 1) + else: + hidden_states = hidden_states[:, -1, :] + + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) \ + for i in range(self.config.pretraining_tp)] # pylint: disable=E1102 + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = torch.nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_idx=None, **kwargs + ): + past_length = 0 + if not hasattr(self, "kv_past_token_length"): + self.kv_past_token_length = 0 + using_attention_sinks = (hasattr(self, "attention_sink_size") and + hasattr(self, "attention_sink_window_size")) + + reuse_cache = kwargs.get("reuse_cache") + if past_key_values is not None: + if token_idx is not None: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + else: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, + # then we are in a setting where + # some of the inputs are exclusively passed as part of the cache + # (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', + # then input_ids holds all input tokens. We can discard input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), + # let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, + # we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] # pylint: disable=E1130 + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token + # we can slice the inputs till token idx for the fwd pass + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] + + # prepare postion_ids and attention_mask for attention_sinks + cache_prune_num = 0 + kv_cache_len = kwargs.get("kv_cache_len", None) + position_ids = kwargs.get("position_ids", None) + q_len = input_ids.shape[-1] + if using_attention_sinks: + assert (kv_cache_len and kv_cache_len == self.attention_sink_size + self.attention_sink_window_size) + self.kv_past_token_length = min(self.kv_past_token_length, kv_cache_len) + position_ids = torch.arange(self.kv_past_token_length, + self.kv_past_token_length + q_len, + device=input_ids.device) + attn_sink_mask = torch.ones((q_len, kv_cache_len), device=input_ids.device) + if self.kv_past_token_length < kv_cache_len: + attn_sink_mask[:, self.kv_past_token_length:] = 0 + mask = torch.zeros((q_len, q_len), device=input_ids.device) + mask_cond = torch.arange(mask.size(-1), device=mask.device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 1) + if self.kv_past_token_length + q_len > kv_cache_len: + cache_prune_num = (self.kv_past_token_length + q_len) - kv_cache_len + position_ids = position_ids - cache_prune_num + attn_sink_mask.index_copy_(-1, position_ids, mask) + attention_mask= attn_sink_mask[None, None, :, :].expand(input_ids.shape[0], 1, q_len, kv_cache_len) + position_ids = position_ids.unsqueeze(0) + self.kv_past_token_length += q_len + + if attention_mask is not None and position_ids is None and not using_attention_sinks: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.index_select(position_ids, 1, token_idx - 1) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + # TODO: we are using token_idx, disable this for now + # if self.generation_config.cache_implementation == "static": + # generation with static cache + # cache_position = kwargs.get("cache_position", None) + # if cache_position is None: + # past_length = 0 + # else: + # past_length = cache_position[-1] + 1 + # input_ids = input_ids[:, past_length:] + # position_ids = position_ids[:, past_length:] + + # TODO @gante we should only keep a `cache_position` in generate, and do +=1. + # same goes for position ids. Could also help with continued generation. + # cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) + # keep cache_position implementation as None for HPU + cache_position = None + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "position_ids": position_ids.contiguous(), + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "token_idx": token_idx, + "trim_logits": kwargs.get("trim_logits"), + "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), + "reuse_cache": reuse_cache, + "use_flash_attention": kwargs.get("use_flash_attention"), + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), + "cache_idx": kwargs.get("cache_idx"), + "lazy_mode": kwargs.get("lazy_mode"), + "cache_prune_num": cache_prune_num, + } + ) + return model_inputs + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + # Separate Attention Sink kwargs from regular kwargs + attention_sink_kwargs = {key: value for key, value in kwargs.items() if key.startswith("attention_sink")} + for key in attention_sink_kwargs: + v = kwargs.pop(key) + assert isinstance(v, int) + + model = super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + **kwargs, + ) + + if len(attention_sink_kwargs) > 0: + from intel_extension_for_transformers.transformers.modeling.modeling_gaudi.streaming_llm \ + import enable_streaming_llm + + enable_streaming_llm(model, **attention_sink_kwargs) + model.attention_sink_size = attention_sink_kwargs.get("attention_sink_size") + model.attention_sink_window_size = attention_sink_kwargs.get("attention_sink_window_size") + model.model.attention_sink_size = attention_sink_kwargs.get("attention_sink_size") + model.model.attention_sink_window_size = attention_sink_kwargs.get("attention_sink_window_size") + + return model + + +def apply_customized_rope(q, k, cos, sin, position_ids): + if q.device.type == "hpu" and has_fused_rope: + # TODO: remove `.clone()` when it is fixed in SynapseAI + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ) + else: + + return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids]) # pylint: disable=E1120 diff --git a/intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_llama.py b/intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_llama.py new file mode 100644 index 00000000000..91b56a031ec --- /dev/null +++ b/intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_llama.py @@ -0,0 +1,974 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch llama model.""" + +import math +from typing import List, Optional, Tuple, Union +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +from transformers.cache_utils import Cache, StaticCache # pylint: disable=E0611 +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.utils import ( + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, + ) +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + rotate_half, + repeat_kv, + _get_unpad_data, + LlamaRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaPreTrainedModel, + LlamaModel, + LLAMA_INPUTS_DOCSTRING +) +from transformers.modeling_outputs import ( + CausalLMOutputWithPast, +) + +from ..prune import PruneConfig, H2OConfig + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "LlamaConfig" + +from packaging import version +import transformers +if version.parse(transformers.__version__) > version.parse("4.33.0"): + from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available + if is_flash_attn_2_available(): + from flash_attn import ( # pylint: disable=E0401 + flash_attn_func, + flash_attn_varlen_func) # pylint: disable=E1101 + from flash_attn.bert_padding import ( # pylint: disable=E0401 + index_first_axis, + pad_input, + unpad_input) # pylint: disable=E1101 + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper.""" + + def __init__( + self, + config: LlamaConfig, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if self.layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + self._init_func = [] + + def register_init_func(self, func): + self._init_func.append(func) + + def post_init(self): + for func in self._init_func: + func(self) + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, # transformers.cache_utils.DynamicCache + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + past_key_value = getattr(self, "past_key_value", past_key_value) + try: # pylint: disable=E1120 + cos, sin = self.rotary_emb(value_states, position_ids) + except: # for old version + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + position_length = kv_seq_len + if not position_ids.nelement() > 1: + if position_length < position_ids.item()+1: + position_length = position_ids.item()+1 + cos, sin = self.rotary_emb(value_states, seq_len=position_length) # pylint: disable=E1120 + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + causal_mask = None + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # pruning kv cache + if self.pruner.real_drop: + if self.layer_idx == 0: + self.pruner.past_length += query_states.size(-2) + if past_key_value is not None: + new_key_states, new_value_states = self.pruner.prune( + self, + query_states, + key_states, + value_states, + causal_mask=causal_mask + ) + # reshape kv cache + if self.num_key_value_groups > 1: + n_rep = self.num_key_value_groups + drop_mask = torch.tensor( + [True if i % n_rep == 0 else False for i in range(0, new_key_states.size(1))] + ).repeat(new_key_states.size(0), 1).to(new_key_states.device) + new_key_states = new_key_states[drop_mask].view( + new_key_states.shape[0], + int(new_key_states.shape[1] / n_rep), + -1, + new_key_states.shape[-1]) + new_value_states = new_value_states[drop_mask].view( + new_value_states.shape[0], + int(new_value_states.shape[1] / n_rep), + -1, + new_value_states.shape[-1]) + + past_key_value.key_cache[self.layer_idx] = new_key_states + past_key_value.value_cache[self.layer_idx] = new_value_states + else: # similuate pruning to calculate acc + mask = self.pruner.get_mask(self, query_states, key_states, value_states, + causal_mask=causal_mask) + attn_weights[~mask] = torch.finfo(attn_weights.dtype).min + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + +class LlamaFlashAttention2(LlamaAttention): + """Llama flash attention module. + + This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, + # that was made default for flash_attn>=2.1. This attribute is used to handle this difference. + # Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, + # using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + + # h2o + # pruning kv cache + if self.layer_idx == 0: + self.pruner.past_length += query_states.size(-2) + if past_key_value is not None: + new_key_states, new_value_states = self.pruner.prune( + self, + query_states, + key_states, + value_states, + causal_mask=causal_mask + ) + + past_key_value.key_cache[self.layer_idx] = new_key_states + past_key_value.value_cache[self.layer_idx] = new_value_states + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + +class LlamaSdpaAttention(LlamaAttention): + """Llama attention module using torch.nn.functional.scaled_dot_product_attention. + + This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` " + "does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards.' + '"This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + try: + cos, sin = self.rotary_emb(value_states, position_ids) # pylint: disable=E1120 + except: + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + position_length = kv_seq_len + if not position_ids.nelement() > 1: + if position_length < position_ids.item()+1: + position_length = position_ids.item()+1 + cos, sin = self.rotary_emb(value_states, seq_len=position_length) # pylint: disable=E1120 + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # In case static cache is used, it is an instance attribute. + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # pruning kv cache + if self.layer_idx == 0: + self.pruner.past_length += query_states.size(-2) + if past_key_value is not None: + new_key_states, new_value_states = self.pruner.prune( + self, + query_states, + key_states, + value_states, + causal_mask=causal_mask + ) + # reshape kv cache + if self.num_key_value_groups > 1: + n_rep = self.num_key_value_groups + drop_mask = torch.tensor( + [True if i % n_rep == 0 else False for i in range(0, new_key_states.size(1))] + ).repeat(new_key_states.size(0), 1).to(new_key_states.device) + new_key_states = new_key_states[drop_mask].view( + new_key_states.shape[0], + int(new_key_states.shape[1] / n_rep), + -1, + new_key_states.shape[-1]) + new_value_states = new_value_states[drop_mask].view( + new_value_states.shape[0], + int(new_value_states.shape[1] / n_rep), + -1, + new_value_states.shape[-1]) + + past_key_value.key_cache[self.layer_idx] = new_key_states + past_key_value.value_cache[self.layer_idx] = new_value_states + + # In case we are not compiling, we may set `causal_mask` to None, + # which is required to dispatch to SDPA's Flash Attention 2 backend, rather + # relying on the `is_causal` argument. + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + # if self.layer_idx == 0: + # print(attn_output.shape) + + return attn_output, None, past_key_value + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__( + self, + config: LlamaConfig, + prune_config: PruneConfig, + ): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + if isinstance(prune_config, H2OConfig): + from ..prune import H2OKVPruner + self.pruner = H2OKVPruner(prune_config) + else: + from ..prune import KVPruner + self.pruner = KVPruner(prune_config) + + num_layers = len(self.model.layers) + for layer_idx in range(num_layers): + module = self.model.layers[layer_idx].self_attn + cls_name = module.__class__.__name__ + if not prune_config.real_drop: + cls = LlamaAttention + elif cls_name == "LlamaFlashAttention2": + cls = LlamaFlashAttention2 + elif cls_name == "LlamaSdpaAttention": + cls = LlamaSdpaAttention + else: + cls = LlamaAttention + self.model.layers[layer_idx].self_attn = cls( + config, + layer_idx + ) + self.model.layers[layer_idx].self_attn.register_init_func(self.pruner.self_attn_init) + self.model.layers[layer_idx].self_attn.post_init() + + self.model.layers[layer_idx].self_attn.pruner = self.pruner + + # Initialize weights and apply final processing + self.post_init() + + def _generate(*args, **kwargs): + self.pruner.before_generate(self, *args, **kwargs) + result = self.ori_generate(*args, **kwargs) + self.pruner.after_generate(self, *args, **kwargs) + return result + + self.ori_generate = self.generate + self.generate = _generate + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # cache_length = past_length = input_ids.shape[-1] - 1 + cache_length = past_length = self.pruner.past_length + max_cache_length = None + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] # pylint: disable=E1130 + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" \ + and not using_static_cache and not output_attentions: # pylint: disable=E1101 + if AttentionMaskConverter._ignore_causal_mask_sdpa( # pylint: disable=E1101 + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) # pylint: disable=E1120 + + return causal_mask diff --git a/intel_extension_for_transformers/transformers/kv_cache_compression/prune/__init__.py b/intel_extension_for_transformers/transformers/kv_cache_compression/prune/__init__.py new file mode 100644 index 00000000000..7f9ede223e8 --- /dev/null +++ b/intel_extension_for_transformers/transformers/kv_cache_compression/prune/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import PruneConfig, KVPruner +from .h2o import H2OConfig, H2OKVPruner diff --git a/intel_extension_for_transformers/transformers/kv_cache_compression/prune/base.py b/intel_extension_for_transformers/transformers/kv_cache_compression/prune/base.py new file mode 100644 index 00000000000..4cc8f617657 --- /dev/null +++ b/intel_extension_for_transformers/transformers/kv_cache_compression/prune/base.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +class PruneConfig(dict): + def __init__(self, real_drop=True): + self.real_drop = real_drop + +class KVPruner: + def __init__(self, prune_config) -> None: + self._past_length = 0 + self.prune_kv_cache_size = None + + def self_attn_init(self, module): + pass + + def prune(self, module, query_states, key_states, value_states, **kwargs): + pass + + def before_generate(self, model, inputs, *args, **kwargs): + self.past_length = 0 + + def after_generate(self, model, inputs, *args, **kwargs): + pass + + def get_mask(self, model, **kwargs): + pass + + @property + def past_length(self): + return self._past_length + + @past_length.setter + def past_length(self, value): + self._past_length = value + + def remove_repeat_kv(self, kv_tensor, n_rep): + if n_rep == 1: + return kv_tensor + drop_mask = torch.tensor( + [True if i % n_rep == 0 else False for i in range(0, kv_tensor.size(1))] + ).repeat(kv_tensor.size(0), 1).to(kv_tensor.device) + new_shape = list(kv_tensor.shape) + new_shape[1] = int(new_shape[1] / n_rep) + kv_tensor = kv_tensor[drop_mask].view(new_shape) + return kv_tensor diff --git a/intel_extension_for_transformers/transformers/kv_cache_compression/prune/h2o.py b/intel_extension_for_transformers/transformers/kv_cache_compression/prune/h2o.py new file mode 100644 index 00000000000..ecb2f8ca3fe --- /dev/null +++ b/intel_extension_for_transformers/transformers/kv_cache_compression/prune/h2o.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import torch +import torch.nn as nn + +from .base import KVPruner, PruneConfig + +def local_heavy_hitter_mask(attn_weights, heavy_budget, no_padding_seq_length=None): + + # attn_weights (BS, head, query, keys) + dtype_attn_weights = attn_weights.dtype + seq_length = attn_weights.shape[-1] + if no_padding_seq_length is None: + padding_length = 0 + else: + padding_length = seq_length - no_padding_seq_length + + offset = torch.finfo(attn_weights.dtype).min + tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights) + + accumulated_attention_score = torch.sum( + tmp_attn[:,:,padding_length:heavy_budget+padding_length,:], dim=-2) #(head, keys) + accumulated_attention_score[:,:,heavy_budget+padding_length:] = 0 + accumulated_attention_score[:,:,:padding_length] = 0 + + mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool) + mask_bottom[:,:, padding_length:heavy_budget+padding_length, + padding_length:heavy_budget+padding_length] = True + for token_index in range(heavy_budget+padding_length, seq_length): + + tmp_attn_index = nn.functional.softmax( + attn_weights[:,:,token_index,:], dim=-1, dtype=torch.float32).to(dtype_attn_weights) + _, tmp_topk_index = accumulated_attention_score.topk(k=heavy_budget-1, dim=-1) + zeros_index = torch.zeros_like(tmp_attn_index, dtype=torch.bool) + mask_bottom_index = zeros_index.scatter(-1, tmp_topk_index, True) #(head, keys) + mask_bottom_index[:,:, token_index] = True + mask_bottom[:,:,token_index,:] = mask_bottom_index + accumulated_attention_score += tmp_attn_index + accumulated_attention_score = accumulated_attention_score * mask_bottom_index + + mask_bottom = torch.tril(mask_bottom, diagonal=0) + + return mask_bottom + + +def get_hh_mask(heavy_budget_ratio, recent_budget_ratio, attn_weights, local=True): + heavy_budget = int(heavy_budget_ratio * attn_weights.shape[-1]) + recent_budget = int(recent_budget_ratio * attn_weights.shape[-1]) + if heavy_budget > 0: + # Default: No padding applied to input + if local: + mask_bottom = local_heavy_hitter_mask(attn_weights, heavy_budget, None) + else: + tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(attn_weights.dtype) + tmp_sum = torch.sum(tmp_attn, dim=-2) + _, tmp_topk = tmp_sum.topk(k=heavy_budget, dim=-1) + + zeros = torch.zeros_like(tmp_sum, dtype=torch.bool) + mask_bottom = zeros.scatter(-1, tmp_topk, True).unsqueeze(2) + mask_bottom = mask_bottom.expand( + mask_bottom.shape[0], + mask_bottom.shape[1], + attn_weights.shape[-2], + mask_bottom.shape[-1]) + else: + mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool) + + # Recent Mask + ones = torch.ones_like(attn_weights, dtype=torch.bool) + ones = torch.triu(ones, diagonal=-recent_budget) + mask_bottom = torch.logical_or(mask_bottom, ones) + # Combine h2o+recent and apply casual mask + mask_bottom = torch.tril(mask_bottom, diagonal=0) + + return mask_bottom + +class H2OConfig(PruneConfig): + def __init__( + self, + heavy_ratio: float = None, + recent_ratio: float = None, + heavy_budget: int = None, + recent_budget: int = None, + h2o_min_seqlen: int = -1, + real_drop: bool = True, + mean: bool = False, + local: bool = True + ): + super().__init__() + self.heavy_ratio = heavy_ratio + self.recent_ratio = recent_ratio + self.heavy_budget = heavy_budget + self.recent_budget = recent_budget + self.h2o_min_seqlen = h2o_min_seqlen + self.real_drop = real_drop + self.mean = mean + self.local = local + + +class H2OKVCache: + def __init__( + self, + heavy_ratio=0.2, + recent_ratio=0.2, + heavy_budget=None, + recent_budget=None, + min_seqlen=-1, + mean=False + ): + ## bsz, num_heads, seq_len, head_dim + assert 0 <= heavy_ratio <= 1 and 0 <= recent_ratio <= 1, "ratio should be in [0, 1]" + assert heavy_budget is None or heavy_budget >= 0, "heavy_budget should be non-negative" + assert recent_budget is None or recent_budget >= 0, "recent_budget should be non-negative" + self.heavy_ratio = heavy_ratio + self.recent_ratio = recent_ratio + self.heavy_budget = heavy_budget + self.recent_budget = recent_budget + self.hh_score = None + self.min_seqlen = min_seqlen + self.mean = mean + self.idx = 0 + + def __call__(self, attn_score, key_states, value_states, **kwargs): + seq_len = key_states.size(-2) + if self.heavy_budget is None: + self.heavy_budget = int(self.heavy_ratio * seq_len) + if self.recent_budget is None: + self.recent_budget = int(self.recent_ratio * seq_len) + cache_size = self.heavy_budget + self.recent_budget + if seq_len <= self.min_seqlen or seq_len <= cache_size: + return key_states, value_states + self.idx += 1 + # attn_score shape (bsz, num_heads, seq_len, head_dim) + if len(attn_score.shape) == 3: + attn_score = attn_score.unsqueeze(0) + if len(attn_score.shape) == 5: + attn_score = attn_score.reshape( + attn_score.shape[0], + attn_score.shape[1] * attn_score.shape[2], + attn_score.shape[3], + attn_score.shape[4] + ) + self._update_hh_score(attn_score, mean=self.mean) + + # hh-selection + mask = torch.zeros(self.hh_score.shape, dtype=attn_score.dtype).to(key_states.device) + if not self.recent_budget == 0: + mask[:,:,-self.recent_budget:] = 1 # pylint: disable=E1130 + select_hh_scores = self.hh_score[:,:,:seq_len - self.recent_budget] + + if not self.heavy_budget == 0: + _, keep_topk = torch.topk(select_hh_scores, self.heavy_budget, dim=-1, largest=True) + mask = mask.scatter(-1, keep_topk, 1) + + mask = mask.bool() + self.hh_score = self.hh_score[mask].view(self.hh_score.shape[0], self.hh_score.shape[1], cache_size) + + # if use repeat_kv, need to reshape mask + n_rep = mask.size(1) / key_states.size(1) + if n_rep > 1: + drop_mask = torch.tensor( + [True if i % n_rep == 0 else False for i in range(0, mask.size(1))] + ).repeat(mask.size(0), 1).to(mask.device) + mask = mask[drop_mask].view(key_states.shape[:-1]) + + k_hh_recent = key_states[mask].view(key_states.shape[0], key_states.shape[1], cache_size, -1) + v_hh_recent = value_states[mask].view(value_states.shape[0], value_states.shape[1], cache_size, -1) + + return k_hh_recent, v_hh_recent + + def _update_hh_score(self, attn_score_cache, mean=False): + # attn_score_cache (bsz, num_heads, seq_len, head_dim) + # hh_score size (bsz, num_heads, head_dim) + + attn_score_cache = attn_score_cache.sum(-2) + if self.hh_score is not None: + # clean self.hh_score if not generation mode + if attn_score_cache.size(-1) < self.hh_score.size(-1): + self.clean_scores() + if not mean: + attn_score_cache[:, :, :self.hh_score.shape[-1]] += self.hh_score + else: + attn_score_cache[:,:,:self.hh_score.shape[-1]] = attn_score_cache[:,:,:self.hh_score.shape[-1]] \ + * (self.idx - 1) + self.hh_score + attn_score_cache /= self.idx + + self.hh_score = attn_score_cache + + def clean_scores(self): + self.idx = 0 + self.hh_score = None + + +class H2OKVPruner(KVPruner): + def __init__(self, config: H2OConfig) -> None: + self.config = config + self.real_drop = self.config.real_drop + self.prune_kv_cache_size = None + + + def self_attn_init(self, module): + module.h2o_kv_cache = H2OKVCache( + self.config.heavy_ratio, + self.config.recent_ratio, + self.config.heavy_budget, + self.config.recent_budget, + self.config.h2o_min_seqlen, + self.config.mean + ) + + def before_generate(self, model, inputs, *args, **kwargs): + assert self.real_drop is True, 'H2O only support real drop mode when use generate func.' + self.past_length = 0 + if kwargs.get('max_new_tokens', None): + max_length = kwargs['max_new_tokens'] + inputs.size(-1) + elif kwargs.get('max_length', None): + max_length = kwargs['max_length'] + else: + max_length = model.config.max_length + if max_length <= inputs.size(-1): + max_length += inputs.size(-1) + for _, module in model.named_modules(): + if "Attention" in module.__class__.__name__: + if module.h2o_kv_cache.heavy_budget is None: + module.h2o_kv_cache.heavy_budget = round(max_length * module.h2o_kv_cache.heavy_ratio) + if module.h2o_kv_cache.recent_budget is None: + module.h2o_kv_cache.recent_budget = round(max_length * module.h2o_kv_cache.recent_ratio) + if self.prune_kv_cache_size is None: + self.prune_kv_cache_size = module.h2o_kv_cache.recent_budget + module.h2o_kv_cache.heavy_budget + + def after_generate(self, model, inputs, *args, **kwargs): + for _, module in model.named_modules(): + if "Attention" in module.__class__.__name__: + module.h2o_kv_cache.clean_scores() + self.prune_kv_cache_size = None + + def prune(self, module, query_states, key_states, value_states, causal_mask=None, **kwargs): + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(module.head_dim) + if causal_mask is not None: # no matter the length, we just slice it + attn_weights = attn_weights + causal_mask + if not self.config.real_drop: + module.h2o_kv_cache.clean_scores() + return module.h2o_kv_cache(attn_weights, key_states, value_states, **kwargs) + + def get_mask(self, module, query_states, key_states, value_states, causal_mask=None, **kwargs): + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(module.head_dim) + if causal_mask is not None: # no matter the length, we just slice it + attn_weights = attn_weights + causal_mask + mask = get_hh_mask( + self.config.heavy_ratio, + self.config.recent_ratio, + attn_weights, + local=self.config.local) + return mask