You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 22.04.3 LTS
Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: 0.9.0, 0.4.33, 0.4.33
Python version: 3.10
GPU/TPU model and memory: Colab T4 GPU
CUDA version (if applicable): N/A
Problem you have encountered:
Using a tuple as the input and/or output in the NNX conv layer results in an error that the // operator cannot be used on a tuple.
What you expected to happen:
That creating a new nnx.Conv layer with a tuple as the input/output does not produce an error. Using a conv layer Conv(in_features=(x, x), out_features=(y, y)) on an input shape of (a, b, x, x) will produce an output with shape (a, b, y, y), as alluded to in the documentation.
Logs, error messages, etc:
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
from flax import nnx
import jax.numpy as jnp
rngs = nnx.Rngs(0)
x = jnp.ones((2, 3, 8, 8))
conv1 = nnx.Conv(in_features=(8, 8), out_features=(4, 4), kernel_size=(3,3), rngs=rngs)
@8bitmp3 can you please update the docstrings for nnx.Conv's in_features and out_features to specify that only ints are valid? I think we might have copied this from LinearGeneral but its obviously wrong.
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
pip show flax jax jaxlib
: 0.9.0, 0.4.33, 0.4.33Problem you have encountered:
Using a tuple as the input and/or output in the NNX conv layer results in an error that the
//
operator cannot be used on a tuple.What you expected to happen:
That creating a new
nnx.Conv
layer with a tuple as the input/output does not produce an error. Using a conv layerConv(in_features=(x, x), out_features=(y, y))
on an input shape of(a, b, x, x)
will produce an output with shape(a, b, y, y)
, as alluded to in the documentation.Logs, error messages, etc:
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
https://colab.research.google.com/drive/1jIIowlJaQ-SyS59nfy-Wn8dYcgz4mUwH?usp=sharing
The text was updated successfully, but these errors were encountered: