Skip to content

Commit 5d254c5

Browse files
Minor updates
1 parent 99142ae commit 5d254c5

File tree

7 files changed

+137
-69
lines changed

7 files changed

+137
-69
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ examples/pid/*data/
88
#
99
experiments/mpsc/temp-data/
1010
experiments/mpsc/unsafe_rl_temp_data/
11+
experiments/mpsc/models/rl_models/
12+
experiments/mpsc/results*/
1113
#
1214
results/
1315
z_docstring.py
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
algo: cpo
22
algo_config:
33
# Model args
4-
hidden1: 128
5-
hidden2: 128
4+
hidden1: 256
5+
hidden2: 256
66

77
# Optim args
88
discount_factor: 0.98
@@ -16,15 +16,15 @@ algo_config:
1616
cost_d: 0.0
1717

1818
# Runner args
19-
max_steps: 1000
20-
num_epochs: 4000
21-
value_epochs: 150
19+
max_steps: 2000
20+
num_epochs: 5000
21+
value_epochs: 300
2222
eval_batch_size: 20
2323

2424
# Misc
25-
log_interval: 40
25+
log_interval: 50
2626
save_interval: 0
2727
num_checkpoints: 0
28-
eval_interval: 40
28+
eval_interval: 50
2929
eval_save_best: True
3030
tensorboard: False

experiments/mpsc/plotting_results.py

Lines changed: 96 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from safe_control_gym.safety_filters.mpsc.mpsc_utils import get_discrete_derivative, high_frequency_content
1313
from safe_control_gym.utils.plotting import load_from_logs
1414

15-
plot = False
16-
save_figs = True
15+
plot = True # Saves figure if False
1716

