-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
47 lines (30 loc) · 1.52 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import model_init
from transformers import pipeline
from langchain import HuggingFacePipeline
import json
def read_json(path):
with open(path, 'r') as fr:
file = json.load(fr)
return file
def llm_init_langchain(config, max_new_tokens, seed):
if config['model_type'] == 'gpt4-turbo-128k':
llm = model_init.gpt(config['env_file_path'], config['deployment_name'], config['model_version'], max_new_tokens, seed)
elif config['model_type'] == 'llama-2-chat-70b' or config['model_type'] == 'llama-3-instruct-70b':
model, tokenizer = model_init.llama(config['model_path'], load_in_4bit=True)
text_pipeline = pipeline(task="text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=max_new_tokens,
do_sample=False)
llm = HuggingFacePipeline(pipeline=text_pipeline)
elif config['model_type'] == 'Mixtral-8x7B-Instruct-v0.1':
model, tokenizer = model_init.mixtral(config['model_path'], load_in_4bit=True)
text_pipeline = pipeline(task="text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=max_new_tokens,
do_sample=False)
llm = HuggingFacePipeline(pipeline=text_pipeline)
else:
raise ValueError('Model type {} not supported', config['model_type'])
return llm