diff --git a/gym_pybullet_drones/examples/learn.py b/gym_pybullet_drones/examples/learn.py index 40bd3f0ae..63809ab73 100644 --- a/gym_pybullet_drones/examples/learn.py +++ b/gym_pybullet_drones/examples/learn.py @@ -43,7 +43,7 @@ DEFAULT_AGENTS = 2 DEFAULT_MA = False -def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, colab=DEFAULT_COLAB, record_video=DEFAULT_RECORD_VIDEO): +def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, colab=DEFAULT_COLAB, record_video=DEFAULT_RECORD_VIDEO, local=True): filename = os.path.join(output_folder, 'save-'+datetime.now().strftime("%m.%d.%Y_%H.%M.%S")) if not os.path.exists(filename): @@ -85,7 +85,7 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_ eval_freq=int(2000), deterministic=True, render=False) - model.learn(total_timesteps=3*int(1e5), + model.learn(total_timesteps=3*int(1e5) if local else int(1e2), # shorter training in GitHub Actions pytest callback=eval_callback, log_interval=100) diff --git a/tests/test_examples.py b/tests/test_examples.py index 8df9b1ca0..3a04dd88e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -12,4 +12,4 @@ def test_downwash(): def test_learn(): from gym_pybullet_drones.examples.learn import run - run(gui=False, plot=False, output_folder='tmp') + run(gui=False, plot=False, output_folder='tmp', local=False)