Skip to content

Commit

Permalink
[nnx] fast jit
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Dec 21, 2024
1 parent 53bde74 commit d192618
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 76 deletions.
14 changes: 9 additions & 5 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class GraphDefState(struct.PyTreeNode):

class NodeStates(struct.PyTreeNode):
_graphdef: graph.GraphDef[tp.Any] | None
states: tuple[graph.GraphState, ...]
states: tuple[graph.GraphState | graph.GraphFlatState, ...]
metadata: tp.Any = struct.field(pytree_node=False)

@property
Expand All @@ -264,7 +264,7 @@ def graphdef(self) -> graph.GraphDef[tp.Any]:
return self._graphdef

@property
def state(self) -> graph.GraphState:
def state(self) -> graph.GraphState | graph.GraphFlatState:
if len(self.states) != 1:
raise ValueError(
f'Expected exactly one GraphDefState, got {len(self.states)}'
Expand All @@ -275,15 +275,19 @@ def state(self) -> graph.GraphState:
def from_split(
cls,
graphdef: graph.GraphDef[tp.Any],
state: graph.GraphState,
state: graph.GraphState | graph.GraphFlatState,
/,
*states: graph.GraphState,
*states: graph.GraphState | graph.GraphFlatState,
metadata: tp.Any = None,
):
return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata)

@classmethod
def from_states(cls, state: graph.GraphState, *states: graph.GraphState):
def from_states(
cls,
state: graph.GraphState | graph.GraphFlatState,
*states: graph.GraphState | graph.GraphFlatState,
):
return cls(_graphdef=None, states=(state, *states), metadata=None)

@classmethod
Expand Down
115 changes: 100 additions & 15 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
CallableProxy,
DelayedAccessor,
)
from flax.nnx.statelib import State
from flax.nnx.statelib import FlatState, State
from flax.nnx import variablelib
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key, PathParts, is_key_like
Expand All @@ -53,6 +53,7 @@
StateLeaf = VariableState[tp.Any]
NodeLeaf = Variable[tp.Any]
GraphState = State[Key, StateLeaf]
GraphFlatState = FlatState[StateLeaf]


