diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py index 7e1dc24f..9f49f834 100644 --- a/skrl/agents/torch/ppo/ppo.py +++ b/skrl/agents/torch/ppo/ppo.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -142,6 +143,14 @@ def __init__(self, self._rewards_shaper = self.cfg["rewards_shaper"] self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"] + # broadcast models' parameters + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.value is not None and self.policy is not self.value: + self.value.broadcast_parameters() + # set up optimizer and learning rate scheduler if self.policy is not None and self.value is not None: if self.policy is self.value: @@ -418,11 +427,18 @@ def compute_gae(rewards: torch.Tensor, # optimization step self.optimizer.zero_grad() (policy_loss + entropy_loss + value_loss).backward() + + if config.torch.is_distributed: + self.policy.reduce_parameters() + if self.policy is not self.value: + self.value.reduce_parameters() + if self._grad_norm_clip > 0: if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) else: nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip) + self.optimizer.step() # update cumulative losses @@ -434,7 +450,7 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler: if isinstance(self.scheduler, KLAdaptiveLR): - self.scheduler.step(torch.tensor(kl_divergences).mean()) + self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean()) else: self.scheduler.step()