Skip to content

Commit

Permalink
new gifs, trained examples
Browse files Browse the repository at this point in the history
  • Loading branch information
JacopoPan committed Dec 9, 2023
1 parent e83f827 commit 075eee7
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 9 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ This is a minimalist refactoring of the original `gym-pybullet-drones` repositor

> **NOTE**: if you prefer to access the original codebase, presented at IROS in 2021, please `git checkout [paper|master]` after cloning the repo, and refer to the corresponding `README.md`'s.
<img src="gym_pybullet_drones/assets/helix.gif" alt="formation flight" width="350"> <img src="gym_pybullet_drones/assets/helix.png" alt="control info" width="450">
<img src="gym_pybullet_drones/assets/helix.gif" alt="formation flight" width="325"> <img src="gym_pybullet_drones/assets/helix.png" alt="control info" width="425">

## Installation

Expand Down Expand Up @@ -39,14 +39,15 @@ cd gym_pybullet_drones/examples/
python3 downwash.py
```

### Reinforcement learning 10'-training example (SB3's PPO)
### Reinforcement learning examples (SB3's PPO)

```sh
cd gym_pybullet_drones/examples/
python learn.py # task: single drone hover at z == 1.0
python learn.py --multiagent true # task: 2-drone hover at z == 1.2 and 0.7
```

<img src="gym_pybullet_drones/assets/rl.gif" alt="rl example" width="800">
<img src="gym_pybullet_drones/assets/rl.gif" alt="rl example" width="375"> <img src="gym_pybullet_drones/assets/marl.gif" alt="marl example" width="375">

### Betaflight SITL example (Ubuntu only)

Expand Down
Binary file added gym_pybullet_drones/assets/marl.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified gym_pybullet_drones/assets/rl.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 6 additions & 1 deletion gym_pybullet_drones/envs/MultiHoverAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __init__(self,
obs=obs,
act=act
)

self.TARGET_POS = self.INIT_XYZS + np.array([[0,0,1/(i+1)] for i in range(num_drones)])

################################################################################
Expand Down Expand Up @@ -119,6 +118,12 @@ def _computeTruncated(self):
Whether the current episode timed out.
"""
states = np.array([self._getDroneStateVector(i) for i in range(self.NUM_DRONES)])
for i in range(self.NUM_DRONES):
if (abs(states[i][0]) > 2.0 or abs(states[i][1]) > 2.0 or states[i][2] > 2.0 # Truncate when a drones is too far away
or abs(states[i][7]) > .5 or abs(states[i][8]) > .5 # Truncate when a drone is too tilted
):
return True
if self.step_counter/self.PYB_FREQ > self.EPISODE_LEN_SEC:
return True
else:
Expand Down
11 changes: 6 additions & 5 deletions gym_pybullet_drones/examples/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,20 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_
#### Train the model #######################################
model = PPO('MlpPolicy',
train_env,
# policy_kwargs=dict(activation_fn=torch.nn.ReLU, net_arch=[512, 512, dict(vf=[256, 128], pi=[256, 128])]),
# tensorboard_log=filename+'/tb/',
verbose=1)

callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=np.inf,
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=465 if not multiagent else 920, # reward thresholds for the 3D case, use 474 and 950 for the 1D case
verbose=1)
eval_callback = EvalCallback(eval_env,
callback_on_new_best=callback_on_best,
verbose=1,
best_model_save_path=filename+'/',
log_path=filename+'/',
eval_freq=int(2000),
eval_freq=int(1000),
deterministic=True,
render=False)
model.learn(total_timesteps=int(1e6) if local else int(1e2), # shorter training in GitHub Actions pytest
model.learn(total_timesteps=int(1e7) if local else int(1e2), # shorter training in GitHub Actions pytest
callback=eval_callback,
log_interval=100)

Expand All @@ -107,7 +106,9 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_
if local:
input("Press Enter to continue...")

if os.path.isfile(filename+'/best_model.zip'):
if os.path.isfile(filename+'/final_model.zip'):
path = filename+'/final_model.zip'
elif os.path.isfile(filename+'/best_model.zip'):
path = filename+'/best_model.zip'
else:
print("[ERROR]: no model under the specified path", filename)
Expand Down

0 comments on commit 075eee7

Please sign in to comment.