-
Notifications
You must be signed in to change notification settings - Fork 266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
assert_is_replicated in Analytic policy gradients training #328
Comments
It would report the same bug even when I use only one gpu. |
I just realized that it might because some elements are nan and |
Hi @wangyian-me, indeed when assert_is_replicated fails it's usually because of a NaN in training. So it looks like humanoidstandup trained with APG causes a NaN? |
Yeah, it'll happen when I use the "generalized" backend. I've also tried to use the "positional" backend, which will work without this bug. |
I meet the same question when I use ppo, there is NaNs but i don't know how to locate it. could you please help me? |
@queenxy are you getting NaNs on humanoidstandup with PPO with the generalized backend (and on which device)? Afaik this was tested on TPU, but would be good to know |
I am getting NaNs on my own environment with the PPO provided by brax. The device is GPU (both multi and single will lead to this question). I have checked my environment but there seems to be nothing wrong. So I am trying to locate whether the NaN is made in PPO. @btaba |
any further movements on this? |
When I try to use a 4-gpus machine to run the Analytic policy gradients training in parallel, it reports an AssertionError in
brax/training/agents/apg/train.py
line 255. Seems that it is becausetraining_state
becomes different on the devices while it should be replicated.I only make minimum change according to the example training code.
To make the error comes sooner, I add
pmap.assert_is_replicated(training_state)
in the iteration ofbrax/training/agents/apg/train.py
.And the full output is:
If I use
from brax.training.agents.apg.train import train as apgtrain
, the full output will become:The text was updated successfully, but these errors were encountered: