Skip to content

Commit f4430b7

Browse files
committed
Track explicit sharding in Variables
1 parent 7811acc commit f4430b7

2 files changed

Lines changed: 10 additions & 5 deletions

File tree

flax/nnx/spmd.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ def remove_field(fields, index, value):
7474

7575
def _remove_axis(x: tp.Any):
7676
if isinstance(x, variablelib.Variable):
77-
if hasattr(x, 'out_sharding') and x.out_sharding is not None:
78-
x.set_metadata(
79-
out_sharding=remove_field(x.out_sharding, index, axis_name)
80-
)
77+
if hasattr(x, 'out_sharding') and x.out_sharding:
78+
x.set_metadata(
79+
out_sharding=remove_field(x.out_sharding, index, axis_name)
80+
)
8181

8282
for k, v in other_meta.items():
8383
if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
@@ -174,4 +174,4 @@ def get_abstract_model(init_fn, mesh):
174174
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
175175
abs_state, get_named_sharding(abs_state, mesh)
176176
)
177-
return gdef, abs_state
177+
return gdef, abs_state

flax/nnx/variablelib.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from flax.core import spmd as core_spmd
2929
from flax.nnx import reprlib, tracers, visualization
3030
from flax.typing import MISSING, Missing, SizeBytes
31+
from jax.sharding import NamedSharding
3132
import jax
3233
from jax._src.state.types import AbstractRef
3334
import jax.experimental
@@ -1418,6 +1419,10 @@ def __init__(
14181419
if ref:
14191420
value = jax.new_ref(value) # type: ignore
14201421
object.__setattr__(self, '_raw_value', value)
1422+
if ((not hasattr(metadata, 'out_sharding')) and hasattr(value, 'sharding') and
1423+
isinstance(value.sharding, NamedSharding) and any(value.sharding.spec)):
1424+
metadata['out_sharding'] = value.sharding.spec
1425+
metadata['mesh'] = value.sharding.mesh
14211426

14221427
@property
14231428
def _can_update(self) -> bool:

0 commit comments

Comments
 (0)