Skip to content

Commit

Permalink
Add method to broadcast distributed model parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jun 30, 2024
1 parent 4897e80 commit 490c285
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
18 changes: 18 additions & 0 deletions skrl/models/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,3 +572,21 @@ def update_parameters(self, model: flax.linen.Module, polyak: float = 1) -> None
params = jax.tree_util.tree_map(lambda params, model_params: polyak * model_params + (1 - polyak) * params,
self.state_dict.params, model.state_dict.params)
self.state_dict = self.state_dict.replace(params=params)

def broadcast_parameters(self, rank: int = 0):
"""Broadcast model parameters to the whole group (e.g.: across all nodes) in distributed runs
After calling this method, the distributed model will contain the broadcasted parameters from ``rank``
:param rank: Worker/process rank from which to broadcast model parameters (default: ``0``)
:type rank: int
Example::
# broadcast model parameter from worker/process with rank 1
>>> if config.jax.is_distributed:
... model.broadcast_parameters(rank=1)
"""
is_source = jax.process_index() == rank
params = jax.experimental.multihost_utils.broadcast_one_to_all(self.state_dict.params, is_source=is_source)
self.state_dict = self.state_dict.replace(params=params)
2 changes: 1 addition & 1 deletion skrl/models/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def broadcast_parameters(self, rank: int = 0):
# broadcast model parameter from worker/process with rank 1
>>> if config.torch.is_distributed:
... model.update_parameters(source_model, rank=1)
... model.broadcast_parameters(rank=1)
"""
object_list = [self.state_dict()]
torch.distributed.broadcast_object_list(object_list, rank)
Expand Down

0 comments on commit 490c285

Please sign in to comment.