Skip to content

Commit

Permalink
Update NNX Module class docs in module.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 16, 2024
1 parent 6bc9858 commit cccc806
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ class ModuleMeta(ObjectMeta):
class Module(Object, metaclass=ModuleMeta):
"""Base class for all neural network modules.
Layers and models should subclass this class.
Flax NNX layers and models should subclass this :class`flax.nnx.Module` class.
``Module``'s can contain submodules, and in this way can be nested in a tree
structure. Submodules can be assigned as regular attributes inside the
``__init__`` method.
An ``nnx.Module`` can contain sub-``Module``'s, allowing them to be nested in a
JAX pytree-like structure. Sub-``Module``'s can be assigned as regular attributes
inside the ``__init__`` method.
You can define arbitrary "forward pass" methods on your ``Module`` subclass.
You can define arbitrary "forward pass" methods on your ``nnx.Module`` subclass.
While no methods are special-cased, ``__call__`` is a popular choice since
you can call the ``Module`` directly::
you can call the ``nnx.Module`` directly::
>>> from flax import nnx
>>> import jax.numpy as jnp
Expand Down

0 comments on commit cccc806

Please sign in to comment.