diff --git a/examples/lqr/lqr_experiment.py b/examples/lqr/lqr_experiment.py index 4b6b46d86..faff7934b 100644 --- a/examples/lqr/lqr_experiment.py +++ b/examples/lqr/lqr_experiment.py @@ -15,11 +15,12 @@ from safe_control_gym.utils.registration import make -def run(gui=True, n_episodes=1, n_steps=None, save_data=False): +def run(gui=False, plot=True, n_episodes=2, n_steps=None, save_data=False): '''The main function running LQR and iLQR experiments. Args: - gui (bool): Whether to display the gui and plot graphs. + gui (bool): Whether to display the gui. + plot (bool): Whether to plot graphs. n_episodes (int): The number of episodes to execute. n_steps (int): The total number of steps to execute. save_data (bool): Whether to save the collected experiment data. @@ -61,7 +62,7 @@ def run(gui=True, n_episodes=1, n_steps=None, save_data=False): else: trajs_data, _ = experiment.run_evaluation(training=True, n_steps=n_steps) - if gui: + if plot: post_analysis(trajs_data['obs'][0], trajs_data['action'][0], ctrl.env) # Close environments @@ -132,20 +133,5 @@ def post_analysis(state_stack, input_stack, env): plt.show() -def wrap2pi_vec(angle_vec): - '''Wraps a vector of angles between -pi and pi. - - Args: - angle_vec (ndarray): A vector of angles. - ''' - for k, angle in enumerate(angle_vec): - while angle > np.pi: - angle -= np.pi - while angle <= -np.pi: - angle += np.pi - angle_vec[k] = angle - return angle_vec - - if __name__ == '__main__': run() diff --git a/examples/mpc/mpc_experiment.py b/examples/mpc/mpc_experiment.py index 27b8874b7..139ad8032 100644 --- a/examples/mpc/mpc_experiment.py +++ b/examples/mpc/mpc_experiment.py @@ -2,7 +2,6 @@ import os import pickle -from collections import defaultdict from functools import partial import matplotlib.pyplot as plt @@ -15,11 +14,12 @@ from safe_control_gym.utils.registration import make -def run(gui=True, n_episodes=1, n_steps=None, save_data=False): +def run(gui=False, plot=True, n_episodes=2, n_steps=None, save_data=False): '''The main function running MPC and Linear MPC experiments. Args: - gui (bool): Whether to display the gui and plot graphs. + gui (bool): Whether to display the gui. + plot (bool): Whether to plot graphs. n_episodes (int): The number of episodes to execute. n_steps (int): The total number of steps to execute. save_data (bool): Whether to save the collected experiment data. @@ -34,7 +34,7 @@ def run(gui=True, n_episodes=1, n_steps=None, save_data=False): config.task, **config.task_config ) - random_env = env_func(gui=False) + env = env_func(gui=gui) # Create controller. ctrl = make(config.algo, @@ -42,43 +42,19 @@ def run(gui=True, n_episodes=1, n_steps=None, save_data=False): **config.algo_config ) - all_trajs = defaultdict(list) - n_episodes = 1 if n_episodes is None else n_episodes - # Run the experiment. - for _ in range(n_episodes): - # Get initial state and create environments - init_state, _ = random_env.reset() - static_env = env_func(gui=gui, randomized_init=False, init_state=init_state) - static_train_env = env_func(gui=False, randomized_init=False, init_state=init_state) - - # Create experiment, train, and run evaluation - experiment = BaseExperiment(env=static_env, ctrl=ctrl, train_env=static_train_env) - experiment.launch_training() - - if n_steps is None: - trajs_data, _ = experiment.run_evaluation(training=True, n_episodes=1) - else: - trajs_data, _ = experiment.run_evaluation(training=True, n_steps=n_steps) - - if gui: - post_analysis(trajs_data['obs'][0], trajs_data['action'][0], ctrl.env) - - # Close environments - static_env.close() - static_train_env.close() + experiment = BaseExperiment(env=env, ctrl=ctrl) + trajs_data, metrics = experiment.run_evaluation(training=True, n_episodes=n_episodes, n_steps=n_steps) - # Merge in new trajectory data - for key, value in trajs_data.items(): - all_trajs[key] += value + if plot: + for i in range(len(trajs_data['obs'])): + post_analysis(trajs_data['obs'][i], trajs_data['action'][i], ctrl.env) ctrl.close() - random_env.close() - metrics = experiment.compute_metrics(all_trajs) - all_trajs = dict(all_trajs) + env.close() if save_data: - results = {'trajs_data': all_trajs, 'metrics': metrics} + results = {'trajs_data': trajs_data, 'metrics': metrics} path_dir = os.path.dirname('./temp-data/') os.makedirs(path_dir, exist_ok=True) with open(f'./temp-data/{config.algo}_data_{config.task}_{config.task_config.task}.pkl', 'wb') as file: @@ -132,20 +108,5 @@ def post_analysis(state_stack, input_stack, env): plt.show() -def wrap2pi_vec(angle_vec): - '''Wraps a vector of angles between -pi and pi. - - Args: - angle_vec (ndarray): A vector of angles. - ''' - for k, angle in enumerate(angle_vec): - while angle > np.pi: - angle -= np.pi - while angle <= -np.pi: - angle += np.pi - angle_vec[k] = angle - return angle_vec - - if __name__ == '__main__': run()