Skip to content

Commit

Permalink
Perform JAX computation on the selected device
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jun 23, 2024
1 parent 5559069 commit 2d06c89
Show file tree
Hide file tree
Showing 17 changed files with 60 additions and 38 deletions.
6 changes: 4 additions & 2 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def key(self) -> "jax.Array":
if isinstance(self._key, np.ndarray):
try:
import jax
self._key = jax.random.PRNGKey(self._key[1])
with jax.default_device(jax.devices("cpu")[0]):
self._key = jax.random.PRNGKey(self._key[1])
except ImportError:
pass
return self._key
Expand All @@ -83,7 +84,8 @@ def key(self, value: Union[int, "jax.Array"]) -> None:
# don't import JAX if it has not been imported before
if "jax" in sys.modules:
import jax
value = jax.random.PRNGKey(value)
with jax.default_device(jax.devices("cpu")[0]):
value = jax.random.PRNGKey(value)
else:
value = np.array([0, value], dtype=np.uint32)
self._key = value
Expand Down
5 changes: 3 additions & 2 deletions skrl/agents/jax/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,9 @@ def __init__(self,
else:
self._learning_rate = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"])
# optimizer
self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
with jax.default_device(self.device):
self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)

self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer
self.checkpoint_modules["value_optimizer"] = self.value_optimizer
Expand Down
4 changes: 3 additions & 1 deletion skrl/agents/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,9 @@ def post_interaction(self, timestep: int, timesteps: int) -> None:
self.checkpoint_best_modules["timestep"] = timestep
self.checkpoint_best_modules["reward"] = reward
self.checkpoint_best_modules["saved"] = False
self.checkpoint_best_modules["modules"] = {k: copy.deepcopy(self._get_internal_value(v)) for k, v in self.checkpoint_modules.items()}
with jax.default_device(self.device):
self.checkpoint_best_modules["modules"] = \
{k: copy.deepcopy(self._get_internal_value(v)) for k, v in self.checkpoint_modules.items()}
# write checkpoints
self.write_checkpoint(timestep, timesteps)

