Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNX Conv Layer Input Tuple Error #4295

Open
riverliway opened this issue Oct 14, 2024 · 1 comment
Open

NNX Conv Layer Input Tuple Error #4295

riverliway opened this issue Oct 14, 2024 · 1 comment
Assignees
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)

Comments

@riverliway
Copy link

riverliway commented Oct 14, 2024

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 22.04.3 LTS
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: 0.9.0, 0.4.33, 0.4.33
  • Python version: 3.10
  • GPU/TPU model and memory: Colab T4 GPU
  • CUDA version (if applicable): N/A

Problem you have encountered:

Using a tuple as the input and/or output in the NNX conv layer results in an error that the // operator cannot be used on a tuple.

What you expected to happen:

That creating a new nnx.Conv layer with a tuple as the input/output does not produce an error. Using a conv layer Conv(in_features=(x, x), out_features=(y, y)) on an input shape of (a, b, x, x) will produce an output with shape (a, b, y, y), as alluded to in the documentation.

Image

Logs, error messages, etc:

Image

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

from flax import nnx
import jax.numpy as jnp

rngs = nnx.Rngs(0)
x = jnp.ones((2, 3, 8, 8))
conv1 = nnx.Conv(in_features=(8, 8), out_features=(4, 4), kernel_size=(3,3), rngs=rngs)

https://colab.research.google.com/drive/1jIIowlJaQ-SyS59nfy-Wn8dYcgz4mUwH?usp=sharing

@riverliway riverliway changed the title NNX Conv Layer NNX Conv Layer Input Tuple Error Oct 14, 2024
@cgarciae
Copy link
Collaborator

@8bitmp3 can you please update the docstrings for nnx.Conv's in_features and out_features to specify that only ints are valid? I think we might have copied this from LinearGeneral but its obviously wrong.

@cgarciae cgarciae added the Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) label Oct 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)
Projects
None yet
Development

No branches or pull requests

3 participants