-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
About jax.PRNGKey:Error reporting when running #8
Comments
No idea about the second issue, but you can find others having the same issue online: kuixu/alphafold#8 |
The first issue might be because of some recent change to Flax. What version are you using? Can you try changing it to:
(ref: https://github.com/google/flax/blob/main/examples/imagenet/train.py#L74) The second issue is likely due to some problem during setup. Can you perhaps try asking in https://github.com/google/jax. |
The first issue is fixed for me by upgrading to the latest version of Flax. |
@woctezuma @Hyeonjin1989 @kohjingyu Thanks for your help. |
add
explanation may refer to gpu memory allocation. |
Excuse me. When I tried to run this code, I have a problem about this line:
xmcgan_image_generation/xmcgan/train_utils.py
Line 167 in 22a7ef2
and the error is flax.errors.InvalidRngError: rngs should be a dictionary mapping strings to
jax.PRNGKey
. Actually, the g_rng is a array of shape[2,].So anyone else can help me solve this problem?
By the way, I have configured cuda, but it still tell me cuda not found.
xla_bridge.py:232] Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
I even use tensorflow to test the gpu, which is right. I don't know what the problem is.
The text was updated successfully, but these errors were encountered: