Skip to content

Commit

Permalink
Move examples to package.
Browse files Browse the repository at this point in the history
Added unit test script for build testing.
Added bash file for building and testing package.
Changed paths to be user controlled, results are not to be saved in package path.
  • Loading branch information
spencerteetaert committed Jun 9, 2022
1 parent 28bafb3 commit f747c9a
Show file tree
Hide file tree
Showing 12 changed files with 385 additions and 179 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ examples/test.py
# NumPy saves and videos
files/logs/*.npy
files/videos/*.mp4
results/

# Learning results
experiments/learning/results/save-*
Expand Down
7 changes: 7 additions & 0 deletions build_project.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
echo "Y" | pip uninstall gym_pybullet_drones
rm -rf dist/
poetry build
pip install dist/gym_pybullet_drones-1.0.0-py3-none-any.whl
cd tests
python test_build.py
cd ..
File renamed without changes.
53 changes: 36 additions & 17 deletions examples/compare.py → gym_pybullet_drones/examples/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
The comparison is along a 2D trajectory in the X-Z plane, between x == +1 and -1.
"""
from filecmp import DEFAULT_IGNORES
import os
import time
import argparse
Expand All @@ -25,18 +26,22 @@
from gym_pybullet_drones.control.DSLPIDControl import DSLPIDControl
from gym_pybullet_drones.utils.Logger import Logger

if __name__ == "__main__":

#### Define and parse (optional) arguments for the script ##
parser = argparse.ArgumentParser(description='Trace comparison script using CtrlAviary and DSLPIDControl')
parser.add_argument('--physics', default="pyb", type=Physics, help='Physics updates (default: PYB)', metavar='', choices=Physics)
parser.add_argument('--gui', default=False, type=str2bool, help='Whether to use PyBullet GUI (default: False)', metavar='')
parser.add_argument('--record_video', default=False, type=str2bool, help='Whether to record a video (default: False)', metavar='')
parser.add_argument('--trace_file', default="example_trace.pkl", type=str, help='Pickle file with the trace to compare to (default: "example_trace.pkl")', metavar='')
ARGS = parser.parse_args()

DEFAULT_PHYICS = Physics('pyb')
DEFAULT_GUI = False
DEFAULT_RECORD_VIDEO = False
DEFAULT_TRACE_FILE = os.path.dirname(os.path.abspath(__file__))+"/../assets/example_trace.pkl"
DEFAULT_OUTPUT_FOLDER = 'results'

def run(
physics=DEFAULT_PHYICS,
gui=DEFAULT_GUI,
record_video=DEFAULT_RECORD_VIDEO,
trace_file=DEFAULT_TRACE_FILE,
output_folder=DEFAULT_OUTPUT_FOLDER,
plot=True
):
#### Load a trace and control reference from a .pkl file ###
with open(os.path.dirname(os.path.abspath(__file__))+"/../files/"+ARGS.trace_file, 'rb') as in_file:
with open(trace_file, 'rb') as in_file:
TRACE_TIMESTAMPS, TRACE_DATA, TRACE_CTRL_REFERENCE, _, _, _ = pickle.load(in_file)

#### Compute trace's parameters ############################
Expand All @@ -47,10 +52,10 @@
env = CtrlAviary(drone_model=DroneModel.CF2X,
num_drones=1,
initial_xyzs=np.array([0, 0, .1]).reshape(1, 3),
physics=ARGS.physics,
physics=physics,
freq=SIMULATION_FREQ_HZ,
gui=ARGS.gui,
record=ARGS.record_video,
gui=gui,
record=record_video,
obstacles=False
)
INITIAL_STATE = env.reset()
Expand All @@ -63,7 +68,8 @@
#### Initialize the logger #################################
logger = Logger(logging_freq_hz=SIMULATION_FREQ_HZ,
num_drones=2,
duration_sec=DURATION_SEC
duration_sec=DURATION_SEC,
output_folder=output_folder
)

#### Initialize the controller #############################
Expand Down Expand Up @@ -105,7 +111,7 @@
env.render()

#### Sync the simulation ###################################
if ARGS.gui:
if gui:
sync(i, START, env.TIMESTEP)

#### Close the environment #################################
Expand All @@ -115,4 +121,17 @@
logger.save()

#### Plot the simulation results ###########################
logger.plot(pwm=True)
if plot:
logger.plot(pwm=True)

if __name__ == "__main__":
#### Define and parse (optional) arguments for the script ##
parser = argparse.ArgumentParser(description='Trace comparison script using CtrlAviary and DSLPIDControl')
parser.add_argument('--physics', default=DEFAULT_PHYICS, type=Physics, help='Physics updates (default: PYB)', metavar='', choices=Physics)
parser.add_argument('--gui', default=DEFAULT_GUI, type=str2bool, help='Whether to use PyBullet GUI (default: False)', metavar='')
parser.add_argument('--record_video', default=DEFAULT_RECORD_VIDEO, type=str2bool, help='Whether to record a video (default: False)', metavar='')
parser.add_argument('--trace_file', default=DEFAULT_TRACE_FILE, type=str, help='Pickle file with the trace to compare to (default: "example_trace.pkl")', metavar='')
parser.add_argument('--output_folder', default=DEFAULT_OUTPUT_FOLDER, type=str, help='Folder where to save logs (default: "results")', metavar='')
ARGS = parser.parse_args()

run(**vars(ARGS))
77 changes: 51 additions & 26 deletions examples/downwash.py → gym_pybullet_drones/examples/downwash.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,56 +21,64 @@
from gym_pybullet_drones.control.DSLPIDControl import DSLPIDControl
from gym_pybullet_drones.utils.Logger import Logger

if __name__ == "__main__":

#### Define and parse (optional) arguments for the script ##
parser = argparse.ArgumentParser(description='Downwash example script using CtrlAviary and DSLPIDControl')
parser.add_argument('--drone', default="cf2x", type=DroneModel, help='Drone model (default: CF2X)', metavar='', choices=DroneModel)
parser.add_argument('--gui', default=True, type=str2bool, help='Whether to use PyBullet GUI (default: True)', metavar='')
parser.add_argument('--record_video', default=False, type=str2bool, help='Whether to record a video (default: False)', metavar='')
parser.add_argument('--simulation_freq_hz', default=240, type=int, help='Simulation frequency in Hz (default: 240)', metavar='')
parser.add_argument('--control_freq_hz', default=48, type=int, help='Control frequency in Hz (default: 48)', metavar='')
parser.add_argument('--aggregate', default=True, type=str2bool, help='Whether to aggregate physics steps (default: False)', metavar='')
parser.add_argument('--duration_sec', default=12, type=int, help='Duration of the simulation in seconds (default: 10)', metavar='')
ARGS = parser.parse_args()

DEFAULT_DRONE = DroneModel('cf2x')
DEFAULT_GUI = True
DEFAULT_RECORD_VIDEO = False
DEFAULT_SIMULATION_FREQ_HZ = 240
DEFAULT_CONTROL_FREQ_HZ = 48
DEFAULT_AGGREGATE = True
DEFAULT_DURATION_SEC = 12
DEFAULT_OUTPUT_FOLDER = 'results'

def run(
drone=DEFAULT_DRONE,
gui=DEFAULT_GUI,
record_video=DEFAULT_RECORD_VIDEO,
simulation_freq_hz=DEFAULT_SIMULATION_FREQ_HZ,
control_freq_hz=DEFAULT_CONTROL_FREQ_HZ,
aggregate=DEFAULT_AGGREGATE,
duration_sec=DEFAULT_DURATION_SEC,
output_folder=DEFAULT_OUTPUT_FOLDER,
plot=True
):
#### Initialize the simulation #############################
INIT_XYZS = np.array([[.5, 0, 1],[-.5, 0, .5]])
AGGR_PHY_STEPS = int(ARGS.simulation_freq_hz/ARGS.control_freq_hz) if ARGS.aggregate else 1
env = CtrlAviary(drone_model=ARGS.drone,
AGGR_PHY_STEPS = int(simulation_freq_hz/control_freq_hz) if aggregate else 1
env = CtrlAviary(drone_model=drone,
num_drones=2,
initial_xyzs=INIT_XYZS,
physics=Physics.PYB_DW,
neighbourhood_radius=10,
freq=ARGS.simulation_freq_hz,
freq=simulation_freq_hz,
aggregate_phy_steps=AGGR_PHY_STEPS,
gui=ARGS.gui,
record=ARGS.record_video,
gui=gui,
record=record_video,
obstacles=True
)

#### Initialize the trajectories ###########################
PERIOD = 5
NUM_WP = ARGS.control_freq_hz*PERIOD
NUM_WP = control_freq_hz*PERIOD
TARGET_POS = np.zeros((NUM_WP, 2))
for i in range(NUM_WP):
TARGET_POS[i, :] = [0.5*np.cos(2*np.pi*(i/NUM_WP)), 0]
wp_counters = np.array([0, int(NUM_WP/2)])

#### Initialize the logger #################################
logger = Logger(logging_freq_hz=int(ARGS.simulation_freq_hz/AGGR_PHY_STEPS),
logger = Logger(logging_freq_hz=int(simulation_freq_hz/AGGR_PHY_STEPS),
num_drones=2,
duration_sec=ARGS.duration_sec
duration_sec=duration_sec,
output_folder=output_folder
)

#### Initialize the controllers ############################
ctrl = [DSLPIDControl(drone_model=ARGS.drone) for i in range(2)]
ctrl = [DSLPIDControl(drone_model=drone) for i in range(2)]

#### Run the simulation ####################################
CTRL_EVERY_N_STEPS = int(np.floor(env.SIM_FREQ/ARGS.control_freq_hz))
CTRL_EVERY_N_STEPS = int(np.floor(env.SIM_FREQ/control_freq_hz))
action = {str(i): np.array([0, 0, 0, 0]) for i in range(2)}
START = time.time()
for i in range(0, int(ARGS.duration_sec*env.SIM_FREQ), AGGR_PHY_STEPS):
for i in range(0, int(duration_sec*env.SIM_FREQ), AGGR_PHY_STEPS):

#### Step the simulation ###################################
obs, reward, done, info = env.step(action)
Expand Down Expand Up @@ -102,7 +110,7 @@
env.render()

#### Sync the simulation ###################################
if ARGS.gui:
if gui:
sync(i, START, env.TIMESTEP)

#### Close the environment #################################
Expand All @@ -113,4 +121,21 @@
logger.save_as_csv("dw") # Optional CSV save

#### Plot the simulation results ###########################
logger.plot()
if plot:
logger.plot()


if __name__ == "__main__":
#### Define and parse (optional) arguments for the script ##
parser = argparse.ArgumentParser(description='Downwash example script using CtrlAviary and DSLPIDControl')
parser.add_argument('--drone', default=DEFAULT_DRONE, type=DroneModel, help='Drone model (default: CF2X)', metavar='', choices=DroneModel)
parser.add_argument('--gui', default=DEFAULT_GUI, type=str2bool, help='Whether to use PyBullet GUI (default: True)', metavar='')
parser.add_argument('--record_video', default=DEFAULT_RECORD_VIDEO, type=str2bool, help='Whether to record a video (default: False)', metavar='')
parser.add_argument('--simulation_freq_hz', default=DEFAULT_SIMULATION_FREQ_HZ, type=int, help='Simulation frequency in Hz (default: 240)', metavar='')
parser.add_argument('--control_freq_hz', default=DEFAULT_CONTROL_FREQ_HZ, type=int, help='Control frequency in Hz (default: 48)', metavar='')
parser.add_argument('--aggregate', default=DEFAULT_AGGREGATE, type=str2bool, help='Whether to aggregate physics steps (default: False)', metavar='')
parser.add_argument('--duration_sec', default=DEFAULT_DURATION_SEC, type=int, help='Duration of the simulation in seconds (default: 10)', metavar='')
parser.add_argument('--output_folder', default=DEFAULT_OUTPUT_FOLDER, type=str, help='Folder where to save logs (default: "results")', metavar='')
ARGS = parser.parse_args()

run(**vars(ARGS))
Loading

0 comments on commit f747c9a

Please sign in to comment.