Expand Down
3 changes: 2 additions & 1 deletion skrl/agents/jax/cem/cem.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def __init__(self,

# set up optimizer and learning rate scheduler
if self.policy is not None:
self.optimizer = Adam(model=self.policy, lr=self._learning_rate)
with jax.default_device(self.device):
self.optimizer = Adam(model=self.policy, lr=self._learning_rate)
if self._learning_rate_scheduler is not None:
self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"])

Expand Down
5 changes: 3 additions & 2 deletions skrl/agents/jax/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,9 @@ def __init__(self,

# set up optimizers and learning rate schedulers
if self.policy is not None and self.critic is not None:
self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip)
self.critic_optimizer = Adam(model=self.critic, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
with jax.default_device(self.device):
self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip)
self.critic_optimizer = Adam(model=self.critic, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
if self._learning_rate_scheduler is not None:
self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
self.critic_scheduler = self._learning_rate_scheduler(self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
Expand Down
3 changes: 2 additions & 1 deletion skrl/agents/jax/dqn/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def __init__(self,

# set up optimizer and learning rate scheduler
if self.q_network is not None:
self.optimizer = Adam(model=self.q_network, lr=self._learning_rate)
with jax.default_device(self.device):
self.optimizer = Adam(model=self.q_network, lr=self._learning_rate)
if self._learning_rate_scheduler is not None:
self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"])

Expand Down
3 changes: 2 additions & 1 deletion skrl/agents/jax/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def __init__(self,

# set up optimizer and learning rate scheduler
if self.q_network is not None:
self.optimizer = Adam(model=self.q_network, lr=self._learning_rate)
with jax.default_device(self.device):
self.optimizer = Adam(model=self.q_network, lr=self._learning_rate)
if self._learning_rate_scheduler is not None:
self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"])

Expand Down
5 changes: 3 additions & 2 deletions skrl/agents/jax/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,9 @@ def __init__(self,
else:
self._learning_rate = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"])
# optimizer
self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
with jax.default_device(self.device):
self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)

self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer
self.checkpoint_modules["value_optimizer"] = self.value_optimizer
Expand Down
5 changes: 3 additions & 2 deletions skrl/agents/jax/rpo/rpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,9 @@ def __init__(self,
else:
self._learning_rate = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"])
# optimizer
self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
with jax.default_device(self.device):
self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)
self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale)

self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer
self.checkpoint_modules["value_optimizer"] = self.value_optimizer
Expand Down
12 changes: 7 additions & 5 deletions skrl/agents/jax/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,18 @@ class StateDict(flax.struct.PyTreeNode):
def value(self):
return self.state_dict.params["params"]

self.log_entropy_coefficient = _LogEntropyCoefficient(self._entropy_coefficient)
self.entropy_optimizer = Adam(model=self.log_entropy_coefficient, lr=self._entropy_learning_rate)
with jax.default_device(self.device):
self.log_entropy_coefficient = _LogEntropyCoefficient(self._entropy_coefficient)
self.entropy_optimizer = Adam(model=self.log_entropy_coefficient, lr=self._entropy_learning_rate)

self.checkpoint_modules["entropy_optimizer"] = self.entropy_optimizer

# set up optimizers and learning rate schedulers
if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None:
self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip)
self.critic_1_optimizer = Adam(model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
self.critic_2_optimizer = Adam(model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
with jax.default_device(self.device):
self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip)
self.critic_1_optimizer = Adam(model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
self.critic_2_optimizer = Adam(model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
if self._learning_rate_scheduler is not None:
self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
self.critic_1_scheduler = self._learning_rate_scheduler(self.critic_1_optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
Expand Down
7 changes: 4 additions & 3 deletions skrl/agents/jax/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,10 @@ def __init__(self,

# set up optimizers and learning rate schedulers
if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None:
self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip)
self.critic_1_optimizer = Adam(model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
self.critic_2_optimizer = Adam(model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
with jax.default_device(self.device):
self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip)
self.critic_1_optimizer = Adam(model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
self.critic_2_optimizer = Adam(model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip)
if self._learning_rate_scheduler is not None:
self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
self.critic_1_scheduler = self._learning_rate_scheduler(self.critic_1_optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
Expand Down
9 changes: 6 additions & 3 deletions skrl/memories/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,21 @@ def create_tensor(self,
view_shape = (-1, *size) if keep_dimensions else (-1, size)
# create tensor (_tensor_<name>) and add it to the internal storage
if self._jax:
setattr(self, f"_tensor_{name}", jnp.zeros(tensor_shape, dtype=dtype))
with jax.default_device(self.device):
setattr(self, f"_tensor_{name}", jnp.zeros(tensor_shape, dtype=dtype))
else:
setattr(self, f"_tensor_{name}", np.zeros(tensor_shape, dtype=dtype))
# update internal variables
self.tensors[name] = getattr(self, f"_tensor_{name}")
self.tensors_view[name] = self.tensors[name].reshape(*view_shape)
with jax.default_device(self.device):
self.tensors_view[name] = self.tensors[name].reshape(*view_shape)
self.tensors_keep_dimensions[name] = keep_dimensions
# fill the tensors (float tensors) with NaN
for name, tensor in self.tensors.items():
if tensor.dtype == np.float32 or tensor.dtype == np.float64:
if self._jax:
self.tensors[name] = _copyto(self.tensors[name], float("nan"))
with jax.default_device(self.device):
self.tensors[name] = _copyto(self.tensors[name], float("nan"))
else:
self.tensors[name].fill(float("nan"))
# check views
Expand Down
3 changes: 2 additions & 1 deletion skrl/models/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def init_state_dict(self,
if isinstance(inputs["states"], (int, np.int32, np.int64)):
inputs["states"] = np.array(inputs["states"]).reshape(-1,1)
# init internal state dict
self.state_dict = StateDict.create(apply_fn=self.apply, params=self.init(key, inputs, role))
with jax.default_device(self.device):
self.state_dict = StateDict.create(apply_fn=self.apply, params=self.init(key, inputs, role))

def _get_space_size(self,
space: Union[int, Sequence[int], gym.Space, gymnasium.Space],
Expand Down
7 changes: 4 additions & 3 deletions skrl/models/jax/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,10 @@ def act(self,
>>> print(actions.shape, log_prob.shape, outputs["net_output"].shape)
(4096, 1) (4096, 1) (4096, 2)
"""
self._i += 1
subkey = jax.random.fold_in(self._key, self._i)
inputs["key"] = subkey
with jax.default_device(self.device):
self._i += 1
subkey = jax.random.fold_in(self._key, self._i)
inputs["key"] = subkey

# map from states/observations to normalized probabilities or unnormalized log probabilities
net_output, outputs = self.apply(self.state_dict.params if params is None else params, inputs, role)
Expand Down
7 changes: 4 additions & 3 deletions skrl/models/jax/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,10 @@ def act(self,
>>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape)
(4096, 8) (4096, 1) (4096, 8)
"""
self._i += 1
subkey = jax.random.fold_in(self._key, self._i)
inputs["key"] = subkey
with jax.default_device(self.device):
self._i += 1
subkey = jax.random.fold_in(self._key, self._i)
inputs["key"] = subkey

# map from states/observations to mean actions and log standard deviations
mean_actions, log_std, outputs = self.apply(self.state_dict.params if params is None else params, inputs, role)
Expand Down
7 changes: 4 additions & 3 deletions skrl/models/jax/multicategorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,10 @@ def act(self,
>>> print(actions.shape, log_prob.shape, outputs["net_output"].shape)
(4096, 2) (4096, 1) (4096, 5)
"""
self._i += 1
subkey = jax.random.fold_in(self._key, self._i)
inputs["key"] = subkey
with jax.default_device(self.device):
self._i += 1
subkey = jax.random.fold_in(self._key, self._i)
inputs["key"] = subkey

# map from states/observations to normalized probabilities or unnormalized log probabilities
net_output, outputs = self.apply(self.state_dict.params if params is None else params, inputs, role)
Expand Down
7 changes: 4 additions & 3 deletions skrl/resources/preprocessors/jax/running_standard_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ def __init__(self,
size = self._get_space_size(size)

if self._jax:
self.running_mean = jnp.zeros(size, dtype=jnp.float32)
self.running_variance = jnp.ones(size, dtype=jnp.float32)
self.current_count = jnp.ones((1,), dtype=jnp.float32)
with jax.default_device(self.device):
self.running_mean = jnp.zeros(size, dtype=jnp.float32)
self.running_variance = jnp.ones(size, dtype=jnp.float32)
self.current_count = jnp.ones((1,), dtype=jnp.float32)
else:
self.running_mean = np.zeros(size, dtype=np.float32)
self.running_variance = np.ones(size, dtype=np.float32)
Expand Down

0 comments on commit 2d06c89

Please sign in to comment.