def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
Expand Down Expand Up @@ -377,7 +378,9 @@ def _apply(
module = merge(self, state, *states)
fn = accessor(module)
out = fn(*args, **kwargs)
return out, flatten(module)
graphdef, flat_state = flatten(module)
state_ = State.from_flat_path(flat_state)
return out, (graphdef, state_)

return CallableProxy(_apply, accessor) # type: ignore

Expand All @@ -389,7 +392,7 @@ def _apply(

def flatten(
node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None
) -> tuple[GraphDef[Node], GraphState]:
) -> tuple[GraphDef[Node], FlatState[tp.Any]]:
"""Flattens a graph node into a (graphdef, state) pair.
Args:
Expand All @@ -402,7 +405,7 @@ def flatten(
ref_index = RefMap()
flat_state: list[tuple[PathParts, StateLeaf]] = []
graphdef = _graph_flatten((), ref_index, flat_state, node)
return graphdef, GraphState.from_flat_path(flat_state)
return graphdef, FlatState(flat_state)


def _graph_flatten(
Expand Down Expand Up @@ -811,8 +814,11 @@ def split(
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
graphdef, state = flatten(node, self.ref_index)
states = _split_state(state, filters)
graphdef, flat_state = flatten(node, self.ref_index)
flat_states = _split_state(flat_state, filters)
states = tuple(
State.from_flat_path(flat_state) for flat_state in flat_states
)
if ctx is not None:
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
Expand All @@ -822,6 +828,47 @@ def split(

return graphdef, *states

@tp.overload
def flatten(
self, graph_node: A, /
) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ...
@tp.overload
def flatten(
self, graph_node: A, first: filterlib.Filter, /
) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ...
@tp.overload
def flatten(
self,
graph_node: A,
first: filterlib.Filter,
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[
GraphDef[A],
FlatState[VariableState[tp.Any]],
tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]],
]: ...
def flatten(
self, node: A, *filters: filterlib.Filter
) -> tuple[
GraphDef[A], tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]]
]:
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
graphdef, flat_state = flatten(node, self.ref_index)
flat_states = _split_state(flat_state, filters)

if ctx is not None:
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=HashableMapping(index_to_index, copy=False)
)

return graphdef, *flat_states


@contextlib.contextmanager
def split_context(ctxtag: str | None = None):
Expand Down Expand Up @@ -874,6 +921,39 @@ def merge(
)
return node

def unflatten(
self,
graphdef: GraphDef[A],
flat_state: GraphFlatState,
/,
*flat_states: GraphFlatState,
) -> A:
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
if (
ctx is not None
and isinstance(graphdef, NodeDef)
and graphdef.index_mapping is not None
):
# outer merge (4), create index_ref_cache
assert ctx.ref_index is not None
index_ref_cache = compose_mapping_reversed(
ctx.ref_index, graphdef.index_mapping
)
else:
# inner merge (2)
index_ref_cache = None

state = FlatState.merge(flat_state, *flat_states).to_nested_state()
node = unflatten(
graphdef,
state,
index_ref=self.index_ref,
index_ref_cache=index_ref_cache,
)
return node


@contextlib.contextmanager
def merge_context(ctxtag: str | None = None):
Expand Down Expand Up @@ -1001,9 +1081,11 @@ def split(
filters are passed, a single :class:`State` is returned.
"""
ref_index: RefMap[tp.Any, Index] = RefMap()
graphdef, state = flatten(node, ref_index)
states = _split_state(state, filters)

graphdef, flat_state = flatten(node, ref_index)
states = tuple(
State.from_flat_path(flat_state)
for flat_state in _split_state(flat_state, filters)
)
if self.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(self.index_ref, ref_index)
graphdef = dataclasses.replace(
Expand Down Expand Up @@ -1195,13 +1277,13 @@ def current_update_context(tag: str) -> UpdateContext:
# --------------------------------------------------------

def _split_state(
state: GraphState,
state: FlatState[tp.Any],
filters: tuple[filterlib.Filter, ...],
) -> tuple[GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
) -> tuple[FlatState[tp.Any], tpe.Unpack[tuple[FlatState[tp.Any], ...]]]:
if not filters:
return (state,)
states = state.split(*filters)
if isinstance(states, State):
if not isinstance(states, tuple):
return (states,)
assert len(states) > 0
return states # type: ignore[return-value]
Expand Down Expand Up @@ -1292,9 +1374,11 @@ def split(
``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no
filters are passed, a single ``State`` is returned.
"""
graphdef, state = flatten(node)
states = _split_state(state, filters)
return graphdef, *states
graphdef, flat_state = flatten(node)
flat_states = _split_state(flat_state, filters)
states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
return graphdef, *states # type: ignore[return-value]


def merge(
graphdef: GraphDef[A],
Expand Down Expand Up @@ -1486,6 +1570,7 @@ def state(
One or more :class:`State` mappings.
"""
_, state = flatten(node)
state = state.to_nested_state()

states: GraphState | tuple[GraphState, ...]
if len(filters) == 0:
Expand Down
8 changes: 8 additions & 0 deletions flax/nnx/reprlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ def __nnx_repr__(self):
for key, value in self.items():
yield Attr(repr(key), value)

class SequenceReprMixin(tp.Sequence[A], Representable):
def __nnx_repr__(self):
yield Object(type='', value_sep='', start='[', end=']')

for value in self:
yield Attr('', value)


@dataclasses.dataclass(repr=False)
class PrettyMapping(Representable):
mapping: tp.Mapping
Expand Down
91 changes: 86 additions & 5 deletions flax/nnx/statelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __treescope_repr__(self, path, subtree_renderer):
# Render as the dictionary itself at the same path.
return subtree_renderer(children, path=path)

class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.PrettySequence):
class FlatState(reprlib.SequenceReprMixin[tuple[PathParts, V]]):
_keys: tuple[PathParts, ...]
_values: list[V]

Expand Down Expand Up @@ -83,6 +83,85 @@ def __len__(self) -> int:
def __iter__(self) -> tp.Iterator[tuple[PathParts, V]]:
return iter(zip(self._keys, self._values))

def to_nested_state(self) -> State[PathParts, V]:
return State.from_flat_path(self)

@tp.overload
def split(self, first: filterlib.Filter, /) -> FlatState[V]: ...

@tp.overload
def split(
self,
first: filterlib.Filter,
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[FlatState[V], ...]: ...

@tp.overload
def split(
self, /, *filters: filterlib.Filter
) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]: ...

def split( # type: ignore[misc]
self, first: filterlib.Filter, /, *filters: filterlib.Filter
) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]:
filters = (first, *filters)
*flat_states_, rest = _split_state(self, *filters)

if rest:
raise ValueError(
'Non-exhaustive filters, got a non-empty remainder: '
f'{rest}.\nUse `...` to match all remaining elements.'
)

flat_states: FlatState[V] | tuple[FlatState[V], ...]
if len(flat_states_) == 1:
flat_states = flat_states_[0]
else:
flat_states = tuple(flat_states_)
return flat_states # type: ignore

@tp.overload
def filter(self, first: filterlib.Filter, /) -> FlatState[V]: ...

@tp.overload
def filter(
self,
first: filterlib.Filter,
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[FlatState[V], ...]: ...

def filter(
self,
first: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]:
*flat_states_, _rest = _split_state(self, first, *filters)

assert len(flat_states_) == len(filters) + 1

flat_states: FlatState[V] | tuple[FlatState[V], ...]
if len(flat_states_) == 1:
flat_states = flat_states_[0]
else:
flat_states = tuple(flat_states_)

return flat_states # type: ignore

@staticmethod
def merge(
flat_state: tp.Iterable[tuple[PathParts, V]],
/,
*flat_states: tp.Iterable[tuple[PathParts, V]],
) -> FlatState[V]:
flat_states = (flat_state, *flat_states)

return FlatState(elem for flat_state in flat_states for elem in flat_state)


def _flat_state_pytree_flatten(x: FlatState[V]):
return x._values, x._keys
Expand Down Expand Up @@ -291,7 +370,8 @@ def split( # type: ignore[misc]
One or more ``States`` equal to the number of filters passed.
"""
filters = (first, *filters)
*states_, rest = _split_state(self.flat_state(), *filters)
flat_states = _split_state(self.flat_state(), *filters)
*states_, rest = (state.to_nested_state() for state in flat_states)

if rest:
raise ValueError(
Expand Down Expand Up @@ -356,7 +436,8 @@ def filter(
Returns:
One or more ``States`` equal to the number of filters passed.
"""
*states_, _rest = _split_state(self.flat_state(), first, *filters)
flat_states = _split_state(self.flat_state(), first, *filters)
*states_, _rest = (state.to_nested_state() for state in flat_states)

assert len(states_) == len(filters) + 1

Expand Down Expand Up @@ -456,7 +537,7 @@ def _state_unflatten(
def _split_state(
flat_state: FlatState[V],
*filters: filterlib.Filter,
) -> tuple[State[PathParts, V], ...]:
) -> tuple[FlatState[V], ...]:
for i, filter_ in enumerate(filters):
if filter_ in (..., True) and i != len(filters) - 1:
remaining_filters = filters[i + 1 :]
Expand All @@ -482,7 +563,7 @@ def _split_state(
# if we didn't break, set leaf to last state
flat_states[-1].append((path, value)) # type: ignore[index] # mypy is wrong here?

return tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
return tuple(FlatState(flat_state) for flat_state in flat_states)


def create_path_filters(state: State):
Expand Down
Loading

0 comments on commit d192618

Please sign in to comment.