Skip to content

Commit

Permalink
Reduce KL in the multi-agent implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jul 16, 2024
1 parent 0af1edc commit 946a544
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
7 changes: 6 additions & 1 deletion skrl/multi_agents/torch/ippo/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,12 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler[uid]:
if isinstance(self.schedulers[uid], KLAdaptiveLR):
self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean())
kl = torch.tensor(kl_divergences, device=self.device).mean()
# reduce (collect from all workers/processes) KL in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size
self.schedulers[uid].step(kl.item())
else:
self.schedulers[uid].step()

Expand Down
7 changes: 6 additions & 1 deletion skrl/multi_agents/torch/mappo/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,12 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler[uid]:
if isinstance(self.schedulers[uid], KLAdaptiveLR):
self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean())
kl = torch.tensor(kl_divergences, device=self.device).mean()
# reduce (collect from all workers/processes) KL in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size
self.schedulers[uid].step(kl.item())
else:
self.schedulers[uid].step()

Expand Down

0 comments on commit 946a544

Please sign in to comment.