Skip to content

Commit

Permalink
Move the PyTorch distributed initialization from the agent base class…
Browse files Browse the repository at this point in the history
… to the ML framework configuration (#176)

* Move distributed initialization to ML framework configuration

* Update CHANGELOG
  • Loading branch information
Toni-SM committed Jul 18, 2024
1 parent ca41dc9 commit e0f3ea5
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [1.3.0] - Unreleased
### Changed
- Move the KL reduction in distributed runs from the `KLAdaptiveLR` class to each agent using it
- Move the KL reduction from the PyTorch `KLAdaptiveLR` class to each agent using it in distributed runs
- Move the PyTorch distributed initialization from the agent base class to the ML framework configuration

## [1.2.0] - 2024-06-23
### Added
Expand Down
7 changes: 7 additions & 0 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def __init__(self) -> None:
self._world_size = int(os.getenv("WORLD_SIZE", "1"))
self._is_distributed = self._world_size > 1

# set up distributed runs
if self._is_distributed:
import torch
logger.info(f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})")
torch.distributed.init_process_group("nccl", rank=self._rank, world_size=self._world_size)
torch.cuda.set_device(self._local_rank)

@property
def local_rank(self) -> int:
"""The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)
Expand Down
6 changes: 0 additions & 6 deletions skrl/agents/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,6 @@ def __init__(self,
experiment_name = "{}_{}".format(datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S-%f"), self.__class__.__name__)
self.experiment_dir = os.path.join(directory, experiment_name)

# set up distributed runs
if config.torch.is_distributed:
logger.info(f"Distributed (rank: {config.torch.rank}, local rank: {config.torch.local_rank}, world size: {config.torch.world_size})")
torch.distributed.init_process_group("nccl", rank=config.torch.rank, world_size=config.torch.world_size)
torch.cuda.set_device(config.torch.local_rank)

def __str__(self) -> str:
"""Generate a representation of the agent as string
Expand Down

0 comments on commit e0f3ea5

Please sign in to comment.