Skip to content

Commit

Permalink
[BugFix] Better return_log_prob=True for tensordict outputs
Browse files Browse the repository at this point in the history
ghstack-source-id: 977af3880f39cb341c1c715f1b8c9d59b7c580a0
Pull Request resolved: #1155
  • Loading branch information
vmoens committed Dec 20, 2024
1 parent 863ba10 commit b229c59
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,12 @@ def forward(
kwargs = {"aggregate_probabilities": False}
log_prob = dist.log_prob(out_tensors, **kwargs)
if log_prob is not out_tensors:
# Composite dists return the tensordict_out directly when aggrgate_prob is False
out_tensors.set(self.log_prob_key, log_prob)
else:
if is_tensor_collection(log_prob):
out_tensors.update(log_prob)
else:
# Composite dists return the tensordict_out directly when aggrgate_prob is False
out_tensors.set(self.log_prob_key, log_prob)
elif dist.log_prob_key in out_tensors:
out_tensors.rename_key_(dist.log_prob_key, self.log_prob_key)
tensordict_out.update(out_tensors)
else:
Expand Down

0 comments on commit b229c59

Please sign in to comment.