From ea559a02db51e2ce2efff8a2ce6834427de0d651 Mon Sep 17 00:00:00 2001 From: Pavel Rumiantsev Date: Tue, 28 Apr 2020 17:39:52 -0400 Subject: [PATCH] action boundaries fix --- pendulum_problem/ddpg.py | 3 ++- pendulum_problem/neural_nets.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pendulum_problem/ddpg.py b/pendulum_problem/ddpg.py index 5efd48c..a59c3b2 100644 --- a/pendulum_problem/ddpg.py +++ b/pendulum_problem/ddpg.py @@ -97,4 +97,5 @@ def _update_target_networks(self): v2.assign(self.actor_net.tau * v1 + (1 - self.actor_net.tau) * v2) for v2, v1 in zip(self.target_critic_net.trainable_variables, self.critic_net.trainable_variables): - v2.assign(self.critic_net.tau * v1 + (1 - self.critic_net.tau) * v2) \ No newline at end of file + v2.assign(self.critic_net.tau * v1 + (1 - self.critic_net.tau) * v2) + diff --git a/pendulum_problem/neural_nets.py b/pendulum_problem/neural_nets.py index 4985882..6499c06 100644 --- a/pendulum_problem/neural_nets.py +++ b/pendulum_problem/neural_nets.py @@ -8,7 +8,8 @@ class ActorNet(tf.keras.Model): def __init__(self, net_structure, action_bounds, tau, learning_rate): super(ActorNet, self).__init__() self.net_structure = net_structure - self.action_bounds = tf.constant(action_bounds, shape=[1, len(action_bounds)], dtype=tf.float32) + action_bounds = tf.convert_to_tensor(action_bounds, dtype=tf.float32) + self.action_bounds = tf.constant(action_bounds, shape=[1, action_bounds.shape[-1]], dtype=tf.float32) self.tau = tf.constant(tau, dtype=tf.float32) self.learning_rate = learning_rate @@ -27,7 +28,7 @@ def call(self, input, training=None, mask=None): def clone(self): structure = clone_net_structure(self.net_structure) - return ActorNet(structure, self.action_bounds, self.tau, self.learning_rate) + return ActorNet(structure, self.action_bounds[0], self.tau, self.learning_rate) class CriticNet(tf.keras.Model): @@ -57,4 +58,5 @@ def call(self, input, action, training=None, mask=None): def clone(self): structure = clone_net_structure(self.net_structure) - return CriticNet(structure, self.tau, self.learning_rate, self.grad_norm) \ No newline at end of file + return CriticNet(structure, self.tau, self.learning_rate, self.grad_norm) +