Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,17 +300,6 @@ def get_node_impl_for_type(
else:
return None

# use type-aware sorting to support int keys
def _type_aware_sort(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
key, _ = item
if isinstance(key, int):
return (0, key)
elif isinstance(key, str):
return (1, key)
else:
raise ValueError(f'Unsupported key type: {type(key)!r}')


@jax.tree_util.register_static
@dataclasses.dataclass(frozen=True, repr=False)
class NodeRef(tp.Generic[Node], reprlib.Representable):
Expand Down
41 changes: 23 additions & 18 deletions flax/nnx/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,20 +117,17 @@ def __init__(self, it: tp.Iterable[A] | None = None, /):
for value in it:
self.append(value)

def _get_elem(self, key: int) -> A:
return getattr(self, str(key))
def _getattr(self, key) -> A:
return vars(self)[key] # type: ignore[unsupported-operands]

def _set_elem(self, key: int, value: A) -> None:
setattr(self, str(key), value)

def _del_elem(self, key: int) -> None:
delattr(self, str(key))
def _delattr(self, key) -> None:
vars(self).pop(key)

def __len__(self) -> int:
return self._length

def append(self, value: A) -> None:
self._set_elem(self._length, value)
self._setattr(self._length, value)
self._length += 1

def insert(self, index: int, value: A) -> None:
Expand All @@ -143,15 +140,15 @@ def insert(self, index: int, value: A) -> None:

# Shift elements to the right
for i in range(self._length, index, -1):
self._set_elem(i, self._get_elem(i - 1))
self._setattr(i, self._getattr(i - 1))

# Insert the new value
self._set_elem(index, value)
self._setattr(index, value)
self._length += 1

def __iter__(self) -> tp.Iterator[A]:
for i in range(self._length):
yield self._get_elem(i)
yield self._getattr(i)

@tp.overload
def __getitem__(self, index: int) -> A: ...
Expand All @@ -163,10 +160,10 @@ def __getitem__(self, index: int | slice) -> A | tp.List[A]:
index += self._length
if index < 0 or index >= self._length:
raise IndexError('Index out of bounds')
return self._get_elem(index)
return self._getattr(index)
elif isinstance(index, slice):
idxs = list(range(self._length))[index]
return [self._get_elem(i) for i in idxs]
return [self._getattr(i) for i in idxs]
else:
raise TypeError('Invalid index type')

Expand All @@ -176,7 +173,7 @@ def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -> None:
index += self._length
if index < 0 or index >= self._length:
raise IndexError('Index out of bounds')
self._set_elem(index, value) # type: ignore[arg-type]
self._setattr(index, value)
elif isinstance(index, slice):
if not isinstance(value, tp.Iterable):
raise TypeError('Expected an iterable')
Expand All @@ -185,7 +182,7 @@ def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -> None:
if len(idxs) != len(values):
raise ValueError('Length mismatch')
for i, v in zip(idxs, values):
self._set_elem(i, v)
self._setattr(i, v)
else:
raise TypeError('Invalid index type')

Expand All @@ -206,9 +203,9 @@ def __delitem__(self, index: int | slice) -> None:
index += self._length
if index < 0 or index >= self._length:
raise IndexError('Index out of bounds')
self._del_elem(index)
self._delattr(index)
for i in range(index + 1, self._length):
self._set_elem(i - 1, self._get_elem(i))
self._setattr(i - 1, self._getattr(i))
self._length -= 1
elif isinstance(index, slice):
idxs = list(range(self._length))[index]
Expand All @@ -218,7 +215,15 @@ def __delitem__(self, index: int | slice) -> None:
else:
raise TypeError('Invalid index type')

_pytree__has_int_keys = True
@staticmethod
def _pytree__key_sort_fn(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
key, _ = item
if isinstance(key, int):
return (0, key)
elif isinstance(key, str):
return (1, key)
else:
raise ValueError(f'Unsupported key type: {type(key)!r}')


class Sequential(Module):
Expand Down
112 changes: 41 additions & 71 deletions flax/nnx/pytreelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,18 +407,16 @@ def _pytree_meta_construct(cls, self, *args, **kwargs):
def _graph_node_meta_call(cls: tp.Type[P], *args, **kwargs) -> P:
node = cls.__new__(cls, *args, **kwargs)
vars_obj = vars(node)
object.__setattr__(node, '_pytree__state', PytreeState())
object.__setattr__(node, '_pytree__nodes', cls._pytree__nodes)
vars_obj['_pytree__state'] = PytreeState()
vars_obj['_pytree__nodes'] = cls._pytree__nodes
cls._pytree_meta_construct(node, *args, **kwargs)
if cls._pytree__is_pytree:
missing: dict[str, bool] = {}
for name, value in vars(node).items():
if name not in node._pytree__nodes:
if name not in vars_obj['_pytree__nodes']:
missing[name] = is_data(value)
if missing:
object.__setattr__(
node, '_pytree__nodes', node._pytree__nodes.update(missing)
)
vars_obj['_pytree__nodes'] = vars_obj['_pytree__nodes'].update(missing)
check_pytree(node)

return node
Expand Down Expand Up @@ -639,10 +637,11 @@ def _setattr(self, name, value: tp.Any) -> None:
if name not in self._pytree__nodes or (
explicit and self._pytree__nodes[name] != data
):
object.__setattr__(
self, '_pytree__nodes', self._pytree__nodes.update({name: data})
)
object.__setattr__(self, name, value)
vars(self)['_pytree__nodes'] = self._pytree__nodes.update({name: data})
if isinstance(name, str):
object.__setattr__(self, name, value)
else:
vars(self)[name] = value

def _check_value(self, key, value, new_status: AttributeStatus | None):
def _has_data(leaves):
Expand Down Expand Up @@ -859,26 +858,20 @@ def __getstate__(self):
return vars(self).copy()

def __setstate__(self, state):
for key, value in state.items():
object.__setattr__(self, key, value)
vars(self).update(state)

# -------------------------
# Pytree Definition
# -------------------------
_pytree__has_int_keys: bool = False
_pytree__key_sort_fn: tp.Callable | None = None

def _pytree__flatten_with_paths(self):
obj_items = vars(self).items()
if self._pytree__has_int_keys:
obj_items = ((_maybe_int(name), value) for name, value in obj_items)
key_fn = graph._type_aware_sort
else:
key_fn = None
obj_vars = vars(self)
node_attributes = self._pytree__nodes
node_names: list[str] = []
node_attrs: list[tuple[tp.Any, tp.Any]] = []
static_attrs: list[tuple[str, tp.Any]] = []
for name, value in sorted(obj_items, key=key_fn):
for name, value in sorted(obj_vars.items(), key=self._pytree__key_sort_fn):
if name in node_attributes and node_attributes[name]:
node_names.append(name)
node_attrs.append((
Expand All @@ -893,17 +886,12 @@ def _pytree__flatten_with_paths(self):
return node_attrs, (tuple(node_names), tuple(static_attrs))

def _pytree__flatten(self):
obj_items = vars(self).items()
if self._pytree__has_int_keys:
obj_items = ((_maybe_int(name), value) for name, value in obj_items)
key_fn = graph._type_aware_sort
else:
key_fn = None
obj_vars = vars(self)
node_attributes = self._pytree__nodes
node_names: list[str] = []
node_attrs: list[tp.Any] = []
static_attrs: list[tuple[str, tp.Any]] = []
for name, value in sorted(obj_items, key=key_fn):
for name, value in sorted(obj_vars.items(), key=self._pytree__key_sort_fn):
if name in node_attributes and node_attributes[name]:
node_names.append(name)
node_attrs.append(value)
Expand All @@ -921,47 +909,42 @@ def _pytree__unflatten(
node_names, static_attrs = static
obj = object.__new__(cls)
vars_obj = vars(obj)
if cls._pytree__has_int_keys:
node_names = tuple(
str(name) if isinstance(name, int) else name for name in node_names
)
for name, value in zip(node_names, node_attrs, strict=True):
object.__setattr__(obj, name, value)
for name, value in static_attrs:
object.__setattr__(obj, name, value)
vars_obj.update(zip(node_names, node_attrs, strict=True))
vars_obj.update(static_attrs)
return obj

# -------------------------
# Graph Definition
# -------------------------
def _graph_node_flatten(self):
obj_items = vars(self).items()
if self._pytree__has_int_keys:
obj_items = ((_maybe_int(name), value) for name, value in obj_items)
key_fn = graph._type_aware_sort
else:
key_fn = None
nodes = sorted(obj_items, key=key_fn)
nodes = vars(self)
nodes = sorted(nodes.items(), key=self._pytree__key_sort_fn)
return nodes, type(self)

def _graph_node_set_key(self, key, value: tp.Any):
if self._pytree__has_int_keys and isinstance(key, int):
key = str(key)
setattr(self, key, value)
def _graph_node_set_key(self, key: str, value: tp.Any):
if not isinstance(key, str):
raise KeyError(f'Invalid key: {key!r}')
elif (
hasattr(self, key)
and isinstance(variable := getattr(self, key), Variable)
and isinstance(value, Variable)
):
variable.update_from_state(value)
else:
setattr(self, key, value)

def _graph_node_pop_key(self, key):
if self._pytree__has_int_keys and isinstance(key, int):
key = str(key)
value = getattr(self, key)
def _graph_node_pop_key(self, key: str):
if not isinstance(key, str):
raise KeyError(f'Invalid key: {key!r}')
delattr(self, key)
return value
return self

def __delattr__(self, name: str) -> None:
if name in self._pytree__nodes:
mapping = {k: v for k, v in self._pytree__nodes.items() if k != name}
object.__setattr__(
self, '_pytree__nodes', graph.HashableMapping(mapping, copy=False)
)
if isinstance(self._pytree__nodes._mapping, tp.MutableMapping):
del self._pytree__nodes._mapping[name]
else:
self._pytree__nodes._mapping = {k: v for k, v in self._pytree__nodes.items() if k != name}

super().__delattr__(name)

Expand All @@ -971,17 +954,10 @@ def _graph_node_create_empty(node_type: tp.Type[P]) -> P:
return node

def _graph_node_clear(self):
for name in list(vars(self)):
delattr(self, name)
vars(self).clear()

def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]):
if self._pytree__has_int_keys:
attributes = (
(str(name) if isinstance(name, int) else name, value)
for name, value in attributes
)
for name, value in attributes:
object.__setattr__(self, name, value)
vars(self).update(attributes)

if tp.TYPE_CHECKING:
def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: ...
Expand All @@ -997,10 +973,4 @@ def __init_subclass__(cls, **kwargs):
"Object is not a pytree, but 'pytree' was explicitly set to "
f'{pytree!r} for type {cls}.'
)
super().__init_subclass__(pytree=pytree, **kwargs)

def _maybe_int(x):
try:
return int(x)
except (ValueError, TypeError):
return x
super().__init_subclass__(pytree=pytree, **kwargs)
3 changes: 3 additions & 0 deletions tests/nnx/partitioning_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def test_get_paritition(self):
d=5.0,
)

# test Variables not shared
self.assertIsNot(vars(m.a)[0], vars(m)['b'])

state = nnx.state(m, nnx.Variable)
self.assertEqual(state['a'][0][...], m.a[0][...])
self.assertEqual(state['a'][1][...], m.a[1][...])
Expand Down
Loading