diff --git a/examples/lqr/config_overrides/cartpole/cartpole_stabilization.yaml b/examples/lqr/config_overrides/cartpole/cartpole_stabilization.yaml index 0cbb7226e..eefdf5dca 100644 --- a/examples/lqr/config_overrides/cartpole/cartpole_stabilization.yaml +++ b/examples/lqr/config_overrides/cartpole/cartpole_stabilization.yaml @@ -30,4 +30,6 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [1, 1, 1, 1] # Match LQR weights + rew_act_weight: [0.1] done_on_out_of_bound: True diff --git a/examples/lqr/config_overrides/cartpole/cartpole_tracking.yaml b/examples/lqr/config_overrides/cartpole/cartpole_tracking.yaml index 7eefd0fb0..11b956747 100644 --- a/examples/lqr/config_overrides/cartpole/cartpole_tracking.yaml +++ b/examples/lqr/config_overrides/cartpole/cartpole_tracking.yaml @@ -33,4 +33,6 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [1, 0.1, 0.1, 0.1] # Match LQR weights + rew_act_weight: [0.1] done_on_out_of_bound: True diff --git a/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml b/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml index 134575710..8132fbca0 100644 --- a/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml +++ b/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml @@ -42,4 +42,6 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [1, 1, 1, 1, 1, 1] # Match LQR weights + rew_act_weight: [0.1] done_on_out_of_bound: True diff --git a/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml b/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml index aa68b6912..8cf000795 100644 --- a/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml +++ b/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml @@ -45,4 +45,6 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [1, 0.1, 1, 0.1, 0.1, 0.1] # Match LQR weights + rew_act_weight: [0.1] done_on_out_of_bound: True diff --git a/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml b/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml index 8ebffccb7..878726f3f 100644 --- a/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml +++ b/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml @@ -68,4 +68,7 @@ task_config: episode_len_sec: 6 cost: quadratic + # Match LQR weights + rew_state_weight: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + rew_act_weight: [0.1] done_on_out_of_bound: True diff --git a/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml b/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml index 4089e5359..42fab5a12 100644 --- a/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml +++ b/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml @@ -71,4 +71,7 @@ task_config: episode_len_sec: 6 cost: quadratic + # Match LQR weights + rew_state_weight: [1, 0.1, 1, 0.1, 1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] + rew_act_weight: [0.1] done_on_out_of_bound: True diff --git a/examples/lqr/lqr_experiment.py b/examples/lqr/lqr_experiment.py index 4b6b46d86..e21807185 100644 --- a/examples/lqr/lqr_experiment.py +++ b/examples/lqr/lqr_experiment.py @@ -15,11 +15,12 @@ from safe_control_gym.utils.registration import make -def run(gui=True, n_episodes=1, n_steps=None, save_data=False): +def run(gui=False, plot=True, n_episodes=1, n_steps=None, save_data=False): '''The main function running LQR and iLQR experiments. Args: - gui (bool): Whether to display the gui and plot graphs. + gui (bool): Whether to display the gui. + plot (bool): Whether to plot graphs. n_episodes (int): The number of episodes to execute. n_steps (int): The total number of steps to execute. save_data (bool): Whether to save the collected experiment data. @@ -61,7 +62,7 @@ def run(gui=True, n_episodes=1, n_steps=None, save_data=False): else: trajs_data, _ = experiment.run_evaluation(training=True, n_steps=n_steps) - if gui: + if plot: post_analysis(trajs_data['obs'][0], trajs_data['action'][0], ctrl.env) # Close environments @@ -132,20 +133,5 @@ def post_analysis(state_stack, input_stack, env): plt.show() -def wrap2pi_vec(angle_vec): - '''Wraps a vector of angles between -pi and pi. - - Args: - angle_vec (ndarray): A vector of angles. - ''' - for k, angle in enumerate(angle_vec): - while angle > np.pi: - angle -= np.pi - while angle <= -np.pi: - angle += np.pi - angle_vec[k] = angle - return angle_vec - - if __name__ == '__main__': run() diff --git a/examples/mpc/config_overrides/cartpole/cartpole_stabilization.yaml b/examples/mpc/config_overrides/cartpole/cartpole_stabilization.yaml index b942acb4a..4e800fa31 100644 --- a/examples/mpc/config_overrides/cartpole/cartpole_stabilization.yaml +++ b/examples/mpc/config_overrides/cartpole/cartpole_stabilization.yaml @@ -30,6 +30,8 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [5.0, 0.1, 5.0, 0.1] # Match MPC weights + rew_act_weight: [0.1] done_on_out_of_bound: True constraints: diff --git a/examples/mpc/config_overrides/cartpole/cartpole_tracking.yaml b/examples/mpc/config_overrides/cartpole/cartpole_tracking.yaml index 37c07aa09..cd4164f41 100644 --- a/examples/mpc/config_overrides/cartpole/cartpole_tracking.yaml +++ b/examples/mpc/config_overrides/cartpole/cartpole_tracking.yaml @@ -33,6 +33,8 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [5.0, 0.1, 5.0, 0.1] # Match MPC weights + rew_act_weight: [0.1] done_on_out_of_bound: True constraints: diff --git a/examples/mpc/config_overrides/quadrotor_2D/linear_mpc_quadrotor_2D_tracking.yaml b/examples/mpc/config_overrides/quadrotor_2D/linear_mpc_quadrotor_2D_tracking.yaml index ee1853730..3786a1db4 100644 --- a/examples/mpc/config_overrides/quadrotor_2D/linear_mpc_quadrotor_2D_tracking.yaml +++ b/examples/mpc/config_overrides/quadrotor_2D/linear_mpc_quadrotor_2D_tracking.yaml @@ -5,9 +5,9 @@ algo_config: - 0.1 - 0.1 q_mpc: - - 1.0 + - 5.0 - 0.1 - - 1.0 + - 5.0 - 0.1 - 0.1 - 0.1 diff --git a/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml b/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml index 494c9aefa..8d380fc11 100644 --- a/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml +++ b/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml @@ -42,6 +42,8 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [5.0, 0.1, 5.0, 0.1, 0.1, 0.1] # Match MPC weights + rew_act_weight: [0.1, 0.1] done_on_out_of_bound: True constraints: diff --git a/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml b/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml index 2405a2238..e3eb0959b 100644 --- a/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml +++ b/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml @@ -45,6 +45,8 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [5.0, 0.1, 5.0, 0.1, 0.1, 0.1] # Match MPC weights + rew_act_weight: [0.1, 0.1] done_on_out_of_bound: True constraints: diff --git a/examples/mpc/config_overrides/quadrotor_3D/linear_mpc_quadrotor_3D_tracking.yaml b/examples/mpc/config_overrides/quadrotor_3D/linear_mpc_quadrotor_3D_tracking.yaml index 1102acce3..009f561b2 100644 --- a/examples/mpc/config_overrides/quadrotor_3D/linear_mpc_quadrotor_3D_tracking.yaml +++ b/examples/mpc/config_overrides/quadrotor_3D/linear_mpc_quadrotor_3D_tracking.yaml @@ -7,11 +7,11 @@ algo_config: - 0.1 - 0.1 q_mpc: - - 1.0 + - 5.0 - 0.1 - - 1.0 + - 5.0 - 0.1 - - 1.0 + - 5.0 - 0.1 - 0.1 - 0.1 diff --git a/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml b/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml index 4c652d130..305cda6eb 100644 --- a/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml +++ b/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml @@ -68,6 +68,9 @@ task_config: episode_len_sec: 6 cost: quadratic + # Match MPC weights + rew_state_weight: [5.0, 0.1, 5.0, 0.1, 5.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] + rew_act_weight: [0.1, 0.1, 0.1, 0.1] done_on_out_of_bound: True constraints: - constraint_form: default_constraint diff --git a/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml b/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml index e4263c249..1890b5251 100644 --- a/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml +++ b/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml @@ -70,6 +70,9 @@ task_config: proj_normal: [0, 1, 1] episode_len_sec: 6 cost: quadratic + # Match MPC weights + rew_state_weight: [5.0, 0.1, 5.0, 0.1, 5.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] + rew_act_weight: [0.1, 0.1, 0.1, 0.1] done_on_out_of_bound: True constraints: - constraint_form: default_constraint diff --git a/examples/mpc/mpc_experiment.py b/examples/mpc/mpc_experiment.py index 27b8874b7..5d69738a3 100644 --- a/examples/mpc/mpc_experiment.py +++ b/examples/mpc/mpc_experiment.py @@ -2,7 +2,6 @@ import os import pickle -from collections import defaultdict from functools import partial import matplotlib.pyplot as plt @@ -15,11 +14,12 @@ from safe_control_gym.utils.registration import make -def run(gui=True, n_episodes=1, n_steps=None, save_data=False): +def run(gui=False, plot=True, n_episodes=1, n_steps=None, save_data=False): '''The main function running MPC and Linear MPC experiments. Args: - gui (bool): Whether to display the gui and plot graphs. + gui (bool): Whether to display the gui. + plot (bool): Whether to plot graphs. n_episodes (int): The number of episodes to execute. n_steps (int): The total number of steps to execute. save_data (bool): Whether to save the collected experiment data. @@ -34,7 +34,7 @@ def run(gui=True, n_episodes=1, n_steps=None, save_data=False): config.task, **config.task_config ) - random_env = env_func(gui=False) + env = env_func(gui=gui) # Create controller. ctrl = make(config.algo, @@ -42,43 +42,19 @@ def run(gui=True, n_episodes=1, n_steps=None, save_data=False): **config.algo_config ) - all_trajs = defaultdict(list) - n_episodes = 1 if n_episodes is None else n_episodes - # Run the experiment. - for _ in range(n_episodes): - # Get initial state and create environments - init_state, _ = random_env.reset() - static_env = env_func(gui=gui, randomized_init=False, init_state=init_state) - static_train_env = env_func(gui=False, randomized_init=False, init_state=init_state) - - # Create experiment, train, and run evaluation - experiment = BaseExperiment(env=static_env, ctrl=ctrl, train_env=static_train_env) - experiment.launch_training() - - if n_steps is None: - trajs_data, _ = experiment.run_evaluation(training=True, n_episodes=1) - else: - trajs_data, _ = experiment.run_evaluation(training=True, n_steps=n_steps) - - if gui: - post_analysis(trajs_data['obs'][0], trajs_data['action'][0], ctrl.env) - - # Close environments - static_env.close() - static_train_env.close() + experiment = BaseExperiment(env=env, ctrl=ctrl) + trajs_data, metrics = experiment.run_evaluation(training=True, n_episodes=n_episodes, n_steps=n_steps) - # Merge in new trajectory data - for key, value in trajs_data.items(): - all_trajs[key] += value + if plot: + for i in range(len(trajs_data['obs'])): + post_analysis(trajs_data['obs'][i], trajs_data['action'][i], ctrl.env) ctrl.close() - random_env.close() - metrics = experiment.compute_metrics(all_trajs) - all_trajs = dict(all_trajs) + env.close() if save_data: - results = {'trajs_data': all_trajs, 'metrics': metrics} + results = {'trajs_data': trajs_data, 'metrics': metrics} path_dir = os.path.dirname('./temp-data/') os.makedirs(path_dir, exist_ok=True) with open(f'./temp-data/{config.algo}_data_{config.task}_{config.task_config.task}.pkl', 'wb') as file: @@ -132,20 +108,5 @@ def post_analysis(state_stack, input_stack, env): plt.show() -def wrap2pi_vec(angle_vec): - '''Wraps a vector of angles between -pi and pi. - - Args: - angle_vec (ndarray): A vector of angles. - ''' - for k, angle in enumerate(angle_vec): - while angle > np.pi: - angle -= np.pi - while angle <= -np.pi: - angle += np.pi - angle_vec[k] = angle - return angle_vec - - if __name__ == '__main__': run() diff --git a/safe_control_gym/controllers/lqr/ilqr.py b/safe_control_gym/controllers/lqr/ilqr.py index 3ebc484c7..7407a923f 100644 --- a/safe_control_gym/controllers/lqr/ilqr.py +++ b/safe_control_gym/controllers/lqr/ilqr.py @@ -64,7 +64,6 @@ def __init__( self.model = self.get_prior(self.env) self.Q = get_cost_weight_matrix(self.q_lqr, self.model.nx) self.R = get_cost_weight_matrix(self.r_lqr, self.model.nu) - self.env.set_cost_function_param(self.Q, self.R) self.gain = compute_lqr_gain(self.model, self.model.X_EQ, self.model.U_EQ, self.Q, self.R, self.discrete_dynamics) diff --git a/safe_control_gym/controllers/lqr/lqr.py b/safe_control_gym/controllers/lqr/lqr.py index 5e9597d51..f03069525 100644 --- a/safe_control_gym/controllers/lqr/lqr.py +++ b/safe_control_gym/controllers/lqr/lqr.py @@ -33,7 +33,6 @@ def __init__( self.discrete_dynamics = discrete_dynamics self.Q = get_cost_weight_matrix(q_lqr, self.model.nx) self.R = get_cost_weight_matrix(r_lqr, self.model.nu) - self.env.set_cost_function_param(self.Q, self.R) self.gain = compute_lqr_gain(self.model, self.model.X_EQ, self.model.U_EQ, self.Q, self.R, self.discrete_dynamics) diff --git a/safe_control_gym/controllers/mpc/mpc.py b/safe_control_gym/controllers/mpc/mpc.py index ac2ed4e59..61287b92f 100644 --- a/safe_control_gym/controllers/mpc/mpc.py +++ b/safe_control_gym/controllers/mpc/mpc.py @@ -454,10 +454,7 @@ def run(self, self.x_prev = None self.u_prev = None - if not env.initial_reset: - env.set_cost_function_param(self.Q, self.R) - # obs, info = env.reset() - obs = env.reset() + obs, info = env.reset() print('Init State:') print(obs) ep_returns, ep_lengths = [], [] diff --git a/safe_control_gym/envs/benchmark_env.py b/safe_control_gym/envs/benchmark_env.py index c60220974..61df37a61 100644 --- a/safe_control_gym/envs/benchmark_env.py +++ b/safe_control_gym/envs/benchmark_env.py @@ -176,10 +176,6 @@ def __init__(self, self.state_dim = self.state_space.shape[0] else: self.state_dim = self.obs_dim - # Default Q and R matrices for quadratic cost. - if self.COST == Cost.QUADRATIC: - self.Q = np.eye(self.observation_space.shape[0]) - self.R = np.eye(self.action_space.shape[0]) # Set constraint info. self.CONSTRAINTS = constraints self.DONE_ON_VIOLATION = done_on_violation @@ -221,25 +217,6 @@ def seed(self, disturbs.seed(self) return [seed] - def set_cost_function_param(self, - Q, - R - ): - '''Set the cost function parameters. - - Args: - Q (ndarray): State weight matrix (nx by nx). - R (ndarray): Input weight matrix (nu by nu). - ''' - - if not self.initial_reset: - self.Q = Q - self.R = R - else: - raise RuntimeError( - '[ERROR] env.set_cost_function_param() cannot be called after the first reset of the environment.' - ) - def set_adversary_control(self, action): '''Sets disturbance by an adversary controller, called before (each) step(). diff --git a/safe_control_gym/envs/gym_control/cartpole.py b/safe_control_gym/envs/gym_control/cartpole.py index 6a1c91317..7ed6ea442 100644 --- a/safe_control_gym/envs/gym_control/cartpole.py +++ b/safe_control_gym/envs/gym_control/cartpole.py @@ -150,7 +150,9 @@ def __init__(self, self.obs_goal_horizon = obs_goal_horizon self.obs_wrap_angle = obs_wrap_angle self.rew_state_weight = np.array(rew_state_weight, ndmin=1, dtype=float) + self.Q = np.diag(self.rew_state_weight) self.rew_act_weight = np.array(rew_act_weight, ndmin=1, dtype=float) + self.R = np.diag(self.rew_act_weight) self.rew_exponential = rew_exponential self.done_on_out_of_bound = done_on_out_of_bound # BenchmarkEnv constructor, called after defining the custom args, diff --git a/safe_control_gym/envs/gym_pybullet_drones/quadrotor.py b/safe_control_gym/envs/gym_pybullet_drones/quadrotor.py index 2e2cb1887..2619bdfe5 100644 --- a/safe_control_gym/envs/gym_pybullet_drones/quadrotor.py +++ b/safe_control_gym/envs/gym_pybullet_drones/quadrotor.py @@ -178,7 +178,9 @@ def __init__(self, self.norm_act_scale = norm_act_scale self.obs_goal_horizon = obs_goal_horizon self.rew_state_weight = np.array(rew_state_weight, ndmin=1, dtype=float) + self.Q = np.diag(self.rew_state_weight) self.rew_act_weight = np.array(rew_act_weight, ndmin=1, dtype=float) + self.R = np.diag(self.rew_act_weight) self.rew_exponential = rew_exponential self.done_on_out_of_bound = done_on_out_of_bound if info_mse_metric_state_weight is None: