diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8394b8177..13075c111 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: - id: double-quote-string-fixer - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort name: isort @@ -45,7 +45,7 @@ repos: files: (^tests/|^safe_control_gym/math_and_models/transformations.py) - repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 + rev: 7.0.0 hooks: - id: flake8 name: flake8_default diff --git a/examples/mpc/config_overrides/cartpole/cartpole_stabilization.yaml b/examples/mpc/config_overrides/cartpole/cartpole_stabilization.yaml new file mode 100644 index 000000000..b942acb4a --- /dev/null +++ b/examples/mpc/config_overrides/cartpole/cartpole_stabilization.yaml @@ -0,0 +1,39 @@ +task_config: + seed: 42 + info_in_reset: True + ctrl_freq: 15 + pyb_freq: 750 + physics: pyb + + init_state_randomization_info: + init_x: + distrib: 'uniform' + low: -2 + high: 2 + init_x_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_theta: + distrib: 'uniform' + low: -0.2 + high: 0.2 + init_theta_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + + task: stabilization + task_info: + stabilization_goal: [1.0, 0.0] + stabilization_goal_tolerance: 0.0 + + episode_len_sec: 6 + cost: quadratic + done_on_out_of_bound: True + + constraints: + - constraint_form: default_constraint + constrained_variable: input + - constraint_form: default_constraint + constrained_variable: state diff --git a/examples/mpc/config_overrides/cartpole/cartpole_tracking.yaml b/examples/mpc/config_overrides/cartpole/cartpole_tracking.yaml new file mode 100644 index 000000000..37c07aa09 --- /dev/null +++ b/examples/mpc/config_overrides/cartpole/cartpole_tracking.yaml @@ -0,0 +1,42 @@ +task_config: + seed: 42 + info_in_reset: True + ctrl_freq: 15 + pyb_freq: 750 + physics: pyb + + init_state_randomization_info: + init_x: + distrib: 'uniform' + low: -1 + high: 1 + init_x_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_theta: + distrib: 'uniform' + low: -0.2 + high: 0.2 + init_theta_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + + task: traj_tracking + task_info: + trajectory_type: 'circle' + num_cycles: 2 + trajectory_plane: 'zx' + trajectory_position_offset: [0, 0] + trajectory_scale: 1 + + episode_len_sec: 6 + cost: quadratic + done_on_out_of_bound: True + + constraints: + - constraint_form: default_constraint + constrained_variable: input + - constraint_form: default_constraint + constrained_variable: state diff --git a/examples/mpc/config_overrides/cartpole/linear_mpc_cartpole_stabilization.yaml b/examples/mpc/config_overrides/cartpole/linear_mpc_cartpole_stabilization.yaml new file mode 100644 index 000000000..d32fe0779 --- /dev/null +++ b/examples/mpc/config_overrides/cartpole/linear_mpc_cartpole_stabilization.yaml @@ -0,0 +1,17 @@ +algo: linear_mpc +algo_config: + horizon: 20 + r_mpc: + - 0.1 + q_mpc: + - 5.0 + - 0.1 + - 5.0 + - 0.1 + # Prior info + prior_info: + prior_prop: null + randomize_prior_prop: False + prior_prop_rand_info: null + warmstart: True + solver: qrqp diff --git a/examples/mpc/config_overrides/cartpole/linear_mpc_cartpole_tracking.yaml b/examples/mpc/config_overrides/cartpole/linear_mpc_cartpole_tracking.yaml new file mode 100644 index 000000000..7e091ac13 --- /dev/null +++ b/examples/mpc/config_overrides/cartpole/linear_mpc_cartpole_tracking.yaml @@ -0,0 +1,17 @@ +algo: linear_mpc +algo_config: + horizon: 40 + r_mpc: + - 0.1 + q_mpc: + - 5.0 + - 0.1 + - 5.0 + - 0.1 + # Prior info + prior_info: + prior_prop: null + randomize_prior_prop: False + prior_prop_rand_info: null + warmstart: True + solver: qrqp diff --git a/examples/mpc/config_overrides/cartpole/mpc_cartpole_stabilization.yaml b/examples/mpc/config_overrides/cartpole/mpc_cartpole_stabilization.yaml new file mode 100644 index 000000000..9b52a28a7 --- /dev/null +++ b/examples/mpc/config_overrides/cartpole/mpc_cartpole_stabilization.yaml @@ -0,0 +1,17 @@ +algo: mpc +algo_config: + horizon: 20 + r_mpc: + - 0.1 + q_mpc: + - 5.0 + - 0.1 + - 5.0 + - 0.1 + # Prior info + prior_info: + prior_prop: null + randomize_prior_prop: False + prior_prop_rand_info: null + warmstart: True + solver: ipopt diff --git a/examples/mpc/config_overrides/cartpole/mpc_cartpole_tracking.yaml b/examples/mpc/config_overrides/cartpole/mpc_cartpole_tracking.yaml new file mode 100644 index 000000000..9b52a28a7 --- /dev/null +++ b/examples/mpc/config_overrides/cartpole/mpc_cartpole_tracking.yaml @@ -0,0 +1,17 @@ +algo: mpc +algo_config: + horizon: 20 + r_mpc: + - 0.1 + q_mpc: + - 5.0 + - 0.1 + - 5.0 + - 0.1 + # Prior info + prior_info: + prior_prop: null + randomize_prior_prop: False + prior_prop_rand_info: null + warmstart: True + solver: ipopt diff --git a/examples/mpc/config_overrides/quadrotor_2D/linear_mpc_quadrotor_2D_stabilization.yaml b/examples/mpc/config_overrides/quadrotor_2D/linear_mpc_quadrotor_2D_stabilization.yaml new file mode 100644 index 000000000..dbb8a288f --- /dev/null +++ b/examples/mpc/config_overrides/quadrotor_2D/linear_mpc_quadrotor_2D_stabilization.yaml @@ -0,0 +1,20 @@ +algo: linear_mpc +algo_config: + horizon: 20 + r_mpc: + - 0.1 + - 0.1 + q_mpc: + - 5.0 + - 0.1 + - 5.0 + - 0.1 + - 0.1 + - 0.1 + # Prior info + prior_info: + prior_prop: null + randomize_prior_prop: False + prior_prop_rand_info: null + warmstart: True + solver: qrqp diff --git a/examples/mpc/config_overrides/quadrotor_2D/linear_mpc_quadrotor_2D_tracking.yaml b/examples/mpc/config_overrides/quadrotor_2D/linear_mpc_quadrotor_2D_tracking.yaml new file mode 100644 index 000000000..ee1853730 --- /dev/null +++ b/examples/mpc/config_overrides/quadrotor_2D/linear_mpc_quadrotor_2D_tracking.yaml @@ -0,0 +1,20 @@ +algo: linear_mpc +algo_config: + horizon: 40 + r_mpc: + - 0.1 + - 0.1 + q_mpc: + - 1.0 + - 0.1 + - 1.0 + - 0.1 + - 0.1 + - 0.1 + # Prior info + prior_info: + prior_prop: null + randomize_prior_prop: False + prior_prop_rand_info: null + warmstart: True + solver: qrqp diff --git a/examples/mpc/config_overrides/quadrotor_2D/mpc_quadrotor_2D_stabilization.yaml b/examples/mpc/config_overrides/quadrotor_2D/mpc_quadrotor_2D_stabilization.yaml new file mode 100644 index 000000000..9e25529fc --- /dev/null +++ b/examples/mpc/config_overrides/quadrotor_2D/mpc_quadrotor_2D_stabilization.yaml @@ -0,0 +1,20 @@ +algo: mpc +algo_config: + horizon: 20 + r_mpc: + - 0.1 + - 0.1 + q_mpc: + - 5.0 + - 0.1 + - 5.0 + - 0.1 + - 0.1 + - 0.1 + # Prior info + prior_info: + prior_prop: null + randomize_prior_prop: False + prior_prop_rand_info: null + warmstart: True + solver: ipopt diff --git a/examples/mpc/config_overrides/quadrotor_2D/mpc_quadrotor_2D_tracking.yaml b/examples/mpc/config_overrides/quadrotor_2D/mpc_quadrotor_2D_tracking.yaml new file mode 100644 index 000000000..9e25529fc --- /dev/null +++ b/examples/mpc/config_overrides/quadrotor_2D/mpc_quadrotor_2D_tracking.yaml @@ -0,0 +1,20 @@ +algo: mpc +algo_config: + horizon: 20 + r_mpc: + - 0.1 + - 0.1 + q_mpc: + - 5.0 + - 0.1 + - 5.0 + - 0.1 + - 0.1 + - 0.1 + # Prior info + prior_info: + prior_prop: null + randomize_prior_prop: False + prior_prop_rand_info: null + warmstart: True + solver: ipopt diff --git a/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml b/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml new file mode 100644 index 000000000..494c9aefa --- /dev/null +++ b/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml @@ -0,0 +1,51 @@ +task_config: + seed: 1337 + info_in_reset: True + ctrl_freq: 50 + pyb_freq: 1000 + gui: False + physics: pyb + quad_type: 2 + + init_state_randomization_info: + init_x: + distrib: 'uniform' + low: -1 + high: 1 + init_x_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_z: + distrib: 'uniform' + low: 0.5 + high: 1.5 + init_z_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_theta: + distrib: 'uniform' + low: -0.2 + high: 0.2 + init_theta_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + randomized_init: True + randomized_inertial_prop: False + + task: stabilization + task_info: + stabilization_goal: [0, 1] + stabilization_goal_tolerance: 0.0 + + episode_len_sec: 6 + cost: quadratic + done_on_out_of_bound: True + + constraints: + - constraint_form: default_constraint + constrained_variable: input + - constraint_form: default_constraint + constrained_variable: state diff --git a/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml b/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml new file mode 100644 index 000000000..2405a2238 --- /dev/null +++ b/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml @@ -0,0 +1,54 @@ +task_config: + seed: 1337 + info_in_reset: True + ctrl_freq: 50 + pyb_freq: 1000 + gui: False + physics: pyb + quad_type: 2 + + init_state_randomization_info: + init_x: + distrib: 'uniform' + low: -1 + high: 1 + init_x_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_z: + distrib: 'uniform' + low: 0.5 + high: 1.5 + init_z_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_theta: + distrib: 'uniform' + low: -0.2 + high: 0.2 + init_theta_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + randomized_init: True + randomized_inertial_prop: False + + task: traj_tracking + task_info: + trajectory_type: figure8 + num_cycles: 1 + trajectory_plane: 'xz' + trajectory_position_offset: [0, 1] + trajectory_scale: 0.75 + + episode_len_sec: 6 + cost: quadratic + done_on_out_of_bound: True + + constraints: + - constraint_form: default_constraint + constrained_variable: input + - constraint_form: default_constraint + constrained_variable: state diff --git a/examples/mpc/config_overrides/quadrotor_3D/linear_mpc_quadrotor_3D_stabilization.yaml b/examples/mpc/config_overrides/quadrotor_3D/linear_mpc_quadrotor_3D_stabilization.yaml new file mode 100644 index 000000000..009f561b2 --- /dev/null +++ b/examples/mpc/config_overrides/quadrotor_3D/linear_mpc_quadrotor_3D_stabilization.yaml @@ -0,0 +1,28 @@ +algo: linear_mpc +algo_config: + horizon: 20 + r_mpc: + - 0.1 + - 0.1 + - 0.1 + - 0.1 + q_mpc: + - 5.0 + - 0.1 + - 5.0 + - 0.1 + - 5.0 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + # Prior info + prior_info: + prior_prop: null + randomize_prior_prop: False + prior_prop_rand_info: null + warmstart: True + solver: qrqp diff --git a/examples/mpc/config_overrides/quadrotor_3D/linear_mpc_quadrotor_3D_tracking.yaml b/examples/mpc/config_overrides/quadrotor_3D/linear_mpc_quadrotor_3D_tracking.yaml new file mode 100644 index 000000000..1102acce3 --- /dev/null +++ b/examples/mpc/config_overrides/quadrotor_3D/linear_mpc_quadrotor_3D_tracking.yaml @@ -0,0 +1,28 @@ +algo: linear_mpc +algo_config: + horizon: 20 + r_mpc: + - 0.1 + - 0.1 + - 0.1 + - 0.1 + q_mpc: + - 1.0 + - 0.1 + - 1.0 + - 0.1 + - 1.0 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + # Prior info + prior_info: + prior_prop: null + randomize_prior_prop: False + prior_prop_rand_info: null + warmstart: True + solver: qrqp diff --git a/examples/mpc/config_overrides/quadrotor_3D/mpc_quadrotor_3D_stabilization.yaml b/examples/mpc/config_overrides/quadrotor_3D/mpc_quadrotor_3D_stabilization.yaml new file mode 100644 index 000000000..359242b66 --- /dev/null +++ b/examples/mpc/config_overrides/quadrotor_3D/mpc_quadrotor_3D_stabilization.yaml @@ -0,0 +1,28 @@ +algo: mpc +algo_config: + horizon: 20 + r_mpc: + - 0.1 + - 0.1 + - 0.1 + - 0.1 + q_mpc: + - 5.0 + - 0.1 + - 5.0 + - 0.1 + - 5.0 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + # Prior info + prior_info: + prior_prop: null + randomize_prior_prop: False + prior_prop_rand_info: null + warmstart: True + solver: ipopt diff --git a/examples/mpc/config_overrides/quadrotor_3D/mpc_quadrotor_3D_tracking.yaml b/examples/mpc/config_overrides/quadrotor_3D/mpc_quadrotor_3D_tracking.yaml new file mode 100644 index 000000000..359242b66 --- /dev/null +++ b/examples/mpc/config_overrides/quadrotor_3D/mpc_quadrotor_3D_tracking.yaml @@ -0,0 +1,28 @@ +algo: mpc +algo_config: + horizon: 20 + r_mpc: + - 0.1 + - 0.1 + - 0.1 + - 0.1 + q_mpc: + - 5.0 + - 0.1 + - 5.0 + - 0.1 + - 5.0 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + - 0.1 + # Prior info + prior_info: + prior_prop: null + randomize_prior_prop: False + prior_prop_rand_info: null + warmstart: True + solver: ipopt diff --git a/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml b/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml new file mode 100644 index 000000000..4c652d130 --- /dev/null +++ b/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml @@ -0,0 +1,78 @@ +task_config: + seed: 1337 + info_in_reset: True + ctrl_freq: 50 + pyb_freq: 1000 + gui: False + physics: pyb + quad_type: 3 + + init_state_randomization_info: + init_x: + distrib: 'uniform' + low: -1 + high: 1 + init_x_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_y: + distrib: 'uniform' + low: -1 + high: 1 + init_y_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_z: + distrib: 'uniform' + low: 0.5 + high: 1.5 + init_z_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_phi: + distrib: 'uniform' + low: -0.2 + high: 0.2 + init_theta: + distrib: 'uniform' + low: -0.2 + high: 0.2 + init_psi: + distrib: 'uniform' + low: -0.2 + high: 0.2 + init_p: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_q: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_r: + distrib: 'uniform' + low: -0.1 + high: 0.1 + randomized_init: True + randomized_inertial_prop: False + + task: stabilization + task_info: + stabilization_goal: [0, 0, 1] + stabilization_goal_tolerance: 0.0 + proj_point: [0, 0, 0.5] + proj_normal: [0, 1, 1] + + episode_len_sec: 6 + cost: quadratic + done_on_out_of_bound: True + constraints: + - constraint_form: default_constraint + constrained_variable: input + - constraint_form: default_constraint + constrained_variable: state + done_on_violation: False + disturbances: null diff --git a/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml b/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml new file mode 100644 index 000000000..e4263c249 --- /dev/null +++ b/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml @@ -0,0 +1,80 @@ +task_config: + seed: 1337 + info_in_reset: True + ctrl_freq: 50 + pyb_freq: 1000 + gui: False + physics: pyb + quad_type: 3 + + init_state_randomization_info: + init_x: + distrib: 'uniform' + low: -1 + high: 1 + init_x_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_y: + distrib: 'uniform' + low: -1 + high: 1 + init_y_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_z: + distrib: 'uniform' + low: 0.5 + high: 1.5 + init_z_dot: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_phi: + distrib: 'uniform' + low: -0.2 + high: 0.2 + init_theta: + distrib: 'uniform' + low: -0.2 + high: 0.2 + init_psi: + distrib: 'uniform' + low: -0.2 + high: 0.2 + init_p: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_q: + distrib: 'uniform' + low: -0.1 + high: 0.1 + init_r: + distrib: 'uniform' + low: -0.1 + high: 0.1 + randomized_init: True + randomized_inertial_prop: False + + task: traj_tracking + task_info: + trajectory_type: figure8 + num_cycles: 1 + trajectory_plane: 'xz' + trajectory_position_offset: [0, 1] + trajectory_scale: 0.75 + proj_point: [0, 0, 0.5] + proj_normal: [0, 1, 1] + episode_len_sec: 6 + cost: quadratic + done_on_out_of_bound: True + constraints: + - constraint_form: default_constraint + constrained_variable: input + - constraint_form: default_constraint + constrained_variable: state + done_on_violation: False + disturbances: null diff --git a/examples/mpc/mpc_experiment.py b/examples/mpc/mpc_experiment.py new file mode 100644 index 000000000..27b8874b7 --- /dev/null +++ b/examples/mpc/mpc_experiment.py @@ -0,0 +1,151 @@ +'''An MPC and Linear MPC example.''' + +import os +import pickle +from collections import defaultdict +from functools import partial + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.ticker import FormatStrFormatter + +from safe_control_gym.envs.benchmark_env import Task +from safe_control_gym.experiments.base_experiment import BaseExperiment +from safe_control_gym.utils.configuration import ConfigFactory +from safe_control_gym.utils.registration import make + + +def run(gui=True, n_episodes=1, n_steps=None, save_data=False): + '''The main function running MPC and Linear MPC experiments. + + Args: + gui (bool): Whether to display the gui and plot graphs. + n_episodes (int): The number of episodes to execute. + n_steps (int): The total number of steps to execute. + save_data (bool): Whether to save the collected experiment data. + ''' + + # Create the configuration dictionary. + CONFIG_FACTORY = ConfigFactory() + config = CONFIG_FACTORY.merge() + + # Create an environment + env_func = partial(make, + config.task, + **config.task_config + ) + random_env = env_func(gui=False) + + # Create controller. + ctrl = make(config.algo, + env_func, + **config.algo_config + ) + + all_trajs = defaultdict(list) + n_episodes = 1 if n_episodes is None else n_episodes + + # Run the experiment. + for _ in range(n_episodes): + # Get initial state and create environments + init_state, _ = random_env.reset() + static_env = env_func(gui=gui, randomized_init=False, init_state=init_state) + static_train_env = env_func(gui=False, randomized_init=False, init_state=init_state) + + # Create experiment, train, and run evaluation + experiment = BaseExperiment(env=static_env, ctrl=ctrl, train_env=static_train_env) + experiment.launch_training() + + if n_steps is None: + trajs_data, _ = experiment.run_evaluation(training=True, n_episodes=1) + else: + trajs_data, _ = experiment.run_evaluation(training=True, n_steps=n_steps) + + if gui: + post_analysis(trajs_data['obs'][0], trajs_data['action'][0], ctrl.env) + + # Close environments + static_env.close() + static_train_env.close() + + # Merge in new trajectory data + for key, value in trajs_data.items(): + all_trajs[key] += value + + ctrl.close() + random_env.close() + metrics = experiment.compute_metrics(all_trajs) + all_trajs = dict(all_trajs) + + if save_data: + results = {'trajs_data': all_trajs, 'metrics': metrics} + path_dir = os.path.dirname('./temp-data/') + os.makedirs(path_dir, exist_ok=True) + with open(f'./temp-data/{config.algo}_data_{config.task}_{config.task_config.task}.pkl', 'wb') as file: + pickle.dump(results, file) + + print('FINAL METRICS - ' + ', '.join([f'{key}: {value}' for key, value in metrics.items()])) + + +def post_analysis(state_stack, input_stack, env): + '''Plots the input and states to determine iLQR's success. + + Args: + state_stack (ndarray): The list of observations of iLQR in the latest run. + input_stack (ndarray): The list of inputs of iLQR in the latest run. + ''' + model = env.symbolic + stepsize = model.dt + + plot_length = np.min([np.shape(input_stack)[0], np.shape(state_stack)[0]]) + times = np.linspace(0, stepsize * plot_length, plot_length) + + reference = env.X_GOAL + if env.TASK == Task.STABILIZATION: + reference = np.tile(reference.reshape(1, model.nx), (plot_length, 1)) + + # Plot states + fig, axs = plt.subplots(model.nx) + for k in range(model.nx): + axs[k].plot(times, np.array(state_stack).transpose()[k, 0:plot_length], label='actual') + axs[k].plot(times, reference.transpose()[k, 0:plot_length], color='r', label='desired') + axs[k].set(ylabel=env.STATE_LABELS[k] + f'\n[{env.STATE_UNITS[k]}]') + axs[k].yaxis.set_major_formatter(FormatStrFormatter('%.1f')) + if k != model.nx - 1: + axs[k].set_xticks([]) + axs[0].set_title('State Trajectories') + axs[-1].legend(ncol=3, bbox_transform=fig.transFigure, bbox_to_anchor=(1, 0), loc='lower right') + axs[-1].set(xlabel='time (sec)') + + # Plot inputs + _, axs = plt.subplots(model.nu) + if model.nu == 1: + axs = [axs] + for k in range(model.nu): + axs[k].plot(times, np.array(input_stack).transpose()[k, 0:plot_length]) + axs[k].set(ylabel=f'input {k}') + axs[k].set(ylabel=env.ACTION_LABELS[k] + f'\n[{env.ACTION_UNITS[k]}]') + axs[k].yaxis.set_major_formatter(FormatStrFormatter('%.1f')) + axs[0].set_title('Input Trajectories') + axs[-1].set(xlabel='time (sec)') + + plt.show() + + +def wrap2pi_vec(angle_vec): + '''Wraps a vector of angles between -pi and pi. + + Args: + angle_vec (ndarray): A vector of angles. + ''' + for k, angle in enumerate(angle_vec): + while angle > np.pi: + angle -= np.pi + while angle <= -np.pi: + angle += np.pi + angle_vec[k] = angle + return angle_vec + + +if __name__ == '__main__': + run() diff --git a/examples/mpc/mpc_experiment.sh b/examples/mpc/mpc_experiment.sh new file mode 100755 index 000000000..4d2ddaed6 --- /dev/null +++ b/examples/mpc/mpc_experiment.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# MPC and Linear MPC Experiment. + +#SYS='cartpole' +#SYS='quadrotor_2D' +SYS='quadrotor_3D' + +#TASK='stabilization' +TASK='tracking' + +#ALGO='mpc' +ALGO='linear_mpc' + +if [ "$SYS" == 'cartpole' ]; then + SYS_NAME=$SYS +else + SYS_NAME='quadrotor' +fi + +python3 ./mpc_experiment.py \ + --task ${SYS_NAME} \ + --algo ${ALGO} \ + --overrides \ + ./config_overrides/${SYS}/${SYS}_${TASK}.yaml \ + ./config_overrides/${SYS}/${ALGO}_${SYS}_${TASK}.yaml diff --git a/tests/test_examples/test_mpc.py b/tests/test_examples/test_mpc.py new file mode 100644 index 000000000..a5bb459ab --- /dev/null +++ b/tests/test_examples/test_mpc.py @@ -0,0 +1,21 @@ +import sys + +import pytest + +from examples.mpc.mpc_experiment import run + + +@pytest.mark.parametrize('SYS', ['cartpole', 'quadrotor_2D', 'quadrotor_3D']) +@pytest.mark.parametrize('TASK', ['stabilization', 'tracking']) +@pytest.mark.parametrize('ALGO', ['mpc', 'linear_mpc']) +def test_lqr(SYS, TASK, ALGO): + SYS_NAME = 'quadrotor' if 'quadrotor' in SYS else SYS + sys.argv[1:] = ['--algo', ALGO, + '--task', SYS_NAME, + '--overrides', + f'./examples/mpc/config_overrides/{SYS}/{SYS}_{TASK}.yaml', + f'./examples/mpc/config_overrides/{SYS}/{ALGO}_{SYS}_{TASK}.yaml', + '--kv_overrides', + 'algo_config.max_iterations=2' + ] + run(gui=False, n_episodes=None, n_steps=10, save_data=False)