Skip to content

Commit

Permalink
Get JAX device from string
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jun 22, 2024
1 parent 4a9bb1e commit 1ad585e
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 16 deletions.
5 changes: 4 additions & 1 deletion skrl/agents/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def __init__(self,
if device is None:
self.device = jax.devices()[0]
else:
self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0]
self.device = device
if type(device) == str:
device_type, device_index = f"{device}:0".split(':')[:2]
self.device = jax.devices(device_type)[int(device_index)]

if type(memory) is list:
self.memory = memory[0]
Expand Down
33 changes: 23 additions & 10 deletions skrl/envs/wrappers/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,19 @@ def __init__(self, env: Any) -> None:
self._env = env

# device (faster than @property)
self.device = jax.devices()[0]
self.device = None
if hasattr(self._env, "device"):
try:
self.device = jax.devices(self._env.device.split(':')[0] if type(self._env.device) == str else self._env.device.type)[0]
except RuntimeError:
pass
if type(self._env.device) == str:
device_type, device_index = f"{self._env.device}:0".split(':')[:2]
try:
self.device = jax.devices(device_type)[int(device_index)]
except RuntimeError:
self.device = None
else:
self.device = self._env.device
if self.device is None:
self.device = jax.devices()[0]

# spaces
try:
self._action_space = self._env.single_action_space
Expand Down Expand Up @@ -135,12 +142,18 @@ def __init__(self, env: Any) -> None:
self._env = env

# device (faster than @property)
self.device = jax.devices()[0]
self.device = None
if hasattr(self._env, "device"):
try:
self.device = jax.devices(self._env.device.split(':')[0] if type(self._env.device) == str else self._env.device.type)[0]
except RuntimeError:
pass
if type(self._env.device) == str:
device_type, device_index = f"{self._env.device}:0".split(':')[:2]
try:
self.device = jax.devices(device_type)[int(device_index)]
except RuntimeError:
self.device = None
else:
self.device = self._env.device
if self.device is None:
self.device = jax.devices()[0]

self.possible_agents = []

Expand Down
5 changes: 4 additions & 1 deletion skrl/memories/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def __init__(self,
if device is None:
self.device = jax.devices()[0]
else:
self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0]
self.device = device
if type(device) == str:
device_type, device_index = f"{device}:0".split(':')[:2]
self.device = jax.devices(device_type)[int(device_index)]

# internal variables
self.filled = False
Expand Down
5 changes: 4 additions & 1 deletion skrl/models/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def __call__(self, inputs, role):
if device is None:
self.device = jax.devices()[0]
else:
self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0]
self.device = device
if type(device) == str:
device_type, device_index = f"{device}:0".split(':')[:2]
self.device = jax.devices(device_type)[int(device_index)]

self.observation_space = observation_space
self.action_space = action_space
Expand Down
5 changes: 4 additions & 1 deletion skrl/multi_agents/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def __init__(self,
if device is None:
self.device = jax.devices()[0]
else:
self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0]
self.device = device
if type(device) == str:
device_type, device_index = f"{device}:0".split(':')[:2]
self.device = jax.devices(device_type)[int(device_index)]

# convert the models to their respective device
for _models in self.models.values():
Expand Down
5 changes: 4 additions & 1 deletion skrl/resources/noises/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def sample(self, size):
if device is None:
self.device = jax.devices()[0]
else:
self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0]
self.device = device
if type(device) == str:
device_type, device_index = f"{device}:0".split(':')[:2]
self.device = jax.devices(device_type)[int(device_index)]

def sample_like(self, tensor: Union[np.ndarray, jax.Array]) -> Union[np.ndarray, jax.Array]:
"""Sample a noise with the same size (shape) as the input tensor
Expand Down
5 changes: 4 additions & 1 deletion skrl/resources/preprocessors/jax/running_standard_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def __init__(self,
if device is None:
self.device = jax.devices()[0]
else:
self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0]
self.device = device
if type(device) == str:
device_type, device_index = f"{device}:0".split(':')[:2]
self.device = jax.devices(device_type)[int(device_index)]

size = self._get_space_size(size)

Expand Down

0 comments on commit 1ad585e

Please sign in to comment.