Skip to content

Commit

Permalink
Removing set_cost_function_param and setting it via rew_state_weight (#…
Browse files Browse the repository at this point in the history
…181)

* Removing set_cost_function_param and setting it via rew_state_weight and rew_action_weight

* Re-adding the MPC run function as it is necessary for GP-MPC

* Cleaning up MPC and LQR example scripts

* Reverting to n_episodes=1
  • Loading branch information
Federico-PizarroBejarano authored Dec 4, 2024
1 parent 253f25c commit 55f3db2
Show file tree
Hide file tree
Showing 22 changed files with 53 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [1, 1, 1, 1] # Match LQR weights
rew_act_weight: [0.1]
done_on_out_of_bound: True
2 changes: 2 additions & 0 deletions examples/lqr/config_overrides/cartpole/cartpole_tracking.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,6 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [1, 0.1, 0.1, 0.1] # Match LQR weights
rew_act_weight: [0.1]
done_on_out_of_bound: True
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [1, 1, 1, 1, 1, 1] # Match LQR weights
rew_act_weight: [0.1]
done_on_out_of_bound: True
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [1, 0.1, 1, 0.1, 0.1, 0.1] # Match LQR weights
rew_act_weight: [0.1]
done_on_out_of_bound: True
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,7 @@ task_config:

episode_len_sec: 6
cost: quadratic
# Match LQR weights
rew_state_weight: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
rew_act_weight: [0.1]
done_on_out_of_bound: True
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,7 @@ task_config:

episode_len_sec: 6
cost: quadratic
# Match LQR weights
rew_state_weight: [1, 0.1, 1, 0.1, 1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
rew_act_weight: [0.1]
done_on_out_of_bound: True
22 changes: 4 additions & 18 deletions examples/lqr/lqr_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from safe_control_gym.utils.registration import make


def run(gui=True, n_episodes=1, n_steps=None, save_data=False):
def run(gui=False, plot=True, n_episodes=1, n_steps=None, save_data=False):
'''The main function running LQR and iLQR experiments.
Args:
gui (bool): Whether to display the gui and plot graphs.
gui (bool): Whether to display the gui.
plot (bool): Whether to plot graphs.
n_episodes (int): The number of episodes to execute.
n_steps (int): The total number of steps to execute.
save_data (bool): Whether to save the collected experiment data.
Expand Down Expand Up @@ -61,7 +62,7 @@ def run(gui=True, n_episodes=1, n_steps=None, save_data=False):
else:
trajs_data, _ = experiment.run_evaluation(training=True, n_steps=n_steps)

if gui:
if plot:
post_analysis(trajs_data['obs'][0], trajs_data['action'][0], ctrl.env)

# Close environments
Expand Down Expand Up @@ -132,20 +133,5 @@ def post_analysis(state_stack, input_stack, env):
plt.show()


def wrap2pi_vec(angle_vec):
'''Wraps a vector of angles between -pi and pi.
Args:
angle_vec (ndarray): A vector of angles.
'''
for k, angle in enumerate(angle_vec):
while angle > np.pi:
angle -= np.pi
while angle <= -np.pi:
angle += np.pi
angle_vec[k] = angle
return angle_vec


if __name__ == '__main__':
run()
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [5.0, 0.1, 5.0, 0.1] # Match MPC weights
rew_act_weight: [0.1]
done_on_out_of_bound: True

constraints:
Expand Down
2 changes: 2 additions & 0 deletions examples/mpc/config_overrides/cartpole/cartpole_tracking.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [5.0, 0.1, 5.0, 0.1] # Match MPC weights
rew_act_weight: [0.1]
done_on_out_of_bound: True

constraints:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ algo_config:
- 0.1
- 0.1
q_mpc:
- 1.0
- 5.0
- 0.1
- 1.0
- 5.0
- 0.1
- 0.1
- 0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [5.0, 0.1, 5.0, 0.1, 0.1, 0.1] # Match MPC weights
rew_act_weight: [0.1, 0.1]
done_on_out_of_bound: True

constraints:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [5.0, 0.1, 5.0, 0.1, 0.1, 0.1] # Match MPC weights
rew_act_weight: [0.1, 0.1]
done_on_out_of_bound: True

constraints:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ algo_config:
- 0.1
- 0.1
q_mpc:
- 1.0
- 5.0
- 0.1
- 1.0
- 5.0
- 0.1
- 1.0
- 5.0
- 0.1
- 0.1
- 0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ task_config:

episode_len_sec: 6
cost: quadratic
# Match MPC weights
rew_state_weight: [5.0, 0.1, 5.0, 0.1, 5.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
rew_act_weight: [0.1, 0.1, 0.1, 0.1]
done_on_out_of_bound: True
constraints:
- constraint_form: default_constraint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ task_config:
proj_normal: [0, 1, 1]
episode_len_sec: 6
cost: quadratic
# Match MPC weights
rew_state_weight: [5.0, 0.1, 5.0, 0.1, 5.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
rew_act_weight: [0.1, 0.1, 0.1, 0.1]
done_on_out_of_bound: True
constraints:
- constraint_form: default_constraint
Expand Down
61 changes: 11 additions & 50 deletions examples/mpc/mpc_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import pickle
from collections import defaultdict
from functools import partial

import matplotlib.pyplot as plt
Expand All @@ -15,11 +14,12 @@
from safe_control_gym.utils.registration import make


def run(gui=True, n_episodes=1, n_steps=None, save_data=False):
def run(gui=False, plot=True, n_episodes=1, n_steps=None, save_data=False):
'''The main function running MPC and Linear MPC experiments.
Args:
gui (bool): Whether to display the gui and plot graphs.
gui (bool): Whether to display the gui.
plot (bool): Whether to plot graphs.
n_episodes (int): The number of episodes to execute.
n_steps (int): The total number of steps to execute.
save_data (bool): Whether to save the collected experiment data.
Expand All @@ -34,51 +34,27 @@ def run(gui=True, n_episodes=1, n_steps=None, save_data=False):
config.task,
**config.task_config
)
random_env = env_func(gui=False)
env = env_func(gui=gui)

