Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

h2o for kv cache compression #1468

Open
wants to merge 84 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
41d8647
h2o for kv cache compression
n1ck-guo Apr 10, 2024
eb7f564
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2024
c46ea7d
rebuild
BiaoFangAIA Apr 23, 2024
95ff9ae
merge
BiaoFangAIA Apr 23, 2024
9d27733
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
444490d
update
n1ck-guo Apr 25, 2024
4309089
update
n1ck-guo Apr 25, 2024
8c5272e
merge
n1ck-guo Apr 25, 2024
1b83e52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2024
3fd73cb
update
n1ck-guo May 7, 2024
a2d3ae0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2024
ddf5445
Merge branch 'main' into hengguo/h2o
VincyZhang May 13, 2024
91d4394
Merge branch 'main' into hengguo/h2o
n1ck-guo May 14, 2024
a83e6d6
real drop
n1ck-guo May 14, 2024
92c8a62
modify real drop code
n1ck-guo May 15, 2024
70a1cf3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2024
bc9eade
fix
BiaoFangAIA May 16, 2024
9aa25f6
update for real drop and sim mode, using the same api
n1ck-guo May 16, 2024
03cdc8d
Merge branch 'main' into hengguo/h2o
n1ck-guo May 16, 2024
e51b5b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2024
b8e9df2
support for sdpa and flash attention
n1ck-guo May 16, 2024
02f31b2
Merge branch 'hengguo/h2o' of https://github.com/intel/intel-extensio…
n1ck-guo May 16, 2024
274b7ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2024
877329d
change to new api
n1ck-guo May 17, 2024
b435e3e
Merge branch 'hengguo/h2o' of https://github.com/intel/intel-extensio…
n1ck-guo May 17, 2024
5068552
clean code
n1ck-guo May 17, 2024
5e5f589
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2024
d0dce7d
fix
n1ck-guo May 20, 2024
febb76a
Merge branch 'hengguo/h2o' of https://github.com/intel/intel-extensio…
n1ck-guo May 20, 2024
24c4725
add example
n1ck-guo May 20, 2024
5bd3f16
Merge branch 'main' into hengguo/h2o
n1ck-guo May 20, 2024
955e132
clean
n1ck-guo May 20, 2024
91efe57
pylint
n1ck-guo May 20, 2024
4190edb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2024
e71bf92
pylint
n1ck-guo May 20, 2024
41f016c
pylint
n1ck-guo May 20, 2024
d49487f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2024
9ac5eca
fix import error
n1ck-guo May 21, 2024
09def0b
update
n1ck-guo May 21, 2024
5cae1fd
Merge branch 'hengguo/h2o' of https://github.com/intel/intel-extensio…
n1ck-guo May 21, 2024
3042dd4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2024
3a992ab
pylint
n1ck-guo May 21, 2024
8c89cbc
pylint
n1ck-guo May 21, 2024
072ad76
merge
n1ck-guo May 21, 2024
4c26487
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2024
741c7cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2024
0800df2
add example readme
n1ck-guo May 27, 2024
558dfd9
update
n1ck-guo May 27, 2024
d6de2b3
merge
n1ck-guo May 27, 2024
693983f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2024
30bed25
fix acc bug
n1ck-guo Jun 6, 2024
f8a64fc
fix
n1ck-guo Jun 7, 2024
d892e74
Merge branch 'hengguo/h2o' of https://github.com/intel/intel-extensio…
n1ck-guo Jun 11, 2024
7a12ec6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 11, 2024
124bb72
refactor code
n1ck-guo Jun 18, 2024
5181afc
fix
n1ck-guo Jun 18, 2024
93ad39b
merge
n1ck-guo Jun 18, 2024
58bfcd0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2024
76656c9
Merge branch 'main' into hengguo/h2o
n1ck-guo Jun 18, 2024
91c5f3c
new api
n1ck-guo Jun 20, 2024
4884b3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 20, 2024
812a838
support for gaudi
n1ck-guo Jun 24, 2024
2cc6a8f
merge
n1ck-guo Jun 24, 2024
2d82bb5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2024
a9488b4
update
n1ck-guo Jun 25, 2024
3c185ad
Merge branch 'hengguo/h2o' of https://github.com/intel/intel-extensio…
n1ck-guo Jun 25, 2024
a6d3fc7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 25, 2024
9698230
pylint
n1ck-guo Jun 26, 2024
14f5a6d
pylint
n1ck-guo Jun 27, 2024
dd6ee3c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2024
523ca76
Merge branch 'main' into hengguo/h2o
n1ck-guo Jun 28, 2024
2618e6f
Merge branch 'main' into hengguo/h2o
n1ck-guo Jul 2, 2024
cedcd43
Merge branch 'main' into hengguo/h2o
changwangss Jul 3, 2024
eb8441c
Merge branch 'main' into hengguo/h2o
n1ck-guo Jul 12, 2024
0894b6d
Merge branch 'main' into hengguo/h2o
n1ck-guo Jul 15, 2024
d241c25
add desc to h2o in readme
n1ck-guo Jul 15, 2024
0c547c5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2024
3723158
add doc for h2o
n1ck-guo Jul 15, 2024
7da0cf5
Merge branch 'hengguo/h2o' of https://github.com/intel/intel-extensio…
n1ck-guo Jul 15, 2024
d600112
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2024
b1ab771
update
n1ck-guo Jul 16, 2024
4bee0f0
add ut
n1ck-guo Jul 16, 2024
3fad641
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2024
46fa4d7
Merge branch 'main' into hengguo/h2o
XuehaoSun Jul 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import argparse
import json

