Skip to content

Commit

Permalink
Allow to evalaute using stochastic actions
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jan 13, 2025
1 parent e5c6b81 commit 71549e3
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 11 deletions.
11 changes: 9 additions & 2 deletions skrl/trainers/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
self.disable_progressbar = self.cfg.get("disable_progressbar", False)
self.close_environment_at_exit = self.cfg.get("close_environment_at_exit", True)
self.environment_info = self.cfg.get("environment_info", "episode")
self.stochastic_evaluation = self.cfg.get("stochastic_evaluation", False)

self.initial_timestep = 0

Expand Down Expand Up @@ -255,7 +256,8 @@ def single_agent_eval(self) -> None:

with torch.no_grad():
# compute actions
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]
outputs = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)
actions = outputs[0] if self.stochastic_evaluation else outputs[-1].get("mean_actions", outputs[0])

# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
Expand Down Expand Up @@ -394,7 +396,12 @@ def multi_agent_eval(self) -> None:

with torch.no_grad():
# compute actions
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]
outputs = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)
actions = (
outputs[0]
if self.stochastic_evaluation
else {k: outputs[-1][k].get("mean_actions", outputs[0][k]) for k in outputs[-1]}
)

# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
Expand Down
16 changes: 13 additions & 3 deletions skrl/trainers/torch/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
"environment_info": "episode", # key used to get and log environment info
"environment_info": "episode", # key used to get and log environment info
"stochastic_evaluation": False, # whether to use actions rather than (deterministic) mean actions during evaluation
}
# [end-config-dict-torch]
# fmt: on
Expand Down Expand Up @@ -65,7 +66,9 @@ def fn_processor(process_index, *args):
elif task == "act":
_states = queue.get()[scope[0] : scope[1]]
with torch.no_grad():
_actions = agent.act(_states, timestep=msg["timestep"], timesteps=msg["timesteps"])[0]
stochastic_evaluation = msg["stochastic_evaluation"]
_outputs = agent.act(_states, timestep=msg["timestep"], timesteps=msg["timesteps"])
_actions = _outputs[0] if stochastic_evaluation else _outputs[-1].get("mean_actions", _outputs[0])
if not _actions.is_cuda:
_actions.share_memory_()
queue.put(_actions)
Expand Down Expand Up @@ -363,7 +366,14 @@ def eval(self) -> None:
# compute actions
with torch.no_grad():
for pipe, queue in zip(producer_pipes, queues):
pipe.send({"task": "act", "timestep": timestep, "timesteps": self.timesteps})
pipe.send(
{
"task": "act",
"timestep": timestep,
"timesteps": self.timesteps,
"stochastic_evaluation": self.stochastic_evaluation,
}
)
queue.put(states)

barrier.wait()
Expand Down
11 changes: 8 additions & 3 deletions skrl/trainers/torch/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
"environment_info": "episode", # key used to get and log environment info
"environment_info": "episode", # key used to get and log environment info
"stochastic_evaluation": False, # whether to use actions rather than (deterministic) mean actions during evaluation
}
# [end-config-dict-torch]
# fmt: on
Expand Down Expand Up @@ -187,10 +188,14 @@ def eval(self) -> None:

with torch.no_grad():
# compute actions
outputs = [
agent.act(states[scope[0] : scope[1]], timestep=timestep, timesteps=self.timesteps)
for agent, scope in zip(self.agents, self.agents_scope)
]
actions = torch.vstack(
[
agent.act(states[scope[0] : scope[1]], timestep=timestep, timesteps=self.timesteps)[0]
for agent, scope in zip(self.agents, self.agents_scope)
output[0] if self.stochastic_evaluation else output[-1].get("mean_actions", output[0])
for output in outputs
]
)

Expand Down
11 changes: 8 additions & 3 deletions skrl/trainers/torch/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
"environment_info": "episode", # key used to get and log environment info
"environment_info": "episode", # key used to get and log environment info
"stochastic_evaluation": False, # whether to use actions rather than (deterministic) mean actions during evaluation
}
# [end-config-dict-torch]
# fmt: on
Expand Down Expand Up @@ -212,10 +213,14 @@ def eval(

with torch.no_grad():
# compute actions
outputs = [
agent.act(self.states[scope[0] : scope[1]], timestep=timestep, timesteps=timesteps)
for agent, scope in zip(self.agents, self.agents_scope)
]
actions = torch.vstack(
[
agent.act(self.states[scope[0] : scope[1]], timestep=timestep, timesteps=timesteps)[0]
for agent, scope in zip(self.agents, self.agents_scope)
output[0] if self.stochastic_evaluation else output[-1].get("mean_actions", output[0])
for output in outputs
]
)

Expand Down

0 comments on commit 71549e3

Please sign in to comment.