Skip to content

Commit

Permalink
Multi hover example trained, other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JacopoPan committed Nov 26, 2023
1 parent c7bc897 commit 899d7b3
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 29 deletions.
17 changes: 13 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,27 @@ pip3 install -e . # if needed, `sudo apt install build-essential` to install `gc

## Use

### PID position control example
### PID control examples

```sh
cd gym_pybullet_drones/examples/
python3 pid.py
python3 pid.py # position and velocity reference
python3 pid_velocity.py # desired velocity reference
```

### Stable-baselines3 PPO RL example
### Downwash efect examples

```sh
cd gym_pybullet_drones/examples/
python3 learn.py
python3 downwash.py
```

### Stable-baselines3 PPO RL examples (3' training)

```sh
cd gym_pybullet_drones/examples/
python learn.py # task: single drone hover at z == 1
python learn.py --multiagent true # task: 2-drone hover at z == 1.2 and 0.7
```

### Betaflight SITL example (Ubuntu only)
Expand Down
4 changes: 2 additions & 2 deletions gym_pybullet_drones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
)

register(
id='leaderfollower-aviary-v0',
entry_point='gym_pybullet_drones.envs:LeaderFollowerAviary',
id='multihover-aviary-v0',
entry_point='gym_pybullet_drones.envs:MultiHoverAviary',
)
6 changes: 2 additions & 4 deletions gym_pybullet_drones/assets/clone_bfs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ cd ../../
mkdir betaflight_sitl/
cd betaflight_sitl/

# Step 1: Clone and open betaflight's source:
# Step 1: Clone and open betaflight's source (at the time of writing, branch `master`, future release 4.5)):
git clone https://github.com/betaflight/betaflight temp/


Expand All @@ -27,8 +27,6 @@ git clone https://github.com/betaflight/betaflight temp/
# from Betaflight's `SIMULATOR_BUILD`
cd temp/
sed -i "s/delayMicroseconds_real(50);/\/\/delayMicroseconds_real(50);/g" ./src/main/main.c
sed -i "s/ret = udpInit(\&stateLink, NULL, 9003, true);/\/\/ret = udpInit(\&stateLink, NULL, PORT_STATE, true);/g" ./src/main/target/SITL/sitl.c
sed -i "s/printf(\"start UDP server.../\/\/printf(\"start UDP server.../g" ./src/main/target/SITL/sitl.c

# Prepare
make arm_sdk_install
Expand All @@ -47,7 +45,7 @@ for ((i = 0; i < desired_max_num_drones; i++)); do
cp -r temp/ "bf${i}/"
cd "bf${i}/"

# Step 3: Change the UDP ports used by each Betaflight SITL instancet
# Step 3: Change the UDP ports used by each Betaflight SITL instance
replacement1="PORT_PWM_RAW 90${i}1"
sed -i "s/$pattern1/$replacement1/g" ./src/main/target/SITL/sitl.c
replacement2="PORT_PWM 90${i}2"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from gym_pybullet_drones.envs.BaseRLAviary import BaseRLAviary
from gym_pybullet_drones.utils.enums import DroneModel, Physics, ActionType, ObservationType

class LeaderFollowerAviary(BaseRLAviary):
class MultiHoverAviary(BaseRLAviary):
"""Multi-agent RL problem: leader-follower."""

