Skip to content

Commit

Permalink
Use ML framework configuration device parsing method to parse devices
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Dec 5, 2024
1 parent a08bf91 commit cf05a50
Show file tree
Hide file tree
Showing 14 changed files with 31 additions and 90 deletions.
8 changes: 1 addition & 7 deletions skrl/agents/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,7 @@ def __init__(
self.action_space = action_space
self.cfg = cfg if cfg is not None else {}

if device is None:
self.device = jax.devices()[0]
else:
self.device = device
if type(device) == str:
device_type, device_index = f"{device}:0".split(":")[:2]
self.device = jax.devices(device_type)[int(device_index)]
self.device = config.jax.parse_device(device)

if type(memory) is list:
self.memory = memory[0]
Expand Down
5 changes: 2 additions & 3 deletions skrl/agents/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ def __init__(
self.observation_space = observation_space
self.action_space = action_space
self.cfg = cfg if cfg is not None else {}
self.device = (
torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
)

self.device = config.torch.parse_device(device)

if type(memory) is list:
self.memory = memory[0]
Expand Down
28 changes: 6 additions & 22 deletions skrl/envs/wrappers/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,10 @@ def __init__(self, env: Any) -> None:
self._unwrapped = env

# device
self._device = None
if hasattr(self._unwrapped, "device"):
if type(self._unwrapped.device) == str:
device_type, device_index = f"{self._unwrapped.device}:0".split(":")[:2]
try:
self._device = jax.devices(device_type)[int(device_index)]
except (RuntimeError, IndexError):
self._device = None
else:
self._device = self._unwrapped.device
if self._device is None:
self._device = jax.devices()[0]
self._device = config.jax.parse_device(self._unwrapped.device)
else:
self._device = config.jax.parse_device(None)

def __getattr__(self, key: str) -> Any:
"""Get an attribute from the wrapped environment
Expand Down Expand Up @@ -172,18 +164,10 @@ def __init__(self, env: Any) -> None:
self._unwrapped = env

# device
self._device = None
if hasattr(self._unwrapped, "device"):
if type(self._unwrapped.device) == str:
device_type, device_index = f"{self._unwrapped.device}:0".split(":")[:2]
try:
self._device = jax.devices(device_type)[int(device_index)]
except (RuntimeError, IndexError):
self._device = None
else:
self._device = self._unwrapped.device
if self._device is None:
self._device = jax.devices()[0]
self._device = config.jax.parse_device(self._unwrapped.device)
else:
self._device = config.jax.parse_device(None)

def __getattr__(self, key: str) -> Any:
"""Get an attribute from the wrapped environment
Expand Down
10 changes: 6 additions & 4 deletions skrl/envs/wrappers/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import torch

from skrl import config


class Wrapper(object):
def __init__(self, env: Any) -> None:
Expand All @@ -20,9 +22,9 @@ def __init__(self, env: Any) -> None:

# device
if hasattr(self._unwrapped, "device"):
self._device = torch.device(self._unwrapped.device)
self._device = config.torch.parse_device(self._unwrapped.device)
else:
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._device = config.torch.parse_device(None)

def __getattr__(self, key: str) -> Any:
"""Get an attribute from the wrapped environment
Expand Down Expand Up @@ -152,9 +154,9 @@ def __init__(self, env: Any) -> None:

# device
if hasattr(self._unwrapped, "device"):
self._device = torch.device(self._unwrapped.device)
self._device = config.torch.parse_device(self._unwrapped.device)
else:
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._device = config.torch.parse_device(None)

def __getattr__(self, key: str) -> Any:
"""Get an attribute from the wrapped environment
Expand Down
8 changes: 1 addition & 7 deletions skrl/memories/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,7 @@ def __init__(

self.memory_size = memory_size
self.num_envs = num_envs
if device is None:
self.device = jax.devices()[0]
else:
self.device = device
if type(device) == str:
device_type, device_index = f"{device}:0".split(":")[:2]
self.device = jax.devices(device_type)[int(device_index)]
self.device = config.jax.parse_device(device)

# internal variables
self.filled = False
Expand Down
5 changes: 2 additions & 3 deletions skrl/memories/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from torch.utils.data.sampler import BatchSampler

from skrl import config
from skrl.utils.spaces.torch import compute_space_size


Expand Down Expand Up @@ -50,9 +51,7 @@ def __init__(
"""
self.memory_size = memory_size
self.num_envs = num_envs
self.device = (
torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
)
self.device = config.torch.parse_device(device)

# internal variables
self.filled = False
Expand Down
8 changes: 1 addition & 7 deletions skrl/models/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,7 @@ def __call__(self, inputs, role):
"""
self._jax = config.jax.backend == "jax"

if device is None:
self.device = jax.devices()[0]
else:
self.device = device
if type(device) == str:
device_type, device_index = f"{device}:0".split(":")[:2]
self.device = jax.devices(device_type)[int(device_index)]
self.device = config.jax.parse_device(device)

self.observation_space = observation_space
self.action_space = action_space
Expand Down
4 changes: 1 addition & 3 deletions skrl/models/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def act(self, inputs, role=""):
"""
super(Model, self).__init__()

self.device = (
torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
)
self.device = config.torch.parse_device(device)

self.observation_space = observation_space
self.action_space = action_space
Expand Down
9 changes: 1 addition & 8 deletions skrl/multi_agents/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,9 @@ def __init__(
self.memories = memories
self.observation_spaces = observation_spaces
self.action_spaces = action_spaces

self.cfg = cfg if cfg is not None else {}

if device is None:
self.device = jax.devices()[0]
else:
self.device = device
if type(device) == str:
device_type, device_index = f"{device}:0".split(":")[:2]
self.device = jax.devices(device_type)[int(device_index)]
self.device = config.jax.parse_device(device)

# convert the models to their respective device
for _models in self.models.values():
Expand Down
6 changes: 2 additions & 4 deletions skrl/multi_agents/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,9 @@ def __init__(
self.memories = memories
self.observation_spaces = observation_spaces
self.action_spaces = action_spaces

self.cfg = cfg if cfg is not None else {}
self.device = (
torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
)

self.device = config.torch.parse_device(device)

# convert the models to their respective device
for _models in self.models.values():
Expand Down
8 changes: 1 addition & 7 deletions skrl/resources/noises/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,7 @@ def sample(self, size):
"""
self._jax = config.jax.backend == "jax"

if device is None:
self.device = jax.devices()[0]
else:
self.device = device
if type(device) == str:
device_type, device_index = f"{device}:0".split(":")[:2]
self.device = jax.devices(device_type)[int(device_index)]
self.device = config.jax.parse_device(device)

def sample_like(self, tensor: Union[np.ndarray, jax.Array]) -> Union[np.ndarray, jax.Array]:
"""Sample a noise with the same size (shape) as the input tensor
Expand Down
7 changes: 3 additions & 4 deletions skrl/resources/noises/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from skrl import config


class Noise:
def __init__(self, device: Optional[Union[str, torch.device]] = None) -> None:
Expand All @@ -23,10 +25,7 @@ def __init__(self, device=None):
def sample(self, size):
return torch.rand(size, device=self.device)
"""
if device is None:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
self.device = config.torch.parse_device(device)

def sample_like(self, tensor: torch.Tensor) -> torch.Tensor:
"""Sample a noise with the same size (shape) as the input tensor
Expand Down
8 changes: 1 addition & 7 deletions skrl/resources/preprocessors/jax/running_standard_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,7 @@ def __init__(

self.epsilon = epsilon
self.clip_threshold = clip_threshold
if device is None:
self.device = jax.devices()[0]
else:
self.device = device
if type(device) == str:
device_type, device_index = f"{device}:0".split(":")[:2]
self.device = jax.devices(device_type)[int(device_index)]
self.device = config.jax.parse_device(device)

size = compute_space_size(size, occupied_size=True)

Expand Down
7 changes: 3 additions & 4 deletions skrl/resources/preprocessors/torch/running_standard_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn as nn

from skrl import config
from skrl.utils.spaces.torch import compute_space_size


Expand Down Expand Up @@ -44,10 +45,8 @@ def __init__(

self.epsilon = epsilon
self.clip_threshold = clip_threshold
if device is None:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)

self.device = config.torch.parse_device(device)

size = compute_space_size(size, occupied_size=True)

Expand Down

0 comments on commit cf05a50

Please sign in to comment.