diff --git a/README.md b/README.md index 142e827a1..54bdbe960 100644 --- a/README.md +++ b/README.md @@ -39,14 +39,15 @@ cd gym_pybullet_drones/examples/ python3 downwash.py ``` -### Reinforcement learning 3'-training examples (SB3's PPO) +### Reinforcement learning 10'-training example (SB3's PPO) ```sh cd gym_pybullet_drones/examples/ python learn.py # task: single drone hover at z == 1.0 -python learn.py --multiagent true # task: 2-drone hover at z == 1.2 and 0.7 ``` +rl example + ### Betaflight SITL example (Ubuntu only) ```sh diff --git a/gym_pybullet_drones/assets/rl.gif b/gym_pybullet_drones/assets/rl.gif new file mode 100644 index 000000000..a45e7d573 Binary files /dev/null and b/gym_pybullet_drones/assets/rl.gif differ diff --git a/gym_pybullet_drones/envs/HoverAviary.py b/gym_pybullet_drones/envs/HoverAviary.py index 165b35c32..449a37a57 100644 --- a/gym_pybullet_drones/envs/HoverAviary.py +++ b/gym_pybullet_drones/envs/HoverAviary.py @@ -106,6 +106,11 @@ def _computeTruncated(self): Whether the current episode timed out. """ + state = self._getDroneStateVector(0) + if (abs(state[0]) > 2.0 or abs(state[1]) > 2.0 or state[2] > 2.0 # Truncate when the drone is too far away + or abs(state[7]) > .5 or abs(state[8]) > .5 # Truncate when the drone is too tilted + ): + return True if self.step_counter/self.PYB_FREQ > self.EPISODE_LEN_SEC: return True else: diff --git a/gym_pybullet_drones/examples/learn.py b/gym_pybullet_drones/examples/learn.py index 6116d91dc..cac6f5793 100644 --- a/gym_pybullet_drones/examples/learn.py +++ b/gym_pybullet_drones/examples/learn.py @@ -39,7 +39,7 @@ DEFAULT_COLAB = False DEFAULT_OBS = ObservationType('kin') # 'kin' or 'rgb' -DEFAULT_ACT = ActionType('one_d_rpm') # 'rpm' or 'pid' or 'vel' or 'one_d_rpm' or 'one_d_pid' +DEFAULT_ACT = ActionType('rpm') # 'rpm' or 'pid' or 'vel' or 'one_d_rpm' or 'one_d_pid' DEFAULT_AGENTS = 2 DEFAULT_MA = False @@ -85,7 +85,7 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_ eval_freq=int(2000), deterministic=True, render=False) - model.learn(total_timesteps=3*int(1e5) if local else int(1e2), # shorter training in GitHub Actions pytest + model.learn(total_timesteps=int(1e6) if local else int(1e2), # shorter training in GitHub Actions pytest callback=eval_callback, log_interval=100)