From c6301fbdeb4a24ae6209d0cfdf34462221d2b574 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Sun, 23 Jun 2024 22:19:34 -0400 Subject: [PATCH] Implement distributed runs for multi-agents --- skrl/multi_agents/torch/ippo/ippo.py | 16 +++++++++++++++- skrl/multi_agents/torch/mappo/mappo.py | 16 +++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/skrl/multi_agents/torch/ippo/ippo.py b/skrl/multi_agents/torch/ippo/ippo.py index 45913edd..3a1f7932 100644 --- a/skrl/multi_agents/torch/ippo/ippo.py +++ b/skrl/multi_agents/torch/ippo/ippo.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.memories.torch import Memory from skrl.models.torch import Model from skrl.multi_agents.torch import MultiAgent @@ -109,9 +110,18 @@ def __init__(self, self.values = {uid: self.models[uid].get("value", None) for uid in self.possible_agents} for uid in self.possible_agents: + # checkpoint models self.checkpoint_modules[uid]["policy"] = self.policies[uid] self.checkpoint_modules[uid]["value"] = self.values[uid] + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policies[uid] is not None: + self.policies[uid].broadcast_parameters() + if self.values[uid] is not None and self.policies[uid] is not self.values[uid]: + self.values[uid].broadcast_parameters() + # configuration self._learning_epochs = self._as_dict(self.cfg["learning_epochs"]) self._mini_batches = self._as_dict(self.cfg["mini_batches"]) @@ -437,6 +447,10 @@ def compute_gae(rewards: torch.Tensor, # optimization step self.optimizers[uid].zero_grad() (policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: + policy.reduce_parameters() + if policy is not value: + value.reduce_parameters() if self._grad_norm_clip[uid] > 0: if policy is value: nn.utils.clip_grad_norm_(policy.parameters(), self._grad_norm_clip[uid]) @@ -453,7 +467,7 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler[uid]: if isinstance(self.schedulers[uid], KLAdaptiveLR): - self.schedulers[uid].step(torch.tensor(kl_divergences).mean()) + self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean()) else: self.schedulers[uid].step() diff --git a/skrl/multi_agents/torch/mappo/mappo.py b/skrl/multi_agents/torch/mappo/mappo.py index 98fff05c..ef0e7bc2 100644 --- a/skrl/multi_agents/torch/mappo/mappo.py +++ b/skrl/multi_agents/torch/mappo/mappo.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.memories.torch import Memory from skrl.models.torch import Model from skrl.multi_agents.torch import MultiAgent @@ -116,9 +117,18 @@ def __init__(self, self.values = {uid: self.models[uid].get("value", None) for uid in self.possible_agents} for uid in self.possible_agents: + # checkpoint models self.checkpoint_modules[uid]["policy"] = self.policies[uid] self.checkpoint_modules[uid]["value"] = self.values[uid] + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policies[uid] is not None: + self.policies[uid].broadcast_parameters() + if self.values[uid] is not None and self.policies[uid] is not self.values[uid]: + self.values[uid].broadcast_parameters() + # configuration self._learning_epochs = self._as_dict(self.cfg["learning_epochs"]) self._mini_batches = self._as_dict(self.cfg["mini_batches"]) @@ -457,6 +467,10 @@ def compute_gae(rewards: torch.Tensor, # optimization step self.optimizers[uid].zero_grad() (policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: + policy.reduce_parameters() + if policy is not value: + value.reduce_parameters() if self._grad_norm_clip[uid] > 0: if policy is value: nn.utils.clip_grad_norm_(policy.parameters(), self._grad_norm_clip[uid]) @@ -473,7 +487,7 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler[uid]: if isinstance(self.schedulers[uid], KLAdaptiveLR): - self.schedulers[uid].step(torch.tensor(kl_divergences).mean()) + self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean()) else: self.schedulers[uid].step()