This repository has been archived by the owner on Jun 21, 2024. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 77
/
inference.py
77 lines (58 loc) · 2.16 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import hidet
from transformers import AutoTokenizer
from einops._torch_specific import allow_ops_in_compiled_graph
import argparse
def main():
allow_ops_in_compiled_graph()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
parser = argparse.ArgumentParser(description="Generate text using PaLM model")
parser.add_argument("prompt", type=str, help="Text prompt to generate text")
parser.add_argument(
"--seq_len", type=int, default=256, help="Sequence length for generated text"
)
parser.add_argument(
"--temperature", type=float, default=0.8, help="Sampling temperature"
)
parser.add_argument(
"--filter_thres", type=float, default=0.9, help="Filter threshold for sampling"
)
parser.add_argument(
"--model",
type=str,
default="palm_1b_8k_v0",
help="Model to use for generation",
)
parser.add_argument(
"--dtype",
type=str,
default="fp32",
help="Data type for the model: 'bf16', or 'fp32'",
)
args = parser.parse_args()
dtype = torch.float32
if args.dtype == 'bf16':
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.hub.load("conceptofmind/PaLM", args.model).to(device).to(dtype).eval()
hidet.torch.dynamo_config.use_tensor_core(True)
hidet.torch.dynamo_config.search_space(2)
opt_model = torch.compile(model, backend="hidet")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
encoded_text = tokenizer(args.prompt, return_tensors="pt")
output_tensor = opt_model.generate(
seq_len=args.seq_len,
prompt=encoded_text["input_ids"].to(device),
temperature=args.temperature,
filter_thres=args.filter_thres,
pad_value=0.0,
eos_token=tokenizer.eos_token_id,
return_seq_without_prompt=False,
use_tqdm=True,
)
decoded_output = tokenizer.batch_decode(output_tensor, skip_special_tokens=True)
return decoded_output
if __name__ == "__main__":
generated_text = main()
for text in generated_text:
print(f"{text}")