Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About jax.PRNGKey:Error reporting when running #8

Open
euyy opened this issue Oct 16, 2021 · 5 comments
Open

About jax.PRNGKey:Error reporting when running #8

euyy opened this issue Oct 16, 2021 · 5 comments

Comments

@euyy
Copy link

euyy commented Oct 16, 2021

Excuse me. When I tried to run this code, I have a problem about this line:

generator_variables = generator(train=False).init(g_rng, (inputs, z))

and the error is flax.errors.InvalidRngError: rngs should be a dictionary mapping strings to jax.PRNGKey. Actually, the g_rng is a array of shape[2,].
So anyone else can help me solve this problem?

By the way, I have configured cuda, but it still tell me cuda not found.
xla_bridge.py:232] Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
I even use tensorflow to test the gpu, which is right. I don't know what the problem is.

@woctezuma
Copy link

No idea about the second issue, but you can find others having the same issue online: kuixu/alphafold#8

@kohjingyu
Copy link
Contributor

The first issue might be because of some recent change to Flax. What version are you using? Can you try changing it to:

generator_variables = generator(train=False).init({'params': g_rng}, (inputs, z))

(ref: https://github.com/google/flax/blob/main/examples/imagenet/train.py#L74)

The second issue is likely due to some problem during setup. Can you perhaps try asking in https://github.com/google/jax.

@hyeonjinXZ
Copy link

hyeonjinXZ commented Oct 18, 2021

The first issue is fixed for me by upgrading to the latest version of Flax.
pip install --upgrade git+https://github.com/google/flax.git
(ref: https://pythonrepo.com/repo/google-flax-python-deep-learning)

@euyy
Copy link
Author

euyy commented Oct 20, 2021

@woctezuma @Hyeonjin1989 @kohjingyu Thanks for your help.
But now I have a new problem.
UNKNOWN: Failed to determine best cudnn convolution algorithm: UNKNOWN: GetConvolveAlgorithms failed.
Convolution performance may be suboptimal. To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning.
I don't know if it's the error caused by my device. So I want to know if there are minimum configuration requirements for training. If anyone knows about it, please tell me. Thanks.

@adambot806
Copy link

add

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = '.7'

explanation may refer to gpu memory allocation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants