Skip to content

Commit

Permalink
Catch invalid device index
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jun 22, 2024
1 parent 1ad585e commit 5559069
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions skrl/envs/wrappers/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, env: Any) -> None:
device_type, device_index = f"{self._env.device}:0".split(':')[:2]
try:
self.device = jax.devices(device_type)[int(device_index)]
except RuntimeError:
except (RuntimeError, IndexError):
self.device = None
else:
self.device = self._env.device
Expand Down Expand Up @@ -148,7 +148,7 @@ def __init__(self, env: Any) -> None:
device_type, device_index = f"{self._env.device}:0".split(':')[:2]
try:
self.device = jax.devices(device_type)[int(device_index)]
except RuntimeError:
except (RuntimeError, IndexError):
self.device = None
else:
self.device = self._env.device
Expand Down

0 comments on commit 5559069

Please sign in to comment.