-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_llama2_7b_torch_musa.py
105 lines (91 loc) · 4.02 KB
/
test_llama2_7b_torch_musa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch_musa
import time
from numpy import percentile
from modelscope import snapshot_download, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
model_dir = '/data/models/llama-2-7b-chat-hf-fp16/'
batch_size = 48
num_tokens = 256
tokenizer = AutoTokenizer.from_pretrained(model_dir, device_map="musa",
trust_remote_code=True, torch_dtype=torch.float16)
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="musa",
trust_remote_code=True, torch_dtype=torch.float16)
model.generation_config = GenerationConfig.from_pretrained(model_dir)
device = 'musa'
print('#----------------------------warmup start---------------------------')
with torch.no_grad():
# prompts = "春眠不觉晓,处处闻啼鸟。"
prompts = "[Round 1]\n\n问:如何获得同事的认可\n\n答:\n"
input_ids = tokenizer(prompts).input_ids
# print(f"input: {tokenizer.decode(input_ids)}")
input_ids = torch.LongTensor([input_ids]).to(device)
# print(input_ids.shape)
output_ids = list()
past_key_values = None
output = model(input_ids, use_cache=True, past_key_values=past_key_values)
logits = output.logits
past_key_values = output.past_key_values
res = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
output_ids.append(int(res.cpu().numpy()[0][0]))
input_ids = res
for i in range(256):
output = model(input_ids, use_cache=True,
past_key_values=past_key_values)
logits = output.logits
past_key_values = output.past_key_values
res = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
output_ids.append(int(res.cpu().numpy()[0][0]))
input_ids = res
# print(f"output_ids: {output_ids}")
print("output:", tokenizer.decode(output_ids).strip())
print('#----------------------------warmup end -----------------------------')
# ----------------------------profile start---------------------------
# for i in [128, 256, 512, 1024, 2048, 3072]:
st = time.time()
prefill_time = 0
decode_time = 0
with torch.no_grad():
# generate random input_ids.
input_ids = torch.randint(
10240, (batch_size, num_tokens)).to(device) + 10240
# generate by multiple prompts.
# prompts = ["[Round 1]\n\n问:如何获得同事的认可\n\n答:\n" for i in range(batch_size)]
# input_ids = tokenizer(prompts).input_ids
# input_ids = torch.LongTensor(input_ids).to(device)
st = time.time()
output_ids = torch.LongTensor()
# output_ids = list()
past_key_values = None
# start prefill
prefill_start = time.time()
output = model(input_ids, use_cache=True, past_key_values=past_key_values)
logits = output.logits
past_key_values = output.past_key_values
res = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
output_ids = torch.cat((output_ids, res.cpu()), dim=1)
input_ids = res
prefill_end = time.time()
prefill_time = prefill_end - prefill_start
# start generate tokens
for i in range(num_tokens):
output = model(input_ids, use_cache=True,
past_key_values=past_key_values)
logits = output.logits
past_key_values = output.past_key_values
res = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
output_ids = torch.cat((output_ids, res.cpu()), dim=1)
input_ids = res
decode_time = time.time() - prefill_end
# print the resultes
# for output_id in output_ids:
# print(output_id)
# print(f"output: {tokenizer.decode(output_id.tolist())}")
consume_time = time.time() - st
num_tokens = batch_size * num_tokens
fps = num_tokens / decode_time
print(f'generate token fps: {fps:.3f} tokens/s')
print(
f'prefill latency :{prefill_time:.3f} s, single one batch prefill latency:{prefill_time/batch_size:.3f} s')
print(
f'decode latency :{decode_time:.3f} s, single one batch decode latency:{decode_time/batch_size:.3f} s')
print(f'end-to-end time:{consume_time:.3f} s, num_tokens:{num_tokens}')