diff --git a/skrl/models/jax/base.py b/skrl/models/jax/base.py index 9a3738a5..6f637902 100644 --- a/skrl/models/jax/base.py +++ b/skrl/models/jax/base.py @@ -572,3 +572,21 @@ def update_parameters(self, model: flax.linen.Module, polyak: float = 1) -> None params = jax.tree_util.tree_map(lambda params, model_params: polyak * model_params + (1 - polyak) * params, self.state_dict.params, model.state_dict.params) self.state_dict = self.state_dict.replace(params=params) + + def broadcast_parameters(self, rank: int = 0): + """Broadcast model parameters to the whole group (e.g.: across all nodes) in distributed runs + + After calling this method, the distributed model will contain the broadcasted parameters from ``rank`` + + :param rank: Worker/process rank from which to broadcast model parameters (default: ``0``) + :type rank: int + + Example:: + + # broadcast model parameter from worker/process with rank 1 + >>> if config.jax.is_distributed: + ... model.broadcast_parameters(rank=1) + """ + is_source = jax.process_index() == rank + params = jax.experimental.multihost_utils.broadcast_one_to_all(self.state_dict.params, is_source=is_source) + self.state_dict = self.state_dict.replace(params=params) diff --git a/skrl/models/torch/base.py b/skrl/models/torch/base.py index bf90ba8a..ec64feb0 100644 --- a/skrl/models/torch/base.py +++ b/skrl/models/torch/base.py @@ -756,7 +756,7 @@ def broadcast_parameters(self, rank: int = 0): # broadcast model parameter from worker/process with rank 1 >>> if config.torch.is_distributed: - ... model.update_parameters(source_model, rank=1) + ... model.broadcast_parameters(rank=1) """ object_list = [self.state_dict()] torch.distributed.broadcast_object_list(object_list, rank)