Skip to content

Commit

Permalink
Merge pull request #195 from prabhatnagarajan/loss_storage
Browse files Browse the repository at this point in the history
Modifies loss storage in DDPG, TD3, and Soft Actor Critic
  • Loading branch information
muupan committed Aug 4, 2024
2 parents 580ac4f + 41c0e92 commit c8cb332
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions pfrl/agents/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def compute_critic_loss(self, batch):
loss = F.mse_loss(target_q, predict_q)

# Update stats
self.critic_loss_record.append(float(loss.detach().cpu().numpy()))
self.critic_loss_record.append(loss.item())

return loss

Expand All @@ -182,7 +182,7 @@ def compute_actor_loss(self, batch):

# Update stats
self.q_record.extend(q.detach().cpu().numpy())
self.actor_loss_record.append(float(loss.detach().cpu().numpy()))
self.actor_loss_record.append(loss.item())

return loss

Expand Down
4 changes: 2 additions & 2 deletions pfrl/agents/soft_actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ def update_q_func(self, batch):
# Update stats
self.q1_record.extend(predict_q1.detach().cpu().numpy())
self.q2_record.extend(predict_q2.detach().cpu().numpy())
self.q_func1_loss_record.append(float(loss1))
self.q_func2_loss_record.append(float(loss2))
self.q_func1_loss_record.append(loss1.item())
self.q_func2_loss_record.append(loss2.item())

self.q_func1_optimizer.zero_grad()
loss1.backward()
Expand Down
6 changes: 3 additions & 3 deletions pfrl/agents/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def update_q_func(self, batch):
# Update stats
self.q1_record.extend(predict_q1.detach().cpu().numpy())
self.q2_record.extend(predict_q2.detach().cpu().numpy())
self.q_func1_loss_record.append(float(loss1))
self.q_func2_loss_record.append(float(loss2))
self.q_func1_loss_record.append(loss1.item())
self.q_func2_loss_record.append(loss2.item())

self.q_func1_optimizer.zero_grad()
loss1.backward()
Expand All @@ -241,7 +241,7 @@ def update_policy(self, batch):
# Since we want to maximize Q, loss is negation of Q
loss = -torch.mean(q)

self.policy_loss_record.append(float(loss))
self.policy_loss_record.append(loss.item())
self.policy_optimizer.zero_grad()
loss.backward()
if self.max_grad_norm is not None:
Expand Down

0 comments on commit c8cb332

Please sign in to comment.