Skip to content

Commit

Permalink
Update runner to support the definition of any agent and its models i…
Browse files Browse the repository at this point in the history
…n JAX
  • Loading branch information
Toni-SM committed Jan 15, 2025
1 parent d2bdfab commit a43b4f8
Showing 1 changed file with 134 additions and 148 deletions.
282 changes: 134 additions & 148 deletions skrl/utils/runner/jax/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,12 @@

from skrl import logger
from skrl.agents.jax import Agent
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.envs.wrappers.jax import MultiAgentEnvWrapper, Wrapper
from skrl.memories.jax import RandomMemory
from skrl.models.jax import Model
from skrl.multi_agents.jax.ippo import IPPO, IPPO_DEFAULT_CONFIG
from skrl.multi_agents.jax.mappo import MAPPO, MAPPO_DEFAULT_CONFIG
from skrl.resources.preprocessors.jax import RunningStandardScaler # noqa
from skrl.resources.schedulers.jax import KLAdaptiveLR # noqa
from skrl.trainers.jax import SequentialTrainer, Trainer
from skrl.trainers.jax import Trainer
from skrl.utils import set_seed
from skrl.utils.model_instantiators.jax import categorical_model, deterministic_model, gaussian_model


class Runner:
Expand All @@ -32,22 +27,6 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str,
# set random seed
set_seed(self._cfg.get("seed", None))

self._class_mapping = {
# model
"gaussianmixin": gaussian_model,
"categoricalmixin": categorical_model,
"deterministicmixin": deterministic_model,
"shared": None,
# memory
"randommemory": RandomMemory,
# agent
"ppo": PPO,
"ippo": IPPO,
"mappo": MAPPO,
# trainer
"sequentialtrainer": SequentialTrainer,
}

self._cfg["agent"]["rewards_shaper"] = None # FIXME: avoid 'dictionary changed size during iteration'

self._models = self._generate_models(self._env, copy.deepcopy(self._cfg))
Expand Down Expand Up @@ -85,14 +64,76 @@ def load_cfg_from_yaml(path: str) -> dict:
logger.error(f"Loading yaml error: {e}")
return {}

