From cccc8064d3ec6d3a2580b03a80ef4e189f714b2d Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 16 Dec 2024 00:22:35 +0000 Subject: [PATCH] Update NNX Module class docs in module.py --- flax/nnx/module.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/flax/nnx/module.py b/flax/nnx/module.py index 795bb9a088..d63201edd7 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -51,15 +51,15 @@ class ModuleMeta(ObjectMeta): class Module(Object, metaclass=ModuleMeta): """Base class for all neural network modules. - Layers and models should subclass this class. + Flax NNX layers and models should subclass this :class`flax.nnx.Module` class. - ``Module``'s can contain submodules, and in this way can be nested in a tree - structure. Submodules can be assigned as regular attributes inside the - ``__init__`` method. + An ``nnx.Module`` can contain sub-``Module``'s, allowing them to be nested in a + JAX pytree-like structure. Sub-``Module``'s can be assigned as regular attributes + inside the ``__init__`` method. - You can define arbitrary "forward pass" methods on your ``Module`` subclass. + You can define arbitrary "forward pass" methods on your ``nnx.Module`` subclass. While no methods are special-cased, ``__call__`` is a popular choice since - you can call the ``Module`` directly:: + you can call the ``nnx.Module`` directly:: >>> from flax import nnx >>> import jax.numpy as jnp