From 17ba19f6e69a26e74bb214e69e8b99b8849a27a7 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 12 Feb 2026 16:59:40 -0600 Subject: [PATCH 1/2] Track explicit sharding in Variables --- flax/nnx/spmd.py | 10 +++++----- flax/nnx/variablelib.py | 5 +++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index cd4a215b6..be3d0d1b7 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -74,10 +74,10 @@ def remove_field(fields, index, value): def _remove_axis(x: tp.Any): if isinstance(x, variablelib.Variable): - if hasattr(x, 'out_sharding') and x.out_sharding is not None: - x.set_metadata( - out_sharding=remove_field(x.out_sharding, index, axis_name) - ) + if hasattr(x, 'out_sharding') and x.out_sharding: + x.set_metadata( + out_sharding=remove_field(x.out_sharding, index, axis_name) + ) for k, v in other_meta.items(): if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple): @@ -174,4 +174,4 @@ def get_abstract_model(init_fn, mesh): lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), abs_state, get_named_sharding(abs_state, mesh) ) - return gdef, abs_state \ No newline at end of file + return gdef, abs_state diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 9a3f2b536..6c021ef97 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -28,6 +28,7 @@ from flax.core import spmd as core_spmd from flax.nnx import reprlib, tracers, visualization from flax.typing import MISSING, Missing, SizeBytes +from jax.sharding import NamedSharding import jax from jax._src.state.types import AbstractRef import jax.experimental @@ -1418,6 +1419,10 @@ def __init__( if ref: value = jax.new_ref(value) # type: ignore object.__setattr__(self, '_raw_value', value) + if (('out_sharding' not in metadata) and hasattr(value, 'sharding') and + isinstance(value.sharding, NamedSharding) and any(value.sharding.spec)): + metadata['out_sharding'] = value.sharding.spec + metadata['mesh'] = value.sharding.mesh @property def _can_update(self) -> bool: From eecb11311ffc5dd6b3b5bfa3d70941659c78a235 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 17 Feb 2026 12:41:08 -0600 Subject: [PATCH 2/2] Add support for setting values with different sharding --- flax/nnx/variablelib.py | 8 ++++++++ tests/nnx/spmd_test.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 6c021ef97..62c562404 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -1766,6 +1766,14 @@ def set_value(self, value: A, *, index: tp.Any = MISSING): else: self._raw_value[index] = value # type: ignore + if ( + hasattr(self._raw_value, 'sharding') + and isinstance(self._raw_value.sharding, NamedSharding) + and any(self._raw_value.sharding.spec) + ): + self._var_metadata['out_sharding'] = self._raw_value.sharding.spec + self._var_metadata['mesh'] = self._raw_value.sharding.mesh + def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): if 'on_add_axis' in self._var_metadata: self._var_metadata['on_add_axis'](self, axis_index, axis_name) diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 0a86505fe..7d38a8fd2 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -61,6 +61,41 @@ def create_module(): assert m.w.shape == (8, 2) assert m.w.sharding.shard_shape(m.w.shape) == (4, 1) + def test_init_with_sharded_value(self): + mesh = jax.make_mesh( + (2, 2), + ('model', 'data'), + axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto), + ) + + with jax.set_mesh(mesh): + data = jax.device_put(jnp.ones((4, 4)), NamedSharding(mesh, P('model', 'data'))) + v = nnx.Variable(data) + + self.assertEqual(v.out_sharding, P('model', 'data')) + self.assertEqual(v.mesh, mesh) + + def test_set_value_with_different_sharding(self): + mesh = jax.make_mesh( + (2, 2), + ('model', 'data'), + axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto), + ) + + with jax.set_mesh(mesh): + data = jax.device_put(jnp.ones((4, 4)), NamedSharding(mesh, P('model', 'data'))) + v = nnx.Variable(data) + + self.assertEqual(v.out_sharding, P('model', 'data')) + self.assertEqual(v.mesh, mesh) + + with jax.set_mesh(mesh): + new_data = jax.device_put(jnp.zeros((4, 4)), NamedSharding(mesh, P('data', 'model'))) + v.set_value(new_data) + + self.assertEqual(v.out_sharding, P('data', 'model')) + self.assertEqual(v.mesh, mesh) + def test_init_all_devices(self): class Foo(nnx.Module): def __init__(self):