def _class(self, value: str) -> Type:
"""Get skrl component class (e.g.: agent, trainer, etc..) from string identifier
def _component(self, name: str) -> Type:
"""Get skrl component (e.g.: agent, trainer, etc..) from string identifier
:return: skrl component class
:return: skrl component
"""
if value.lower() in self._class_mapping:
return self._class_mapping[value.lower()]
raise ValueError(f"Unknown class '{value}' in runner cfg")
component = None
name = name.lower()
# model
if name == "gaussianmixin":
from skrl.utils.model_instantiators.jax import gaussian_model as component
elif name == "categoricalmixin":
from skrl.utils.model_instantiators.jax import categorical_model as component
elif name == "deterministicmixin":
from skrl.utils.model_instantiators.jax import deterministic_model as component
# memory
elif name == "randommemory":
from skrl.memories.jax import RandomMemory as component
# agent
elif name in ["a2c", "a2c_default_config"]:
from skrl.agents.jax.a2c import A2C, A2C_DEFAULT_CONFIG

component = A2C_DEFAULT_CONFIG if "default_config" in name else A2C
elif name in ["cem", "cem_default_config"]:
from skrl.agents.jax.cem import CEM, CEM_DEFAULT_CONFIG

component = CEM_DEFAULT_CONFIG if "default_config" in name else CEM
elif name in ["ddpg", "ddpg_default_config"]:
from skrl.agents.jax.ddpg import DDPG, DDPG_DEFAULT_CONFIG

component = DDPG_DEFAULT_CONFIG if "default_config" in name else DDPG
elif name in ["ddqn", "ddqn_default_config"]:
from skrl.agents.jax.dqn import DDQN, DDQN_DEFAULT_CONFIG

component = DDQN_DEFAULT_CONFIG if "default_config" in name else DDQN
elif name in ["dqn", "dqn_default_config"]:
from skrl.agents.jax.dqn import DQN, DQN_DEFAULT_CONFIG

component = DQN_DEFAULT_CONFIG if "default_config" in name else DQN
elif name in ["ppo", "ppo_default_config"]:
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG

component = PPO_DEFAULT_CONFIG if "default_config" in name else PPO
elif name in ["rpo", "rpo_default_config"]:
from skrl.agents.jax.rpo import RPO, RPO_DEFAULT_CONFIG

component = RPO_DEFAULT_CONFIG if "default_config" in name else RPO
elif name in ["sac", "sac_default_config"]:
from skrl.agents.jax.sac import SAC, SAC_DEFAULT_CONFIG

component = SAC_DEFAULT_CONFIG if "default_config" in name else SAC
elif name in ["td3", "td3_default_config"]:
from skrl.agents.jax.td3 import TD3, TD3_DEFAULT_CONFIG

component = TD3_DEFAULT_CONFIG if "default_config" in name else TD3
# multi-agent
elif name in ["ippo", "ippo_default_config"]:
from skrl.multi_agents.jax.ippo import IPPO, IPPO_DEFAULT_CONFIG

component = IPPO_DEFAULT_CONFIG if "default_config" in name else IPPO
elif name in ["mappo", "mappo_default_config"]:
from skrl.multi_agents.jax.mappo import MAPPO, MAPPO_DEFAULT_CONFIG

component = MAPPO_DEFAULT_CONFIG if "default_config" in name else MAPPO
# trainer
elif name == "sequentialtrainer":
from skrl.trainers.jax import SequentialTrainer as component

if component is None:
raise ValueError(f"Unknown component '{name}' in runner cfg")
return component

def _process_cfg(self, cfg: dict) -> dict:
"""Convert simple types to skrl classes/components
Expand All @@ -106,6 +147,7 @@ def _process_cfg(self, cfg: dict) -> dict:
"shared_state_preprocessor",
"state_preprocessor",
"value_preprocessor",
"amp_state_preprocessor",
]

