Skip to content

Commit 852663b

Browse files
authored
fix loss function
1 parent aef028b commit 852663b

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

learner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def learner(model, data, ps, args):
3030
optimizer.zero_grad()
3131
logits, values = model(obs, actions, rewards, dones, hx=hx)
3232
bootstrap_value = values[-1]
33+
actions, behaviour_logits, dones, rewards = actions[1:], behaviour_logits[1:], dones[1:], rewards[1:]
34+
logits, values = logits[:-1], values[:-1]
3335
discounts = (~dones).float() * gamma
3436
vs, pg_advantages = vtrace.from_logits(
3537
behaviour_policy_logits=behaviour_logits,

0 commit comments

Comments
 (0)