Skip to content

Commit

Permalink
Implement distributed runs for on-policy agents
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jun 23, 2024
1 parent ffd6503 commit a336bbd
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 7 deletions.
18 changes: 17 additions & 1 deletion skrl/agents/torch/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -128,6 +129,14 @@ def __init__(self,
self._rewards_shaper = self.cfg["rewards_shaper"]
self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"]

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.value is not None and self.policy is not self.value:
self.value.broadcast_parameters()

# set up optimizer and learning rate scheduler
if self.policy is not None and self.value is not None:
if self.policy is self.value:
Expand Down Expand Up @@ -391,11 +400,18 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()

if config.torch.is_distributed:
self.policy.reduce_parameters()
if self.policy is not self.value:
self.value.reduce_parameters()

if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
else:
nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip)

self.optimizer.step()

# update cumulative losses
Expand All @@ -407,7 +423,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences).mean())
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()

Expand Down
18 changes: 17 additions & 1 deletion skrl/agents/torch/a2c/a2c_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -128,6 +129,14 @@ def __init__(self,
self._rewards_shaper = self.cfg["rewards_shaper"]
self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"]

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.value is not None and self.policy is not self.value:
self.value.broadcast_parameters()

# set up optimizer and learning rate scheduler
if self.policy is not None and self.value is not None:
if self.policy is self.value:
Expand Down Expand Up @@ -462,11 +471,18 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()

if config.torch.is_distributed:
self.policy.reduce_parameters()
if self.policy is not self.value:
self.value.reduce_parameters()

if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
else:
nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip)

self.optimizer.step()

# update cumulative losses
Expand All @@ -478,7 +494,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences).mean())
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()

Expand Down
20 changes: 19 additions & 1 deletion skrl/agents/torch/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -188,6 +189,16 @@ def __init__(self,
self._rewards_shaper = self.cfg["rewards_shaper"]
self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"]

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.value is not None:
self.value.broadcast_parameters()
if self.discriminator is not None:
self.discriminator.broadcast_parameters()

# set up optimizer and learning rate scheduler
if self.policy is not None and self.value is not None and self.discriminator is not None:
self.optimizer = torch.optim.Adam(itertools.chain(self.policy.parameters(),
Expand Down Expand Up @@ -554,10 +565,17 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss + discriminator_loss).backward()

if config.torch.is_distributed:
self.policy.reduce_parameters()
self.value.reduce_parameters()
self.discriminator.reduce_parameters()

if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(),
self.value.parameters(),
self.discriminator.parameters()), self._grad_norm_clip)

self.optimizer.step()

# update cumulative losses
Expand All @@ -571,7 +589,7 @@ def compute_gae(rewards: torch.Tensor,
if self._learning_rate_scheduler:
self.scheduler.step()

# update AMP repaly buffer
# update AMP replay buffer
self.reply_buffer.add_samples(states=amp_states.view(-1, amp_states.shape[-1]))

# record data
Expand Down
2 changes: 1 addition & 1 deletion skrl/agents/torch/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(self,
self._rewards_shaper = self.cfg["rewards_shaper"]
self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"]

# broadcast models' parameters
# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
Expand Down
18 changes: 17 additions & 1 deletion skrl/agents/torch/ppo/ppo_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -142,6 +143,14 @@ def __init__(self,
self._rewards_shaper = self.cfg["rewards_shaper"]
self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"]

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.value is not None and self.policy is not self.value:
self.value.broadcast_parameters()

# set up optimizer and learning rate scheduler
if self.policy is not None and self.value is not None:
if self.policy is self.value:
Expand Down Expand Up @@ -490,11 +499,18 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()

if config.torch.is_distributed:
self.policy.reduce_parameters()
if self.policy is not self.value:
self.value.reduce_parameters()

if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
else:
nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip)

self.optimizer.step()

# update cumulative losses
Expand All @@ -506,7 +522,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences).mean())
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()

Expand Down
18 changes: 17 additions & 1 deletion skrl/agents/torch/rpo/rpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -144,6 +145,14 @@ def __init__(self,
self._rewards_shaper = self.cfg["rewards_shaper"]
self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"]

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.value is not None and self.policy is not self.value:
self.value.broadcast_parameters()

# set up optimizer and learning rate scheduler
if self.policy is not None and self.value is not None:
if self.policy is self.value:
Expand Down Expand Up @@ -420,11 +429,18 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()

if config.torch.is_distributed:
self.policy.reduce_parameters()
if self.policy is not self.value:
self.value.reduce_parameters()

if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
else:
nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip)

self.optimizer.step()

# update cumulative losses
Expand All @@ -436,7 +452,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences).mean())
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()

Expand Down
18 changes: 17 additions & 1 deletion skrl/agents/torch/rpo/rpo_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -144,6 +145,14 @@ def __init__(self,
self._rewards_shaper = self.cfg["rewards_shaper"]
self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"]

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.value is not None and self.policy is not self.value:
self.value.broadcast_parameters()

# set up optimizer and learning rate scheduler
if self.policy is not None and self.value is not None:
if self.policy is self.value:
Expand Down Expand Up @@ -492,11 +501,18 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()

if config.torch.is_distributed:
self.policy.reduce_parameters()
if self.policy is not self.value:
self.value.reduce_parameters()

if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
else:
nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip)

self.optimizer.step()

# update cumulative losses
Expand All @@ -508,7 +524,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences).mean())
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()

Expand Down
17 changes: 17 additions & 0 deletions skrl/agents/torch/trpo/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as F
from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -143,6 +144,14 @@ def __init__(self,
self._rewards_shaper = self.cfg["rewards_shaper"]
self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"]

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.value is not None:
self.value.broadcast_parameters()

# set up optimizer and learning rate scheduler
if self.policy is not None and self.value is not None:
self.value_optimizer = torch.optim.Adam(self.value.parameters(), lr=self._value_learning_rate)
Expand Down Expand Up @@ -521,6 +530,9 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor
if restore_policy_flag:
self.policy.update_parameters(self.backup_policy)

if config.torch.is_distributed:
self.policy.reduce_parameters()

# sample mini-batches from memory
sampled_batches = self.memory.sample_all(names=self._tensors_names_value, mini_batches=self._mini_batches)

Expand All @@ -542,8 +554,13 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor
# optimization step (value)
self.value_optimizer.zero_grad()
value_loss.backward()

if config.torch.is_distributed:
self.value.reduce_parameters()

if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.value.parameters(), self._grad_norm_clip)

self.value_optimizer.step()

# update cumulative losses
Expand Down
Loading

0 comments on commit a336bbd

Please sign in to comment.