1817
U_EQs = {
1918
'cartpole': 0,
@@ -26,22 +25,23 @@
2625

2726

2827
def load_all_models(system, task, algo):
29-
'''Loads the results of every MPSC cost function for a specific experiment with every algo.
28+
'''Loads the results of every experiment.
3029
3130
Args:
32-
system (str): The system to be controlled.
33-
task (str): The task to be completed (either 'stab' or 'track').
31+
system (str): The system to be plotted.
32+
task (str): The task to be plotted (either 'stab' or 'track').
33+
algo (str): The controller to be plotted.
3434
3535
Returns:
3636
all_results (dict): A dictionary containing all the results.
3737
'''
3838

3939
all_results = {}
4040

41-
for model in os.listdir(f'./models/rl_models/{system}/{task}/{algo}/'):
41+
for model in ordered_models:
4242
all_results[model] = []
43-
for seed in os.listdir(f'./models/rl_models/{system}/{task}/{algo}/{model}/'):
44-
with open(f'./results_mpsc/{system}/{task}/{algo}/results_{system}_{task}_{algo}_{model}/{seed}.pkl', 'rb') as f:
43+
for seed in os.listdir(f'./results_mpsc/{system}/{task}/{algo}/results_{system}_{task}_{algo}_{model}/'):
44+
with open(f'./results_mpsc/{system}/{task}/{algo}/results_{system}_{task}_{algo}_{model}/{seed}', 'rb') as f:
4545
all_results[model].append(pickle.load(f))
4646
consolidate_multiple_seeds(all_results, model)
4747

@@ -497,21 +497,22 @@ def plot_model_comparisons(system, task, algo, data_extractor):
497497
'''Plots the constraint violations of every controller for a specific experiment.
498498
499499
Args:
500-
system (str): The system to be controlled.
501-
task (str): The task to be completed (either 'stab' or 'track').
502-
mpsc_cost_horizon (str): The cost horizon used by the smooth MPSC cost functions.
500+
system (str): The system to be plotted.
501+
task (str): The task to be plotted (either 'stab' or 'track').
502+
algo (str): The controller to be plotted.
503+
data_extractor (func): The function which extracts the desired data.
503504
'''
504505

505506
all_results = load_all_models(system, task, algo)
506507

507508
fig = plt.figure(figsize=(16.0, 10.0))
508509
ax = fig.add_subplot(111)
509510

510-
labels = sorted(os.listdir(f'./models/rl_models/{system}/{task}/{algo}/'))
511+
labels = ordered_models
511512

512513
data = []
513514

514-
for model in labels:
515+
for model in ordered_models:
515516
exp_data = all_results[model]
516517
data.append(data_extractor(exp_data))
517518

@@ -522,24 +523,71 @@ def plot_model_comparisons(system, task, algo, data_extractor):
522523
ax.set_xticks(x, labels, weight='bold', fontsize=15, rotation=30, ha='right')
523524

524525
medianprops = dict(linestyle='--', linewidth=2.5, color='black')
525-
bplot = ax.boxplot(data, patch_artist=True, labels=labels, medianprops=medianprops, widths=[0.75] * len(labels))
526-
527-
colors = {'mpsf_sr_pen_1': 'lightgreen', 'mpsf_sr_pen_10': 'limegreen', 'mpsf_sr_pen_100': 'forestgreen', 'mpsf_sr_pen_1000': 'darkgreen', 'none': 'cornflowerblue', 'none_cpen': 'plum'}
526+
bplot = ax.boxplot(data, patch_artist=True, labels=labels, medianprops=medianprops, widths=[0.75] * len(labels), showfliers=False)
528527

529528
for patch, color in zip(bplot['boxes'], colors.values()):
530529
patch.set_facecolor(color)
531530

532531
fig.tight_layout()
533532

534-
if data_extractor != extract_reward_cert:
535-
ax.set_ylim(ymin=0)
536533
ax.yaxis.grid(True)
537534

538535
if plot is True:
539536
plt.show()
540-
if save_figs:
537+
else:
541538
image_suffix = data_extractor.__name__.replace('extract_', '')
542-
fig.savefig(f'./results_mpsc/{system}/{task}/{algo}/graphs/{system}_{task}_{image_suffix}.png', dpi=300)
539+
fig.savefig(f'./results_mpsc/{image_suffix}.png', dpi=300)
540+
plt.close()
541+
542+
543+
def plot_step_time(system, task, algo):
544+
'''Plots the constraint violations of every controller for a specific experiment.
545+
546+
Args:
547+
system (str): The system to be plotted.
548+
task (str): The task to be plotted (either 'stab' or 'track').
549+
algo (str): The controller to be plotted.
550+
'''
551+
552+
all_results = {}
553+
for model in ordered_models:
554+
all_results[model] = []
555+
for seed in os.listdir(f'./models/rl_models/{system}/{task}/{algo}/{model}/'):
556+
all_results[model].append(load_from_logs(f'./models/rl_models/{system}/{task}/{algo}/{model}/{seed}/logs/'))
557+
558+
fig = plt.figure(figsize=(16.0, 10.0))
559+
ax = fig.add_subplot(111)
560+
561+
labels = ordered_models
562+
563+
data = []
564+
565+
for model in ordered_models:
566+
datum = np.array([values['stat/step_time'][3] for values in all_results[model]]).flatten()
567+
data.append(datum)
568+
569+
ylabel = 'Training Time per Step [ms]'
570+
ax.set_ylabel(ylabel, weight='bold', fontsize=45, labelpad=10)
571+
572+
x = np.arange(1, len(labels) + 1)
573+
ax.set_xticks(x, labels, weight='bold', fontsize=15, rotation=30, ha='right')
574+
575+
medianprops = dict(linestyle='--', linewidth=2.5, color='black')
576+
bplot = ax.boxplot(data, patch_artist=True, labels=labels, medianprops=medianprops, widths=[0.75] * len(labels), showfliers=False)
577+
578+
for patch, color in zip(bplot['boxes'], colors.values()):
579+
patch.set_facecolor(color)
580+
581+
fig.tight_layout()
582+
583+
ax.set_ylim(ymin=0)
584+
ax.yaxis.grid(True)
585+
586+
if plot is True:
587+
plt.show()
588+
else:
589+
image_suffix = 'step_time'
590+
fig.savefig(f'./results_mpsc/{image_suffix}.png', dpi=300)
543591
plt.close()
544592

545593

@@ -571,43 +619,40 @@ def plot_all_logs(system, task, algo):
571619
'''Plots comparative plots of all the logs.
572620
573621
Args:
574-
system (str): The system to be controlled.
575-
task (str): The task to be completed (either 'stab' or 'track').
576-
mpsc_cost_horizon (str): The cost horizon used by the smooth MPSC cost functions.
622+
system (str): The system to be plotted.
623+
task (str): The task to be plotted (either 'stab' or 'track').
624+
algo (str): The controller to be plotted.
577625
'''
578626
all_results = {}
579627

580-
for model in os.listdir(f'./models/rl_models/{system}/{task}/{algo}/'):
628+
for model in ordered_models:
581629
all_results[model] = []
582630
for seed in os.listdir(f'./models/rl_models/{system}/{task}/{algo}/{model}/'):
583631
all_results[model].append(load_from_logs(f'./models/rl_models/{system}/{task}/{algo}/{model}/{seed}/logs/'))
584632

585-
for key in all_results['none'][0].keys():
586-
plot_log(system, task, algo, key, all_results)
633+
for key in all_results[ordered_models[0]][0].keys():
634+
if key == 'stat_eval/ep_return':
635+
plot_log(key, all_results)
636+
if key == 'stat/constraint_violation':
637+
plot_log(key, all_results)
587638

588639

589-
def plot_log(system, task, algo, key, all_results):
640+
def plot_log(key, all_results):
590641
'''Plots a comparative plot of the log 'key'.
591642
592643
Args:
593-
system (str): The system to be controlled.
594-
task (str): The task to be completed (either 'stab' or 'track').
595-
mpsc_cost_horizon (str): The cost horizon used by the smooth MPSC cost functions.
596644
key (str): The name of the log to be plotted.
597645
all_results (dict): A dictionary of all the logged results for all models.
598646
'''
599647
fig = plt.figure(figsize=(16.0, 10.0))
600648
ax = fig.add_subplot(111)
601649

602-
labels = sorted(all_results.keys())
603-
labels = [label for label in labels if '_es' not in label]
650+
labels = ordered_models
604651

605-
colors = {'mpsf_sr_pen_1': 'lightgreen', 'mpsf_sr_pen_10': 'limegreen', 'mpsf_sr_pen_100': 'forestgreen', 'mpsf_sr_pen_1000': 'darkgreen', 'none': 'cornflowerblue', 'none_cpen': 'plum'}
606-
607-
for model in labels:
652+
for model, label in zip(ordered_models, labels):
608653
x = all_results[model][0][key][1] / 1000
609654
all_data = np.array([values[key][3] for values in all_results[model]])
610-
ax.plot(x, np.mean(all_data, axis=0), label=model, color=colors[model])
655+
ax.plot(x, np.mean(all_data, axis=0), label=label, color=colors[model])
611656
ax.fill_between(x, np.min(all_data, axis=0), np.max(all_data, axis=0), alpha=0.3, edgecolor=colors[model], facecolor=colors[model])
612657

613658
ax.set_ylabel(key, weight='bold', fontsize=45, labelpad=10)
@@ -619,14 +664,25 @@ def plot_log(system, task, algo, key, all_results):
619664

620665
if plot is True:
621666
plt.show()
622-
if save_figs:
667+
else:
623668
image_suffix = key.replace('/', '__')
624-
fig.savefig(f'./results_mpsc/{system}/{task}/{algo}/graphs/{system}_{task}_{image_suffix}.png', dpi=300)
669+
fig.savefig(f'./results_mpsc/{image_suffix}.png', dpi=300)
625670
plt.close()
626671

627672

628673
if __name__ == '__main__':
629-
ordered_costs = ['one_step', 'regularized', 'precomputed']
674+
ordered_models = ['none', 'none_cpen_0.01', 'none_cpen_0.1', 'none_cpen_1', 'mpsf_sr_pen_0.1', 'mpsf_sr_pen_1', 'mpsf_sr_pen_10', 'mpsf_sr_pen_100']
675+
676+
colors = {
677+
'none': 'cornflowerblue',
678+
'none_cpen_0.01': 'plum',
679+
'none_cpen_0.1': 'mediumorchid',
680+
'none_cpen_1': 'darkorchid',
681+
'mpsf_sr_pen_0.1': 'lightgreen',
682+
'mpsf_sr_pen_1': 'limegreen',
683+
'mpsf_sr_pen_10': 'forestgreen',
684+
'mpsf_sr_pen_100': 'darkgreen',
685+
}
630686

631687
def extract_rate_of_change_of_inputs(results_data, certified=True):
632688
return extract_rate_of_change(results_data, certified, order=1, mode='input')
@@ -682,6 +738,7 @@ def extract_length_uncert(results_data, certified=False):
682738
algo_name = sys.argv[3]
683739

684740
plot_all_logs(system_name, task_name, algo_name)
741+
plot_step_time(system_name, task_name, algo_name)
685742
plot_model_comparisons(system_name, task_name, algo_name, extract_magnitude_of_corrections)
686743
plot_model_comparisons(system_name, task_name, algo_name, extract_percent_magnitude_of_corrections)
687744
plot_model_comparisons(system_name, task_name, algo_name, extract_max_correction)

experiments/mpsc/train_all_models.sh

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,21 @@
22
for SYS in quadrotor_3D; do
33
for ALGO in ppo; do
44
for TASK in track; do
5-
for SEED in 42 62 821 99 4077; do # 1102 1014 14 960406 2031; do
6-
sbatch train_model.sbatch mpsf True True $SYS $TASK $ALGO False 1 $SEED #mpsf_sr_pen_1
7-
sbatch train_model.sbatch mpsf True True $SYS $TASK $ALGO False 10 $SEED #mpsf_sr_pen_10
8-
sbatch train_model.sbatch mpsf True True $SYS $TASK $ALGO False 100 $SEED #mpsf_sr_pen_100
9-
sbatch train_model.sbatch mpsf True True $SYS $TASK $ALGO False 1000 $SEED #mpsf_sr_pen_1000
10-
sbatch train_model.sbatch none False False $SYS $TASK $ALGO False False $SEED #none
11-
sbatch train_model.sbatch none False False $SYS $TASK $ALGO True False $SEED #none_cpen
5+
for SEED in 42 62 821 99 4077; do
6+
# MPSF Ablation
7+
./train_model.sbatch none False False $SYS $TASK $ALGO False False $SEED #none
8+
./train_model.sbatch none False True $SYS $TASK $ALGO False 1 $SEED #none_pen_1
9+
./train_model.sbatch none True False $SYS $TASK $ALGO False False $SEED #none_sr
10+
./train_model.sbatch none True True $SYS $TASK $ALGO False 1 $SEED #none_sr_pen_1
11+
./train_model.sbatch mpsf False False $SYS $TASK $ALGO False False $SEED #mpsf
12+
./train_model.sbatch mpsf False True $SYS $TASK $ALGO False 1 $SEED #mpsf_pen_1
13+
./train_model.sbatch mpsf True False $SYS $TASK $ALGO False False $SEED #mpsf_sr
14+
./train_model.sbatch mpsf True True $SYS $TASK $ALGO False 1 $SEED #mpsf_sr_pen_1
15+
16+
# Constr Pen
17+
./train_model.sbatch none False False $SYS $TASK $ALGO True 0.01 $SEED #none_cpen_0.01
18+
./train_model.sbatch none False False $SYS $TASK $ALGO True 0.1 $SEED #none_cpen_0.1
19+
./train_model.sbatch none False False $SYS $TASK $ALGO True 1 $SEED #none_cpen_1
1220
done
1321
done
1422
done

experiments/mpsc/train_model.sbatch

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,10 @@ fi
7070

7171
if [ "$8" = False ]; then
7272
SF_PEN_TAG=''
73+
CONSTR_PEN_VAL=0
7374
else
7475
SF_PEN_TAG="_$8"
76+
CONSTR_PEN_VAL=$8
7577
fi
7678

7779
if [ -z "$9" ]; then
@@ -103,6 +105,7 @@ python3 train_rl.py \
103105
--kv_overrides \
104106
task_config.init_state=None \
105107
task_config.use_constraint_penalty=${CONSTR_PEN} \
108+
task_config.constraint_penalty=${CONSTR_PEN_VAL} \
106109
sf_config.cost_function=${MPSC_COST} \
107110
sf_config.mpsc_cost_horizon=${MPSC_COST_HORIZON} \
108111
sf_config.decay_factor=${DECAY_FACTOR} \
@@ -116,4 +119,3 @@ python3 train_rl.py \
116119
sf_config.seed=${SEED} \
117120

118121
./mpsc_experiment.sh $TAG $SYS $TASK $ALGO $SEED
119-
# python plotting_results.py $SYS $TASK $ALGO

0 commit comments

Comments
 (0)