diff --git a/README.md b/README.md index 077f7ec8d..288df8298 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ This is a minimalist refactoring of the original `gym-pybullet-drones` repositor > **NOTE**: if you prefer to access the original codebase, presented at IROS in 2021, please `git checkout [paper|master]` after cloning the repo, and refer to the corresponding `README.md`'s. -formation flight control info +formation flight control info ## Installation @@ -39,14 +39,15 @@ cd gym_pybullet_drones/examples/ python3 downwash.py ``` -### Reinforcement learning 10'-training example (SB3's PPO) +### Reinforcement learning examples (SB3's PPO) ```sh cd gym_pybullet_drones/examples/ python learn.py # task: single drone hover at z == 1.0 +python learn.py --multiagent true # task: 2-drone hover at z == 1.2 and 0.7 ``` -rl example +rl example marl example ### Betaflight SITL example (Ubuntu only) diff --git a/gym_pybullet_drones/assets/marl.gif b/gym_pybullet_drones/assets/marl.gif new file mode 100644 index 000000000..5c7b23db4 Binary files /dev/null and b/gym_pybullet_drones/assets/marl.gif differ diff --git a/gym_pybullet_drones/assets/rl.gif b/gym_pybullet_drones/assets/rl.gif index efbace3a1..a967a177c 100644 Binary files a/gym_pybullet_drones/assets/rl.gif and b/gym_pybullet_drones/assets/rl.gif differ diff --git a/gym_pybullet_drones/envs/MultiHoverAviary.py b/gym_pybullet_drones/envs/MultiHoverAviary.py index f1616cf5f..6a8db25d5 100644 --- a/gym_pybullet_drones/envs/MultiHoverAviary.py +++ b/gym_pybullet_drones/envs/MultiHoverAviary.py @@ -68,7 +68,6 @@ def __init__(self, obs=obs, act=act ) - self.TARGET_POS = self.INIT_XYZS + np.array([[0,0,1/(i+1)] for i in range(num_drones)]) ################################################################################ @@ -119,6 +118,12 @@ def _computeTruncated(self): Whether the current episode timed out. """ + states = np.array([self._getDroneStateVector(i) for i in range(self.NUM_DRONES)]) + for i in range(self.NUM_DRONES): + if (abs(states[i][0]) > 2.0 or abs(states[i][1]) > 2.0 or states[i][2] > 2.0 # Truncate when a drones is too far away + or abs(states[i][7]) > .5 or abs(states[i][8]) > .5 # Truncate when a drone is too tilted + ): + return True if self.step_counter/self.PYB_FREQ > self.EPISODE_LEN_SEC: return True else: diff --git a/gym_pybullet_drones/examples/learn.py b/gym_pybullet_drones/examples/learn.py index cac6f5793..5b585fc60 100644 --- a/gym_pybullet_drones/examples/learn.py +++ b/gym_pybullet_drones/examples/learn.py @@ -71,21 +71,20 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_ #### Train the model ####################################### model = PPO('MlpPolicy', train_env, - # policy_kwargs=dict(activation_fn=torch.nn.ReLU, net_arch=[512, 512, dict(vf=[256, 128], pi=[256, 128])]), # tensorboard_log=filename+'/tb/', verbose=1) - callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=np.inf, + callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=465 if not multiagent else 920, # reward thresholds for the 3D case, use 474 and 950 for the 1D case verbose=1) eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1, best_model_save_path=filename+'/', log_path=filename+'/', - eval_freq=int(2000), + eval_freq=int(1000), deterministic=True, render=False) - model.learn(total_timesteps=int(1e6) if local else int(1e2), # shorter training in GitHub Actions pytest + model.learn(total_timesteps=int(1e7) if local else int(1e2), # shorter training in GitHub Actions pytest callback=eval_callback, log_interval=100) @@ -107,7 +106,9 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_ if local: input("Press Enter to continue...") - if os.path.isfile(filename+'/best_model.zip'): + if os.path.isfile(filename+'/final_model.zip'): + path = filename+'/final_model.zip' + elif os.path.isfile(filename+'/best_model.zip'): path = filename+'/best_model.zip' else: print("[ERROR]: no model under the specified path", filename)