Skip to content

Commit

Permalink
Apply imput preprocessor once
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jul 21, 2024
1 parent 4a89f80 commit 89f6bc4
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions skrl/agents/torch/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,20 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens
:return: Actions
:rtype: torch.Tensor
"""
inputs = {"states": self._state_preprocessor(states)}
# sample random actions
# TODO, check for stochasticity
if timestep < self._random_timesteps:
return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy")
return self.policy.random_act(inputs, role="policy")

# sample stochastic actions
actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy")
actions, log_prob, outputs = self.policy.act(inputs, role="policy")
self._current_log_prob = log_prob

# compute values
if self.value is not None and self.memory is not None:
self._current_values, _, _ = self.value.act(inputs, role="value")

return actions, log_prob, outputs

def record_transition(self,
Expand Down Expand Up @@ -264,9 +269,8 @@ def record_transition(self,
if self._rewards_shaper is not None:
rewards = self._rewards_shaper(rewards, timestep, timesteps)

# compute values
values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value")
values = self._value_preprocessor(values, inverse=True)
# apply value preprocessor
values = self._value_preprocessor(self._current_values, inverse=True)

# time-limit (truncation) boostrapping
if self._time_limit_bootstrap:
Expand Down

0 comments on commit 89f6bc4

Please sign in to comment.