diff --git a/skrl/__init__.py b/skrl/__init__.py index 7a027642..068206ee 100644 --- a/skrl/__init__.py +++ b/skrl/__init__.py @@ -72,7 +72,8 @@ def key(self) -> "jax.Array": if isinstance(self._key, np.ndarray): try: import jax - self._key = jax.random.PRNGKey(self._key[1]) + with jax.default_device(jax.devices("cpu")[0]): + self._key = jax.random.PRNGKey(self._key[1]) except ImportError: pass return self._key @@ -83,7 +84,8 @@ def key(self, value: Union[int, "jax.Array"]) -> None: # don't import JAX if it has not been imported before if "jax" in sys.modules: import jax - value = jax.random.PRNGKey(value) + with jax.default_device(jax.devices("cpu")[0]): + value = jax.random.PRNGKey(value) else: value = np.array([0, value], dtype=np.uint32) self._key = value diff --git a/skrl/agents/jax/a2c/a2c.py b/skrl/agents/jax/a2c/a2c.py index b7d59308..378a84b7 100644 --- a/skrl/agents/jax/a2c/a2c.py +++ b/skrl/agents/jax/a2c/a2c.py @@ -251,8 +251,9 @@ def __init__(self, else: self._learning_rate = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"]) # optimizer - self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) - self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) + with jax.default_device(self.device): + self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) + self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["value_optimizer"] = self.value_optimizer diff --git a/skrl/agents/jax/base.py b/skrl/agents/jax/base.py index fbcb9fa5..f5de7d2a 100644 --- a/skrl/agents/jax/base.py +++ b/skrl/agents/jax/base.py @@ -477,7 +477,9 @@ def post_interaction(self, timestep: int, timesteps: int) -> None: self.checkpoint_best_modules["timestep"] = timestep self.checkpoint_best_modules["reward"] = reward self.checkpoint_best_modules["saved"] = False - self.checkpoint_best_modules["modules"] = {k: copy.deepcopy(self._get_internal_value(v)) for k, v in self.checkpoint_modules.items()} + with jax.default_device(self.device): + self.checkpoint_best_modules["modules"] = \ + {k: copy.deepcopy(self._get_internal_value(v)) for k, v in self.checkpoint_modules.items()} # write checkpoints self.write_checkpoint(timestep, timesteps) diff --git a/skrl/agents/jax/cem/cem.py b/skrl/agents/jax/cem/cem.py index a5606a70..a597cc0e 100644 --- a/skrl/agents/jax/cem/cem.py +++ b/skrl/agents/jax/cem/cem.py @@ -117,7 +117,8 @@ def __init__(self, # set up optimizer and learning rate scheduler if self.policy is not None: - self.optimizer = Adam(model=self.policy, lr=self._learning_rate) + with jax.default_device(self.device): + self.optimizer = Adam(model=self.policy, lr=self._learning_rate) if self._learning_rate_scheduler is not None: self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) diff --git a/skrl/agents/jax/ddpg/ddpg.py b/skrl/agents/jax/ddpg/ddpg.py index 6b7f3a03..7c1a811e 100644 --- a/skrl/agents/jax/ddpg/ddpg.py +++ b/skrl/agents/jax/ddpg/ddpg.py @@ -188,8 +188,9 @@ def __init__(self, # set up optimizers and learning rate schedulers if self.policy is not None and self.critic is not None: - self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip) - self.critic_optimizer = Adam(model=self.critic, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) + with jax.default_device(self.device): + self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip) + self.critic_optimizer = Adam(model=self.critic, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) if self._learning_rate_scheduler is not None: self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) self.critic_scheduler = self._learning_rate_scheduler(self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) diff --git a/skrl/agents/jax/dqn/ddqn.py b/skrl/agents/jax/dqn/ddqn.py index 6de9520d..5bd17fc4 100644 --- a/skrl/agents/jax/dqn/ddqn.py +++ b/skrl/agents/jax/dqn/ddqn.py @@ -161,7 +161,8 @@ def __init__(self, # set up optimizer and learning rate scheduler if self.q_network is not None: - self.optimizer = Adam(model=self.q_network, lr=self._learning_rate) + with jax.default_device(self.device): + self.optimizer = Adam(model=self.q_network, lr=self._learning_rate) if self._learning_rate_scheduler is not None: self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) diff --git a/skrl/agents/jax/dqn/dqn.py b/skrl/agents/jax/dqn/dqn.py index 331532be..6bcfac72 100644 --- a/skrl/agents/jax/dqn/dqn.py +++ b/skrl/agents/jax/dqn/dqn.py @@ -158,7 +158,8 @@ def __init__(self, # set up optimizer and learning rate scheduler if self.q_network is not None: - self.optimizer = Adam(model=self.q_network, lr=self._learning_rate) + with jax.default_device(self.device): + self.optimizer = Adam(model=self.q_network, lr=self._learning_rate) if self._learning_rate_scheduler is not None: self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) diff --git a/skrl/agents/jax/ppo/ppo.py b/skrl/agents/jax/ppo/ppo.py index 4352ccc5..f8a09dd1 100644 --- a/skrl/agents/jax/ppo/ppo.py +++ b/skrl/agents/jax/ppo/ppo.py @@ -277,8 +277,9 @@ def __init__(self, else: self._learning_rate = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"]) # optimizer - self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) - self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) + with jax.default_device(self.device): + self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) + self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["value_optimizer"] = self.value_optimizer diff --git a/skrl/agents/jax/rpo/rpo.py b/skrl/agents/jax/rpo/rpo.py index cbbc31cc..deba2dcd 100644 --- a/skrl/agents/jax/rpo/rpo.py +++ b/skrl/agents/jax/rpo/rpo.py @@ -281,8 +281,9 @@ def __init__(self, else: self._learning_rate = self._learning_rate_scheduler(self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"]) # optimizer - self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) - self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) + with jax.default_device(self.device): + self.policy_optimizer = Adam(model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) + self.value_optimizer = Adam(model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale) self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer self.checkpoint_modules["value_optimizer"] = self.value_optimizer diff --git a/skrl/agents/jax/sac/sac.py b/skrl/agents/jax/sac/sac.py index 0dbfc480..fd157616 100644 --- a/skrl/agents/jax/sac/sac.py +++ b/skrl/agents/jax/sac/sac.py @@ -219,16 +219,18 @@ class StateDict(flax.struct.PyTreeNode): def value(self): return self.state_dict.params["params"] - self.log_entropy_coefficient = _LogEntropyCoefficient(self._entropy_coefficient) - self.entropy_optimizer = Adam(model=self.log_entropy_coefficient, lr=self._entropy_learning_rate) + with jax.default_device(self.device): + self.log_entropy_coefficient = _LogEntropyCoefficient(self._entropy_coefficient) + self.entropy_optimizer = Adam(model=self.log_entropy_coefficient, lr=self._entropy_learning_rate) self.checkpoint_modules["entropy_optimizer"] = self.entropy_optimizer # set up optimizers and learning rate schedulers if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: - self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip) - self.critic_1_optimizer = Adam(model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) - self.critic_2_optimizer = Adam(model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) + with jax.default_device(self.device): + self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip) + self.critic_1_optimizer = Adam(model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) + self.critic_2_optimizer = Adam(model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) if self._learning_rate_scheduler is not None: self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) self.critic_1_scheduler = self._learning_rate_scheduler(self.critic_1_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) diff --git a/skrl/agents/jax/td3/td3.py b/skrl/agents/jax/td3/td3.py index 34cd0846..3248aebf 100644 --- a/skrl/agents/jax/td3/td3.py +++ b/skrl/agents/jax/td3/td3.py @@ -219,9 +219,10 @@ def __init__(self, # set up optimizers and learning rate schedulers if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: - self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip) - self.critic_1_optimizer = Adam(model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) - self.critic_2_optimizer = Adam(model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) + with jax.default_device(self.device): + self.policy_optimizer = Adam(model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip) + self.critic_1_optimizer = Adam(model=self.critic_1, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) + self.critic_2_optimizer = Adam(model=self.critic_2, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip) if self._learning_rate_scheduler is not None: self.policy_scheduler = self._learning_rate_scheduler(self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) self.critic_1_scheduler = self._learning_rate_scheduler(self.critic_1_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) diff --git a/skrl/memories/jax/base.py b/skrl/memories/jax/base.py index 90783836..c87af33b 100644 --- a/skrl/memories/jax/base.py +++ b/skrl/memories/jax/base.py @@ -239,18 +239,21 @@ def create_tensor(self, view_shape = (-1, *size) if keep_dimensions else (-1, size) # create tensor (_tensor_) and add it to the internal storage if self._jax: - setattr(self, f"_tensor_{name}", jnp.zeros(tensor_shape, dtype=dtype)) + with jax.default_device(self.device): + setattr(self, f"_tensor_{name}", jnp.zeros(tensor_shape, dtype=dtype)) else: setattr(self, f"_tensor_{name}", np.zeros(tensor_shape, dtype=dtype)) # update internal variables self.tensors[name] = getattr(self, f"_tensor_{name}") - self.tensors_view[name] = self.tensors[name].reshape(*view_shape) + with jax.default_device(self.device): + self.tensors_view[name] = self.tensors[name].reshape(*view_shape) self.tensors_keep_dimensions[name] = keep_dimensions # fill the tensors (float tensors) with NaN for name, tensor in self.tensors.items(): if tensor.dtype == np.float32 or tensor.dtype == np.float64: if self._jax: - self.tensors[name] = _copyto(self.tensors[name], float("nan")) + with jax.default_device(self.device): + self.tensors[name] = _copyto(self.tensors[name], float("nan")) else: self.tensors[name].fill(float("nan")) # check views diff --git a/skrl/models/jax/base.py b/skrl/models/jax/base.py index 7d121e25..9a3738a5 100644 --- a/skrl/models/jax/base.py +++ b/skrl/models/jax/base.py @@ -122,7 +122,8 @@ def init_state_dict(self, if isinstance(inputs["states"], (int, np.int32, np.int64)): inputs["states"] = np.array(inputs["states"]).reshape(-1,1) # init internal state dict - self.state_dict = StateDict.create(apply_fn=self.apply, params=self.init(key, inputs, role)) + with jax.default_device(self.device): + self.state_dict = StateDict.create(apply_fn=self.apply, params=self.init(key, inputs, role)) def _get_space_size(self, space: Union[int, Sequence[int], gym.Space, gymnasium.Space], diff --git a/skrl/models/jax/categorical.py b/skrl/models/jax/categorical.py index 95b0331a..c9aa11a2 100644 --- a/skrl/models/jax/categorical.py +++ b/skrl/models/jax/categorical.py @@ -122,9 +122,10 @@ def act(self, >>> print(actions.shape, log_prob.shape, outputs["net_output"].shape) (4096, 1) (4096, 1) (4096, 2) """ - self._i += 1 - subkey = jax.random.fold_in(self._key, self._i) - inputs["key"] = subkey + with jax.default_device(self.device): + self._i += 1 + subkey = jax.random.fold_in(self._key, self._i) + inputs["key"] = subkey # map from states/observations to normalized probabilities or unnormalized log probabilities net_output, outputs = self.apply(self.state_dict.params if params is None else params, inputs, role) diff --git a/skrl/models/jax/gaussian.py b/skrl/models/jax/gaussian.py index fce2d90c..53245372 100644 --- a/skrl/models/jax/gaussian.py +++ b/skrl/models/jax/gaussian.py @@ -173,9 +173,10 @@ def act(self, >>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape) (4096, 8) (4096, 1) (4096, 8) """ - self._i += 1 - subkey = jax.random.fold_in(self._key, self._i) - inputs["key"] = subkey + with jax.default_device(self.device): + self._i += 1 + subkey = jax.random.fold_in(self._key, self._i) + inputs["key"] = subkey # map from states/observations to mean actions and log standard deviations mean_actions, log_std, outputs = self.apply(self.state_dict.params if params is None else params, inputs, role) diff --git a/skrl/models/jax/multicategorical.py b/skrl/models/jax/multicategorical.py index 9ad1bb3e..b3a91e84 100644 --- a/skrl/models/jax/multicategorical.py +++ b/skrl/models/jax/multicategorical.py @@ -136,9 +136,10 @@ def act(self, >>> print(actions.shape, log_prob.shape, outputs["net_output"].shape) (4096, 2) (4096, 1) (4096, 5) """ - self._i += 1 - subkey = jax.random.fold_in(self._key, self._i) - inputs["key"] = subkey + with jax.default_device(self.device): + self._i += 1 + subkey = jax.random.fold_in(self._key, self._i) + inputs["key"] = subkey # map from states/observations to normalized probabilities or unnormalized log probabilities net_output, outputs = self.apply(self.state_dict.params if params is None else params, inputs, role) diff --git a/skrl/resources/preprocessors/jax/running_standard_scaler.py b/skrl/resources/preprocessors/jax/running_standard_scaler.py index ace65f06..97d5eb63 100644 --- a/skrl/resources/preprocessors/jax/running_standard_scaler.py +++ b/skrl/resources/preprocessors/jax/running_standard_scaler.py @@ -103,9 +103,10 @@ def __init__(self, size = self._get_space_size(size) if self._jax: - self.running_mean = jnp.zeros(size, dtype=jnp.float32) - self.running_variance = jnp.ones(size, dtype=jnp.float32) - self.current_count = jnp.ones((1,), dtype=jnp.float32) + with jax.default_device(self.device): + self.running_mean = jnp.zeros(size, dtype=jnp.float32) + self.running_variance = jnp.ones(size, dtype=jnp.float32) + self.current_count = jnp.ones((1,), dtype=jnp.float32) else: self.running_mean = np.zeros(size, dtype=np.float32) self.running_variance = np.ones(size, dtype=np.float32)