Skip to content

Commit

Permalink
Move the conversion of deprecated parameters to common file
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Sep 5, 2024
1 parent 6a9734e commit d811651
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 65 deletions.
18 changes: 2 additions & 16 deletions skrl/utils/model_instantiators/torch/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import torch
import torch.nn as nn # noqa

from skrl import logger
from skrl.models.torch import CategoricalMixin # noqa
from skrl.models.torch import Model
from skrl.utils.model_instantiators.torch.common import generate_containers
from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers


def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -51,20 +50,7 @@ def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Spa
"""
# compatibility with versions prior to 1.3.0
if not network and kwargs:
logger.warning(f'The following parameters ({", ".join(list(kwargs.keys()))}) are deprecated. '
"See https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html")
network = [
{
"name": "net",
"input": str(kwargs.get("input_shape", "STATES")),
"layers": kwargs.get("hiddens", []),
"activations": kwargs.get("hidden_activation", []),
}
]
if kwargs.get("output_activation", None):
output = f'{kwargs["output_activation"]}({str(kwargs.get("output_shape", "ACTIONS"))})'
else:
output = f'{str(kwargs.get("output_shape", "ACTIONS"))}'
network, output = convert_deprecated_parameters(kwargs)

# parse model definition
containers, output = generate_containers(network, output, embed_output=True, indent=1)
Expand Down
30 changes: 30 additions & 0 deletions skrl/utils/model_instantiators/torch/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import ast

from skrl import logger


def _get_activation_function(activation: Union[str, None], as_module: bool = True) -> Union[str, None]:
"""Get the activation function
Expand Down Expand Up @@ -269,3 +271,31 @@ def generate_containers(network: Sequence[Mapping[str, Any]], output: Union[str,
output = output.replace("PLACEHOLDER", container["name"] if embed_output else "output")
output = {"output": output, "modules": output_modules, "size": output_size}
return containers, output

def convert_deprecated_parameters(parameters: Mapping[str, Any]) -> Tuple[Mapping[str, Any], str]:
"""Function to convert deprecated parameters to network-output format
:param parameters: Deprecated parameters and their values.
:return: Network and output definitions
"""
logger.warning(f'The following parameters ({", ".join(list(parameters.keys()))}) are deprecated. '
"See https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html")
# network definition
network = [
{
"name": "net",
"input": str(parameters.get("input_shape", "STATES")),
"layers": parameters.get("hiddens", []),
"activations": parameters.get("hidden_activation", []),
}
]
# output
output_scale = parameters.get("output_scale", 1.0)
scale_operation = f"{output_scale} * " if output_scale != 1.0 else ""
if parameters.get("output_activation", None):
output = f'{scale_operation}{parameters["output_activation"]}({str(parameters.get("output_shape", "ACTIONS"))})'
else:
output = f'{scale_operation}{str(parameters.get("output_shape", "ACTIONS"))}'

return network, output
18 changes: 2 additions & 16 deletions skrl/utils/model_instantiators/torch/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import torch
import torch.nn as nn # noqa

from skrl import logger
from skrl.models.torch import DeterministicMixin # noqa
from skrl.models.torch import Model
from skrl.utils.model_instantiators.torch.common import generate_containers
from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers


def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -48,20 +47,7 @@ def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.S
"""
# compatibility with versions prior to 1.3.0
if not network and kwargs:
logger.warning(f'The following parameters ({", ".join(list(kwargs.keys()))}) are deprecated. '
"See https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html")
network = [
{
"name": "net",
"input": str(kwargs.get("input_shape", "STATES")),
"layers": kwargs.get("hiddens", []),
"activations": kwargs.get("hidden_activation", []),
}
]
if kwargs.get("output_activation", None):
output = f'{kwargs.get("output_scale", 1.0)} * {kwargs["output_activation"]}({str(kwargs.get("output_shape", "ACTIONS"))})'
else:
output = f'{kwargs.get("output_scale", 1.0)} * {str(kwargs.get("output_shape", "ACTIONS"))}'
network, output = convert_deprecated_parameters(kwargs)

# parse model definition
containers, output = generate_containers(network, output, embed_output=True, indent=1)
Expand Down
18 changes: 2 additions & 16 deletions skrl/utils/model_instantiators/torch/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import torch
import torch.nn as nn # noqa

from skrl import logger
from skrl.models.torch import GaussianMixin # noqa
from skrl.models.torch import Model
from skrl.utils.model_instantiators.torch.common import generate_containers
from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers


def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -60,20 +59,7 @@ def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space,
"""
# compatibility with versions prior to 1.3.0
if not network and kwargs:
logger.warning(f'The following parameters ({", ".join(list(kwargs.keys()))}) are deprecated. '
"See https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html")
network = [
{
"name": "net",
"input": str(kwargs.get("input_shape", "STATES")),
"layers": kwargs.get("hiddens", []),
"activations": kwargs.get("hidden_activation", []),
}
]
if kwargs.get("output_activation", None):
output = f'{kwargs.get("output_scale", 1.0)} * {kwargs["output_activation"]}({str(kwargs.get("output_shape", "ACTIONS"))})'
else:
output = f'{kwargs.get("output_scale", 1.0)} * {str(kwargs.get("output_shape", "ACTIONS"))}'
network, output = convert_deprecated_parameters(kwargs)

# parse model definition
containers, output = generate_containers(network, output, embed_output=True, indent=1)
Expand Down
18 changes: 2 additions & 16 deletions skrl/utils/model_instantiators/torch/multivariate_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import torch
import torch.nn as nn # noqa

from skrl import logger
from skrl.models.torch import MultivariateGaussianMixin # noqa
from skrl.models.torch import Model
from skrl.utils.model_instantiators.torch.common import generate_containers
from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers


def multivariate_gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -60,20 +59,7 @@ def multivariate_gaussian_model(observation_space: Optional[Union[int, Tuple[int
"""
# compatibility with versions prior to 1.3.0
if not network and kwargs:
logger.warning(f'The following parameters ({", ".join(list(kwargs.keys()))}) are deprecated. '
"See https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html")
network = [
{
"name": "net",
"input": str(kwargs.get("input_shape", "STATES")),
"layers": kwargs.get("hiddens", []),
"activations": kwargs.get("hidden_activation", []),
}
]
if kwargs.get("output_activation", None):
output = f'{kwargs.get("output_scale", 1.0)} * {kwargs["output_activation"]}({str(kwargs.get("output_shape", "ACTIONS"))})'
else:
output = f'{kwargs.get("output_scale", 1.0)} * {str(kwargs.get("output_shape", "ACTIONS"))}'
network, output = convert_deprecated_parameters(kwargs)

# parse model definition
containers, output = generate_containers(network, output, embed_output=True, indent=1)
Expand Down
14 changes: 13 additions & 1 deletion skrl/utils/model_instantiators/torch/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from skrl.models.torch import Model # noqa
from skrl.models.torch import DeterministicMixin, GaussianMixin # noqa
from skrl.utils.model_instantiators.torch.common import generate_containers
from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers


def shared_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -47,6 +47,18 @@ def shared_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, g
:return: Shared model instance or definition source
:rtype: Model
"""
# compatibility with versions prior to 1.3.0
if not "network" in parameters[0]:
parameters[0]["network"], parameters[0]["output"] = convert_deprecated_parameters(parameters[0])
parameters[1]["network"], parameters[1]["output"] = convert_deprecated_parameters(parameters[1])
# delete deprecated parameters
for parameter in ["input_shape", "hiddens", "hidden_activation", "output_shape", "output_activation", "output_scale"]:
if parameter in parameters[0]:
del parameters[0][parameter]
if parameter in parameters[1]:
del parameters[1][parameter]

# parse model definitions
containers_gaussian, output_gaussian = generate_containers(parameters[0]["network"], parameters[0]["output"], embed_output=False, indent=1)
containers_deterministic, output_deterministic = generate_containers(parameters[1]["network"], parameters[1]["output"], embed_output=False, indent=1)

Expand Down

0 comments on commit d811651

Please sign in to comment.