Skip to content

Commit

Permalink
Update shared model instantiator to allow specifying its structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Nov 29, 2024
1 parent ece2ae1 commit 11588c8
Showing 1 changed file with 43 additions and 41 deletions.
84 changes: 43 additions & 41 deletions skrl/utils/model_instantiators/torch/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn # noqa

from skrl.models.torch import Model # noqa
from skrl.models.torch import DeterministicMixin, GaussianMixin # noqa
from skrl.models.torch import CategoricalMixin, DeterministicMixin, GaussianMixin # noqa
from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers
from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa

Expand All @@ -16,7 +16,7 @@ def shared_model(
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
device: Optional[Union[str, torch.device]] = None,
structure: str = "",
structure: Sequence[str] = ["GaussianMixin", "DeterministicMixin"],
roles: Sequence[str] = [],
parameters: Sequence[Mapping[str, Any]] = [],
single_forward_pass: bool = True,
Expand All @@ -33,9 +33,8 @@ def shared_model(
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
:type device: str or torch.device, optional
:param structure: Shared model structure (default: ``""``).
Note: this parameter is ignored for the moment
:type structure: str, optional
:param structure: Shared model structure (default: Gaussian-Deterministic).
:type structure: sequence of strings, optional
:param roles: Organized list of model roles (default: ``[]``)
:type roles: sequence of strings, optional
:param parameters: Organized list of model instantiator parameters (default: ``[]``)
Expand Down Expand Up @@ -67,18 +66,25 @@ def shared_model(
if parameter in parameters[1]:
del parameters[1][parameter]

# checking
assert (
len(structure) == len(roles) == len(parameters)
), f"Invalid configuration: structures ({len(structure)}), roles ({len(roles)}) and parameters ({len(parameters)}) have different lengths"

models = [{"class": item} for item in structure]

# parse model definitions
containers_gaussian, output_gaussian = generate_containers(
parameters[0]["network"], parameters[0]["output"], embed_output=False, indent=1
)
containers_deterministic, output_deterministic = generate_containers(
parameters[1]["network"], parameters[1]["output"], embed_output=False, indent=1
)
for i, model in enumerate(models):
model["forward"] = []
model["networks"] = []
model["containers"], model["output"] = generate_containers(
parameters[i]["network"], parameters[i]["output"], embed_output=False, indent=1
)

# network definitions
networks_common = []
forward_common = []
for container in containers_gaussian:
for container in models[0]["containers"]:
networks_common.append(f'self.{container["name"]}_container = {container["sequential"]}')
forward_common.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})')
forward_common.insert(
Expand All @@ -87,35 +93,31 @@ def shared_model(
forward_common.insert(0, 'states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))')

# process output
networks_gaussian = []
forward_gaussian = []
if output_gaussian["modules"]:
networks_gaussian.append(f'self.{roles[0]}_layer = {output_gaussian["modules"][0]}')
forward_gaussian.append(f'output = self.{roles[0]}_layer({container["name"]})')
if output_gaussian["output"]:
forward_gaussian.append(f'output = {output_gaussian["output"]}')
if models[0]["output"]["modules"]:
models[0]["networks"].append(f'self.{roles[0]}_layer = {models[0]["output"]["modules"][0]}')
models[0]["forward"].append(f'output = self.{roles[0]}_layer({container["name"]})')
if models[0]["output"]["output"]:
models[0]["forward"].append(f'output = {models[0]["output"]["output"]}')
else:
forward_gaussian[-1] = forward_gaussian[-1].replace(f'{container["name"]} =', "output =", 1)
models[0]["forward"][-1] = models[0]["forward"][-1].replace(f'{container["name"]} =', "output =", 1)

networks_deterministic = []
forward_deterministic = []
if output_deterministic["modules"]:
networks_deterministic.append(f'self.{roles[1]}_layer = {output_deterministic["modules"][0]}')
forward_deterministic.append(
if models[1]["output"]["modules"]:
models[1]["networks"].append(f'self.{roles[1]}_layer = {models[1]["output"]["modules"][0]}')
models[1]["forward"].append(
f'output = self.{roles[1]}_layer({"shared_output" if single_forward_pass else container["name"]})'
)
if output_deterministic["output"]:
forward_deterministic.append(f'output = {output_deterministic["output"]}')
if models[1]["output"]["output"]:
models[1]["forward"].append(f'output = {models[1]["output"]["output"]}')
else:
forward_deterministic[-1] = forward_deterministic[-1].replace(f'{container["name"]} =', "output =", 1)
models[1]["forward"][-1] = models[1]["forward"][-1].replace(f'{container["name"]} =', "output =", 1)

# build substitutions and indent content
networks_common = textwrap.indent("\n".join(networks_common), prefix=" " * 8)[8:]
networks_gaussian = textwrap.indent("\n".join(networks_gaussian), prefix=" " * 8)[8:]
networks_deterministic = textwrap.indent("\n".join(networks_deterministic), prefix=" " * 8)[8:]
models[0]["networks"] = textwrap.indent("\n".join(models[0]["networks"]), prefix=" " * 8)[8:]
models[1]["networks"] = textwrap.indent("\n".join(models[1]["networks"]), prefix=" " * 8)[8:]

if single_forward_pass:
forward_deterministic = (
models[1]["forward"] = (
[
"if self._shared_output is None:",
]
Expand All @@ -126,15 +128,15 @@ def shared_model(
" shared_output = self._shared_output",
"self._shared_output = None",
]
+ forward_deterministic
+ models[1]["forward"]
)
forward_common.append(f'self._shared_output = {container["name"]}')
forward_common = textwrap.indent("\n".join(forward_common), prefix=" " * 12)[12:]
else:
forward_common = textwrap.indent("\n".join(forward_common), prefix=" " * 8)[8:]

forward_gaussian = textwrap.indent("\n".join(forward_gaussian), prefix=" " * 12)[12:]
forward_deterministic = textwrap.indent("\n".join(forward_deterministic), prefix=" " * 12)[12:]
models[0]["forward"] = textwrap.indent("\n".join(models[0]["forward"]), prefix=" " * 12)[12:]
models[1]["forward"] = textwrap.indent("\n".join(models[1]["forward"]), prefix=" " * 12)[12:]

template = f"""class GaussianDeterministicModel(GaussianMixin, DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device):
Expand All @@ -148,9 +150,9 @@ def __init__(self, observation_space, action_space, device):
DeterministicMixin.__init__(self, clip_actions={parameters[1]["clip_actions"]}, role="{roles[1]}")
{networks_common}
{networks_gaussian}
{networks_deterministic}
self.log_std_parameter = nn.Parameter({parameters[0]["initial_log_std"]} * torch.ones({output_gaussian["size"]}))
{models[0]["networks"]}
{models[1]["networks"]}
self.log_std_parameter = nn.Parameter({parameters[0]["initial_log_std"]} * torch.ones({models[0]["output"]["size"]}))
def act(self, inputs, role):
if role == "{roles[0]}":
Expand All @@ -163,21 +165,21 @@ def act(self, inputs, role):
def compute(self, inputs, role=""):
if role == "{roles[0]}":
{forward_common}
{forward_gaussian}
{models[0]["forward"]}
return output, self.log_std_parameter, {{}}
elif role == "{roles[1]}":
{forward_deterministic}
{models[1]["forward"]}
return output, {{}}
"""
else:
template += f"""
def compute(self, inputs, role=""):
{forward_common}
if role == "{roles[0]}":
{forward_gaussian}
{models[0]["forward"]}
return output, self.log_std_parameter, {{}}
elif role == "{roles[1]}":
{forward_deterministic}
{models[1]["forward"]}
return output, {{}}
"""
# return source
Expand Down

0 comments on commit 11588c8

Please sign in to comment.