-
Notifications
You must be signed in to change notification settings - Fork 658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nnx.vmap example from documentation raise an index error #4355
Comments
Same behavior when updating to flax==0.10.0 and jax[cuda12_local]==0.4.35 |
Hey @jhn-nt, thanks for reporting this! Very curious why our CI is not failing. Easiest fix is to split the keys for the Rngs: keys = jax.random.split(jax.random.key(0), 5)
model = create_model(nnx.Rngs(keys)) Will fix this quickly. |
Thanks a lot again for the prompt help! Giovanni |
Oh wait, the link you posted is for the old experimental docs in the 0.8.3 version of the site, this is fixed in the new version: https://flax.readthedocs.io/en/latest/nnx_basics.html#scan-over-layers . Did you find this via Google? |
Uh I see, that explaines it then, apologies for opening the issue, But, yes, I find it through google, searching for "flax nnx" Giovanni |
I am encountering an index error when running this example in the documentation
I am running the code in a docker environment using an NVIDIA image for jax.
Best
Giovanni
System information
Problem you have encountered:
Error while going through nnx tutorial
What you expected to happen:
Logs, error messages, etc:
IndexError Traceback (most recent call last)
Cell In[8], line 23
19 @partial(nnx.vmap, axis_size=5)
20 def create_model(rngs: nnx.Rngs):
21 return MLP(10, 32, 10, rngs=rngs)
---> 23 model = create_model(nnx.Rngs(0))
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/graph.py:1158, in UpdateContextManager.call..update_context_manager_wrapper(*args, **kwargs)
1155 @functools.wraps(f)
1156 def update_context_manager_wrapper(*args, **kwargs):
1157 with self:
-> 1158 return f(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/iteration.py:339, in vmap..vmap_wrapper(*args, **kwargs)
335 args = resolve_kwargs(f, args, kwargs)
336 pure_args = extract.to_tree(
337 args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='vmap'
338 )
--> 339 pure_args_out, pure_out = vmapped_fn(*pure_args)
340 _args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag='vmap')
341 return out
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/iteration.py:164, in VmapFn.call(self, *pure_args)
159 pure_args = _update_variable_sharding_metadata(
160 pure_args, self.transform_metadata, spmd.remove_axis
161 )
162 args = extract.from_tree(pure_args, ctxtag='vmap')
--> 164 out = self.f(*args)
166 args_out = extract.clear_non_graph_nodes(args)
167 pure_args_out, pure_out = extract.to_tree(
168 (args_out, out),
169 prefix=(self.in_axes, self.out_axes),
170 split_fn=_vmap_split_fn,
171 ctxtag='vmap',
172 )
Cell In[8], line 21
19 @partial(nnx.vmap, axis_size=5)
20 def create_model(rngs: nnx.Rngs):
---> 21 return MLP(10, 32, 10, rngs=rngs)
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:79, in ObjectMeta.call(cls, *args, **kwargs)
78 def call(cls, *args: Any, **kwargs: Any) -> Any:
---> 79 return _graph_node_meta_call(cls, *args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:88, in _graph_node_meta_call(cls, *args, **kwargs)
86 node = cls.new(cls, *args, **kwargs)
87 vars(node)['_object__state'] = ObjectState()
---> 88 cls._object_meta_construct(node, *args, **kwargs)
90 return node
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:82, in ObjectMeta._object_meta_construct(cls, self, *args, **kwargs)
81 def _object_meta_construct(cls, self, *args, **kwargs):
---> 82 self.init(*args, **kwargs)
Cell In[8], line 8
7 def init(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
----> 8 self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
9 self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
10 self.bn = nnx.BatchNorm(dmid, rngs=rngs)
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:79, in ObjectMeta.call(cls, *args, **kwargs)
78 def call(cls, *args: Any, **kwargs: Any) -> Any:
---> 79 return _graph_node_meta_call(cls, *args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:88, in _graph_node_meta_call(cls, *args, **kwargs)
86 node = cls.new(cls, *args, **kwargs)
87 vars(node)['_object__state'] = ObjectState()
---> 88 cls._object_meta_construct(node, *args, **kwargs)
90 return node
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:82, in ObjectMeta._object_meta_construct(cls, self, *args, **kwargs)
81 def _object_meta_construct(cls, self, *args, **kwargs):
---> 82 self.init(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/nn/linear.py:346, in Linear.init(self, in_features, out_features, use_bias, dtype, param_dtype, precision, kernel_init, bias_init, dot_general, rngs)
332 def init(
333 self,
334 in_features: int,
(...)
344 rngs: rnglib.Rngs,
345 ):
--> 346 kernel_key = rngs.params()
347 self.kernel = nnx.Param(
348 kernel_init(kernel_key, (in_features, out_features), param_dtype)
349 )
350 if use_bias:
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/rnglib.py:84, in RngStream.call(self)
80 def call(self) -> jax.Array:
81 self.check_valid_context(
82 lambda: 'Cannot call RngStream from a different trace level'
83 )
---> 84 key = jax.random.fold_in(self.key.value, self.count.value)
85 self.count.value += 1
86 return key
File /usr/local/lib/python3.10/dist-packages/jax/_src/random.py:262, in fold_in(key, data)
251 def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
252 """Folds in data to a PRNG key to form a new PRNG key.
253
254 Args:
(...)
260 statistically safe for producing a stream of new pseudo-random values.
261 """
--> 262 key, wrapped = _check_prng_key("fold_in", key)
263 if np.ndim(data):
264 raise TypeError("fold_in accepts a scalar, but was given an array of"
265 f"shape {np.shape(data)} != (). Use jax.vmap for batching.")
File /usr/local/lib/python3.10/dist-packages/jax/_src/random.py:74, in _check_prng_key(name, key, allow_batched)
72 def _check_prng_key(name: str, key: KeyArrayLike, *,
73 allow_batched: bool = False) -> tuple[KeyArray, bool]:
---> 74 if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key):
75 wrapped_key = key
76 wrapped = False
File /usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/batching.py:346, in BatchTracer.aval(self)
344 return aval
345 elif type(self.batch_dim) is int:
--> 346 return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
347 elif type(self.batch_dim) is RaggedAxis:
348 new_aval = core.mapped_aval(
349 aval.shape[self.batch_dim.stacked_axis], self.batch_dim.stacked_axis, aval)
IndexError: tuple index out of range
Steps to reproduce:
The text was updated successfully, but these errors were encountered: