Skip to content

Commit

Permalink
Merge branch 'develop' into toni/distributed_jax
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jul 18, 2024
2 parents e148788 + e0f3ea5 commit b208709
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 15 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Utilities to start multiple processes from a single program invocation for distributed learning using JAX

### Changed
- Move the KL reduction in distributed runs from the `KLAdaptiveLR` class to each agent using it
- Move the KL reduction from the PyTorch `KLAdaptiveLR` class to each agent using it in distributed runs
- Move the PyTorch distributed initialization from the agent base class to the ML framework configuration

## [1.2.0] - 2024-06-23
### Added
Expand Down
7 changes: 7 additions & 0 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def __init__(self) -> None:
self._world_size = int(os.getenv("WORLD_SIZE", "1"))
self._is_distributed = self._world_size > 1

# set up distributed runs
if self._is_distributed:
import torch
logger.info(f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})")
torch.distributed.init_process_group("nccl", rank=self._rank, world_size=self._world_size)
torch.cuda.set_device(self._local_rank)

@property
def device(self) -> "torch.device":
"""Default device
Expand Down
6 changes: 0 additions & 6 deletions skrl/agents/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,6 @@ def __init__(self,
experiment_name = "{}_{}".format(datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S-%f"), self.__class__.__name__)
self.experiment_dir = os.path.join(directory, experiment_name)

# set up distributed runs
if config.torch.is_distributed:
logger.info(f"Distributed (rank: {config.torch.rank}, local rank: {config.torch.local_rank}, world size: {config.torch.world_size})")
torch.distributed.init_process_group("nccl", rank=config.torch.rank, world_size=config.torch.world_size)
torch.cuda.set_device(config.torch.local_rank)

def __str__(self) -> str:
"""Generate a representation of the agent as string
Expand Down
23 changes: 17 additions & 6 deletions skrl/multi_agents/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from torch.utils.tensorboard import SummaryWriter

from skrl import logger
from skrl import config, logger
from skrl.memories.torch import Memory
from skrl.models.torch import Model

Expand Down Expand Up @@ -149,34 +149,45 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
This method should be called before the agent is used.
It will initialize the TensoBoard writer (and optionally Weights & Biases) and create the checkpoints directory
It will initialize the TensorBoard writer (and optionally Weights & Biases) and create the checkpoints directory
:param trainer_cfg: Trainer configuration
:type trainer_cfg: dict, optional
"""
trainer_cfg = trainer_cfg if trainer_cfg is not None else {}

# update agent configuration to avoid duplicated logging/checking in distributed runs
if config.torch.is_distributed and config.torch.rank:
self.write_interval = 0
self.checkpoint_interval = 0
# TODO: disable wandb

# setup Weights & Biases
if self.cfg.get("experiment", {}).get("wandb", False):
# save experiment config
trainer_cfg = trainer_cfg if trainer_cfg is not None else {}
# save experiment configuration
try:
models_cfg = {uid: {k: v.net._modules for (k, v) in self.models[uid].items()} for uid in self.possible_agents}
except AttributeError:
models_cfg = {uid: {k: v._modules for (k, v) in self.models[uid].items()} for uid in self.possible_agents}
config={**self.cfg, **trainer_cfg, **models_cfg}
wandb_config={**self.cfg, **trainer_cfg, **models_cfg}
# set default values
wandb_kwargs = copy.deepcopy(self.cfg.get("experiment", {}).get("wandb_kwargs", {}))
wandb_kwargs.setdefault("name", os.path.split(self.experiment_dir)[-1])
wandb_kwargs.setdefault("sync_tensorboard", True)
wandb_kwargs.setdefault("config", {})
wandb_kwargs["config"].update(config)
wandb_kwargs["config"].update(wandb_config)
# init Weights & Biases
import wandb
wandb.init(**wandb_kwargs)

# main entry to log data for consumption and visualization by TensorBoard
if self.write_interval == "auto":
self.write_interval = int(trainer_cfg.get("timesteps", 0) / 100)
if self.write_interval > 0:
self.writer = SummaryWriter(log_dir=self.experiment_dir)

if self.checkpoint_interval == "auto":
self.checkpoint_interval = int(trainer_cfg.get("timesteps", 0) / 10)
if self.checkpoint_interval > 0:
os.makedirs(os.path.join(self.experiment_dir, "checkpoints"), exist_ok=True)

Expand Down
7 changes: 6 additions & 1 deletion skrl/multi_agents/torch/ippo/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,12 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler[uid]:
if isinstance(self.schedulers[uid], KLAdaptiveLR):
self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean())
kl = torch.tensor(kl_divergences, device=self.device).mean()
# reduce (collect from all workers/processes) KL in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size
self.schedulers[uid].step(kl.item())
else:
self.schedulers[uid].step()

Expand Down
7 changes: 6 additions & 1 deletion skrl/multi_agents/torch/mappo/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,12 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler[uid]:
if isinstance(self.schedulers[uid], KLAdaptiveLR):
self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean())
kl = torch.tensor(kl_divergences, device=self.device).mean()
# reduce (collect from all workers/processes) KL in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size
self.schedulers[uid].step(kl.item())
else:
self.schedulers[uid].step()

Expand Down

0 comments on commit b208709

Please sign in to comment.