Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update NNX State.state.filter method docs in statelib.py #4452

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions flax/nnx/statelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,11 @@ def filter(
/,
*filters: filterlib.Filter,
) -> tp.Union[State[K, V], tuple[State[K, V], ...]]:
"""Filter a ``State`` into one or more ``State``'s. The
user must pass at least one ``Filter`` (i.e. :class:`Variable`).
This method is similar to :meth:`split() <flax.nnx.State.state.split>`,
except the filters can be non-exhaustive.
"""Filters a :class:`flax.nnx.State` into one or more ``nnx.State``'s.
You must pass at least one NNX ``Filter`` (``flax.nnx.filterlib``)
(i.e. :class:`flax.nnx. Variable`).
This method is similar to :func:`flax.nnx.State.state.split`,
except the ``Filter``'s can be non-exhaustive.

Example usage::

Expand All @@ -351,10 +352,11 @@ def filter(
>>> param, batch_stats = state.filter(nnx.Param, nnx.BatchStat)

Arguments:
first: The first filter
*filters: The optional, additional filters to group the state into mutually exclusive substates.
first: The first NNX ``Filter``.
*filters: The optional, additional NNX ``Filter``'s to group the :class:`flax.nnx.State`
into mutually exclusive sub-``State``'s.
Returns:
One or more ``States`` equal to the number of filters passed.
One or more ``nnx.State``'s equal to the number of NNX ``Filter``'s passed.
"""
*states_, _rest = _split_state(self.flat_state(), first, *filters)

Expand Down Expand Up @@ -492,4 +494,4 @@ def create_path_filters(state: State):
if isinstance(value, (variablelib.Variable, variablelib.VariableState)):
value = value.value
value_paths.setdefault(value, set()).add(path)
return {filterlib.PathIn(*value_paths[value]): value for value in value_paths}
return {filterlib.PathIn(*value_paths[value]): value for value in value_paths}
Loading