# Create controller.
ctrl = make(config.algo,
env_func,
**config.algo_config
)

all_trajs = defaultdict(list)
n_episodes = 1 if n_episodes is None else n_episodes

# Run the experiment.
for _ in range(n_episodes):
# Get initial state and create environments
init_state, _ = random_env.reset()
static_env = env_func(gui=gui, randomized_init=False, init_state=init_state)
static_train_env = env_func(gui=False, randomized_init=False, init_state=init_state)

# Create experiment, train, and run evaluation
experiment = BaseExperiment(env=static_env, ctrl=ctrl, train_env=static_train_env)
experiment.launch_training()

if n_steps is None:
trajs_data, _ = experiment.run_evaluation(training=True, n_episodes=1)
else:
trajs_data, _ = experiment.run_evaluation(training=True, n_steps=n_steps)

if gui:
post_analysis(trajs_data['obs'][0], trajs_data['action'][0], ctrl.env)

# Close environments
static_env.close()
static_train_env.close()
experiment = BaseExperiment(env=env, ctrl=ctrl)
trajs_data, metrics = experiment.run_evaluation(training=True, n_episodes=n_episodes, n_steps=n_steps)

# Merge in new trajectory data
for key, value in trajs_data.items():
all_trajs[key] += value
if plot:
for i in range(len(trajs_data['obs'])):
post_analysis(trajs_data['obs'][i], trajs_data['action'][i], ctrl.env)

ctrl.close()
random_env.close()
metrics = experiment.compute_metrics(all_trajs)
all_trajs = dict(all_trajs)
env.close()

if save_data:
results = {'trajs_data': all_trajs, 'metrics': metrics}
results = {'trajs_data': trajs_data, 'metrics': metrics}
path_dir = os.path.dirname('./temp-data/')
os.makedirs(path_dir, exist_ok=True)
with open(f'./temp-data/{config.algo}_data_{config.task}_{config.task_config.task}.pkl', 'wb') as file:
Expand Down Expand Up @@ -132,20 +108,5 @@ def post_analysis(state_stack, input_stack, env):
plt.show()


def wrap2pi_vec(angle_vec):
'''Wraps a vector of angles between -pi and pi.
Args:
angle_vec (ndarray): A vector of angles.
'''
for k, angle in enumerate(angle_vec):
while angle > np.pi:
angle -= np.pi
while angle <= -np.pi:
angle += np.pi
angle_vec[k] = angle
return angle_vec


