-
Notifications
You must be signed in to change notification settings - Fork 5
/
critic.py
146 lines (110 loc) · 6.55 KB
/
critic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import tensorflow as tf
import tflearn
UNITS = 128
MAX_STEPS = 50
class Critic:
def __init__(self, session, dim_state, dim_goal, dim_action, dim_env, env, tau, learning_rate, num_actor_vars, writer):
self._sess = session
self._dim_state = dim_state
self._dim_action = dim_action
self._dim_env = dim_env
self._dim_goal = dim_goal
self._action_bound = env.action_space.high
self._learning_rate = learning_rate
self._tau = tau
self._sum_writer = writer
self._net_inputs, self._net_out = self.create_network()
self._net_input_env, self._net_input_goal, self._net_input_action, self._net_input_state, self._net_input_history = self._net_inputs
self._network_params = tf.trainable_variables()[num_actor_vars:]
self._target_inputs, self._target_out = self.create_network()
self._target_input_env, self._target_input_goal, self._target_input_action, self._target_input_state, self._target_input_history = self._target_inputs
self._target_network_params = tf.trainable_variables()[(len(self._network_params) + num_actor_vars):]
# op for initializing the target network with online network weight
self._initialize_target_network_params = \
[self._target_network_params[i].assign(
self._network_params[i]) for i in range(len(self._target_network_params))]
# Op for periodically updating target network with online network
# weights with regularization
self._update_target_network_params = \
[self._target_network_params[i].assign(tf.multiply(self._network_params[i], self._tau) \
+ tf.multiply(self._target_network_params[i], 1. - self._tau))
for i in range(len(self._target_network_params))]
# Network target (y_i)
self._predicted_q_value = tf.placeholder(tf.float32, [None, 1])
# Define loss and optimization Op
self._loss = tflearn.mean_square(self._predicted_q_value, self._net_out)
self._grad = tf.gradients(self._loss, self._network_params)
self._optimize = tf.train.AdamOptimizer(self._learning_rate).apply_gradients(zip(self._grad, self._network_params))
#self._optimize = tf.train.AdamOptimizer(
# self._learning_rate).minimize(self._loss)
self._loss_summary = tf.summary.scalar('loss', self._loss)
self._gradients_summaries = [tf.summary.scalar("gradient_{}".format(i), grad) for i, grad in zip(range(1000), [tf.norm(gr) for gr in self._grad])]
self._merged_gradients = tf.summary.merge(self._gradients_summaries)
# Get the gradient of the net w.r.t. the action.
# For each action in the minibatch (i.e., for each x in xs),
# this will sum up the gradients of each critic output in the minibatch
# w.r.t. that action. Each output is independent of all
# actions except for one.
self._action_grads = tf.gradients(self._net_out, self._net_input_action)
def create_network(self):
input_state = tflearn.input_data(shape=[None, self._dim_state])
input_goal = tflearn.input_data(shape=[None, self._dim_goal])
input_action = tflearn.input_data(shape=[None, self._dim_action])
input_env = tflearn.input_data(shape=[None, self._dim_env])
input_history = tflearn.input_data(shape=[None, MAX_STEPS, self._dim_action + self._dim_state])
input_ff = tflearn.merge(
[input_env, input_goal, input_action, input_state], 'concat')
ff_branch = tflearn.fully_connected(input_ff, UNITS)
ff_branch = tflearn.activations.relu(ff_branch)
#recurrent_branch = tflearn.fully_connected(inputs, UNITS)
#recurrent_branch = tflearn.activations.relu(recurrent_branch)
recurrent_branch = tflearn.lstm(input_history, UNITS, dynamic=True)
merged_branch = tflearn.merge([ff_branch, recurrent_branch], 'concat')
merged_branch = tflearn.fully_connected(merged_branch, UNITS)
merged_branch = tflearn.activations.relu(merged_branch)
merged_branch = tflearn.fully_connected(merged_branch, UNITS)
merged_branch = tflearn.activations.relu(merged_branch)
weights_init = tflearn.initializations.uniform(minval=-0.003, maxval=0.003)
out = tflearn.fully_connected(
merged_branch, 1, activation='linear', weights_init=weights_init)
return [input_env, input_goal, input_action, input_state, input_history], out
def train(self, input_env, input_state, input_goal, input_action, input_history, predicted_q_value):
net_out, optimize, loss_summary, gradients_summaries = self._sess.run([self._net_out, self._optimize, self._loss_summary, self._merged_gradients], feed_dict={
self._net_input_env: input_env,
self._net_input_state: input_state,
self._net_input_goal: input_goal,
self._net_input_action: input_action,
self._net_input_history: input_history,
self._predicted_q_value: predicted_q_value
})
self._sum_writer.add_summary(loss_summary)
self._sum_writer.add_summary(gradients_summaries)
return net_out, optimize
def predict(self, input_env, input_state, input_goal, input_action, input_history):
return self._sess.run(self._net_out, feed_dict={
self._net_input_env: input_env,
self._net_input_state: input_state,
self._net_input_goal: input_goal,
self._net_input_action: input_action,
self._net_input_history: input_history,
})
def predict_target(self, input_env, input_state, input_goal, input_action, input_history):
return self._sess.run(self._target_out, feed_dict={
self._target_input_env: input_env,
self._target_input_state: input_state,
self._target_input_goal: input_goal,
self._target_input_action: input_action,
self._target_input_history: input_history,
})
def action_gradients(self, input_env, input_state, input_goal, input_action, input_history):
return self._sess.run(self._action_grads, feed_dict={
self._net_input_env: input_env,
self._net_input_state: input_state,
self._net_input_goal: input_goal,
self._net_input_action: input_action,
self._net_input_history: input_history
})
def update_target_network(self):
self._sess.run(self._update_target_network_params)
def initialize_target_network(self):
self._sess.run(self._initialize_target_network_params)