-
Notifications
You must be signed in to change notification settings - Fork 285
/
utils.py
101 lines (87 loc) · 2.75 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import sys
import json
import time
from functools import partial
import numpy as np
# import tensorflow as tf
# from tensorflow.python.framework import function
from tqdm import tqdm
def encode_dataset(*splits, encoder):
encoded_splits = []
for split in splits:
fields = []
for field in split:
if isinstance(field[0], str):
field = encoder.encode(field)
fields.append(field)
encoded_splits.append(fields)
return encoded_splits
def stsb_label_encoding(labels, nclass=6):
"""
Label encoding from Tree LSTM paper (Tai, Socher, Manning)
"""
Y = np.zeros((len(labels), nclass)).astype(np.float32)
for j, y in enumerate(labels):
for i in range(nclass):
if i == np.floor(y) + 1:
Y[j,i] = y - np.floor(y)
if i == np.floor(y):
Y[j,i] = np.floor(y) - y + 1
return Y
def np_softmax(x, t=1):
x = x/t
x = x - np.max(x, axis=-1, keepdims=True)
ex = np.exp(x)
return ex/np.sum(ex, axis=-1, keepdims=True)
def make_path(f):
d = os.path.dirname(f)
if d and not os.path.exists(d):
os.makedirs(d)
return f
def _identity_init(shape, dtype, partition_info, scale):
n = shape[-1]
w = np.eye(n)*scale
if len([s for s in shape if s != 1]) == 2:
w = w.reshape(shape)
return w.astype(np.float32)
def identity_init(scale=1.0):
return partial(_identity_init, scale=scale)
def _np_init(shape, dtype, partition_info, w):
return w
def np_init(w):
return partial(_np_init, w=w)
class ResultLogger(object):
def __init__(self, path, *args, **kwargs):
if 'time' not in kwargs:
kwargs['time'] = time.time()
self.f_log = open(make_path(path), 'w')
self.f_log.write(json.dumps(kwargs)+'\n')
def log(self, **kwargs):
if 'time' not in kwargs:
kwargs['time'] = time.time()
self.f_log.write(json.dumps(kwargs)+'\n')
self.f_log.flush()
def close(self):
self.f_log.close()
def flatten(outer):
return [el for inner in outer for el in inner]
def remove_none(l):
return [e for e in l if e is not None]
def iter_data(*datas, n_batch=128, truncate=False, verbose=False, max_batches=float("inf")):
n = len(datas[0])
if truncate:
n = (n//n_batch)*n_batch
n = min(n, max_batches*n_batch)
n_batches = 0
if verbose:
f = sys.stderr
else:
f = open(os.devnull, 'w')
for i in tqdm(range(0, n, n_batch), total=n//n_batch, file=f, ncols=80, leave=False):
if n_batches >= max_batches: raise StopIteration
if len(datas) == 1:
yield datas[0][i:i+n_batch]
else:
yield (d[i:i+n_batch] for d in datas)
n_batches += 1