From a32073fc8968e0aee0ac96180f432e3b507b1a1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Sun, 7 Jul 2024 10:50:12 -0400 Subject: [PATCH] Collective reduce model gradients --- skrl/models/jax/base.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/skrl/models/jax/base.py b/skrl/models/jax/base.py index 6f637902..4bf50513 100644 --- a/skrl/models/jax/base.py +++ b/skrl/models/jax/base.py @@ -5,11 +5,25 @@ import flax import jax +import jax.numpy as jnp import numpy as np from skrl import config +@jax.jit +def _vectorize_leaves(leaves: Sequence[jax.Array]) -> jax.Array: + return jnp.expand_dims(jnp.concatenate(list(map(jnp.ravel, leaves)), axis=-1), 0) + +@jax.jit +def _unvectorize_leaves(leaves: Sequence[jax.Array], vector: jax.Array) -> Sequence[jax.Array]: + offset = 0 + for i, leaf in enumerate(leaves): + leaves[i] = leaves[i].at[:].set(vector.at[0, offset:offset + leaf.size].get().reshape(leaf.shape)) + offset += leaf.size + return leaves + + class StateDict(flax.struct.PyTreeNode): apply_fn: Callable = flax.struct.field(pytree_node=False) params: flax.core.FrozenDict[str, Any] = flax.struct.field(pytree_node=True) @@ -590,3 +604,30 @@ def broadcast_parameters(self, rank: int = 0): is_source = jax.process_index() == rank params = jax.experimental.multihost_utils.broadcast_one_to_all(self.state_dict.params, is_source=is_source) self.state_dict = self.state_dict.replace(params=params) + + def reduce_parameters(self, tree: Any) -> Any: + """Reduce model parameters across all workers/processes in the whole group (e.g.: across all nodes) + + After calling this method, the distributed model parameters will be bitwise identical for all workers/processes + + :param tree: pytree to apply collective reduction + :type tree: Any + + Example:: + + # reduce model parameter across all workers/processes + >>> if config.jax.is_distributed: + ... model.reduce_parameters(grad) + """ + # # collective all-reduce mean for each pytree leaves + # return jax.tree_util.tree_map(lambda g: jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i') + # (jnp.expand_dims(g, 0)).squeeze(0) / config.jax.world_size, tree) + + # # using https://jax.readthedocs.io/en/latest/_autosummary/jax.flatten_util.ravel_pytree.html + # vector, unflatten = jax.flatten_util.ravel_pytree(tree) + # vector = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(jnp.expand_dims(vector, 0)) / config.jax.world_size + # return unflatten(jnp.squeeze(vector, 0)) + + leaves, treedef = jax.tree.flatten(tree) + vector = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(_vectorize_leaves(leaves)) / config.jax.world_size + return jax.tree.unflatten(treedef, _unvectorize_leaves(leaves, vector))