-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstarchat_inf.py
69 lines (55 loc) · 1.98 KB
/
starchat_inf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import json
import torch
import logging
from tqdm import tqdm
from transformers import pipeline
from prompt import sp1, sp2
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
INPUT_DIR = "data"
RESULT_DIR = "results"
ERROR_DIR = "error"
os.makedirs(RESULT_DIR, exist_ok=True)
os.makedirs(ERROR_DIR, exist_ok=True)
logging.info("Loading text generation model...")
pipe = pipeline(
"text-generation",
model="HuggingFaceH4/starchat-beta",
torch_dtype=torch.bfloat16,
device_map="auto"
)
PROMPT_TEMPLATE = "<|system|>\n<|end|>\n<|user|>\n{query}<|end|>\n<|assistant|>"
def get_valid_files(directory):
files = sorted(os.listdir(directory))
return [file for file in files if not file.startswith(('.', '~'))]
def process_file(file_path, output_path, error_path, query):
try:
with open(file_path, "r") as f:
data = json.load(f)
for element in tqdm(data, desc=f"Processing {os.path.basename(file_path)}"):
prompt = PROMPT_TEMPLATE.format(
query=f"""{query}
Code A:
{element['codeA']}
Code B:
{element['codeB']}
"""
)
element["Result"] = pipe(prompt, max_new_tokens=4000, do_sample=True, temperature=0.3, top_k=50, top_p=0.95, eos_token_id=49155, num_return_sequences=1)
with open(output_path, "w") as f:
json.dump(data, f, indent=4)
except Exception as e:
logging.error(f"Error processing {file_path}: {e}")
with open(error_path, "w") as f:
json.dump(data, f, indent=4)
def main():
files = get_valid_files(INPUT_DIR)
query = "prompt" #sp1 or sp2
for file in files:
file_path = os.path.join(INPUT_DIR, file)
output_path = os.path.join(RESULT_DIR, f"starchat_{file}")
error_path = os.path.join(ERROR_DIR, f"st__{file}")
logging.info(f"Processing file: {file}")
process_file(file_path, output_path, error_path, query)
if __name__ == "__main__":
main()