Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assert_is_replicated in Analytic policy gradients training #328

Open
wangyian-me opened this issue Apr 2, 2023 · 9 comments
Open

assert_is_replicated in Analytic policy gradients training #328

wangyian-me opened this issue Apr 2, 2023 · 9 comments

Comments

@wangyian-me
Copy link

wangyian-me commented Apr 2, 2023

When I try to use a 4-gpus machine to run the Analytic policy gradients training in parallel, it reports an AssertionError in brax/training/agents/apg/train.py line 255. Seems that it is because training_state becomes different on the devices while it should be replicated.

I only make minimum change according to the example training code.

import functools

from datetime import datetime
# from brax.training.agents.apg.train import train as apgtrain
from train import train as apgtrain
from brax import envs

env_name = 'humanoidstandup'  # @param ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d']
backend = 'generalized'  # @param ['generalized', 'positional', 'spring']

env = envs.get_environment(env_name=env_name,
                           backend=backend)

train_fn = {
  'humanoidstandup': functools.partial(apgtrain, episode_length=320,
          action_repeat=1,
          num_envs=16,
          num_eval_envs=4,
          learning_rate = 1e-4,
          seed = 0,
          max_gradient_norm = 1e8,
          num_evals = 10,
          normalize_observations = True,
          deterministic_eval = False)
}[env_name]

xdata, ydata = [], []
times = [datetime.now()]

def progress(num_steps, metrics):
  times.append(datetime.now())
  print(num_steps, metrics['eval/episode_reward'])

print("begin")

make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)

print("end")

To make the error comes sooner, I add pmap.assert_is_replicated(training_state) in the iteration of brax/training/agents/apg/train.py.

  for it in range(num_evals_after_init):
    logging.info('starting iteration %s %s', it, time.time() - xt)

    # optimization
    epoch_key, local_key = jax.random.split(local_key)
    epoch_keys = jax.random.split(epoch_key, local_devices_to_use)
    (training_state,
     training_metrics) = training_epoch_with_timing(training_state, epoch_keys)
    ######################## I add it here #############################
    pmap.assert_is_replicated(training_state)
    ####################################################################
    if process_id == 0:
      # Run evals.
      metrics = evaluator.run_evaluation(
          _unpmap(
              (training_state.normalizer_params, training_state.policy_params)),
          training_metrics)
      logging.info(metrics)
      progress_fn(it + 1, metrics)

And the full output is:

begin
0 2238.8042
1 2367.4116
Traceback (most recent call last):
  File "xxxxx.py", line 36, in <module>
    make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)
  File "/home/vipuser/playbrax/train.py", line 227, in train
    pmap.assert_is_replicated(training_state)
  File "/home/vipuser/playbrax/brax/brax/training/pmap.py", line 70, in assert_is_replicated
    assert jax.pmap(f, axis_name='i')(x)[0], debug
AssertionError: None

If I use from brax.training.agents.apg.train import train as apgtrain, the full output will become:


begin                                                                                                                   
0 2233.8481
1 2273.3516
2 2460.1377
3 2319.5432
4 2250.9502
5 2289.2446                                                                                                             
6 nan
7 nan
8 nan
9 nan
Traceback (most recent call last):
  File "xxxxx.py", line 36, in <module>
    make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)
  File "/home/vipuser/playbrax/train.py", line 227, in train
    pmap.assert_is_replicated(training_state)
  File "/home/vipuser/playbrax/brax/brax/training/pmap.py", line 255, in assert_is_replicated
    assert jax.pmap(f, axis_name='i')(x)[0], debug
AssertionError: None                                                                                                                   

@wangyian-me
Copy link
Author

wangyian-me commented Apr 2, 2023

It would report the same bug even when I use only one gpu.
Also, I got this warning /home/vipuser/miniconda3/envs/brax/lib/python3.8/site-packages/flax/core/frozen_dict.py:169: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use register_pytree_with_keys() instead. jax.tree_util.register_keypaths( . I don't know if it is relevant.

@wangyian-me
Copy link
Author

I just realized that it might because some elements are nan and nan == nan is false. Then the replicated judgement might return false.

@btaba
Copy link
Collaborator

btaba commented Apr 3, 2023

Hi @wangyian-me, indeed when assert_is_replicated fails it's usually because of a NaN in training. So it looks like humanoidstandup trained with APG causes a NaN?

@wangyian-me
Copy link
Author

Yeah, it'll happen when I use the "generalized" backend. I've also tried to use the "positional" backend, which will work without this bug.

@wangyian-me
Copy link
Author

wangyian-me commented Apr 6, 2023

Also, I've tried to locate the position where NaN is made. It's after this line. So, I guess the gradient might explode in the back propagation process with "generalized" backend. @btaba

@queenxy
Copy link

queenxy commented Apr 12, 2023

I meet the same question when I use ppo, there is NaNs but i don't know how to locate it. could you please help me?

@btaba
Copy link
Collaborator

btaba commented Apr 13, 2023

@queenxy are you getting NaNs on humanoidstandup with PPO with the generalized backend (and on which device)? Afaik this was tested on TPU, but would be good to know
Thanks @wangyian-me for confirming, we'll have to debug. But if you have some time, feel free to dig deeper

@queenxy
Copy link

queenxy commented Apr 14, 2023

I am getting NaNs on my own environment with the PPO provided by brax. The device is GPU (both multi and single will lead to this question). I have checked my environment but there seems to be nothing wrong. So I am trying to locate whether the NaN is made in PPO. @btaba

@i1Cps
Copy link

i1Cps commented Aug 1, 2024

any further movements on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants