Skip to content

Commit

Permalink
Add distributed implementation to PPO agent
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jun 21, 2024
1 parent 104313d commit 36d4a57
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion skrl/agents/torch/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -142,6 +143,14 @@ def __init__(self,
self._rewards_shaper = self.cfg["rewards_shaper"]
self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"]

# broadcast models' parameters
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.value is not None and self.policy is not self.value:
self.value.broadcast_parameters()

# set up optimizer and learning rate scheduler
if self.policy is not None and self.value is not None:
if self.policy is self.value:
Expand Down Expand Up @@ -418,11 +427,18 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()

if config.torch.is_distributed:
self.policy.reduce_parameters()
if self.policy is not self.value:
self.value.reduce_parameters()

if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
else:
nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip)

self.optimizer.step()

# update cumulative losses
Expand All @@ -434,7 +450,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences).mean())
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()

Expand Down

0 comments on commit 36d4a57

Please sign in to comment.