Skip to content

Commit

Permalink
Update model instantiators test in jax
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Dec 1, 2024
1 parent 510a633 commit 9dd3b3b
Showing 1 changed file with 119 additions and 88 deletions.
207 changes: 119 additions & 88 deletions tests/jax/test_jax_model_instantiators.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,135 @@
import hypothesis
import hypothesis.strategies as st
import pytest

import gymnasium as gym
import yaml
from gymnasium import spaces

import jax
import jax.numpy as jnp
import numpy as np
from skrl.utils.model_instantiators.jax import categorical_model, deterministic_model, gaussian_model
from skrl.utils.spaces.jax import flatten_tensorized_space, sample_space

from skrl.utils.model_instantiators.jax import Shape, categorical_model, deterministic_model, gaussian_model

NETWORK_SPEC_OBSERVATION = {
spaces.Box: (
r"""
network:
- name: net
input: STATES
layers: [32, 32, 32]
activations: elu
""",
spaces.Box(low=-1, high=1, shape=(2,)),
),
spaces.Discrete: r"""
network:
- name: net
input: STATES
layers: [32, 32, 32]
activations: elu
""",
spaces.MultiDiscrete: r"""
network:
- name: net
input: STATES
layers: [32, 32, 32]
activations: elu
""",
spaces.Tuple: (
r"""
network:
- name: net_0
input: STATES[0]
layers: [32, 32, 32]
activations: elu
- name: net_1
input: STATES[1]
layers: [32, 32, 32]
activations: elu
- name: net
input: net_0 + net_1
layers: [32, 32, 32]
activations: elu
""",
spaces.Tuple((spaces.Box(low=-1, high=1, shape=(2,)), spaces.Box(low=-1, high=1, shape=(3,)))),
),
spaces.Dict: (
r"""
network:
- name: net_0
input: STATES["0"]
layers: [32, 32, 32]
activations: elu
- name: net_1
input: STATES["1"]
layers: [32, 32, 32]
activations: elu
- name: net
input: net_0 + net_1
layers: [32, 32, 32]
activations: elu
""",
spaces.Dict({"0": spaces.Box(low=-1, high=1, shape=(2,)), "1": spaces.Box(low=-1, high=1, shape=(3,))}),
),
}


@hypothesis.given(
observation_space_size=st.integers(min_value=1, max_value=10),
action_space_size=st.integers(min_value=1, max_value=10),
)
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
@pytest.mark.parametrize("device", [None, "cpu", "cuda:0"])
def test_categorical_model(capsys, observation_space_size, action_space_size, device):
observation_space = gym.spaces.Box(np.array([-1] * observation_space_size), np.array([1] * observation_space_size))
action_space = gym.spaces.Discrete(action_space_size)
# TODO: randomize all parameters
model = categorical_model(
observation_space=observation_space,
action_space=action_space,
device=device,
unnormalized_log_prob=True,
input_shape=Shape.STATES,
hiddens=[256, 256],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
)
model.init_state_dict("model")
def test_categorical_model(capsys, device):
# observation
action_space = spaces.Discrete(2)
for observation_space_type in [spaces.Box, spaces.Tuple, spaces.Dict]:
observation_space = NETWORK_SPEC_OBSERVATION[observation_space_type][1]
model = categorical_model(
observation_space=observation_space,
action_space=action_space,
device=device,
unnormalized_log_prob=True,
network=yaml.safe_load(NETWORK_SPEC_OBSERVATION[observation_space_type][0])["network"],
output="ACTIONS",
)
model.init_state_dict("model")

with jax.default_device(model.device):
observations = jnp.ones((10, model.num_observations))
output = model.act({"states": observations})
assert output[0].shape == (10, 1)
output = model.act({"states": flatten_tensorized_space(sample_space(observation_space, 10, "jax", device))})
assert output[0].shape == (10, 1)


@hypothesis.given(
observation_space_size=st.integers(min_value=1, max_value=10),
action_space_size=st.integers(min_value=1, max_value=10),
)
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
@pytest.mark.parametrize("device", [None, "cpu", "cuda:0"])
def test_deterministic_model(capsys, observation_space_size, action_space_size, device):
observation_space = gym.spaces.Box(np.array([-1] * observation_space_size), np.array([1] * observation_space_size))
action_space = gym.spaces.Box(np.array([-1] * action_space_size), np.array([1] * action_space_size))
# TODO: randomize all parameters
model = deterministic_model(
observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=False,
input_shape=Shape.STATES,
hiddens=[256, 256],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1,
)
model.init_state_dict("model")
def test_deterministic_model(capsys, device):
# observation
action_space = spaces.Box(low=-1, high=1, shape=(2,))
for observation_space_type in [spaces.Box, spaces.Tuple, spaces.Dict]:
observation_space = NETWORK_SPEC_OBSERVATION[observation_space_type][1]
model = deterministic_model(
observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=False,
network=yaml.safe_load(NETWORK_SPEC_OBSERVATION[observation_space_type][0])["network"],
output="ACTIONS",
)
model.init_state_dict("model")

with jax.default_device(model.device):
observations = jnp.ones((10, model.num_observations))
output = model.act({"states": observations})
assert output[0].shape == (10, model.num_actions)
output = model.act({"states": flatten_tensorized_space(sample_space(observation_space, 10, "jax", device))})
assert output[0].shape == (10, 2)


@hypothesis.given(
observation_space_size=st.integers(min_value=1, max_value=10),
action_space_size=st.integers(min_value=1, max_value=10),
)
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
@pytest.mark.parametrize("device", [None, "cpu", "cuda:0"])
def test_gaussian_model(capsys, observation_space_size, action_space_size, device):
observation_space = gym.spaces.Box(np.array([-1] * observation_space_size), np.array([1] * observation_space_size))
action_space = gym.spaces.Box(np.array([-1] * action_space_size), np.array([1] * action_space_size))
# TODO: randomize all parameters
model = gaussian_model(
observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=False,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
initial_log_std=0,
input_shape=Shape.STATES,
hiddens=[256, 256],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1,
)
model.init_state_dict("model")
def test_gaussian_model(capsys, device):
# observation
action_space = spaces.Box(low=-1, high=1, shape=(2,))
for observation_space_type in [spaces.Box, spaces.Tuple, spaces.Dict]:
observation_space = NETWORK_SPEC_OBSERVATION[observation_space_type][1]
model = gaussian_model(
observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=False,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
initial_log_std=0,
network=yaml.safe_load(NETWORK_SPEC_OBSERVATION[observation_space_type][0])["network"],
output="ACTIONS",
)
model.init_state_dict("model")

with jax.default_device(model.device):
observations = jnp.ones((10, model.num_observations))
output = model.act({"states": observations})
assert output[0].shape == (10, model.num_actions)
output = model.act({"states": flatten_tensorized_space(sample_space(observation_space, 10, "jax", device))})
assert output[0].shape == (10, 2)

0 comments on commit 9dd3b3b

Please sign in to comment.