From a43b4f8ecec4b90fea558789bf61479471b69e02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Wed, 15 Jan 2025 11:11:13 -0500 Subject: [PATCH] Update runner to support the definition of any agent and its models in JAX --- skrl/utils/runner/jax/runner.py | 282 +++++++++++++++----------------- 1 file changed, 134 insertions(+), 148 deletions(-) diff --git a/skrl/utils/runner/jax/runner.py b/skrl/utils/runner/jax/runner.py index 46fa429e..872e571d 100644 --- a/skrl/utils/runner/jax/runner.py +++ b/skrl/utils/runner/jax/runner.py @@ -4,17 +4,12 @@ from skrl import logger from skrl.agents.jax import Agent -from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG from skrl.envs.wrappers.jax import MultiAgentEnvWrapper, Wrapper -from skrl.memories.jax import RandomMemory from skrl.models.jax import Model -from skrl.multi_agents.jax.ippo import IPPO, IPPO_DEFAULT_CONFIG -from skrl.multi_agents.jax.mappo import MAPPO, MAPPO_DEFAULT_CONFIG from skrl.resources.preprocessors.jax import RunningStandardScaler # noqa from skrl.resources.schedulers.jax import KLAdaptiveLR # noqa -from skrl.trainers.jax import SequentialTrainer, Trainer +from skrl.trainers.jax import Trainer from skrl.utils import set_seed -from skrl.utils.model_instantiators.jax import categorical_model, deterministic_model, gaussian_model class Runner: @@ -32,22 +27,6 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, # set random seed set_seed(self._cfg.get("seed", None)) - self._class_mapping = { - # model - "gaussianmixin": gaussian_model, - "categoricalmixin": categorical_model, - "deterministicmixin": deterministic_model, - "shared": None, - # memory - "randommemory": RandomMemory, - # agent - "ppo": PPO, - "ippo": IPPO, - "mappo": MAPPO, - # trainer - "sequentialtrainer": SequentialTrainer, - } - self._cfg["agent"]["rewards_shaper"] = None # FIXME: avoid 'dictionary changed size during iteration' self._models = self._generate_models(self._env, copy.deepcopy(self._cfg)) @@ -85,14 +64,76 @@ def load_cfg_from_yaml(path: str) -> dict: logger.error(f"Loading yaml error: {e}") return {} - def _class(self, value: str) -> Type: - """Get skrl component class (e.g.: agent, trainer, etc..) from string identifier + def _component(self, name: str) -> Type: + """Get skrl component (e.g.: agent, trainer, etc..) from string identifier - :return: skrl component class + :return: skrl component """ - if value.lower() in self._class_mapping: - return self._class_mapping[value.lower()] - raise ValueError(f"Unknown class '{value}' in runner cfg") + component = None + name = name.lower() + # model + if name == "gaussianmixin": + from skrl.utils.model_instantiators.jax import gaussian_model as component + elif name == "categoricalmixin": + from skrl.utils.model_instantiators.jax import categorical_model as component + elif name == "deterministicmixin": + from skrl.utils.model_instantiators.jax import deterministic_model as component + # memory + elif name == "randommemory": + from skrl.memories.jax import RandomMemory as component + # agent + elif name in ["a2c", "a2c_default_config"]: + from skrl.agents.jax.a2c import A2C, A2C_DEFAULT_CONFIG + + component = A2C_DEFAULT_CONFIG if "default_config" in name else A2C + elif name in ["cem", "cem_default_config"]: + from skrl.agents.jax.cem import CEM, CEM_DEFAULT_CONFIG + + component = CEM_DEFAULT_CONFIG if "default_config" in name else CEM + elif name in ["ddpg", "ddpg_default_config"]: + from skrl.agents.jax.ddpg import DDPG, DDPG_DEFAULT_CONFIG + + component = DDPG_DEFAULT_CONFIG if "default_config" in name else DDPG + elif name in ["ddqn", "ddqn_default_config"]: + from skrl.agents.jax.dqn import DDQN, DDQN_DEFAULT_CONFIG + + component = DDQN_DEFAULT_CONFIG if "default_config" in name else DDQN + elif name in ["dqn", "dqn_default_config"]: + from skrl.agents.jax.dqn import DQN, DQN_DEFAULT_CONFIG + + component = DQN_DEFAULT_CONFIG if "default_config" in name else DQN + elif name in ["ppo", "ppo_default_config"]: + from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG + + component = PPO_DEFAULT_CONFIG if "default_config" in name else PPO + elif name in ["rpo", "rpo_default_config"]: + from skrl.agents.jax.rpo import RPO, RPO_DEFAULT_CONFIG + + component = RPO_DEFAULT_CONFIG if "default_config" in name else RPO + elif name in ["sac", "sac_default_config"]: + from skrl.agents.jax.sac import SAC, SAC_DEFAULT_CONFIG + + component = SAC_DEFAULT_CONFIG if "default_config" in name else SAC + elif name in ["td3", "td3_default_config"]: + from skrl.agents.jax.td3 import TD3, TD3_DEFAULT_CONFIG + + component = TD3_DEFAULT_CONFIG if "default_config" in name else TD3 + # multi-agent + elif name in ["ippo", "ippo_default_config"]: + from skrl.multi_agents.jax.ippo import IPPO, IPPO_DEFAULT_CONFIG + + component = IPPO_DEFAULT_CONFIG if "default_config" in name else IPPO + elif name in ["mappo", "mappo_default_config"]: + from skrl.multi_agents.jax.mappo import MAPPO, MAPPO_DEFAULT_CONFIG + + component = MAPPO_DEFAULT_CONFIG if "default_config" in name else MAPPO + # trainer + elif name == "sequentialtrainer": + from skrl.trainers.jax import SequentialTrainer as component + + if component is None: + raise ValueError(f"Unknown component '{name}' in runner cfg") + return component def _process_cfg(self, cfg: dict) -> dict: """Convert simple types to skrl classes/components @@ -106,6 +147,7 @@ def _process_cfg(self, cfg: dict) -> dict: "shared_state_preprocessor", "state_preprocessor", "value_preprocessor", + "amp_state_preprocessor", ] def reward_shaper_function(scale): @@ -150,119 +192,66 @@ def _generate_models( # override cfg cfg["models"]["separate"] = True # shared model is not supported in JAX - try: - agent_class = self._class(cfg["agent"]["class"]) - del cfg["agent"]["class"] - except KeyError: - agent_class = self._class("PPO") - logger.warning("No 'class' field defined in 'agent' cfg. 'PPO' will be used as default") + agent_class = cfg.get("agent", {}).get("class", "").lower() # instantiate models models = {} for agent_id in possible_agents: _cfg = copy.deepcopy(cfg) models[agent_id] = {} + models_cfg = _cfg.get("models") + if not models_cfg: + raise ValueError("No 'models' are defined in cfg") + # get separate (non-shared) configuration and remove 'separate' key + try: + separate = models_cfg["separate"] + del models_cfg["separate"] + except KeyError: + separate = True + logger.warning("No 'separate' field defined in 'models' cfg. Defining it as True by default") # non-shared models - if _cfg["models"]["separate"]: - # get instantiator function and remove 'class' field - try: - model_class = self._class(_cfg["models"]["policy"]["class"]) - del _cfg["models"]["policy"]["class"] - except KeyError: - model_class = self._class("GaussianMixin") - logger.warning( - "No 'class' field defined in 'models:policy' cfg. 'GaussianMixin' will be used as default" + if separate: + for role in models_cfg: + # get instantiator function and remove 'class' key + model_class = models_cfg[role].get("class") + if not model_class: + raise ValueError(f"No 'class' field defined in 'models:{role}' cfg") + del models_cfg[role]["class"] + model_class = self._component(model_class) + # get specific spaces according to agent/model cfg + observation_space = observation_spaces[agent_id] + if agent_class == "mappo" and role == "value": + observation_space = state_spaces[agent_id] + if agent_class == "amp" and role == "discriminator": + try: + observation_space = env.amp_observation_space + except Exception as e: + logger.warning( + "Unable to get AMP space via 'env.amp_observation_space'. Using 'env.observation_space' instead" + ) + # print model source + source = model_class( + observation_space=observation_space, + action_space=action_spaces[agent_id], + device=device, + **self._process_cfg(models_cfg[role]), + return_source=True, ) - # print model source - source = model_class( - observation_space=observation_spaces[agent_id], - action_space=action_spaces[agent_id], - device=device, - **self._process_cfg(_cfg["models"]["policy"]), - return_source=True, - ) - print("--------------------------------------------------\n") - print(source) - print("--------------------------------------------------") - # instantiate model - models[agent_id]["policy"] = model_class( - observation_space=observation_spaces[agent_id], - action_space=action_spaces[agent_id], - device=device, - **self._process_cfg(_cfg["models"]["policy"]), - ) - # get instantiator function and remove 'class' field - try: - model_class = self._class(_cfg["models"]["value"]["class"]) - del _cfg["models"]["value"]["class"] - except KeyError: - model_class = self._class("DeterministicMixin") - logger.warning( - "No 'class' field defined in 'models:value' cfg. 'DeterministicMixin' will be used as default" + print("==================================================") + print(f"Model (role): {role}") + print("==================================================\n") + print(source) + print("--------------------------------------------------") + # instantiate model + models[agent_id][role] = model_class( + observation_space=observation_space, + action_space=action_spaces[agent_id], + device=device, + **self._process_cfg(models_cfg[role]), ) - # print model source - source = model_class( - observation_space=(state_spaces if agent_class in [MAPPO] else observation_spaces)[agent_id], - action_space=action_spaces[agent_id], - device=device, - **self._process_cfg(_cfg["models"]["value"]), - return_source=True, - ) - print("--------------------------------------------------\n") - print(source) - print("--------------------------------------------------") - # instantiate model - models[agent_id]["value"] = model_class( - observation_space=(state_spaces if agent_class in [MAPPO] else observation_spaces)[agent_id], - action_space=action_spaces[agent_id], - device=device, - **self._process_cfg(_cfg["models"]["value"]), - ) # shared models else: - # remove 'class' field - try: - del _cfg["models"]["policy"]["class"] - except KeyError: - logger.warning( - "No 'class' field defined in 'models:policy' cfg. 'GaussianMixin' will be used as default" - ) - try: - del _cfg["models"]["value"]["class"] - except KeyError: - logger.warning( - "No 'class' field defined in 'models:value' cfg. 'DeterministicMixin' will be used as default" - ) - model_class = self._class("Shared") - # print model source - source = model_class( - observation_space=observation_spaces[agent_id], - action_space=action_spaces[agent_id], - device=device, - structure=None, - roles=["policy", "value"], - parameters=[ - self._process_cfg(_cfg["models"]["policy"]), - self._process_cfg(_cfg["models"]["value"]), - ], - return_source=True, - ) - print("--------------------------------------------------\n") - print(source) - print("--------------------------------------------------") - # instantiate model - models[agent_id]["policy"] = model_class( - observation_space=observation_spaces[agent_id], - action_space=action_spaces[agent_id], - device=device, - structure=None, - roles=["policy", "value"], - parameters=[ - self._process_cfg(_cfg["models"]["policy"]), - self._process_cfg(_cfg["models"]["value"]), - ], - ) - models[agent_id]["value"] = models[agent_id]["policy"] + raise NotImplementedError # initialize models' state dict for agent_id in possible_agents: @@ -293,6 +282,10 @@ def _generate_agent( observation_spaces = env.observation_spaces if multi_agent else {"agent": env.observation_space} action_spaces = env.action_spaces if multi_agent else {"agent": env.action_space} + agent_class = cfg.get("agent", {}).get("class", "").lower() + if not agent_class: + raise ValueError(f"No 'class' field defined in 'agent' cfg") + # check for memory configuration (backward compatibility) if not "memory" in cfg: logger.warning( @@ -301,10 +294,10 @@ def _generate_agent( cfg["memory"] = {"class": "RandomMemory", "memory_size": -1} # get memory class and remove 'class' field try: - memory_class = self._class(cfg["memory"]["class"]) + memory_class = self._component(cfg["memory"]["class"]) del cfg["memory"]["class"] except KeyError: - memory_class = self._class("RandomMemory") + memory_class = self._component("RandomMemory") logger.warning("No 'class' field defined in 'memory' cfg. 'RandomMemory' will be used as default") memories = {} # instantiate memory @@ -313,17 +306,10 @@ def _generate_agent( for agent_id in possible_agents: memories[agent_id] = memory_class(num_envs=num_envs, device=device, **self._process_cfg(cfg["memory"])) - # instantiate agent - try: - agent_class = self._class(cfg["agent"]["class"]) - del cfg["agent"]["class"] - except KeyError: - agent_class = self._class("PPO") - logger.warning("No 'class' field defined in 'agent' cfg. 'PPO' will be used as default") # single-agent configuration and instantiation - if agent_class in [PPO]: + if agent_class in ["ppo"]: agent_id = possible_agents[0] - agent_cfg = PPO_DEFAULT_CONFIG.copy() + agent_cfg = self._component(f"{agent_class}_DEFAULT_CONFIG").copy() agent_cfg.update(self._process_cfg(cfg["agent"])) agent_cfg["state_preprocessor_kwargs"].update({"size": observation_spaces[agent_id], "device": device}) agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": device}) @@ -334,8 +320,8 @@ def _generate_agent( "action_space": action_spaces[agent_id], } # multi-agent configuration and instantiation - elif agent_class in [IPPO]: - agent_cfg = IPPO_DEFAULT_CONFIG.copy() + elif agent_class in ["ippo"]: + agent_cfg = self._component(f"{agent_class}_DEFAULT_CONFIG").copy() agent_cfg.update(self._process_cfg(cfg["agent"])) agent_cfg["state_preprocessor_kwargs"].update( {agent_id: {"size": observation_spaces[agent_id], "device": device} for agent_id in possible_agents} @@ -348,8 +334,8 @@ def _generate_agent( "action_spaces": action_spaces, "possible_agents": possible_agents, } - elif agent_class in [MAPPO]: - agent_cfg = MAPPO_DEFAULT_CONFIG.copy() + elif agent_class in ["mappo"]: + agent_cfg = self._component(f"{agent_class}_DEFAULT_CONFIG").copy() agent_cfg.update(self._process_cfg(cfg["agent"])) agent_cfg["state_preprocessor_kwargs"].update( {agent_id: {"size": observation_spaces[agent_id], "device": device} for agent_id in possible_agents} @@ -366,7 +352,7 @@ def _generate_agent( "shared_observation_spaces": state_spaces, "possible_agents": possible_agents, } - return agent_class(cfg=agent_cfg, device=device, **agent_kwargs) + return self._component(agent_class)(cfg=agent_cfg, device=device, **agent_kwargs) def _generate_trainer( self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, Any], agent: Agent @@ -381,10 +367,10 @@ def _generate_trainer( """ # get trainer class and remove 'class' field try: - trainer_class = self._class(cfg["trainer"]["class"]) + trainer_class = self._component(cfg["trainer"]["class"]) del cfg["trainer"]["class"] except KeyError: - trainer_class = self._class("SequentialTrainer") + trainer_class = self._component("SequentialTrainer") logger.warning("No 'class' field defined in 'trainer' cfg. 'SequentialTrainer' will be used as default") # instantiate trainer return trainer_class(env=env, agents=agent, cfg=cfg["trainer"])