Skip to content

Commit

Permalink
Collective reduce model gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jul 7, 2024
1 parent f231a85 commit a32073f
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions skrl/models/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,25 @@

import flax
import jax
import jax.numpy as jnp
import numpy as np

from skrl import config


@jax.jit
def _vectorize_leaves(leaves: Sequence[jax.Array]) -> jax.Array:
return jnp.expand_dims(jnp.concatenate(list(map(jnp.ravel, leaves)), axis=-1), 0)

@jax.jit
def _unvectorize_leaves(leaves: Sequence[jax.Array], vector: jax.Array) -> Sequence[jax.Array]:
offset = 0
for i, leaf in enumerate(leaves):
leaves[i] = leaves[i].at[:].set(vector.at[0, offset:offset + leaf.size].get().reshape(leaf.shape))
offset += leaf.size
return leaves


class StateDict(flax.struct.PyTreeNode):
apply_fn: Callable = flax.struct.field(pytree_node=False)
params: flax.core.FrozenDict[str, Any] = flax.struct.field(pytree_node=True)
Expand Down Expand Up @@ -590,3 +604,30 @@ def broadcast_parameters(self, rank: int = 0):
is_source = jax.process_index() == rank
params = jax.experimental.multihost_utils.broadcast_one_to_all(self.state_dict.params, is_source=is_source)
self.state_dict = self.state_dict.replace(params=params)

def reduce_parameters(self, tree: Any) -> Any:
"""Reduce model parameters across all workers/processes in the whole group (e.g.: across all nodes)
After calling this method, the distributed model parameters will be bitwise identical for all workers/processes
:param tree: pytree to apply collective reduction
:type tree: Any
Example::
# reduce model parameter across all workers/processes
>>> if config.jax.is_distributed:
... model.reduce_parameters(grad)
"""
# # collective all-reduce mean for each pytree leaves
# return jax.tree_util.tree_map(lambda g: jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')
# (jnp.expand_dims(g, 0)).squeeze(0) / config.jax.world_size, tree)

# # using https://jax.readthedocs.io/en/latest/_autosummary/jax.flatten_util.ravel_pytree.html
# vector, unflatten = jax.flatten_util.ravel_pytree(tree)
# vector = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(jnp.expand_dims(vector, 0)) / config.jax.world_size
# return unflatten(jnp.squeeze(vector, 0))

leaves, treedef = jax.tree.flatten(tree)
vector = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(_vectorize_leaves(leaves)) / config.jax.world_size
return jax.tree.unflatten(treedef, _unvectorize_leaves(leaves, vector))

0 comments on commit a32073f

Please sign in to comment.