14
14
15
15
plot = False
16
16
save_figs = True
17
- ordered_algos = ['lqr' , 'ppo' , 'sac' ]
18
- # ordered_algos = ['lqr', 'pid', 'ppo', 'sac']
19
17
20
18
U_EQs = {
21
19
'cartpole' : 0 ,
27
25
met .verbose = False
28
26
29
27
30
- def load_one_experiment (system , task , algo , mpsc_cost_horizon ):
31
- '''Loads the results of every MPSC cost function for a specific experiment.
32
-
33
- Args:
34
- system (str): The system to be controlled.
35
- task (str): The task to be completed (either 'stab' or 'track').
36
- algo (str): The controller being used.
37
- mpsc_cost_horizon (str): The cost horizon used by the smooth MPSC cost functions.
38
-
39
- Returns:
40
- all_results (dict): A dictionary containing all the results.
41
- '''
42
-
43
- all_results = {}
44
-
45
- for cost in ordered_costs :
46
- with open (f'./results_mpsc/{ system } /{ task } /m{ mpsc_cost_horizon } /results_{ system } _{ task } _{ algo } _{ cost } _cost_m{ mpsc_cost_horizon } .pkl' , 'rb' ) as f :
47
- all_results [cost ] = pickle .load (f )
48
-
49
- return all_results
50
-
51
-
52
- def load_all_algos (system , task , mpsc_cost_horizon ):
53
- '''Loads the results of every MPSC cost function for a specific experiment with every algo.
54
-
55
- Args:
56
- system (str): The system to be controlled.
57
- task (str): The task to be completed (either 'stab' or 'track').
58
- mpsc_cost_horizon (str): The cost horizon used by the smooth MPSC cost functions.
59
-
60
- Returns:
61
- all_results (dict): A dictionary containing all the results.
62
- '''
63
-
64
- all_results = {}
65
-
66
- for algo in ['lqr' , 'pid' , 'ppo' , 'sac' ]:
67
- if system == 'cartpole' and algo == 'pid' :
68
- continue
69
-
70
- all_results [algo ] = load_one_experiment (system , task , algo , mpsc_cost_horizon )
71
-
72
- return all_results
73
-
74
-
75
28
def load_all_models (system , task , algo ):
76
29
'''Loads the results of every MPSC cost function for a specific experiment with every algo.
77
30
@@ -566,7 +519,7 @@ def plot_model_comparisons(system, task, algo, data_extractor):
566
519
ax .set_ylabel (ylabel , weight = 'bold' , fontsize = 45 , labelpad = 10 )
567
520
568
521
x = np .arange (1 , len (labels ) + 1 )
569
- ax .set_xticks (x , labels , weight = 'bold' , fontsize = 15 , rotation = 45 , ha = 'right' )
522
+ ax .set_xticks (x , labels , weight = 'bold' , fontsize = 15 , rotation = 30 , ha = 'right' )
570
523
571
524
medianprops = dict (linestyle = '--' , linewidth = 2.5 , color = 'black' )
572
525
bplot = ax .boxplot (data , patch_artist = True , labels = labels , medianprops = medianprops , widths = [0.75 ] * len (labels ))
@@ -629,9 +582,6 @@ def plot_all_logs(system, task, algo):
629
582
for seed in os .listdir (f'./models/rl_models/{ system } /{ task } /{ algo } /{ model } /' ):
630
583
all_results [model ].append (load_from_logs (f'./models/rl_models/{ system } /{ task } /{ algo } /{ model } /{ seed } /logs/' ))
631
584
632
- # all_results['safe_ppo'] = load_from_logs(f'./models/rl_models/{system}/{task}/safe_explorer_ppo/none/logs/')
633
- # all_results['cpo'] = load_from_logs(f'./models/rl_models/{system}/{task}/cpo/none/logs/')
634
-
635
585
for key in all_results ['none' ][0 ].keys ():
636
586
plot_log (system , task , algo , key , all_results )
637
587
@@ -655,12 +605,13 @@ def plot_log(system, task, algo, key, all_results):
655
605
colors = {'mpsf_sr_pen_1' : 'lightgreen' , 'mpsf_sr_pen_10' : 'limegreen' , 'mpsf_sr_pen_100' : 'forestgreen' , 'mpsf_sr_pen_1000' : 'darkgreen' , 'none' : 'cornflowerblue' , 'none_cpen' : 'plum' }
656
606
657
607
for model in labels :
658
- x = all_results [model ][0 ][key ][1 ]
608
+ x = all_results [model ][0 ][key ][1 ] / 1000
659
609
all_data = np .array ([values [key ][3 ] for values in all_results [model ]])
660
610
ax .plot (x , np .mean (all_data , axis = 0 ), label = model , color = colors [model ])
661
611
ax .fill_between (x , np .min (all_data , axis = 0 ), np .max (all_data , axis = 0 ), alpha = 0.3 , edgecolor = colors [model ], facecolor = colors [model ])
662
612
663
613
ax .set_ylabel (key , weight = 'bold' , fontsize = 45 , labelpad = 10 )
614
+ ax .set_xlabel ('Training Episodes' )
664
615
ax .legend ()
665
616
666
617
fig .tight_layout ()
0 commit comments