################################################################################
Expand Down Expand Up @@ -54,7 +54,6 @@ def __init__(self,
The type of action space (1 or 3D; RPMS, thurst and torques, or waypoint with PID control)
"""
self.TARGET_POS = np.array([0,0,1])
self.EPISODE_LEN_SEC = 8
super().__init__(drone_model=drone_model,
num_drones=num_drones,
Expand All @@ -69,6 +68,8 @@ def __init__(self,
obs=obs,
act=act
)

self.TARGET_POS = self.INIT_XYZS + np.array([[0,0,1/(i+1)] for i in range(num_drones)])

################################################################################

Expand All @@ -82,9 +83,9 @@ def _computeReward(self):
"""
states = np.array([self._getDroneStateVector(i) for i in range(self.NUM_DRONES)])
ret = max(0, 2 - np.linalg.norm(self.TARGET_POS-states[0, 0:3])**4)
for i in range(1, self.NUM_DRONES):
ret += max(0, 2 - np.linalg.norm(states[i-1, 3]-states[i, 3])**4)
ret = 0
for i in range(self.NUM_DRONES):
ret += max(0, 2 - np.linalg.norm(self.TARGET_POS[i,:]-states[i][0:3])**4)
return ret

################################################################################
Expand All @@ -98,8 +99,11 @@ def _computeTerminated(self):
Whether the current episode is done.
"""
state = self._getDroneStateVector(0)
if np.linalg.norm(self.TARGET_POS-state[0:3]) < .001:
states = np.array([self._getDroneStateVector(i) for i in range(self.NUM_DRONES)])
dist = 0
for i in range(self.NUM_DRONES):
dist += np.linalg.norm(self.TARGET_POS[i,:]-states[i][0:3])
if dist < .0001:
return True
else:
return False
Expand Down
2 changes: 1 addition & 1 deletion gym_pybullet_drones/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from gym_pybullet_drones.envs.BetaAviary import BetaAviary
from gym_pybullet_drones.envs.CtrlAviary import CtrlAviary
from gym_pybullet_drones.envs.HoverAviary import HoverAviary
from gym_pybullet_drones.envs.LeaderFollowerAviary import LeaderFollowerAviary
from gym_pybullet_drones.envs.MultiHoverAviary import MultiHoverAviary
from gym_pybullet_drones.envs.VelocityAviary import VelocityAviary
18 changes: 7 additions & 11 deletions gym_pybullet_drones/examples/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from gym_pybullet_drones.utils.Logger import Logger
from gym_pybullet_drones.envs.HoverAviary import HoverAviary
from gym_pybullet_drones.envs.LeaderFollowerAviary import LeaderFollowerAviary
from gym_pybullet_drones.envs.MultiHoverAviary import MultiHoverAviary
from gym_pybullet_drones.utils.utils import sync, str2bool
from gym_pybullet_drones.utils.enums import ObservationType, ActionType

Expand All @@ -40,7 +40,7 @@

DEFAULT_OBS = ObservationType('kin') # 'kin' or 'rgb'
DEFAULT_ACT = ActionType('one_d_rpm') # 'rpm' or 'pid' or 'vel' or 'one_d_rpm' or 'one_d_pid'
DEFAULT_AGENTS = 3
DEFAULT_AGENTS = 2
DEFAULT_MA = False

def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, colab=DEFAULT_COLAB, record_video=DEFAULT_RECORD_VIDEO):
Expand All @@ -57,12 +57,12 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_
)
eval_env = HoverAviary(obs=DEFAULT_OBS, act=DEFAULT_ACT)
else:
train_env = make_vec_env(LeaderFollowerAviary,
train_env = make_vec_env(MultiHoverAviary,
env_kwargs=dict(num_drones=DEFAULT_AGENTS, obs=DEFAULT_OBS, act=DEFAULT_ACT),
n_envs=1,
seed=0
)
eval_env = LeaderFollowerAviary(num_drones=DEFAULT_AGENTS, obs=DEFAULT_OBS, act=DEFAULT_ACT)
eval_env = MultiHoverAviary(num_drones=DEFAULT_AGENTS, obs=DEFAULT_OBS, act=DEFAULT_ACT)

#### Check the environment's spaces ########################
print('[INFO] Action space:', train_env.action_space)
Expand All @@ -85,11 +85,7 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_
eval_freq=int(2000),
deterministic=True,
render=False)
if not multiagent:
steps = 2 * int(1e5)
else:
steps = int(1e4)
model.learn(total_timesteps=steps,
model.learn(total_timesteps=3*int(1e5),
callback=eval_callback,
log_interval=100)

Expand Down Expand Up @@ -124,12 +120,12 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_
record=record_video)
test_env_nogui = HoverAviary(obs=DEFAULT_OBS, act=DEFAULT_ACT)
else:
test_env = LeaderFollowerAviary(gui=gui,
test_env = MultiHoverAviary(gui=gui,
num_drones=DEFAULT_AGENTS,
obs=DEFAULT_OBS,
act=DEFAULT_ACT,
record=record_video)
test_env_nogui = LeaderFollowerAviary(num_drones=DEFAULT_AGENTS, obs=DEFAULT_OBS, act=DEFAULT_ACT)
test_env_nogui = MultiHoverAviary(num_drones=DEFAULT_AGENTS, obs=DEFAULT_OBS, act=DEFAULT_ACT)
logger = Logger(logging_freq_hz=int(test_env.CTRL_FREQ),
num_drones=DEFAULT_AGENTS if multiagent else 1,
output_folder=output_folder,
Expand Down

0 comments on commit 899d7b3

Please sign in to comment.