def reward_shaper_function(scale):
Expand Down Expand Up @@ -150,119 +192,66 @@ def _generate_models(
# override cfg
cfg["models"]["separate"] = True # shared model is not supported in JAX

try:
agent_class = self._class(cfg["agent"]["class"])
del cfg["agent"]["class"]
except KeyError:
agent_class = self._class("PPO")
logger.warning("No 'class' field defined in 'agent' cfg. 'PPO' will be used as default")
agent_class = cfg.get("agent", {}).get("class", "").lower()

# instantiate models
models = {}
for agent_id in possible_agents:
_cfg = copy.deepcopy(cfg)
models[agent_id] = {}
models_cfg = _cfg.get("models")
if not models_cfg:
raise ValueError("No 'models' are defined in cfg")
# get separate (non-shared) configuration and remove 'separate' key
try:
separate = models_cfg["separate"]
del models_cfg["separate"]
except KeyError:
separate = True
logger.warning("No 'separate' field defined in 'models' cfg. Defining it as True by default")
# non-shared models
if _cfg["models"]["separate"]:
# get instantiator function and remove 'class' field
try:
model_class = self._class(_cfg["models"]["policy"]["class"])
del _cfg["models"]["policy"]["class"]
except KeyError:
model_class = self._class("GaussianMixin")
logger.warning(
"No 'class' field defined in 'models:policy' cfg. 'GaussianMixin' will be used as default"
if separate:
for role in models_cfg:
# get instantiator function and remove 'class' key
model_class = models_cfg[role].get("class")
if not model_class:
raise ValueError(f"No 'class' field defined in 'models:{role}' cfg")
del models_cfg[role]["class"]
model_class = self._component(model_class)
# get specific spaces according to agent/model cfg
observation_space = observation_spaces[agent_id]
if agent_class == "mappo" and role == "value":
observation_space = state_spaces[agent_id]
if agent_class == "amp" and role == "discriminator":
try:
observation_space = env.amp_observation_space
except Exception as e:
logger.warning(
"Unable to get AMP space via 'env.amp_observation_space'. Using 'env.observation_space' instead"
)
# print model source
source = model_class(
observation_space=observation_space,
action_space=action_spaces[agent_id],
device=device,
**self._process_cfg(models_cfg[role]),
return_source=True,
)
# print model source
source = model_class(
observation_space=observation_spaces[agent_id],
action_space=action_spaces[agent_id],
device=device,
**self._process_cfg(_cfg["models"]["policy"]),
return_source=True,
)
print("--------------------------------------------------\n")
print(source)
print("--------------------------------------------------")
# instantiate model
models[agent_id]["policy"] = model_class(
observation_space=observation_spaces[agent_id],
action_space=action_spaces[agent_id],
device=device,
**self._process_cfg(_cfg["models"]["policy"]),
)
# get instantiator function and remove 'class' field
try:
model_class = self._class(_cfg["models"]["value"]["class"])
del _cfg["models"]["value"]["class"]
except KeyError:
model_class = self._class("DeterministicMixin")
logger.warning(
"No 'class' field defined in 'models:value' cfg. 'DeterministicMixin' will be used as default"
print("==================================================")
print(f"Model (role): {role}")
print("==================================================\n")
print(source)
print("--------------------------------------------------")
# instantiate model
models[agent_id][role] = model_class(
observation_space=observation_space,
action_space=action_spaces[agent_id],
device=device,
**self._process_cfg(models_cfg[role]),
)
# print model source
source = model_class(
observation_space=(state_spaces if agent_class in [MAPPO] else observation_spaces)[agent_id],
action_space=action_spaces[agent_id],
device=device,
**self._process_cfg(_cfg["models"]["value"]),
return_source=True,
)
print("--------------------------------------------------\n")
print(source)
print("--------------------------------------------------")
# instantiate model
models[agent_id]["value"] = model_class(
observation_space=(state_spaces if agent_class in [MAPPO] else observation_spaces)[agent_id],
action_space=action_spaces[agent_id],
device=device,
**self._process_cfg(_cfg["models"]["value"]),
)
# shared models
else:
# remove 'class' field
try:
del _cfg["models"]["policy"]["class"]
except KeyError:
logger.warning(
"No 'class' field defined in 'models:policy' cfg. 'GaussianMixin' will be used as default"
)
try:
del _cfg["models"]["value"]["class"]
except KeyError:
logger.warning(
"No 'class' field defined in 'models:value' cfg. 'DeterministicMixin' will be used as default"
)
model_class = self._class("Shared")
# print model source
source = model_class(
observation_space=observation_spaces[agent_id],
action_space=action_spaces[agent_id],
device=device,
structure=None,
roles=["policy", "value"],
parameters=[
self._process_cfg(_cfg["models"]["policy"]),
self._process_cfg(_cfg["models"]["value"]),
],
return_source=True,
)
print("--------------------------------------------------\n")
print(source)
print("--------------------------------------------------")
# instantiate model
models[agent_id]["policy"] = model_class(
observation_space=observation_spaces[agent_id],
action_space=action_spaces[agent_id],
device=device,
structure=None,
roles=["policy", "value"],
parameters=[
self._process_cfg(_cfg["models"]["policy"]),
self._process_cfg(_cfg["models"]["value"]),
],
)
models[agent_id]["value"] = models[agent_id]["policy"]
raise NotImplementedError

# initialize models' state dict
for agent_id in possible_agents:
Expand Down Expand Up @@ -293,6 +282,10 @@ def _generate_agent(
observation_spaces = env.observation_spaces if multi_agent else {"agent": env.observation_space}
action_spaces = env.action_spaces if multi_agent else {"agent": env.action_space}

agent_class = cfg.get("agent", {}).get("class", "").lower()
if not agent_class:
raise ValueError(f"No 'class' field defined in 'agent' cfg")

# check for memory configuration (backward compatibility)
if not "memory" in cfg:
logger.warning(
Expand All @@ -301,10 +294,10 @@ def _generate_agent(
cfg["memory"] = {"class": "RandomMemory", "memory_size": -1}
# get memory class and remove 'class' field
try:
memory_class = self._class(cfg["memory"]["class"])
memory_class = self._component(cfg["memory"]["class"])
del cfg["memory"]["class"]
except KeyError:
memory_class = self._class("RandomMemory")
memory_class = self._component("RandomMemory")
logger.warning("No 'class' field defined in 'memory' cfg. 'RandomMemory' will be used as default")
memories = {}
# instantiate memory
Expand All @@ -313,17 +306,10 @@ def _generate_agent(
for agent_id in possible_agents:
memories[agent_id] = memory_class(num_envs=num_envs, device=device, **self._process_cfg(cfg["memory"]))

# instantiate agent
try:
agent_class = self._class(cfg["agent"]["class"])
del cfg["agent"]["class"]
except KeyError:
agent_class = self._class("PPO")
logger.warning("No 'class' field defined in 'agent' cfg. 'PPO' will be used as default")
# single-agent configuration and instantiation
if agent_class in [PPO]:
if agent_class in ["ppo"]:
agent_id = possible_agents[0]
agent_cfg = PPO_DEFAULT_CONFIG.copy()
agent_cfg = self._component(f"{agent_class}_DEFAULT_CONFIG").copy()
agent_cfg.update(self._process_cfg(cfg["agent"]))
agent_cfg["state_preprocessor_kwargs"].update({"size": observation_spaces[agent_id], "device": device})
agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": device})
Expand All @@ -334,8 +320,8 @@ def _generate_agent(
"action_space": action_spaces[agent_id],
}
# multi-agent configuration and instantiation
elif agent_class in [IPPO]:
agent_cfg = IPPO_DEFAULT_CONFIG.copy()
elif agent_class in ["ippo"]:
agent_cfg = self._component(f"{agent_class}_DEFAULT_CONFIG").copy()
agent_cfg.update(self._process_cfg(cfg["agent"]))
agent_cfg["state_preprocessor_kwargs"].update(
{agent_id: {"size": observation_spaces[agent_id], "device": device} for agent_id in possible_agents}
Expand All @@ -348,8 +334,8 @@ def _generate_agent(
"action_spaces": action_spaces,
"possible_agents": possible_agents,
}
elif agent_class in [MAPPO]:
agent_cfg = MAPPO_DEFAULT_CONFIG.copy()
elif agent_class in ["mappo"]:
agent_cfg = self._component(f"{agent_class}_DEFAULT_CONFIG").copy()
agent_cfg.update(self._process_cfg(cfg["agent"]))
agent_cfg["state_preprocessor_kwargs"].update(
{agent_id: {"size": observation_spaces[agent_id], "device": device} for agent_id in possible_agents}
Expand All @@ -366,7 +352,7 @@ def _generate_agent(
"shared_observation_spaces": state_spaces,
"possible_agents": possible_agents,
}
return agent_class(cfg=agent_cfg, device=device, **agent_kwargs)
return self._component(agent_class)(cfg=agent_cfg, device=device, **agent_kwargs)

def _generate_trainer(
self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, Any], agent: Agent
Expand All @@ -381,10 +367,10 @@ def _generate_trainer(
"""
# get trainer class and remove 'class' field
try:
trainer_class = self._class(cfg["trainer"]["class"])
trainer_class = self._component(cfg["trainer"]["class"])
del cfg["trainer"]["class"]
except KeyError:
trainer_class = self._class("SequentialTrainer")
trainer_class = self._component("SequentialTrainer")
logger.warning("No 'class' field defined in 'trainer' cfg. 'SequentialTrainer' will be used as default")
# instantiate trainer
return trainer_class(env=env, agents=agent, cfg=cfg["trainer"])
Expand Down

0 comments on commit a43b4f8

Please sign in to comment.