from lm_eval import evaluator, tasks
from tasks import EvalHarnessAdaptor


if __name__ == '__main__':


parser = argparse.ArgumentParser(
prog = 'ProgramName',
description = 'What the program does',
epilog = 'Text at the bottom of help')

parser.add_argument('--output-file', type=str, default='input.jsonl')
parser.add_argument('--task-name', type=str, default='hellaswag')
parser.add_argument('--num-fewshot', type=int, default=0)
args = parser.parse_args()

seq = 1024
total_batch = 1
pe = 'fixed'

with open(args.output_file, 'w') as f:
pass

class DryRunner:
def eval(self, batch):

with open(args.output_file, 'a') as f:
for text in batch['text']:
item = {
"best_of": 1,
"echo": True,
"logprobs": 1,
"max_tokens": 0,
"model": "x",
"n": 1,
"prompt": text,
"request_type": "language-model-inference",
"stop": None,
"temperature": 0,
"top_p": 1
}
f.write(json.dumps(item) + '\n')

out = {
'mask_loss': [1.0] * len(batch),
'each_correct': [True] * len(batch),
}
return out

t = DryRunner()
adaptor = EvalHarnessAdaptor(t, seq, total_batch, shrink=pe != "fixed")
results = evaluator.evaluate(adaptor, tasks.get_task_dict([args.task_name
#"lambada_openai",
#"piqa",
#"hellaswag",
#"winogrande",
#"mathqa",
#"pubmedqa",
# "boolq",
# "cb",
# "copa",
# "multirc",
# "record",
# "wic",
# "wsc",
]), False, args.num_fewshot, None)
print('Finished')

# dumped = json.dumps(results, indent=2)
# print(dumped)
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import argparse
import json, tqdm
import torch
import copy
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import sys
sys.path.insert(0, '/home/hengguo/code/intel-extension-for-transformers')
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved

import sys
sys.path.insert(0, '/root/hengguo/intel-extension-for-transformers')

from lm_eval import evaluator, tasks
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

from intel_extension_for_transformers.transformers.modeling.kv_cache_compression.h2o_sim_drop.modify_llama import convert_kvcache_llama_heavy_recent
from intel_extension_for_transformers.transformers.modeling.kv_cache_compression.h2o_sim_drop.modify_opt import convert_kvcache_opt_heavy_recent
from intel_extension_for_transformers.transformers.modeling.kv_cache_compression.h2o_sim_drop.modify_gptneox import convert_kvcache_gpt_neox_heavy_recent

from tasks import EvalHarnessAdaptor

ENABLE_Heavy_Hitter_FUNCTIONS = {
"llama": convert_kvcache_llama_heavy_recent,
"opt": convert_kvcache_opt_heavy_recent,
"gpt_neox": convert_kvcache_gpt_neox_heavy_recent,
}

if __name__ == '__main__':

parser = argparse.ArgumentParser(
prog = 'ProgramName',
description = 'What the program does',
epilog = 'Text at the bottom of help')
parser.add_argument("--tasks", nargs='+', default=["lambada_openai",
"hellaswag", "winogrande", "piqa", "wikitext"],
type=str, help="tasks list for accuracy validation")
parser.add_argument('--num_fewshot', type=int, default=0)

parser.add_argument('--enable_small_cache', action='store_true')
parser.add_argument('--model_name', type=str, default='facebook/opt-350m')
parser.add_argument("--cache_dir", type=str, default=None)

parser.add_argument("--heavy_ratio", type=float, default=0.1)
parser.add_argument("--recent_ratio", type=float, default=0.1)
parser.add_argument("--device", type=str, default='cpu')
parser.add_argument("--seq_len", type=int, default=1024)

parser.add_argument('--debug', action='store_true')
args = parser.parse_args()

batch_size = 1
pe = 'fixed'
seq = args.seq_len

