Skip to content

Commit

Permalink
Add distribute runs implementation to other agents
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jul 13, 2024
1 parent 2b335ae commit 91c6ec6
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 3 deletions.
20 changes: 19 additions & 1 deletion skrl/agents/jax/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax.numpy as jnp
import numpy as np

from skrl import config, logger
from skrl.agents.jax import Agent
from skrl.memories.jax import Memory
from skrl.models.jax import Model
Expand Down Expand Up @@ -215,6 +216,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value

# broadcast models' parameters in distributed runs
if config.jax.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.value is not None:
self.value.broadcast_parameters()

# configuration
self._mini_batches = self.cfg["mini_batches"]
self._rollouts = self.cfg["rollouts"]
Expand Down Expand Up @@ -475,6 +484,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
kl_divergences.append(kl_divergence.item())

# optimization step (policy)
if config.jax.is_distributed:
grad = self.policy.reduce_parameters(grad)
self.policy_optimizer = self.policy_optimizer.step(grad, self.policy, self.scheduler._lr if self.scheduler else None)

# compute value loss
Expand All @@ -484,6 +495,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
sampled_returns)

# optimization step (value)
if config.jax.is_distributed:
grad = self.value.reduce_parameters(grad)
self.value_optimizer = self.value_optimizer.step(grad, self.value, self.scheduler._lr if self.scheduler else None)

# update cumulative losses
Expand All @@ -495,7 +508,12 @@ def _update(self, timestep: int, timesteps: int) -> None:
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(np.mean(kl_divergences))
kl = np.mean(kl_divergences)
# reduce (collect from all workers/processes) KL in distributed runs
if config.jax.is_distributed:
kl = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(kl.reshape(1)).item()
kl /= config.jax.world_size
self.scheduler.step(kl)

# record data
self.track_data("Loss / Policy loss", cumulative_policy_loss / len(sampled_batches))
Expand Down
13 changes: 13 additions & 0 deletions skrl/agents/jax/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax.numpy as jnp
import numpy as np

from skrl import config, logger
from skrl.agents.jax import Agent
from skrl.memories.jax import Memory
from skrl.models.jax import Model
Expand Down Expand Up @@ -161,6 +162,14 @@ def __init__(self,
self.checkpoint_modules["critic"] = self.critic
self.checkpoint_modules["target_critic"] = self.target_critic

# broadcast models' parameters in distributed runs
if config.jax.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.critic is not None:
self.critic.broadcast_parameters()

# configuration
self._gradient_steps = self.cfg["gradient_steps"]
self._batch_size = self.cfg["batch_size"]
Expand Down Expand Up @@ -412,6 +421,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
self._discount_factor)

# optimization step (critic)
if config.jax.is_distributed:
grad = self.critic.reduce_parameters(grad)
self.critic_optimizer = self.critic_optimizer.step(grad, self.critic)

# compute policy (actor) loss
Expand All @@ -422,6 +433,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
sampled_states)

# optimization step (policy)
if config.jax.is_distributed:
grad = self.policy.reduce_parameters(grad)
self.policy_optimizer = self.policy_optimizer.step(grad, self.policy)

# update target networks
Expand Down
9 changes: 9 additions & 0 deletions skrl/agents/jax/dqn/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax.numpy as jnp
import numpy as np

from skrl import config, logger
from skrl.agents.jax import Agent
from skrl.memories.jax import Memory
from skrl.models.jax import Model
Expand Down Expand Up @@ -135,6 +136,12 @@ def __init__(self,
self.checkpoint_modules["q_network"] = self.q_network
self.checkpoint_modules["target_q_network"] = self.target_q_network

# broadcast models' parameters in distributed runs
if config.jax.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.q_network is not None:
self.q_network.broadcast_parameters()

# configuration
self._gradient_steps = self.cfg["gradient_steps"]
self._batch_size = self.cfg["batch_size"]
Expand Down Expand Up @@ -354,6 +361,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
self._discount_factor)

# optimization step (Q-network)
if config.jax.is_distributed:
grad = self.q_network.reduce_parameters(grad)
self.optimizer = self.optimizer.step(grad, self.q_network)

# update target network
Expand Down
9 changes: 9 additions & 0 deletions skrl/agents/jax/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax.numpy as jnp
import numpy as np

from skrl import config, logger
from skrl.agents.jax import Agent
from skrl.memories.jax import Memory
from skrl.models.jax import Model
Expand Down Expand Up @@ -132,6 +133,12 @@ def __init__(self,
self.checkpoint_modules["q_network"] = self.q_network
self.checkpoint_modules["target_q_network"] = self.target_q_network

# broadcast models' parameters in distributed runs
if config.jax.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.q_network is not None:
self.q_network.broadcast_parameters()

# configuration
self._gradient_steps = self.cfg["gradient_steps"]
self._batch_size = self.cfg["batch_size"]
Expand Down Expand Up @@ -350,6 +357,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
self._discount_factor)

