Skip to content

Commit 99142ae

Browse files
Linting
1 parent 53143d2 commit 99142ae

File tree

37 files changed

+7
-56
lines changed

37 files changed

+7
-56
lines changed

experiments/mpsc/mpsc_experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def determine_feasible_starting_points(num_points=100):
163163
# Define arguments.
164164
fac = ConfigFactory()
165165
config = fac.merge()
166-
config.sf_config.cost_function='one_step_cost'
166+
config.sf_config.cost_function = 'one_step_cost'
167167

168168
task = 'stab' if config.task_config.task == Task.STABILIZATION else 'track'
169169
if config.task == Environment.QUADROTOR:

experiments/mpsc/plotting_results.py

Lines changed: 3 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
plot = False
1616
save_figs = True
17-
ordered_algos = ['lqr', 'ppo', 'sac']
18-
# ordered_algos = ['lqr', 'pid', 'ppo', 'sac']
1917

2018
U_EQs = {
2119
'cartpole': 0,
@@ -27,51 +25,6 @@
2725
met.verbose = False
2826

2927

30-
def load_one_experiment(system, task, algo, mpsc_cost_horizon):
31-
'''Loads the results of every MPSC cost function for a specific experiment.
32-
33-
Args:
34-
system (str): The system to be controlled.
35-
task (str): The task to be completed (either 'stab' or 'track').
36-
algo (str): The controller being used.
37-
mpsc_cost_horizon (str): The cost horizon used by the smooth MPSC cost functions.
38-
39-
Returns:
40-
all_results (dict): A dictionary containing all the results.
41-
'''
42-
43-
all_results = {}
44-
45-
for cost in ordered_costs:
46-
with open(f'./results_mpsc/{system}/{task}/m{mpsc_cost_horizon}/results_{system}_{task}_{algo}_{cost}_cost_m{mpsc_cost_horizon}.pkl', 'rb') as f:
47-
all_results[cost] = pickle.load(f)
48-
49-
return all_results
50-
51-
52-
def load_all_algos(system, task, mpsc_cost_horizon):
53-
'''Loads the results of every MPSC cost function for a specific experiment with every algo.
54-
55-
Args:
56-
system (str): The system to be controlled.
57-
task (str): The task to be completed (either 'stab' or 'track').
58-
mpsc_cost_horizon (str): The cost horizon used by the smooth MPSC cost functions.
59-
60-
Returns:
61-
all_results (dict): A dictionary containing all the results.
62-
'''
63-
64-
all_results = {}
65-
66-
for algo in ['lqr', 'pid', 'ppo', 'sac']:
67-
if system == 'cartpole' and algo == 'pid':
68-
continue
69-
70-
all_results[algo] = load_one_experiment(system, task, algo, mpsc_cost_horizon)
71-
72-
return all_results
73-
74-
7528
def load_all_models(system, task, algo):
7629
'''Loads the results of every MPSC cost function for a specific experiment with every algo.
7730
@@ -566,7 +519,7 @@ def plot_model_comparisons(system, task, algo, data_extractor):
566519
ax.set_ylabel(ylabel, weight='bold', fontsize=45, labelpad=10)
567520

568521
x = np.arange(1, len(labels) + 1)
569-
ax.set_xticks(x, labels, weight='bold', fontsize=15, rotation=45, ha='right')
522+
ax.set_xticks(x, labels, weight='bold', fontsize=15, rotation=30, ha='right')
570523

571524
medianprops = dict(linestyle='--', linewidth=2.5, color='black')
572525
bplot = ax.boxplot(data, patch_artist=True, labels=labels, medianprops=medianprops, widths=[0.75] * len(labels))
@@ -629,9 +582,6 @@ def plot_all_logs(system, task, algo):
629582
for seed in os.listdir(f'./models/rl_models/{system}/{task}/{algo}/{model}/'):
630583
all_results[model].append(load_from_logs(f'./models/rl_models/{system}/{task}/{algo}/{model}/{seed}/logs/'))
631584

632-
# all_results['safe_ppo'] = load_from_logs(f'./models/rl_models/{system}/{task}/safe_explorer_ppo/none/logs/')
633-
# all_results['cpo'] = load_from_logs(f'./models/rl_models/{system}/{task}/cpo/none/logs/')
634-
635585
for key in all_results['none'][0].keys():
636586
plot_log(system, task, algo, key, all_results)
637587

@@ -655,12 +605,13 @@ def plot_log(system, task, algo, key, all_results):
655605
colors = {'mpsf_sr_pen_1': 'lightgreen', 'mpsf_sr_pen_10': 'limegreen', 'mpsf_sr_pen_100': 'forestgreen', 'mpsf_sr_pen_1000': 'darkgreen', 'none': 'cornflowerblue', 'none_cpen': 'plum'}
656606

657607
for model in labels:
658-
x = all_results[model][0][key][1]
608+
x = all_results[model][0][key][1] / 1000
659609
all_data = np.array([values[key][3] for values in all_results[model]])
660610
ax.plot(x, np.mean(all_data, axis=0), label=model, color=colors[model])
661611
ax.fill_between(x, np.min(all_data, axis=0), np.max(all_data, axis=0), alpha=0.3, edgecolor=colors[model], facecolor=colors[model])
662612

663613
ax.set_ylabel(key, weight='bold', fontsize=45, labelpad=10)
614+
ax.set_xlabel('Training Episodes')
664615
ax.legend()
665616

666617
fig.tight_layout()
22.2 KB
Loading
22.9 KB
Loading
20.9 KB
Loading
21 KB
Loading
21.7 KB
Loading
21.1 KB
Loading
22 KB
Loading
6.56 KB
Loading

0 commit comments

Comments
 (0)