forked from harvardnlp/seq2seq-attn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_handle.py
79 lines (75 loc) · 2 KB
/
data_handle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# -*- coding: utf-8 -*-
import os
import sys
import argparse
import numpy as np
import h5py
import itertools
from collections import defaultdict
import matplotlib.pyplot as plt
import time
def main(arguments):
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--srcfile', help="path to log of training, ", required=True)
parser.add_argument('--savefile', help="path to save file, ", required=True)
args = parser.parse_args(arguments)
start = 0
rep = 0
model_ppl = []
Joint_model_ppl = []
model_ppl_en_en = []
Joint_model_ppl_de_de = []
x = []
with open(args.srcfile) as f:
content = f.readlines()
for i, sent in enumerate(content):
if (start == 0 and sent == 'Normal flow result:\t\n'):
start = i
rep += 1
continue
if sent.split()[0] == 'Train':
break
if start != 0:
if rep == 1:
print sent
model_ppl.append(float(sent.split()[10][:-1]))
elif rep == 2:
Joint_model_ppl.append(float(sent.split()[13][:-1]))
elif rep == 4:
model_ppl_en_en.append(float(sent.split()[10][:-1]))
elif rep == 5:
Joint_model_ppl_de_de.append(float(sent.split()[13][:-1]))
rep += 1
if rep == 6:
rep = 0
x.append(float(sent.split()[4][:-7]))
f = h5py.File(args.savefile,'w')
f["ppl"] = model_ppl
f["joint_ppl"] = Joint_model_ppl
f["ppl_en_en"] = model_ppl_en_en
f["ppl_de_de"] = Joint_model_ppl_de_de
f.close()
# plt.subplots_adjust(hspace=0.4)
# plt.subplot(221)
# plt.plot(x,model_ppl)
# plt.ylabel('ppl')
# plt.yscale('log')
# plt.subplot(222)
# plt.plot(x,Joint_model_ppl)
# plt.ylabel('ppl')
# plt.yscale('log')
# plt.subplot(223)
# plt.plot(x,model_ppl_en_en)
# plt.ylabel('ppl')
# plt.yscale('log')
# plt.subplot(224)
# plt.plot(x,Joint_model_ppl_de_de)
# plt.ylabel('ppl')
# plt.yscale('log')
# plt.show()
# plt
# np.savez(args.savefile, model_ppl, Joint_model_ppl, model_ppl_en_en, Joint_model_ppl_de_de)
if __name__ == '__main__':
main(sys.argv[1:])