Skip to content

Added AutoSharding Distribution API #21583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
since your modifications would be overwritten.
"""

from keras.src.distribution.distribution_lib import (
AutoShardDistribution as AutoShardDistribution,
)
from keras.src.distribution.distribution_lib import DataParallel as DataParallel
from keras.src.distribution.distribution_lib import DeviceMesh as DeviceMesh
from keras.src.distribution.distribution_lib import LayoutMap as LayoutMap
Expand Down
3 changes: 3 additions & 0 deletions keras/api/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
since your modifications would be overwritten.
"""

from keras.src.distribution.distribution_lib import (
AutoShardDistribution as AutoShardDistribution,
)
from keras.src.distribution.distribution_lib import DataParallel as DataParallel
from keras.src.distribution.distribution_lib import DeviceMesh as DeviceMesh
from keras.src.distribution.distribution_lib import LayoutMap as LayoutMap
Expand Down
301 changes: 296 additions & 5 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Utilities for distribution strategy with JAX backend."""

import collections
import itertools

import jax
import numpy as np

from keras.src.backend.common import global_state
from keras.src.random import seed_generator
from keras.src.utils import jax_utils
from keras.src.utils import rng_utils
from jax import tree_util


def list_devices(device_type=None):
Expand Down Expand Up @@ -63,6 +62,7 @@ def distribute_tensor(tensor, layout):
"""
# Avoid circular imports.
from keras.src.distribution import TensorLayout
from keras.src.utils import jax_utils

if isinstance(layout, TensorLayout):
layout = layout.backend_layout
Expand Down Expand Up @@ -120,6 +120,10 @@ def initialize_rng():

