Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update NNX Module class docs in module.py #4442

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ class ModuleMeta(ObjectMeta):
class Module(Object, metaclass=ModuleMeta):
"""Base class for all neural network modules.

Layers and models should subclass this class.
Flax NNX layers and models should subclass this :class`flax.nnx.Module` class.

``Module``'s can contain submodules, and in this way can be nested in a tree
structure. Submodules can be assigned as regular attributes inside the
``__init__`` method.
An ``nnx.Module`` can contain sub-``Module``'s, allowing them to be nested in a
JAX pytree-like structure. Sub-``Module``'s can be assigned as regular attributes
inside the ``__init__`` method.

You can define arbitrary "forward pass" methods on your ``Module`` subclass.
You can define arbitrary "forward pass" methods on your ``nnx.Module`` subclass.
While no methods are special-cased, ``__call__`` is a popular choice since
you can call the ``Module`` directly::
you can call the ``nnx.Module`` directly::

>>> from flax import nnx
>>> import jax.numpy as jnp
Expand Down
Loading