diff --git a/README.md b/README.md
index 077f7ec8d..288df8298 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@ This is a minimalist refactoring of the original `gym-pybullet-drones` repositor
> **NOTE**: if you prefer to access the original codebase, presented at IROS in 2021, please `git checkout [paper|master]` after cloning the repo, and refer to the corresponding `README.md`'s.
-
+
## Installation
@@ -39,14 +39,15 @@ cd gym_pybullet_drones/examples/
python3 downwash.py
```
-### Reinforcement learning 10'-training example (SB3's PPO)
+### Reinforcement learning examples (SB3's PPO)
```sh
cd gym_pybullet_drones/examples/
python learn.py # task: single drone hover at z == 1.0
+python learn.py --multiagent true # task: 2-drone hover at z == 1.2 and 0.7
```
-
+
### Betaflight SITL example (Ubuntu only)
diff --git a/gym_pybullet_drones/assets/marl.gif b/gym_pybullet_drones/assets/marl.gif
new file mode 100644
index 000000000..5c7b23db4
Binary files /dev/null and b/gym_pybullet_drones/assets/marl.gif differ
diff --git a/gym_pybullet_drones/assets/rl.gif b/gym_pybullet_drones/assets/rl.gif
index efbace3a1..a967a177c 100644
Binary files a/gym_pybullet_drones/assets/rl.gif and b/gym_pybullet_drones/assets/rl.gif differ
diff --git a/gym_pybullet_drones/envs/MultiHoverAviary.py b/gym_pybullet_drones/envs/MultiHoverAviary.py
index f1616cf5f..6a8db25d5 100644
--- a/gym_pybullet_drones/envs/MultiHoverAviary.py
+++ b/gym_pybullet_drones/envs/MultiHoverAviary.py
@@ -68,7 +68,6 @@ def __init__(self,
obs=obs,
act=act
)
-
self.TARGET_POS = self.INIT_XYZS + np.array([[0,0,1/(i+1)] for i in range(num_drones)])
################################################################################
@@ -119,6 +118,12 @@ def _computeTruncated(self):
Whether the current episode timed out.
"""
+ states = np.array([self._getDroneStateVector(i) for i in range(self.NUM_DRONES)])
+ for i in range(self.NUM_DRONES):
+ if (abs(states[i][0]) > 2.0 or abs(states[i][1]) > 2.0 or states[i][2] > 2.0 # Truncate when a drones is too far away
+ or abs(states[i][7]) > .5 or abs(states[i][8]) > .5 # Truncate when a drone is too tilted
+ ):
+ return True
if self.step_counter/self.PYB_FREQ > self.EPISODE_LEN_SEC:
return True
else:
diff --git a/gym_pybullet_drones/examples/learn.py b/gym_pybullet_drones/examples/learn.py
index cac6f5793..5b585fc60 100644
--- a/gym_pybullet_drones/examples/learn.py
+++ b/gym_pybullet_drones/examples/learn.py
@@ -71,21 +71,20 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_
#### Train the model #######################################
model = PPO('MlpPolicy',
train_env,
- # policy_kwargs=dict(activation_fn=torch.nn.ReLU, net_arch=[512, 512, dict(vf=[256, 128], pi=[256, 128])]),
# tensorboard_log=filename+'/tb/',
verbose=1)
- callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=np.inf,
+ callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=465 if not multiagent else 920, # reward thresholds for the 3D case, use 474 and 950 for the 1D case
verbose=1)
eval_callback = EvalCallback(eval_env,
callback_on_new_best=callback_on_best,
verbose=1,
best_model_save_path=filename+'/',
log_path=filename+'/',
- eval_freq=int(2000),
+ eval_freq=int(1000),
deterministic=True,
render=False)
- model.learn(total_timesteps=int(1e6) if local else int(1e2), # shorter training in GitHub Actions pytest
+ model.learn(total_timesteps=int(1e7) if local else int(1e2), # shorter training in GitHub Actions pytest
callback=eval_callback,
log_interval=100)
@@ -107,7 +106,9 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_
if local:
input("Press Enter to continue...")
- if os.path.isfile(filename+'/best_model.zip'):
+ if os.path.isfile(filename+'/final_model.zip'):
+ path = filename+'/final_model.zip'
+ elif os.path.isfile(filename+'/best_model.zip'):
path = filename+'/best_model.zip'
else:
print("[ERROR]: no model under the specified path", filename)