# optimization step (Q-network)
if config.jax.is_distributed:
grad = self.q_network.reduce_parameters(grad)
self.optimizer = self.optimizer.step(grad, self.q_network)

# update target network
Expand Down
20 changes: 19 additions & 1 deletion skrl/agents/jax/rpo/rpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax.numpy as jnp
import numpy as np

from skrl import config, logger
from skrl.agents.jax import Agent
from skrl.memories.jax import Memory
from skrl.models.jax import Model
Expand Down Expand Up @@ -237,6 +238,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value

# broadcast models' parameters in distributed runs
if config.jax.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.value is not None:
self.value.broadcast_parameters()

# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
Expand Down Expand Up @@ -513,6 +522,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
break

# optimization step (policy)
if config.jax.is_distributed:
grad = self.policy.reduce_parameters(grad)
self.policy_optimizer = self.policy_optimizer.step(grad, self.policy, self.scheduler._lr if self.scheduler else None)

# compute value loss
Expand All @@ -527,6 +538,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
self._alpha)

# optimization step (value)
if config.jax.is_distributed:
grad = self.value.reduce_parameters(grad)
self.value_optimizer = self.value_optimizer.step(grad, self.value, self.scheduler._lr if self.scheduler else None)

# update cumulative losses
Expand All @@ -538,7 +551,12 @@ def _update(self, timestep: int, timesteps: int) -> None:
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(np.mean(kl_divergences))
kl = np.mean(kl_divergences)
# reduce (collect from all workers/processes) KL in distributed runs
if config.jax.is_distributed:
kl = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(kl.reshape(1)).item()
kl /= config.jax.world_size
self.scheduler.step(kl)

# record data
self.track_data("Loss / Policy loss", cumulative_policy_loss / (self._learning_epochs * self._mini_batches))
Expand Down
15 changes: 15 additions & 0 deletions skrl/agents/jax/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import jax.numpy as jnp
import numpy as np

from skrl import config, logger
from skrl.agents.jax import Agent
from skrl.memories.jax import Memory
from skrl.models.jax import Model
Expand Down Expand Up @@ -174,6 +175,16 @@ def __init__(self,
self.checkpoint_modules["target_critic_1"] = self.target_critic_1
self.checkpoint_modules["target_critic_2"] = self.target_critic_2

# broadcast models' parameters in distributed runs
if config.jax.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.critic_1 is not None:
self.critic_1.broadcast_parameters()
if self.critic_2 is not None:
self.critic_2.broadcast_parameters()

# configuration
self._gradient_steps = self.cfg["gradient_steps"]
self._batch_size = self.cfg["batch_size"]
Expand Down Expand Up @@ -418,6 +429,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
self._discount_factor)

# optimization step (critic)
if config.jax.is_distributed:
grad = self.critic_1.reduce_parameters(grad)
self.critic_1_optimizer = self.critic_1_optimizer.step(grad, self.critic_1)
self.critic_2_optimizer = self.critic_2_optimizer.step(grad, self.critic_2)

Expand All @@ -432,6 +445,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
sampled_states)

# optimization step (policy)
if config.jax.is_distributed:
grad = self.policy.reduce_parameters(grad)
self.policy_optimizer = self.policy_optimizer.step(grad, self.policy)

# entropy learning
Expand Down
16 changes: 15 additions & 1 deletion skrl/agents/jax/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax.numpy as jnp
import numpy as np

from skrl import logger
from skrl import config, logger
from skrl.agents.jax import Agent
from skrl.memories.jax import Memory
from skrl.models.jax import Model
Expand Down Expand Up @@ -184,6 +184,16 @@ def __init__(self,
self.checkpoint_modules["target_critic_1"] = self.target_critic_1
self.checkpoint_modules["target_critic_2"] = self.target_critic_2

# broadcast models' parameters in distributed runs
if config.jax.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.critic_1 is not None:
self.critic_1.broadcast_parameters()
if self.critic_2 is not None:
self.critic_2.broadcast_parameters()

# configuration
self._gradient_steps = self.cfg["gradient_steps"]
self._batch_size = self.cfg["batch_size"]
Expand Down Expand Up @@ -461,6 +471,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
self._discount_factor)

# optimization step (critic)
if config.jax.is_distributed:
grad = self.critic_1.reduce_parameters(grad)
self.critic_1_optimizer = self.critic_1_optimizer.step(grad, self.critic_1)
self.critic_2_optimizer = self.critic_2_optimizer.step(grad, self.critic_2)

Expand All @@ -476,6 +488,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
sampled_states)

# optimization step (policy)
if config.jax.is_distributed:
grad = self.policy.reduce_parameters(grad)
self.policy_optimizer = self.policy_optimizer.step(grad, self.policy)

# update target networks
Expand Down

0 comments on commit 91c6ec6

Please sign in to comment.