Skip to content

Commit

Permalink
action boundaries fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Rufaim committed Apr 28, 2020
1 parent 7bcbee7 commit ea559a0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
3 changes: 2 additions & 1 deletion pendulum_problem/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,5 @@ def _update_target_networks(self):
v2.assign(self.actor_net.tau * v1 + (1 - self.actor_net.tau) * v2)

for v2, v1 in zip(self.target_critic_net.trainable_variables, self.critic_net.trainable_variables):
v2.assign(self.critic_net.tau * v1 + (1 - self.critic_net.tau) * v2)
v2.assign(self.critic_net.tau * v1 + (1 - self.critic_net.tau) * v2)

8 changes: 5 additions & 3 deletions pendulum_problem/neural_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ class ActorNet(tf.keras.Model):
def __init__(self, net_structure, action_bounds, tau, learning_rate):
super(ActorNet, self).__init__()
self.net_structure = net_structure
self.action_bounds = tf.constant(action_bounds, shape=[1, len(action_bounds)], dtype=tf.float32)
action_bounds = tf.convert_to_tensor(action_bounds, dtype=tf.float32)
self.action_bounds = tf.constant(action_bounds, shape=[1, action_bounds.shape[-1]], dtype=tf.float32)
self.tau = tf.constant(tau, dtype=tf.float32)
self.learning_rate = learning_rate

Expand All @@ -27,7 +28,7 @@ def call(self, input, training=None, mask=None):

def clone(self):
structure = clone_net_structure(self.net_structure)
return ActorNet(structure, self.action_bounds, self.tau, self.learning_rate)
return ActorNet(structure, self.action_bounds[0], self.tau, self.learning_rate)


class CriticNet(tf.keras.Model):
Expand Down Expand Up @@ -57,4 +58,5 @@ def call(self, input, action, training=None, mask=None):

def clone(self):
structure = clone_net_structure(self.net_structure)
return CriticNet(structure, self.tau, self.learning_rate, self.grad_norm)
return CriticNet(structure, self.tau, self.learning_rate, self.grad_norm)

0 comments on commit ea559a0

Please sign in to comment.