# build data
requests = []
class DryRunner:
def eval(self, batch):
for text in batch['text']:
item = {
"best_of": 1,
"echo": True,
"logprobs": 1,
"max_tokens": 0,
"model": "x",
"n": 1,
"prompt": text,
"request_type": "language-model-inference",
"stop": None,
"temperature": 0,
"top_p": 1
}
requests.append(item)
out = {
'mask_loss': [1.0] * len(batch),
'each_correct': [True] * len(batch),
}
return out
t = DryRunner()
adaptor = EvalHarnessAdaptor(t, seq, batch_size, shrink=pe != "fixed")
result = evaluator.evaluate(adaptor, tasks.get_task_dict(args.tasks), False, args.num_fewshot, None)

model_name = args.model_name
if 'cpu' in args.device:
device = args.device
else:
device = f"cuda:{args.device}"

config = AutoConfig.from_pretrained(model_name, cache_dir=args.cache_dir)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=args.cache_dir)
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=args.cache_dir)

if args.enable_small_cache:
print('Enable Small Cache Size')
# checkpoint = copy.deepcopy(model.state_dict())
# model = ENABLE_Heavy_Hitter_FUNCTIONS[args.model_type](model, config)
from intel_extension_for_transformers.transformers.modeling.kv_cache_compression import convert_model
model = convert_model(model, heavy_ratio=args.heavy_ratio, recent_ratio=args.recent_ratio, h2o_min_seqlen=0)
# model.load_state_dict(checkpoint)

model = model.to(device)
print('using device: ', device)
model.eval()
# model.half().eval()

results = []
with torch.no_grad():
for request in tqdm.tqdm(requests):
result = {'request': request, 'result': {}}
prompt = request['prompt']
input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').input_ids.to(model.device)

logits = model(input_ids).logits.log_softmax(dim=-1)

values, indices = logits.squeeze(0).topk(dim=-1, k=1)
tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0))

gold_indices = input_ids[:, 1:] # skip first
logprobs = [None] + torch.gather(logits, -1, gold_indices.unsqueeze(-1)).squeeze(-1).squeeze(0).detach().cpu().tolist()
top_logprobs = [None] + [{tokenizer.convert_ids_to_tokens(i.item()): v.item()} for v, i in zip(values.squeeze(-1), indices.squeeze(-1))]

result['result'] = {
"choices": [
{
"text": prompt,
"logprobs": {
"tokens": tokens,
"token_logprobs": logprobs,
"top_logprobs": top_logprobs,
"text_offset": []
},
"finish_reason": "length"
}
],
"request_time": {
"batch_time": 0,
"batch_size": 1}
}

results.append(result)

# evaluate
class RealRunner:
def __init__(self, args):
self.results = {}
for item in results:
request = item['request']
result = item['result']
self.results[json.dumps(request)] = result
print(f"{len(self.results)} items in the cache")

def eval(self, batch):
from tasks.eval_harness import tokenizer
mask_loss = []
each_correct = []
for i, text in enumerate(batch['text']):
request = {
"best_of": 1,
"echo": True,
"logprobs": 1,
"max_tokens": 0,
"model": "x",
"n": 1,
"prompt": text,
"request_type": "language-model-inference",
"stop": None,
"temperature": 0,
"top_p": 1
}

key = json.dumps(request)
correct = True

if key in self.results:
result = self.results[key]
token_logprobs = result['choices'][0]['logprobs']['token_logprobs']
tokens = result['choices'][0]['logprobs']['tokens']
top_logprobs = result['choices'][0]['logprobs']['top_logprobs']
assert token_logprobs[0] is None
token_ids = tokenizer.convert_tokens_to_ids(tokens)
obs = batch['obs'][i]
target = batch['target'][i]
eval_mask = batch['eval_mask'][i]

n_positive = 0
sum_lobprob = 0
if args.debug:
print(target)
for i, mask in enumerate(eval_mask):
try:
if i+1 >= len(tokens):
break
if mask == True:
if args.debug:
print(tokens[i+1], next(iter(top_logprobs[i+1].keys())))
correct = correct and (tokens[i+1] == next(iter(top_logprobs[i+1].keys())))
sum_lobprob += token_logprobs[i+1]
n_positive += 1
except Exception as e:
raise e
# avg_logprob = sum(token_logprobs[1:]) / (len(token_logprobs) - 1)
avg_logprob = sum_lobprob / n_positive
mask_loss.append( - avg_logprob)
each_correct.append( correct )
else:
assert False

out = {
'mask_loss': mask_loss,
'each_correct': each_correct,
}
return out

t = RealRunner(args)

adaptor = EvalHarnessAdaptor(t, seq, batch_size, shrink=pe != "fixed")
results = evaluator.evaluate(adaptor, tasks.get_task_dict(args.tasks), False, args.num_fewshot, None)

dumped = json.dumps(results, indent=2)
print(dumped)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from tasks.eval_harness import EvalHarnessAdaptor
Loading
Loading