12
12
from safe_control_gym .safety_filters .mpsc .mpsc_utils import get_discrete_derivative , high_frequency_content
13
13
from safe_control_gym .utils .plotting import load_from_logs
14
14
15
- plot = False
16
- save_figs = True
15
+ plot = True # Saves figure if False
17
16
18
17
U_EQs = {
19
18
'cartpole' : 0 ,
26
25
27
26
28
27
def load_all_models (system , task , algo ):
29
- '''Loads the results of every MPSC cost function for a specific experiment with every algo .
28
+ '''Loads the results of every experiment.
30
29
31
30
Args:
32
- system (str): The system to be controlled.
33
- task (str): The task to be completed (either 'stab' or 'track').
31
+ system (str): The system to be plotted.
32
+ task (str): The task to be plotted (either 'stab' or 'track').
33
+ algo (str): The controller to be plotted.
34
34
35
35
Returns:
36
36
all_results (dict): A dictionary containing all the results.
37
37
'''
38
38
39
39
all_results = {}
40
40
41
- for model in os . listdir ( f'./models/rl_models/ { system } / { task } / { algo } /' ) :
41
+ for model in ordered_models :
42
42
all_results [model ] = []
43
- for seed in os .listdir (f'./models/rl_models/ { system } /{ task } /{ algo } /{ model } /' ):
44
- with open (f'./results_mpsc/{ system } /{ task } /{ algo } /results_{ system } _{ task } _{ algo } _{ model } /{ seed } .pkl ' , 'rb' ) as f :
43
+ for seed in os .listdir (f'./results_mpsc/ { system } /{ task } /{ algo } /results_ { system } _ { task } _ { algo } _ { model } /' ):
44
+ with open (f'./results_mpsc/{ system } /{ task } /{ algo } /results_{ system } _{ task } _{ algo } _{ model } /{ seed } ' , 'rb' ) as f :
45
45
all_results [model ].append (pickle .load (f ))
46
46
consolidate_multiple_seeds (all_results , model )
47
47
@@ -497,21 +497,22 @@ def plot_model_comparisons(system, task, algo, data_extractor):
497
497
'''Plots the constraint violations of every controller for a specific experiment.
498
498
499
499
Args:
500
- system (str): The system to be controlled.
501
- task (str): The task to be completed (either 'stab' or 'track').
502
- mpsc_cost_horizon (str): The cost horizon used by the smooth MPSC cost functions.
500
+ system (str): The system to be plotted.
501
+ task (str): The task to be plotted (either 'stab' or 'track').
502
+ algo (str): The controller to be plotted.
503
+ data_extractor (func): The function which extracts the desired data.
503
504
'''
504
505
505
506
all_results = load_all_models (system , task , algo )
506
507
507
508
fig = plt .figure (figsize = (16.0 , 10.0 ))
508
509
ax = fig .add_subplot (111 )
509
510
510
- labels = sorted ( os . listdir ( f'./models/rl_models/ { system } / { task } / { algo } /' ))
511
+ labels = ordered_models
511
512
512
513
data = []
513
514
514
- for model in labels :
515
+ for model in ordered_models :
515
516
exp_data = all_results [model ]
516
517
data .append (data_extractor (exp_data ))
517
518
@@ -522,24 +523,71 @@ def plot_model_comparisons(system, task, algo, data_extractor):
522
523
ax .set_xticks (x , labels , weight = 'bold' , fontsize = 15 , rotation = 30 , ha = 'right' )
523
524
524
525
medianprops = dict (linestyle = '--' , linewidth = 2.5 , color = 'black' )
525
- bplot = ax .boxplot (data , patch_artist = True , labels = labels , medianprops = medianprops , widths = [0.75 ] * len (labels ))
526
-
527
- colors = {'mpsf_sr_pen_1' : 'lightgreen' , 'mpsf_sr_pen_10' : 'limegreen' , 'mpsf_sr_pen_100' : 'forestgreen' , 'mpsf_sr_pen_1000' : 'darkgreen' , 'none' : 'cornflowerblue' , 'none_cpen' : 'plum' }
526
+ bplot = ax .boxplot (data , patch_artist = True , labels = labels , medianprops = medianprops , widths = [0.75 ] * len (labels ), showfliers = False )
528
527
529
528
for patch , color in zip (bplot ['boxes' ], colors .values ()):
530
529
patch .set_facecolor (color )
531
530
532
531
fig .tight_layout ()
533
532
534
- if data_extractor != extract_reward_cert :
535
- ax .set_ylim (ymin = 0 )
536
533
ax .yaxis .grid (True )
537
534
538
535
if plot is True :
539
536
plt .show ()
540
- if save_figs :
537
+ else :
541
538
image_suffix = data_extractor .__name__ .replace ('extract_' , '' )
542
- fig .savefig (f'./results_mpsc/{ system } /{ task } /{ algo } /graphs/{ system } _{ task } _{ image_suffix } .png' , dpi = 300 )
539
+ fig .savefig (f'./results_mpsc/{ image_suffix } .png' , dpi = 300 )
540
+ plt .close ()
541
+
542
+
543
+ def plot_step_time (system , task , algo ):
544
+ '''Plots the constraint violations of every controller for a specific experiment.
545
+
546
+ Args:
547
+ system (str): The system to be plotted.
548
+ task (str): The task to be plotted (either 'stab' or 'track').
549
+ algo (str): The controller to be plotted.
550
+ '''
551
+
552
+ all_results = {}
553
+ for model in ordered_models :
554
+ all_results [model ] = []
555
+ for seed in os .listdir (f'./models/rl_models/{ system } /{ task } /{ algo } /{ model } /' ):
556
+ all_results [model ].append (load_from_logs (f'./models/rl_models/{ system } /{ task } /{ algo } /{ model } /{ seed } /logs/' ))
557
+
558
+ fig = plt .figure (figsize = (16.0 , 10.0 ))
559
+ ax = fig .add_subplot (111 )
560
+
561
+ labels = ordered_models
562
+
563
+ data = []
564
+
565
+ for model in ordered_models :
566
+ datum = np .array ([values ['stat/step_time' ][3 ] for values in all_results [model ]]).flatten ()
567
+ data .append (datum )
568
+
569
+ ylabel = 'Training Time per Step [ms]'
570
+ ax .set_ylabel (ylabel , weight = 'bold' , fontsize = 45 , labelpad = 10 )
571
+
572
+ x = np .arange (1 , len (labels ) + 1 )
573
+ ax .set_xticks (x , labels , weight = 'bold' , fontsize = 15 , rotation = 30 , ha = 'right' )
574
+
575
+ medianprops = dict (linestyle = '--' , linewidth = 2.5 , color = 'black' )
576
+ bplot = ax .boxplot (data , patch_artist = True , labels = labels , medianprops = medianprops , widths = [0.75 ] * len (labels ), showfliers = False )
577
+
578
+ for patch , color in zip (bplot ['boxes' ], colors .values ()):
579
+ patch .set_facecolor (color )
580
+
581
+ fig .tight_layout ()
582
+
583
+ ax .set_ylim (ymin = 0 )
584
+ ax .yaxis .grid (True )
585
+
586
+ if plot is True :
587
+ plt .show ()
588
+ else :
589
+ image_suffix = 'step_time'
590
+ fig .savefig (f'./results_mpsc/{ image_suffix } .png' , dpi = 300 )
543
591
plt .close ()
544
592
545
593
@@ -571,43 +619,40 @@ def plot_all_logs(system, task, algo):
571
619
'''Plots comparative plots of all the logs.
572
620
573
621
Args:
574
- system (str): The system to be controlled .
575
- task (str): The task to be completed (either 'stab' or 'track').
576
- mpsc_cost_horizon (str): The cost horizon used by the smooth MPSC cost functions .
622
+ system (str): The system to be plotted .
623
+ task (str): The task to be plotted (either 'stab' or 'track').
624
+ algo (str): The controller to be plotted .
577
625
'''
578
626
all_results = {}
579
627
580
- for model in os . listdir ( f'./models/rl_models/ { system } / { task } / { algo } /' ) :
628
+ for model in ordered_models :
581
629
all_results [model ] = []
582
630
for seed in os .listdir (f'./models/rl_models/{ system } /{ task } /{ algo } /{ model } /' ):
583
631
all_results [model ].append (load_from_logs (f'./models/rl_models/{ system } /{ task } /{ algo } /{ model } /{ seed } /logs/' ))
584
632
585
- for key in all_results ['none' ][0 ].keys ():
586
- plot_log (system , task , algo , key , all_results )
633
+ for key in all_results [ordered_models [0 ]][0 ].keys ():
634
+ if key == 'stat_eval/ep_return' :
635
+ plot_log (key , all_results )
636
+ if key == 'stat/constraint_violation' :
637
+ plot_log (key , all_results )
587
638
588
639
589
- def plot_log (system , task , algo , key , all_results ):
640
+ def plot_log (key , all_results ):
590
641
'''Plots a comparative plot of the log 'key'.
591
642
592
643
Args:
593
- system (str): The system to be controlled.
594
- task (str): The task to be completed (either 'stab' or 'track').
595
- mpsc_cost_horizon (str): The cost horizon used by the smooth MPSC cost functions.
596
644
key (str): The name of the log to be plotted.
597
645
all_results (dict): A dictionary of all the logged results for all models.
598
646
'''
599
647
fig = plt .figure (figsize = (16.0 , 10.0 ))
600
648
ax = fig .add_subplot (111 )
601
649
602
- labels = sorted (all_results .keys ())
603
- labels = [label for label in labels if '_es' not in label ]
650
+ labels = ordered_models
604
651
605
- colors = {'mpsf_sr_pen_1' : 'lightgreen' , 'mpsf_sr_pen_10' : 'limegreen' , 'mpsf_sr_pen_100' : 'forestgreen' , 'mpsf_sr_pen_1000' : 'darkgreen' , 'none' : 'cornflowerblue' , 'none_cpen' : 'plum' }
606
-
607
- for model in labels :
652
+ for model , label in zip (ordered_models , labels ):
608
653
x = all_results [model ][0 ][key ][1 ] / 1000
609
654
all_data = np .array ([values [key ][3 ] for values in all_results [model ]])
610
- ax .plot (x , np .mean (all_data , axis = 0 ), label = model , color = colors [model ])
655
+ ax .plot (x , np .mean (all_data , axis = 0 ), label = label , color = colors [model ])
611
656
ax .fill_between (x , np .min (all_data , axis = 0 ), np .max (all_data , axis = 0 ), alpha = 0.3 , edgecolor = colors [model ], facecolor = colors [model ])
612
657
613
658
ax .set_ylabel (key , weight = 'bold' , fontsize = 45 , labelpad = 10 )
@@ -619,14 +664,25 @@ def plot_log(system, task, algo, key, all_results):
619
664
620
665
if plot is True :
621
666
plt .show ()
622
- if save_figs :
667
+ else :
623
668
image_suffix = key .replace ('/' , '__' )
624
- fig .savefig (f'./results_mpsc/{ system } / { task } / { algo } /graphs/ { system } _ { task } _ { image_suffix } .png' , dpi = 300 )
669
+ fig .savefig (f'./results_mpsc/{ image_suffix } .png' , dpi = 300 )
625
670
plt .close ()
626
671
627
672
628
673
if __name__ == '__main__' :
629
- ordered_costs = ['one_step' , 'regularized' , 'precomputed' ]
674
+ ordered_models = ['none' , 'none_cpen_0.01' , 'none_cpen_0.1' , 'none_cpen_1' , 'mpsf_sr_pen_0.1' , 'mpsf_sr_pen_1' , 'mpsf_sr_pen_10' , 'mpsf_sr_pen_100' ]
675
+
676
+ colors = {
677
+ 'none' : 'cornflowerblue' ,
678
+ 'none_cpen_0.01' : 'plum' ,
679
+ 'none_cpen_0.1' : 'mediumorchid' ,
680
+ 'none_cpen_1' : 'darkorchid' ,
681
+ 'mpsf_sr_pen_0.1' : 'lightgreen' ,
682
+ 'mpsf_sr_pen_1' : 'limegreen' ,
683
+ 'mpsf_sr_pen_10' : 'forestgreen' ,
684
+ 'mpsf_sr_pen_100' : 'darkgreen' ,
685
+ }
630
686
631
687
def extract_rate_of_change_of_inputs (results_data , certified = True ):
632
688
return extract_rate_of_change (results_data , certified , order = 1 , mode = 'input' )
@@ -682,6 +738,7 @@ def extract_length_uncert(results_data, certified=False):
682
738
algo_name = sys .argv [3 ]
683
739
684
740
plot_all_logs (system_name , task_name , algo_name )
741
+ plot_step_time (system_name , task_name , algo_name )
685
742
plot_model_comparisons (system_name , task_name , algo_name , extract_magnitude_of_corrections )
686
743
plot_model_comparisons (system_name , task_name , algo_name , extract_percent_magnitude_of_corrections )
687
744
plot_model_comparisons (system_name , task_name , algo_name , extract_max_correction )
0 commit comments