Skip to content

Commit

Permalink
Implement distributed runs for off-policy agents
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jun 24, 2024
1 parent 1954ab3 commit b821879
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 2 deletions.
13 changes: 13 additions & 0 deletions skrl/agents/torch/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -109,6 +110,14 @@ def __init__(self,
self.checkpoint_modules["critic"] = self.critic
self.checkpoint_modules["target_critic"] = self.target_critic

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.critic is not None:
self.critic.broadcast_parameters()

if self.target_policy is not None and self.target_critic is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_policy.freeze_parameters(True)
Expand Down Expand Up @@ -341,6 +350,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
if config.torch.is_distributed:
self.critic.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.critic.parameters(), self._grad_norm_clip)
self.critic_optimizer.step()
Expand All @@ -354,6 +365,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
if config.torch.is_distributed:
self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
Expand Down
13 changes: 13 additions & 0 deletions skrl/agents/torch/ddpg/ddpg_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -109,6 +110,14 @@ def __init__(self,
self.checkpoint_modules["critic"] = self.critic
self.checkpoint_modules["target_critic"] = self.target_critic

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.critic is not None:
self.critic.broadcast_parameters()

if self.target_policy is not None and self.target_critic is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_policy.freeze_parameters(True)
Expand Down Expand Up @@ -382,6 +391,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
if config.torch.is_distributed:
self.critic.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.critic.parameters(), self._grad_norm_clip)
self.critic_optimizer.step()
Expand All @@ -395,6 +406,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
if config.torch.is_distributed:
self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
Expand Down
9 changes: 9 additions & 0 deletions skrl/agents/torch/dqn/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -104,6 +105,12 @@ def __init__(self,
self.checkpoint_modules["q_network"] = self.q_network
self.checkpoint_modules["target_q_network"] = self.target_q_network

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.q_network is not None:
self.q_network.broadcast_parameters()

if self.target_q_network is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_q_network.freeze_parameters(True)
Expand Down Expand Up @@ -303,6 +310,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimize Q-network
self.optimizer.zero_grad()
q_network_loss.backward()
if config.torch.is_distributed:
self.q_network.reduce_parameters()
self.optimizer.step()

# update target network
Expand Down
9 changes: 9 additions & 0 deletions skrl/agents/torch/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -104,6 +105,12 @@ def __init__(self,
self.checkpoint_modules["q_network"] = self.q_network
self.checkpoint_modules["target_q_network"] = self.target_q_network

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.q_network is not None:
self.q_network.broadcast_parameters()

if self.target_q_network is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_q_network.freeze_parameters(True)
Expand Down Expand Up @@ -303,6 +310,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimize Q-network
self.optimizer.zero_grad()
q_network_loss.backward()
if config.torch.is_distributed:
self.q_network.reduce_parameters()
self.optimizer.step()

# update target network
Expand Down
16 changes: 16 additions & 0 deletions skrl/agents/torch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -111,6 +112,16 @@ def __init__(self,
self.checkpoint_modules["target_critic_1"] = self.target_critic_1
self.checkpoint_modules["target_critic_2"] = self.target_critic_2

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.critic_1 is not None:
self.critic_1.broadcast_parameters()
if self.critic_2 is not None:
self.critic_2.broadcast_parameters()

if self.target_critic_1 is not None and self.target_critic_2 is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_critic_1.freeze_parameters(True)
Expand Down Expand Up @@ -325,6 +336,9 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
if config.torch.is_distributed:
self.critic_1.reduce_parameters()
self.critic_2.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip)
self.critic_optimizer.step()
Expand All @@ -339,6 +353,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
if config.torch.is_distributed:
self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
Expand Down
16 changes: 16 additions & 0 deletions skrl/agents/torch/sac/sac_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -111,6 +112,16 @@ def __init__(self,
self.checkpoint_modules["target_critic_1"] = self.target_critic_1
self.checkpoint_modules["target_critic_2"] = self.target_critic_2

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.critic_1 is not None:
self.critic_1.broadcast_parameters()
if self.critic_2 is not None:
self.critic_2.broadcast_parameters()

if self.target_critic_1 is not None and self.target_critic_2 is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_critic_1.freeze_parameters(True)
Expand Down Expand Up @@ -367,6 +378,9 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
if config.torch.is_distributed:
self.critic_1.reduce_parameters()
self.critic_2.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip)
self.critic_optimizer.step()
Expand All @@ -381,6 +395,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
if config.torch.is_distributed:
self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
Expand Down
17 changes: 16 additions & 1 deletion skrl/agents/torch/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import logger
from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -119,6 +119,16 @@ def __init__(self,
self.checkpoint_modules["target_critic_1"] = self.target_critic_1
self.checkpoint_modules["target_critic_2"] = self.target_critic_2

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.critic_1 is not None:
self.critic_1.broadcast_parameters()
if self.critic_2 is not None:
self.critic_2.broadcast_parameters()

if self.target_policy is not None and self.target_critic_1 is not None and self.target_critic_2 is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_policy.freeze_parameters(True)
Expand Down Expand Up @@ -372,6 +382,9 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
if config.torch.is_distributed:
self.critic_1.reduce_parameters()
self.critic_2.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip)
self.critic_optimizer.step()
Expand All @@ -389,6 +402,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
if config.torch.is_distributed:
self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
Expand Down
17 changes: 16 additions & 1 deletion skrl/agents/torch/td3/td3_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import logger
from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -119,6 +119,16 @@ def __init__(self,
self.checkpoint_modules["target_critic_1"] = self.target_critic_1
self.checkpoint_modules["target_critic_2"] = self.target_critic_2

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.critic_1 is not None:
self.critic_1.broadcast_parameters()
if self.critic_2 is not None:
self.critic_2.broadcast_parameters()

if self.target_policy is not None and self.target_critic_1 is not None and self.target_critic_2 is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_policy.freeze_parameters(True)
Expand Down Expand Up @@ -413,6 +423,9 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
if config.torch.is_distributed:
self.critic_1.reduce_parameters()
self.critic_2.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip)
self.critic_optimizer.step()
Expand All @@ -430,6 +443,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
if config.torch.is_distributed:
self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
Expand Down

0 comments on commit b821879

Please sign in to comment.