if __name__ == '__main__':
run()
1 change: 0 additions & 1 deletion safe_control_gym/controllers/lqr/ilqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def __init__(
self.model = self.get_prior(self.env)
self.Q = get_cost_weight_matrix(self.q_lqr, self.model.nx)
self.R = get_cost_weight_matrix(self.r_lqr, self.model.nu)
self.env.set_cost_function_param(self.Q, self.R)

self.gain = compute_lqr_gain(self.model, self.model.X_EQ, self.model.U_EQ,
self.Q, self.R, self.discrete_dynamics)
Expand Down
1 change: 0 additions & 1 deletion safe_control_gym/controllers/lqr/lqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __init__(
self.discrete_dynamics = discrete_dynamics
self.Q = get_cost_weight_matrix(q_lqr, self.model.nx)
self.R = get_cost_weight_matrix(r_lqr, self.model.nu)
self.env.set_cost_function_param(self.Q, self.R)

self.gain = compute_lqr_gain(self.model, self.model.X_EQ, self.model.U_EQ,
self.Q, self.R, self.discrete_dynamics)
Expand Down
5 changes: 1 addition & 4 deletions safe_control_gym/controllers/mpc/mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,7 @@ def run(self,

self.x_prev = None
self.u_prev = None
if not env.initial_reset:
env.set_cost_function_param(self.Q, self.R)
# obs, info = env.reset()
obs = env.reset()
obs, info = env.reset()
print('Init State:')
print(obs)
ep_returns, ep_lengths = [], []
Expand Down
23 changes: 0 additions & 23 deletions safe_control_gym/envs/benchmark_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,6 @@ def __init__(self,
self.state_dim = self.state_space.shape[0]
else:
self.state_dim = self.obs_dim
# Default Q and R matrices for quadratic cost.
if self.COST == Cost.QUADRATIC:
self.Q = np.eye(self.observation_space.shape[0])
self.R = np.eye(self.action_space.shape[0])
# Set constraint info.
self.CONSTRAINTS = constraints
self.DONE_ON_VIOLATION = done_on_violation
Expand Down Expand Up @@ -221,25 +217,6 @@ def seed(self,
disturbs.seed(self)
return [seed]

def set_cost_function_param(self,
Q,
R
):
'''Set the cost function parameters.
Args:
Q (ndarray): State weight matrix (nx by nx).
R (ndarray): Input weight matrix (nu by nu).
'''

if not self.initial_reset:
self.Q = Q
self.R = R
else:
raise RuntimeError(
'[ERROR] env.set_cost_function_param() cannot be called after the first reset of the environment.'
)

def set_adversary_control(self, action):
'''Sets disturbance by an adversary controller, called before (each) step().
Expand Down
2 changes: 2 additions & 0 deletions safe_control_gym/envs/gym_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def __init__(self,
self.obs_goal_horizon = obs_goal_horizon
self.obs_wrap_angle = obs_wrap_angle
self.rew_state_weight = np.array(rew_state_weight, ndmin=1, dtype=float)
self.Q = np.diag(self.rew_state_weight)
self.rew_act_weight = np.array(rew_act_weight, ndmin=1, dtype=float)
self.R = np.diag(self.rew_act_weight)
self.rew_exponential = rew_exponential
self.done_on_out_of_bound = done_on_out_of_bound
# BenchmarkEnv constructor, called after defining the custom args,
Expand Down
2 changes: 2 additions & 0 deletions safe_control_gym/envs/gym_pybullet_drones/quadrotor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def __init__(self,
self.norm_act_scale = norm_act_scale
self.obs_goal_horizon = obs_goal_horizon
self.rew_state_weight = np.array(rew_state_weight, ndmin=1, dtype=float)
self.Q = np.diag(self.rew_state_weight)
self.rew_act_weight = np.array(rew_act_weight, ndmin=1, dtype=float)
self.R = np.diag(self.rew_act_weight)
self.rew_exponential = rew_exponential
self.done_on_out_of_bound = done_on_out_of_bound
if info_mse_metric_state_weight is None:
Expand Down

0 comments on commit 55f3db2

Please sign in to comment.