-
Notifications
You must be signed in to change notification settings - Fork 1
/
merge_lora.py
43 lines (36 loc) · 1.24 KB
/
merge_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from peft import PeftModel,PeftConfig
from transformers import AutoModel,AutoTokenizer,AutoConfig
import argparse
def main():
parser =argparse.ArgumentParser()
parser.add_argument("--base_model_path",default=None,type=str,required=True)
parser.add_argument("--lora_model_path", default=None, type=str, required=True)
parser.add_argument("--merged_model_path", default=None, type=str, required=True)
args =parser.parse_args()
print(args)
print("loading lora model")
base_model =AutoModel.from_pretrained(
args.base_model_path,
device_map='auto',
trust_remote_code=True,
torch_dtype =torch.float16
)
tokenizer =AutoTokenizer.from_pretrained(
args.base_model_path,trust_remote_code=True
)
lora_model =PeftModel.from_pretrained(
base_model,
args.lora_model_path,
device_map='auto',
torch_dtype =torch.float16
)
lora_model.eval()
print('merging...')
base_model =lora_model.merge_and_unload()
print('saving...')
tokenizer.save_pretrained(args.merged_model_path)
base_model.save_pretrained(args.merged_model_path)
print("done")
if __name__ =='__main__':
main()