Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def remove_field(fields, index, value):

def _remove_axis(x: tp.Any):
if isinstance(x, variablelib.Variable):
if hasattr(x, 'out_sharding') and x.out_sharding is not None:
x.set_metadata(
out_sharding=remove_field(x.out_sharding, index, axis_name)
)
if hasattr(x, 'out_sharding') and x.out_sharding:
x.set_metadata(
out_sharding=remove_field(x.out_sharding, index, axis_name)
)

for k, v in other_meta.items():
if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
Expand Down Expand Up @@ -174,4 +174,4 @@ def get_abstract_model(init_fn, mesh):
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
abs_state, get_named_sharding(abs_state, mesh)
)
return gdef, abs_state
return gdef, abs_state
13 changes: 13 additions & 0 deletions flax/nnx/variablelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from flax.core import spmd as core_spmd
from flax.nnx import reprlib, tracers, visualization
from flax.typing import MISSING, Missing, SizeBytes
from jax.sharding import NamedSharding
import jax
from jax._src.state.types import AbstractRef
import jax.experimental
Expand Down Expand Up @@ -1418,6 +1419,10 @@ def __init__(
if ref:
value = jax.new_ref(value) # type: ignore
object.__setattr__(self, '_raw_value', value)
if (('out_sharding' not in metadata) and hasattr(value, 'sharding') and
isinstance(value.sharding, NamedSharding) and any(value.sharding.spec)):
metadata['out_sharding'] = value.sharding.spec
metadata['mesh'] = value.sharding.mesh

@property
def _can_update(self) -> bool:
Expand Down Expand Up @@ -1761,6 +1766,14 @@ def set_value(self, value: A, *, index: tp.Any = MISSING):
else:
self._raw_value[index] = value # type: ignore

if (
hasattr(self._raw_value, 'sharding')
and isinstance(self._raw_value.sharding, NamedSharding)
and any(self._raw_value.sharding.spec)
):
self._var_metadata['out_sharding'] = self._raw_value.sharding.spec
self._var_metadata['mesh'] = self._raw_value.sharding.mesh

def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_add_axis' in self._var_metadata:
self._var_metadata['on_add_axis'](self, axis_index, axis_name)
Expand Down
35 changes: 35 additions & 0 deletions tests/nnx/spmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,41 @@ def create_module():
assert m.w.shape == (8, 2)
assert m.w.sharding.shard_shape(m.w.shape) == (4, 1)

def test_init_with_sharded_value(self):
mesh = jax.make_mesh(
(2, 2),
('model', 'data'),
axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto),
)

with jax.set_mesh(mesh):
data = jax.device_put(jnp.ones((4, 4)), NamedSharding(mesh, P('model', 'data')))
v = nnx.Variable(data)

self.assertEqual(v.out_sharding, P('model', 'data'))
self.assertEqual(v.mesh, mesh)

def test_set_value_with_different_sharding(self):
mesh = jax.make_mesh(
(2, 2),
('model', 'data'),
axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto),
)

with jax.set_mesh(mesh):
data = jax.device_put(jnp.ones((4, 4)), NamedSharding(mesh, P('model', 'data')))
v = nnx.Variable(data)

self.assertEqual(v.out_sharding, P('model', 'data'))
self.assertEqual(v.mesh, mesh)

with jax.set_mesh(mesh):
new_data = jax.device_put(jnp.zeros((4, 4)), NamedSharding(mesh, P('data', 'model')))
v.set_value(new_data)

self.assertEqual(v.out_sharding, P('data', 'model'))
self.assertEqual(v.mesh, mesh)

def test_init_all_devices(self):
class Foo(nnx.Module):
def __init__(self):
Expand Down
Loading