forked from chl8856/AC_TPC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils_network.py
executable file
·133 lines (114 loc) · 5.47 KB
/
utils_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import tensorflow as tf
import numpy as np
from tensorflow.contrib.layers import fully_connected as FC_Net
### CONSTRUCT MULTICELL FOR MULTI-LAYER RNNS
def create_rnn_cell(num_units, num_layers, keep_prob, RNN_type, activation_fn):
'''
GOAL : create multi-cell (including a single cell) to construct multi-layer RNN
num_units : number of units in each layer
num_layers : number of layers in MulticellRNN
keep_prob : keep probabilty [0, 1] (if None, dropout is not employed)
RNN_type : either 'LSTM' or 'GRU'
'''
if activation_fn == 'None':
activation_fn = tf.nn.tanh
cells = []
for _ in range(num_layers):
if RNN_type == 'GRU':
cell = tf.contrib.rnn.GRUCell(num_units, activation=activation_fn)
elif RNN_type == 'LSTM':
cell = tf.contrib.rnn.LSTMCell(num_units, activation=activation_fn, state_is_tuple=True)
# cell = tf.contrib.rnn.LSTMCell(num_units, activation=activation_fn)
else:
print('ERROR: WRONG RNN TYPE')
if not keep_prob is None:
cell = tf.contrib.rnn.DropoutWrapper(cell, input_keep_prob=keep_prob, output_keep_prob=keep_prob) # state_keep_prob=keep_prob
cells.append(cell)
cell = tf.contrib.rnn.MultiRNNCell(cells)
return cell
### EXTRACT STATE OUTPUT OF MULTICELL-RNNS
def create_concat_state_h(state, num_layers, RNN_type, BiRNN=None):
'''
GOAL : concatenate the tuple-type tensor (state) into a single tensor
state : input state is a tuple ofo MulticellRNN (i.e. output of MulticellRNN)
consist of only hidden states h for GRU and hidden states c and h for LSTM
num_layers : number of layers in MulticellRNN
RNN_type : either 'LSTM' or 'GRU'
'''
for i in range(num_layers):
if BiRNN != None:
if RNN_type == 'LSTM':
tmp = tf.concat([state[0][i][1], state[1][i][1]], axis=1) ## i-th layer, h state for LSTM
elif RNN_type == 'GRU':
tmp = tf.concat([state[0][i], state[1][i]], axis=1) ## i-th layer, h state for GRU
else:
print('ERROR: WRONG RNN CELL TYPE')
else:
if RNN_type == 'LSTM':
tmp = state[i][1] ## i-th layer, h state for LSTM
elif RNN_type == 'GRU':
tmp = state[i] ## i-th layer, h state for GRU
else:
print('ERROR: WRONG RNN CELL TYPE')
if i == 0:
rnn_state_out = tmp
else:
rnn_state_out = tf.concat([rnn_state_out, tmp], axis = 1)
return rnn_state_out
def create_concat_state_c(state, num_layers, RNN_type, BiRNN=None):
for i in range(num_layers):
if BiRNN != None:
if RNN_type == 'LSTM':
tmp = tf.concat([state[0][i][0], state[1][i][0]], axis=1) ## i-th layer, c state for LSTM
elif RNN_type == 'GRU':
tmp = tf.concat([state[0][i], state[1][i]], axis=1) ## i-th layer, c=h state for GRU
else:
print('ERROR: WRONG RNN CELL TYPE')
else:
if RNN_type == 'LSTM':
tmp = state[i][0] ## i-th layer, c state for LSTM
elif RNN_type == 'GRU':
tmp = state[i] ## i-th layer, h state for GRU
else:
print('ERROR: WRONG RNN CELL TYPE')
if i == 0:
rnn_state_out = tmp
else:
rnn_state_out = tf.concat([rnn_state_out, tmp], axis = 1)
return rnn_state_out
### FEEDFORWARD NETWORK
def create_FCNet(inputs, num_layers, h_dim, h_fn, o_dim, o_fn, w_init, w_reg=None, keep_prob=1.0):
'''
GOAL : Create FC network with different specifications
inputs (tensor) : input tensor
num_layers : number of layers in FCNet
h_dim (int) : number of hidden units
h_fn : activation function for hidden layers (default: tf.nn.relu)
o_dim (int) : number of output units
o_fn : activation function for output layers (defalut: None)
w_init : initialization for weight matrix (defalut: Xavier)
keep_prob : keep probabilty [0, 1] (if None, dropout is not employed)
'''
# default active functions (hidden: relu, out: None)
if h_fn is None:
h_fn = tf.nn.relu
if o_fn is None:
o_fn = None
# default initialization functions (weight: Xavier, bias: None)
if w_init is None:
w_init = tf.contrib.layers.xavier_initializer() # Xavier initialization
for layer in range(num_layers):
if num_layers == 1:
out = FC_Net(inputs, o_dim, activation_fn=o_fn, weights_initializer=w_init, weights_regularizer=w_reg)
else:
if layer == 0:
h = FC_Net(inputs, h_dim, activation_fn=h_fn, weights_initializer=w_init, weights_regularizer=w_reg)
if not keep_prob is None:
h = tf.nn.dropout(h, keep_prob=keep_prob)
elif layer > 0 and layer != (num_layers-1): # layer > 0:
h = FC_Net(h, h_dim, activation_fn=h_fn, weights_initializer=w_init, weights_regularizer=w_reg)
if not keep_prob is None:
h = tf.nn.dropout(h, keep_prob=keep_prob)
else: # layer == num_layers-1 (the last layer)
out = FC_Net(h, o_dim, activation_fn=o_fn, weights_initializer=w_init, weights_regularizer=w_reg)
return out