-
Notifications
You must be signed in to change notification settings - Fork 3
/
transgenerator_translation.py
108 lines (88 loc) · 3.3 KB
/
transgenerator_translation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Set the GPU for training
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
import pickle
import copy
import operator
import json
import numpy as np
import pandas as pd
from functools import reduce
import argparse
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, modeling_utils, GPT2Config, modeling_gpt2, GPT2Model, GPT2PreTrainedModel, GPT2Config
def generation(model,tokenizer,condition):
sentence = condition
inp = torch.tensor(tokenizer.encode(condition)).unsqueeze(0)
inp = inp.to("cuda")
with torch.no_grad():
for x in range(1024 - len(inp[0])): ## stop generation on the max length
outputs = model(inp)
predictions = outputs[0]
new = torch.tensor([[torch.argmax(predictions[0, -1, :]).item()]])
new = new.to("cuda")
inp = torch.cat((inp,new),1)
inp.to("cuda")
if new[0][0].item() == 50256: #EOS token
break
predicted_text = tokenizer.decode(inp.tolist()[0][len(tokenizer.encode(condition)):])
return predicted_text
#For new format
def split_train_2(train_data):
continueSet = []
for x in range(len(train_data)):
train_data[x] = train_data[x].split("====")
if len(train_data[x]) != 3:
pass
else:
continueSet.append(train_data[x][0] + "====" + train_data[x][1] + "====")
return continueSet
def split_train(train_data):
for x in range(len(train_data)):
train_data[x] = train_data[x].split("====")
train_data[x][0] = train_data[x][0] + "===="
return train_data
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--text_path", default="data_files/test.txt", type=str)
parser.add_argument("--model_path", type=str)
parser.add_argument("--test_data",default=False, type = bool)
parser.add_argument("--n_data",default=50, type = int)
parser.add_argument("--save_path",default="data_files/test.p", type = str)
parser.add_argument("--n_layers",default=12,type = int)
args = parser.parse_args()
if args.n_layers == 12:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.load_state_dict(torch.load(args.model_path))
else:
config = GPT2Config(n_layer = args.n_layers)
model = GPT2LMHeadModel(config)
model.load_state_dict(torch.load(args.model_path))
model.to("cuda")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
if args.test_data:
data = pickle.load(open(args.text_path,"rb"))
if len(data[0]) < args.n_data:
args.n_data = len(data)
else:
data = open(args.text_path,"r+",encoding="utf-8")
data = data.read()
data = data.split("<|endoftext|>")
data = split_train_2(data)
if len(data) < args.n_data:
args.n_data = len(data)
predictions = []
if args.test_data :
for x in range(args.n_data):
pred = generation(model,tokenizer,data[0][x])
predictions.append(pred)
print(x)
else:
for x in range(args.n_data):
print(data[x])
pred = generation(model,tokenizer,data[x])
predictions.append(pred)
print(x)
pickle.dump(predictions,open(args.save_path, "wb"))
if __name__ == '__main__':
main()