Skip to content

Commit

Permalink
Improve model instantiators (#232)
Browse files Browse the repository at this point in the history
* Add multivariate Gaussian model to runner in docs

* Update shared model instantiator to allow specifying its structure

* Update torch runner to specify shared model structure

* Define model mixin from given structure

* Use spaces utils to initialize jax model state dictionary

* Add support for MultivariateGaussianMixin in shared models

* Update model instantiators test in torch

* Remove double argument definition

* Parse device in jax spaces utils

* Update model instantiators test in jax

* Update CHANGELOG
  • Loading branch information
Toni-SM authored Dec 2, 2024
1 parent bbe532d commit 6324e46
Show file tree
Hide file tree
Showing 7 changed files with 519 additions and 329 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
### Added
- Utilities to operate on Gymnasium spaces (`Box`, `Discrete`, `MultiDiscrete`, `Tuple` and `Dict`)
- `parse_device` static method in ML framework configuration for JAX
- Model instantiator support for different shared model structures in PyTorch
- Support for other model types than Gaussian and Deterministic in runners

### Changed
- Call agent's `pre_interaction` method during evaluation
- Use spaces utilities to process states, observations and actions for all the library components
- Update model instantiators definitions to process supported fundamental and composite Gymnasium spaces
- Make flattened tensor storage in memory the default option (revert changed introduced in version 1.3.0)
- Drop support for PyTorch versions prior to 1.10 (the previous supported version was 1.9).
- Drop support for PyTorch versions prior to 1.10 (the previous supported version was 1.9)

### Changed (breaking changes: style)
- Format code using Black code formatter (it's ugly, yes, but it does its job)

### Fixed
- Moved the batch sampling inside gradient step loop for DQN, DDQN, DDPG (RNN), TD3 (RNN), SAC and SAC (RNN)
- Model state dictionary initialization for composite Gymnasium spaces in JAX

### Removed
- Remove OpenAI Gym (`gym`) from dependencies and source code. **skrl** continues to support gym environments,
Expand Down
11 changes: 9 additions & 2 deletions skrl/models/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np

from skrl import config
from skrl.utils.spaces.torch import compute_space_size, unflatten_tensorized_space
from skrl.utils.spaces.jax import compute_space_size, flatten_tensorized_space, sample_space, unflatten_tensorized_space


@jax.jit
Expand Down Expand Up @@ -132,7 +132,14 @@ def init_state_dict(
:type key: jax.Array, optional
"""
if not inputs:
inputs = {"states": self.observation_space.sample(), "taken_actions": self.action_space.sample()}
inputs = {
"states": flatten_tensorized_space(
sample_space(self.observation_space, backend="jax", device=self.device), self._jax
),
"taken_actions": flatten_tensorized_space(
sample_space(self.action_space, backend="jax", device=self.device), self._jax
),
}
if key is None:
key = config.jax.key
if isinstance(inputs["states"], (int, np.int32, np.int64)):
Expand Down
164 changes: 106 additions & 58 deletions skrl/utils/model_instantiators/torch/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn # noqa

from skrl.models.torch import Model # noqa
from skrl.models.torch import DeterministicMixin, GaussianMixin # noqa
from skrl.models.torch import CategoricalMixin, DeterministicMixin, GaussianMixin, MultivariateGaussianMixin # noqa
from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers
from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa

Expand All @@ -16,7 +16,7 @@ def shared_model(
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
device: Optional[Union[str, torch.device]] = None,
structure: str = "",
structure: Sequence[str] = ["GaussianMixin", "DeterministicMixin"],
roles: Sequence[str] = [],
parameters: Sequence[Mapping[str, Any]] = [],
single_forward_pass: bool = True,
Expand All @@ -33,9 +33,8 @@ def shared_model(
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or torch.device, optional
:param structure: Shared model structure (default: ``""``).
Note: this parameter is ignored for the moment
:type structure: str, optional
:param structure: Shared model structure (default: Gaussian-Deterministic).
:type structure: sequence of strings, optional
:param roles: Organized list of model roles (default: ``[]``)
:type roles: sequence of strings, optional
:param parameters: Organized list of model instantiator parameters (default: ``[]``)
Expand All @@ -49,6 +48,54 @@ def shared_model(
:return: Shared model instance or definition source
:rtype: Model
"""

def get_init(class_name, parameter, role):
if class_name.lower() == "categoricalmixin":
return f'CategoricalMixin.__init__(self, unnormalized_log_prob={parameter["unnormalized_log_prob"]}, role="{role}")'
elif class_name.lower() == "deterministicmixin":
return f'DeterministicMixin.__init__(self, clip_actions={parameter["clip_actions"]}, role="{role}")'
elif class_name.lower() == "gaussianmixin":
return f"""GaussianMixin.__init__(
self,
clip_actions={parameter["clip_actions"]},
clip_log_std={parameter["clip_log_std"]},
min_log_std={parameter["min_log_std"]},
max_log_std={parameter["max_log_std"]},
role="{role}",
)"""
elif class_name.lower() == "multivariategaussianmixin":
return f"""MultivariateGaussianMixin.__init__(
self,
clip_actions={parameter["clip_actions"]},
clip_log_std={parameter["clip_log_std"]},
min_log_std={parameter["min_log_std"]},
max_log_std={parameter["max_log_std"]},
role="{role}",
)"""
raise ValueError(f"Unknown class: {class_name}")

def get_return(class_name):
if class_name.lower() == "categoricalmixin":
return r"output, {}"
elif class_name.lower() == "deterministicmixin":
return r"output, {}"
elif class_name.lower() == "gaussianmixin":
return r"output, self.log_std_parameter, {}"
elif class_name.lower() == "multivariategaussianmixin":
return r"output, self.log_std_parameter, {}"
raise ValueError(f"Unknown class: {class_name}")

def get_extra(class_name, parameter, role, model):
if class_name.lower() == "categoricalmixin":
return ""
elif class_name.lower() == "deterministicmixin":
return ""
elif class_name.lower() == "gaussianmixin":
return f'self.log_std_parameter = nn.Parameter(torch.full(size=({model["output"]["size"]},), fill_value={float(parameter["initial_log_std"])}))'
elif class_name.lower() == "multivariategaussianmixin":
return f'self.log_std_parameter = nn.Parameter(torch.full(size=({model["output"]["size"]},), fill_value={float(parameter["initial_log_std"])}))'
raise ValueError(f"Unknown class: {class_name}")

# compatibility with versions prior to 1.3.0
if not "network" in parameters[0]:
parameters[0]["network"], parameters[0]["output"] = convert_deprecated_parameters(parameters[0])
Expand All @@ -67,18 +114,25 @@ def shared_model(
if parameter in parameters[1]:
del parameters[1][parameter]

# checking
assert (
len(structure) == len(roles) == len(parameters)
), f"Invalid configuration: structures ({len(structure)}), roles ({len(roles)}) and parameters ({len(parameters)}) have different lengths"

models = [{"class": item} for item in structure]

# parse model definitions
containers_gaussian, output_gaussian = generate_containers(
parameters[0]["network"], parameters[0]["output"], embed_output=False, indent=1
)
containers_deterministic, output_deterministic = generate_containers(
parameters[1]["network"], parameters[1]["output"], embed_output=False, indent=1
)
for i, model in enumerate(models):
model["forward"] = []
model["networks"] = []
model["containers"], model["output"] = generate_containers(
parameters[i]["network"], parameters[i]["output"], embed_output=False, indent=1
)

# network definitions
networks_common = []
forward_common = []
for container in containers_gaussian:
for container in models[0]["containers"]:
networks_common.append(f'self.{container["name"]}_container = {container["sequential"]}')
forward_common.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})')
forward_common.insert(
Expand All @@ -87,35 +141,37 @@ def shared_model(
forward_common.insert(0, 'states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))')

# process output
networks_gaussian = []
forward_gaussian = []
if output_gaussian["modules"]:
networks_gaussian.append(f'self.{roles[0]}_layer = {output_gaussian["modules"][0]}')
forward_gaussian.append(f'output = self.{roles[0]}_layer({container["name"]})')
if output_gaussian["output"]:
forward_gaussian.append(f'output = {output_gaussian["output"]}')
if models[0]["output"]["modules"]:
models[0]["networks"].append(f'self.{roles[0]}_layer = {models[0]["output"]["modules"][0]}')
models[0]["forward"].append(f'output = self.{roles[0]}_layer({container["name"]})')
if models[0]["output"]["output"]:
models[0]["forward"].append(f'output = {models[0]["output"]["output"]}')
else:
forward_gaussian[-1] = forward_gaussian[-1].replace(f'{container["name"]} =', "output =", 1)
models[0]["forward"][-1] = models[0]["forward"][-1].replace(f'{container["name"]} =', "output =", 1)

networks_deterministic = []
forward_deterministic = []
if output_deterministic["modules"]:
networks_deterministic.append(f'self.{roles[1]}_layer = {output_deterministic["modules"][0]}')
forward_deterministic.append(
if models[1]["output"]["modules"]:
models[1]["networks"].append(f'self.{roles[1]}_layer = {models[1]["output"]["modules"][0]}')
models[1]["forward"].append(
f'output = self.{roles[1]}_layer({"shared_output" if single_forward_pass else container["name"]})'
)
if output_deterministic["output"]:
forward_deterministic.append(f'output = {output_deterministic["output"]}')
if models[1]["output"]["output"]:
models[1]["forward"].append(f'output = {models[1]["output"]["output"]}')
else:
forward_deterministic[-1] = forward_deterministic[-1].replace(f'{container["name"]} =', "output =", 1)
models[1]["forward"][-1] = models[1]["forward"][-1].replace(f'{container["name"]} =', "output =", 1)

# build substitutions and indent content
networks_common = textwrap.indent("\n".join(networks_common), prefix=" " * 8)[8:]
networks_gaussian = textwrap.indent("\n".join(networks_gaussian), prefix=" " * 8)[8:]
networks_deterministic = textwrap.indent("\n".join(networks_deterministic), prefix=" " * 8)[8:]
models[0]["networks"] = textwrap.indent("\n".join(models[0]["networks"]), prefix=" " * 8)[8:]
extra = get_extra(structure[0], parameters[0], roles[0], models[0])
if extra:
models[0]["networks"] += "\n" + textwrap.indent(extra, prefix=" " * 8)
models[1]["networks"] = textwrap.indent("\n".join(models[1]["networks"]), prefix=" " * 8)[8:]
extra = get_extra(structure[1], parameters[1], roles[1], models[1])
if extra:
models[1]["networks"] += "\n" + textwrap.indent(extra, prefix=" " * 8)

if single_forward_pass:
forward_deterministic = (
models[1]["forward"] = (
[
"if self._shared_output is None:",
]
Expand All @@ -126,59 +182,53 @@ def shared_model(
" shared_output = self._shared_output",
"self._shared_output = None",
]
+ forward_deterministic
+ models[1]["forward"]
)
forward_common.append(f'self._shared_output = {container["name"]}')
forward_common = textwrap.indent("\n".join(forward_common), prefix=" " * 12)[12:]
else:
forward_common = textwrap.indent("\n".join(forward_common), prefix=" " * 8)[8:]

forward_gaussian = textwrap.indent("\n".join(forward_gaussian), prefix=" " * 12)[12:]
forward_deterministic = textwrap.indent("\n".join(forward_deterministic), prefix=" " * 12)[12:]
models[0]["forward"] = textwrap.indent("\n".join(models[0]["forward"]), prefix=" " * 12)[12:]
models[1]["forward"] = textwrap.indent("\n".join(models[1]["forward"]), prefix=" " * 12)[12:]

template = f"""class GaussianDeterministicModel(GaussianMixin, DeterministicMixin, Model):
template = f"""class SharedModel({",".join(structure)}, Model):
def __init__(self, observation_space, action_space, device):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self,
clip_actions={parameters[0]["clip_actions"]},
clip_log_std={parameters[0]["clip_log_std"]},
min_log_std={parameters[0]["min_log_std"]},
max_log_std={parameters[0]["max_log_std"]},
role="{roles[0]}")
DeterministicMixin.__init__(self, clip_actions={parameters[1]["clip_actions"]}, role="{roles[1]}")
{get_init(structure[0], parameters[0], roles[0])}
{get_init(structure[1], parameters[1], roles[1])}
{networks_common}
{networks_gaussian}
{networks_deterministic}
self.log_std_parameter = nn.Parameter({parameters[0]["initial_log_std"]} * torch.ones({output_gaussian["size"]}))
{models[0]["networks"]}
{models[1]["networks"]}
def act(self, inputs, role):
if role == "{roles[0]}":
return GaussianMixin.act(self, inputs, role)
return {structure[0]}.act(self, inputs, role)
elif role == "{roles[1]}":
return DeterministicMixin.act(self, inputs, role)
return {structure[1]}.act(self, inputs, role)
"""
if single_forward_pass:
template += f"""
def compute(self, inputs, role=""):
if role == "{roles[0]}":
{forward_common}
{forward_gaussian}
return output, self.log_std_parameter, {{}}
{models[0]["forward"]}
return {get_return(structure[0])}
elif role == "{roles[1]}":
{forward_deterministic}
return output, {{}}
{models[1]["forward"]}
return {get_return(structure[1])}
"""
else:
template += f"""
def compute(self, inputs, role=""):
{forward_common}
if role == "{roles[0]}":
{forward_gaussian}
return output, self.log_std_parameter, {{}}
{models[0]["forward"]}
return {get_return(structure[0])}
elif role == "{roles[1]}":
{forward_deterministic}
return output, {{}}
{models[1]["forward"]}
return {get_return(structure[1])}
"""
# return source
if return_source:
Expand All @@ -187,6 +237,4 @@ def compute(self, inputs, role=""):
# instantiate model
_locals = {}
exec(template, globals(), _locals)
return _locals["GaussianDeterministicModel"](
observation_space=observation_space, action_space=action_space, device=device
)
return _locals["SharedModel"](observation_space=observation_space, action_space=action_space, device=device)
15 changes: 12 additions & 3 deletions skrl/utils/runner/torch/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from skrl.resources.schedulers.torch import KLAdaptiveLR # noqa
from skrl.trainers.torch import SequentialTrainer, Trainer
from skrl.utils import set_seed
from skrl.utils.model_instantiators.torch import categorical_model, deterministic_model, gaussian_model, shared_model
from skrl.utils.model_instantiators.torch import (
categorical_model,
deterministic_model,
gaussian_model,
multivariate_gaussian_model,
shared_model,
)


class Runner:
Expand All @@ -37,6 +43,7 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str,
"gaussianmixin": gaussian_model,
"categoricalmixin": categorical_model,
"deterministicmixin": deterministic_model,
"multivariategaussianmixin": multivariate_gaussian_model,
"shared": shared_model,
# memory
"randommemory": RandomMemory,
Expand Down Expand Up @@ -218,6 +225,8 @@ def _generate_models(
# shared models
else:
# remove 'class' field
policy_class_name = _cfg["models"]["policy"].get("class", "GaussianMixin")
value_class_name = _cfg["models"]["value"].get("class", "DeterministicMixin")
try:
del _cfg["models"]["policy"]["class"]
except KeyError:
Expand All @@ -236,7 +245,7 @@ def _generate_models(
observation_space=observation_spaces[agent_id],
action_space=action_spaces[agent_id],
device=device,
structure=None,
structure=[policy_class_name, value_class_name],
roles=["policy", "value"],
parameters=[
self._process_cfg(_cfg["models"]["policy"]),
Expand All @@ -252,7 +261,7 @@ def _generate_models(
observation_space=observation_spaces[agent_id],
action_space=action_spaces[agent_id],
device=device,
structure=None,
structure=[policy_class_name, value_class_name],
roles=["policy", "value"],
parameters=[
self._process_cfg(_cfg["models"]["policy"]),
Expand Down
Loading

0 comments on commit 6324e46

Please sign in to comment.