Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions flax/nnx/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,21 @@ def to_value_metadata(x):


def _to_variable(node):
# import here to avoid circular imports
from flax.nnx.spmd import get_var_pspec

def to_variable(x):
if isinstance(x, ValueMetadata):
var = x.var_type._new(x.value, x.metadata)

global_mesh = jax.sharding.get_mesh()
if global_mesh.axis_sizes == ():
global_mesh = None
mesh = var.get_metadata("mesh", None) or global_mesh
if mesh is not None:
pspec = get_var_pspec(var)
sharding = jax.sharding.NamedSharding(mesh=mesh, spec=pspec)
var.set_value(jax.ShapeDtypeStruct(shape=var.shape, dtype=var.dtype, sharding=sharding))
Comment on lines +241 to +248
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for applying sharding within _to_variable seems to correctly handle cases where a mesh is explicitly provided in metadata or a global mesh is available. This ensures that jax.ShapeDtypeStruct is created with the appropriate sharding information.

return var
return x

Expand Down
30 changes: 30 additions & 0 deletions tests/nnx/spmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,36 @@ def test_sharding_axis_types(self, mode):
self.assertEqual(v.sharding.mesh, mesh)
self.assertEqual(v.sharding.spec, P('row', 'col'))

def test_eval_shape_with_sharding0(self):
# based on https://github.com/google/flax/issues/5110
mesh1 = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto))
mesh2 = jax.make_mesh((1, 4), ("c", "d"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto))

class Model(nnx.Module):
def __init__(self):
self.p1 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"out_sharding": ("a", "b"), "mesh": mesh1})
self.p2 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"out_sharding": ("c", "d"), "mesh": mesh2})

abs_model = nnx.eval_shape(lambda: Model())
assert isinstance(abs_model.p1.kernel.sharding, jax.sharding.NamedSharding)
assert abs_model.p1.kernel.sharding.mesh is mesh1
assert abs_model.p1.kernel.sharding.spec == jax.P("a", "b")
assert isinstance(abs_model.p2.kernel.sharding, jax.sharding.NamedSharding)
assert abs_model.p2.kernel.sharding.mesh is mesh2
assert abs_model.p2.kernel.sharding.spec == jax.P("c", "d")
Comment on lines +362 to +378
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test_eval_shape_with_sharding0 test case effectively verifies the sharding propagation when meshes are explicitly defined in the kernel_metadata for different linear layers. This is a good test for ensuring that custom mesh configurations are respected.


def test_eval_shape_with_sharding1(self):
class Model(nnx.Module):
def __init__(self):
self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0), kernel_metadata={"out_sharding": ("a", "b")})

mesh = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto))
with jax.set_mesh(mesh):
abs_model = nnx.eval_shape(lambda: Model())
assert isinstance(abs_model.linear.kernel.sharding, jax.sharding.NamedSharding)
assert abs_model.linear.kernel.sharding.mesh is mesh
assert abs_model.linear.kernel.sharding.spec == jax.P("a", "b")
Comment on lines +380 to +390
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test_eval_shape_with_sharding1 test case correctly validates sharding propagation when a global mesh is set using jax.set_mesh. This covers a common use case where sharding is implicitly applied based on the global mesh context.


def has_sharding_spec(array):
sharding = array.sharding
if hasattr(sharding, 'spec'):
Expand Down
Loading