From 71549e36e890d43810b21debf41cead0e129afac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 13 Jan 2025 18:27:47 -0500 Subject: [PATCH] Allow to evalaute using stochastic actions --- skrl/trainers/torch/base.py | 11 +++++++++-- skrl/trainers/torch/parallel.py | 16 +++++++++++++--- skrl/trainers/torch/sequential.py | 11 ++++++++--- skrl/trainers/torch/step.py | 11 ++++++++--- 4 files changed, 38 insertions(+), 11 deletions(-) diff --git a/skrl/trainers/torch/base.py b/skrl/trainers/torch/base.py index 16b61161..6c7ef417 100644 --- a/skrl/trainers/torch/base.py +++ b/skrl/trainers/torch/base.py @@ -64,6 +64,7 @@ def __init__( self.disable_progressbar = self.cfg.get("disable_progressbar", False) self.close_environment_at_exit = self.cfg.get("close_environment_at_exit", True) self.environment_info = self.cfg.get("environment_info", "episode") + self.stochastic_evaluation = self.cfg.get("stochastic_evaluation", False) self.initial_timestep = 0 @@ -255,7 +256,8 @@ def single_agent_eval(self) -> None: with torch.no_grad(): # compute actions - actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] + outputs = self.agents.act(states, timestep=timestep, timesteps=self.timesteps) + actions = outputs[0] if self.stochastic_evaluation else outputs[-1].get("mean_actions", outputs[0]) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) @@ -394,7 +396,12 @@ def multi_agent_eval(self) -> None: with torch.no_grad(): # compute actions - actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] + outputs = self.agents.act(states, timestep=timestep, timesteps=self.timesteps) + actions = ( + outputs[0] + if self.stochastic_evaluation + else {k: outputs[-1][k].get("mean_actions", outputs[0][k]) for k in outputs[-1]} + ) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) diff --git a/skrl/trainers/torch/parallel.py b/skrl/trainers/torch/parallel.py index b409ae7f..c0a6f6a5 100644 --- a/skrl/trainers/torch/parallel.py +++ b/skrl/trainers/torch/parallel.py @@ -19,7 +19,8 @@ "headless": False, # whether to use headless mode (no rendering) "disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY "close_environment_at_exit": True, # whether to close the environment on normal program termination - "environment_info": "episode", # key used to get and log environment info + "environment_info": "episode", # key used to get and log environment info + "stochastic_evaluation": False, # whether to use actions rather than (deterministic) mean actions during evaluation } # [end-config-dict-torch] # fmt: on @@ -65,7 +66,9 @@ def fn_processor(process_index, *args): elif task == "act": _states = queue.get()[scope[0] : scope[1]] with torch.no_grad(): - _actions = agent.act(_states, timestep=msg["timestep"], timesteps=msg["timesteps"])[0] + stochastic_evaluation = msg["stochastic_evaluation"] + _outputs = agent.act(_states, timestep=msg["timestep"], timesteps=msg["timesteps"]) + _actions = _outputs[0] if stochastic_evaluation else _outputs[-1].get("mean_actions", _outputs[0]) if not _actions.is_cuda: _actions.share_memory_() queue.put(_actions) @@ -363,7 +366,14 @@ def eval(self) -> None: # compute actions with torch.no_grad(): for pipe, queue in zip(producer_pipes, queues): - pipe.send({"task": "act", "timestep": timestep, "timesteps": self.timesteps}) + pipe.send( + { + "task": "act", + "timestep": timestep, + "timesteps": self.timesteps, + "stochastic_evaluation": self.stochastic_evaluation, + } + ) queue.put(states) barrier.wait() diff --git a/skrl/trainers/torch/sequential.py b/skrl/trainers/torch/sequential.py index 8304b1fd..913bf3cc 100644 --- a/skrl/trainers/torch/sequential.py +++ b/skrl/trainers/torch/sequential.py @@ -18,7 +18,8 @@ "headless": False, # whether to use headless mode (no rendering) "disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY "close_environment_at_exit": True, # whether to close the environment on normal program termination - "environment_info": "episode", # key used to get and log environment info + "environment_info": "episode", # key used to get and log environment info + "stochastic_evaluation": False, # whether to use actions rather than (deterministic) mean actions during evaluation } # [end-config-dict-torch] # fmt: on @@ -187,10 +188,14 @@ def eval(self) -> None: with torch.no_grad(): # compute actions + outputs = [ + agent.act(states[scope[0] : scope[1]], timestep=timestep, timesteps=self.timesteps) + for agent, scope in zip(self.agents, self.agents_scope) + ] actions = torch.vstack( [ - agent.act(states[scope[0] : scope[1]], timestep=timestep, timesteps=self.timesteps)[0] - for agent, scope in zip(self.agents, self.agents_scope) + output[0] if self.stochastic_evaluation else output[-1].get("mean_actions", output[0]) + for output in outputs ] ) diff --git a/skrl/trainers/torch/step.py b/skrl/trainers/torch/step.py index 77987598..2744be10 100644 --- a/skrl/trainers/torch/step.py +++ b/skrl/trainers/torch/step.py @@ -18,7 +18,8 @@ "headless": False, # whether to use headless mode (no rendering) "disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY "close_environment_at_exit": True, # whether to close the environment on normal program termination - "environment_info": "episode", # key used to get and log environment info + "environment_info": "episode", # key used to get and log environment info + "stochastic_evaluation": False, # whether to use actions rather than (deterministic) mean actions during evaluation } # [end-config-dict-torch] # fmt: on @@ -212,10 +213,14 @@ def eval( with torch.no_grad(): # compute actions + outputs = [ + agent.act(self.states[scope[0] : scope[1]], timestep=timestep, timesteps=timesteps) + for agent, scope in zip(self.agents, self.agents_scope) + ] actions = torch.vstack( [ - agent.act(self.states[scope[0] : scope[1]], timestep=timestep, timesteps=timesteps)[0] - for agent, scope in zip(self.agents, self.agents_scope) + output[0] if self.stochastic_evaluation else output[-1].get("mean_actions", output[0]) + for output in outputs ] )