This is required for consistent initialization in multi-host settings.
"""
from keras.src.backend.common import global_state
from keras.src.random import seed_generator
from keras.src.utils import rng_utils

global_seed = rng_utils.get_random_seed()
# Only set a random seed if not already set
# via keras.config.set_random_seed()
Expand Down Expand Up @@ -242,3 +246,290 @@ def _to_backend_layout(tensor_layout):
partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes)
jax_mesh = tensor_layout.device_mesh.backend_mesh
return jax.sharding.NamedSharding(jax_mesh, partition_spec)


_JAX_CLASSES_DEFINED = False
JaxGraph = None
JaxShardingPlanner = None
JaxShardApplier = None


def _define_and_register_jax_classes():
global _JAX_CLASSES_DEFINED, JaxGraph, JaxShardingPlanner, JaxShardApplier
if _JAX_CLASSES_DEFINED:
return

from keras.src.distribution.autoshard_utils import MergeableGraph

def parse_jaxpr(jaxpr) -> MergeableGraph:
graph = MergeableGraph()

def same_axis(node1, node2):
var1, axis1 = node1
var2, axis2 = node2
if var1.aval.shape[axis1] != var2.aval.shape[axis2]:
return
graph.merge_nodes(node1, node2)

def parse_dot_general(eqn):
lhs, rhs = eqn.invars
out = eqn.outvars[0]
(lc, rc), (lb, rb) = eqn.params["dimension_numbers"]
for l, r in zip(lc, rc):
same_axis((lhs, l), (rhs, r))
o_offset = 0
for l, r in zip(lb, rb):
same_axis((lhs, l), (rhs, r))
same_axis((lhs, l), (out, o_offset))
o_offset += 1
for i in range(lhs.aval.ndim):
if i not in lb and i not in lc:
same_axis((lhs, i), (out, o_offset))
o_offset += 1
for j in range(rhs.aval.ndim):
if j not in rb and j not in rc:
same_axis((rhs, j), (out, o_offset))
o_offset += 1

def parse_reshape(eqn):
invar, out = eqn.invars[0], eqn.outvars[0]
in_idx, out_idx, in_prod, out_prod = 0, 0, 1, 1
while in_idx < invar.aval.ndim and out_idx < out.aval.ndim:
if (
in_prod == out_prod
and invar.aval.shape[in_idx] == out.aval.shape[out_idx]
):
if invar.aval.shape[in_idx] > 1:
same_axis((invar, in_idx), (out, out_idx))
in_prod *= invar.aval.shape[in_idx]
out_prod *= out.aval.shape[out_idx]
in_idx += 1
out_idx += 1
elif in_prod < out_prod:
in_prod *= invar.aval.shape[in_idx]
in_idx += 1
else:
out_prod *= out.aval.shape[out_idx]
out_idx += 1

def parse_transpose(eqn):
invar, out = eqn.invars[0], eqn.outvars[0]
for i, j in enumerate(eqn.params["permutation"]):
same_axis((invar, j), (out, i))

def parse_elementwise_with_broadcast(eqn):
out = eqn.outvars[0]
for invar in eqn.invars:
if invar.aval.ndim == 0:
continue
for i in range(1, min(invar.aval.ndim, out.aval.ndim) + 1):
in_axis, out_axis = -i, -i
if invar.aval.shape[in_axis] == out.aval.shape[out_axis]:
same_axis(
(invar, invar.aval.ndim + in_axis),
(out, out.aval.ndim + out_axis),
)

for var in jaxpr.jaxpr.invars:
for i, j in itertools.combinations(range(var.aval.ndim), 2):
graph.add_edge((var, i), (var, j))

for eqn in jaxpr.eqns:
for outvar in eqn.outvars:
for i, j in itertools.combinations(range(outvar.aval.ndim), 2):
graph.add_edge((outvar, i), (outvar, j))

primitive_parsers = {
"dot_general": parse_dot_general,
"reshape": parse_reshape,
"transpose": parse_transpose,
}
parser = primitive_parsers.get(
eqn.primitive.name, parse_elementwise_with_broadcast
)
parser(eqn)
return graph

def shard_model(
jaxpr,
out_avals,
trainable_params,
non_trainable_params,
args,
kwargs,
min_shard_size=1,
data_axis_name="data",
model_axis_name="model",
):
graph = parse_jaxpr(jaxpr)

t_params_flat, t_params_treedef = tree_util.tree_flatten(
trainable_params
)
nt_params_flat, nt_params_treedef = tree_util.tree_flatten(
non_trainable_params
)
args_flat, args_treedef = tree_util.tree_flatten(args)
kwargs_flat, kwargs_treedef = tree_util.tree_flatten(kwargs)
_, outputs_treedef = tree_util.tree_flatten(out_avals)

pos = 0
t_param_invars = jaxpr.jaxpr.invars[pos : pos + len(t_params_flat)]
pos += len(t_params_flat)
nt_param_invars = jaxpr.jaxpr.invars[pos : pos + len(nt_params_flat)]
pos += len(nt_params_flat)
arg_invars = jaxpr.jaxpr.invars[pos : pos + len(args_flat)]
pos += len(args_flat)
kwarg_invars = jaxpr.jaxpr.invars[pos:]

all_param_invars = t_param_invars + nt_param_invars
data_invars = arg_invars + kwarg_invars

seen = collections.Counter()
for var in all_param_invars:
for i in range(var.aval.ndim):
if var.aval.shape[i] >= min_shard_size:
seen.update([graph.get_root((var, i))])

model_axis_root = max(seen, key=seen.get) if seen else None

data_axes_roots = []
for var in data_invars:
for i in range(var.aval.ndim):
root = graph.get_root((var, i))
if root not in seen and root not in data_axes_roots:
data_axes_roots.append(root)

def assign_layouts(vars_flat, is_params=False):
assignments = []
for var in vars_flat:
layout = [None] * var.aval.ndim
for i in range(var.aval.ndim):
if var.aval.shape[i] < min_shard_size:
continue
root = graph.get_root((var, i))
if (
is_params
and model_axis_root
and root == model_axis_root
):
layout[i] = model_axis_name
elif not is_params and root in data_axes_roots:
name = data_axis_name
if len(data_axes_roots) > 1:
name += str(data_axes_roots.index(root))
layout[i] = name
assignments.append(layout)
return assignments

params_assignments = tree_util.tree_unflatten(
t_params_treedef, assign_layouts(t_param_invars, is_params=True)
)
return params_assignments

class _JaxGraph:
def __init__(
self,
jaxpr,
trainable_variables,
non_trainable_variables,
in_treedefs,
out_avals,
):
self.jaxpr = jaxpr
self.trainable_variables = trainable_variables
self.non_trainable_variables = non_trainable_variables
self.in_treedefs = in_treedefs
self.out_avals = out_avals

@classmethod
def from_model(cls, model, *args, **kwargs):
def stateless_fn(
trainable_vars, non_trainable_vars, f_args, f_kwargs
):
return model.stateless_call(
trainable_vars, non_trainable_vars, *f_args, **f_kwargs
)

trainable_vars = model.trainable_variables
non_trainable_vars = model.non_trainable_variables
in_treedefs = tree_util.tree_structure(
(trainable_vars, non_trainable_vars, args, kwargs)
)

closed_jaxpr, out_avals = jax.make_jaxpr(
stateless_fn, return_shape=True
)(trainable_vars, non_trainable_vars, args, kwargs)

return cls(
closed_jaxpr,
trainable_vars,
non_trainable_vars,
in_treedefs,
out_avals,
)

class _JaxShardingPlanner:
def plan(self, graph, device_mesh):
all_in_avals = [var.aval for var in graph.jaxpr.jaxpr.invars]
all_in_leaves = tree_util.tree_unflatten(
graph.in_treedefs, all_in_avals
)
_, _, args_aval_tree, kwargs_aval_tree = all_in_leaves

dummy_args = tree_util.tree_map(
lambda x: np.zeros(x.shape, x.dtype), args_aval_tree
)
dummy_kwargs = tree_util.tree_map(
lambda x: np.zeros(x.shape, x.dtype), kwargs_aval_tree
)

param_assignments = shard_model(
jaxpr=graph.jaxpr,
out_avals=graph.out_avals,
trainable_params=graph.trainable_variables,
non_trainable_params=graph.non_trainable_variables,
args=dummy_args,
kwargs=dummy_kwargs,
)

param_vars_flat, _ = tree_util.tree_flatten(
graph.trainable_variables
)
param_layouts_flat, _ = tree_util.tree_flatten(param_assignments)

parameter_layout_dict = {
var.path: tuple(layout) if layout else None
for var, layout in zip(param_vars_flat, param_layouts_flat)
}
return parameter_layout_dict

class _JaxShardApplier:
def apply(self, model, plan):
for var in model.variables:
layout = plan.get(var.path)
if layout:
var.layout = layout

JaxGraph = _JaxGraph
JaxShardingPlanner = _JaxShardingPlanner
JaxShardApplier = _JaxShardApplier
_JAX_CLASSES_DEFINED = True


def get_sharding_planner():
"""Returns an instance of the JAX sharding planner."""
_define_and_register_jax_classes()
return JaxShardingPlanner()


def get_shard_applier():
"""Returns an instance of the JAX shard applier."""
_define_and_register_jax_classes()
return JaxShardApplier()


def create_graph_from_model(model, *args, **kwargs):
"""Returns a JAX graph representation of the Keras model."""
_define_and_register_jax_classes()
return JaxGraph.from_model(model, *args, **kwargs)
Loading