Skip to content

Commit

Permalink
Merge branch 'develop' into toni/distributed_torch_initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jul 17, 2024
2 parents 708598e + ca41dc9 commit b5ca6d6
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
23 changes: 17 additions & 6 deletions skrl/multi_agents/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from torch.utils.tensorboard import SummaryWriter

from skrl import logger
from skrl import config, logger
from skrl.memories.torch import Memory
from skrl.models.torch import Model

Expand Down Expand Up @@ -149,34 +149,45 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
This method should be called before the agent is used.
It will initialize the TensoBoard writer (and optionally Weights & Biases) and create the checkpoints directory
It will initialize the TensorBoard writer (and optionally Weights & Biases) and create the checkpoints directory
:param trainer_cfg: Trainer configuration
:type trainer_cfg: dict, optional
"""
trainer_cfg = trainer_cfg if trainer_cfg is not None else {}

# update agent configuration to avoid duplicated logging/checking in distributed runs
if config.torch.is_distributed and config.torch.rank:
self.write_interval = 0
self.checkpoint_interval = 0
# TODO: disable wandb

# setup Weights & Biases
if self.cfg.get("experiment", {}).get("wandb", False):
# save experiment config
trainer_cfg = trainer_cfg if trainer_cfg is not None else {}
# save experiment configuration
try:
models_cfg = {uid: {k: v.net._modules for (k, v) in self.models[uid].items()} for uid in self.possible_agents}
except AttributeError:
models_cfg = {uid: {k: v._modules for (k, v) in self.models[uid].items()} for uid in self.possible_agents}
config={**self.cfg, **trainer_cfg, **models_cfg}
wandb_config={**self.cfg, **trainer_cfg, **models_cfg}
# set default values
wandb_kwargs = copy.deepcopy(self.cfg.get("experiment", {}).get("wandb_kwargs", {}))
wandb_kwargs.setdefault("name", os.path.split(self.experiment_dir)[-1])
wandb_kwargs.setdefault("sync_tensorboard", True)
wandb_kwargs.setdefault("config", {})
wandb_kwargs["config"].update(config)
wandb_kwargs["config"].update(wandb_config)
# init Weights & Biases
import wandb
wandb.init(**wandb_kwargs)

# main entry to log data for consumption and visualization by TensorBoard
if self.write_interval == "auto":
self.write_interval = int(trainer_cfg.get("timesteps", 0) / 100)
if self.write_interval > 0:
self.writer = SummaryWriter(log_dir=self.experiment_dir)

if self.checkpoint_interval == "auto":
self.checkpoint_interval = int(trainer_cfg.get("timesteps", 0) / 10)
if self.checkpoint_interval > 0:
os.makedirs(os.path.join(self.experiment_dir, "checkpoints"), exist_ok=True)

Expand Down
7 changes: 6 additions & 1 deletion skrl/multi_agents/torch/ippo/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,12 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler[uid]:
if isinstance(self.schedulers[uid], KLAdaptiveLR):
self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean())
kl = torch.tensor(kl_divergences, device=self.device).mean()
# reduce (collect from all workers/processes) KL in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size
self.schedulers[uid].step(kl.item())
else:
self.schedulers[uid].step()

Expand Down
7 changes: 6 additions & 1 deletion skrl/multi_agents/torch/mappo/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,12 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler[uid]:
if isinstance(self.schedulers[uid], KLAdaptiveLR):
self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean())
kl = torch.tensor(kl_divergences, device=self.device).mean()
# reduce (collect from all workers/processes) KL in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size
self.schedulers[uid].step(kl.item())
else:
self.schedulers[uid].step()

Expand Down

0 comments on commit b5ca6d6

Please sign in to comment.