Skip to content

Commit 49c4ad5

Browse files
author
pizarrob
committed
Merging in CPO code
1 parent c21bf22 commit 49c4ad5

File tree

473 files changed

+22305
-14
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

473 files changed

+22305
-14
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ repos:
1515
- id: check-yaml
1616
- id: check-toml
1717
- id: check-added-large-files
18-
args: ['--maxkb=10000']
18+
args: ['--maxkb=15000']
1919
- id: check-docstring-first
2020
- id: check-executables-have-shebangs
2121
- id: check-shebang-scripts-are-executable
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
algo: cpo
2+
algo_config:
3+
# Model args
4+
hidden1: 64
5+
hidden2: 64
6+
7+
# Optim args
8+
discount_factor: 0.98
9+
v_lr: 2.0e-4
10+
cost_v_lr: 2.0e-4
11+
num_conjugate: 10
12+
line_decay: 0.8
13+
max_kl: 0.001
14+
damping_coeff: 0.01
15+
gae_coeff: 0.97
16+
cost_d: 0.0
17+
18+
# Runner args
19+
max_steps: 600
20+
num_epochs: 4000
21+
value_epochs: 100
22+
eval_batch_size: 20
23+
24+
# Misc
25+
log_interval: 40
26+
save_interval: 0
27+
num_checkpoints: 0
28+
eval_interval: 40
29+
eval_save_best: True
30+
tensorboard: False
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
algo: cpo
2+
algo_config:
3+
# Model args
4+
hidden1: 64
5+
hidden2: 64
6+
7+
# Optim args
8+
discount_factor: 0.98
9+
v_lr: 2.0e-4
10+
cost_v_lr: 2.0e-4
11+
num_conjugate: 10
12+
line_decay: 0.8
13+
max_kl: 0.001
14+
damping_coeff: 0.01
15+
gae_coeff: 0.97
16+
cost_d: 0.0
17+
18+
# Runner args
19+
max_steps: 1000
20+
num_epochs: 4000
21+
value_epochs: 100
22+
eval_batch_size: 20
23+
24+
# Misc
25+
log_interval: 40
26+
save_interval: 0
27+
num_checkpoints: 0
28+
eval_interval: 40
29+
eval_save_best: True
30+
tensorboard: False
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
algo: cpo
2+
algo_config:
3+
# Model args
4+
hidden1: 128
5+
hidden2: 128
6+
7+
# Optim args
8+
discount_factor: 0.98
9+
v_lr: 2.0e-4
10+
cost_v_lr: 2.0e-4
11+
num_conjugate: 10
12+
line_decay: 0.8
13+
max_kl: 0.001
14+
damping_coeff: 0.01
15+
gae_coeff: 0.97
16+
cost_d: 0.0
17+
18+
# Runner args
19+
max_steps: 1000
20+
num_epochs: 4000
21+
value_epochs: 150
22+
eval_batch_size: 20
23+
24+
# Misc
25+
log_interval: 40
26+
save_interval: 0
27+
num_checkpoints: 0
28+
eval_interval: 40
29+
eval_save_best: True
30+
tensorboard: False
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
algo: cpo
2+
algo_config:
3+
cost_d: 0.0
4+
cost_v_lr: 0.0002
5+
damping_coeff: 0.01
6+
discount_factor: 0.98
7+
eval_batch_size: 20
8+
eval_interval: 50
9+
eval_save_best: true
10+
filter_train_actions: false
11+
gae_coeff: 0.97
12+
hidden1: 32
13+
hidden2: 32
14+
line_decay: 0.8
15+
log_interval: 50
16+
max_kl: 0.001
17+
max_steps: 150
18+
num_checkpoints: 0
19+
num_conjugate: 10
20+
num_epochs: 2000
21+
penalize_sf_diff: ''
22+
pretrained: ./models/rl_models/cartpole/stab/cpo_pretrain/
23+
save_interval: 0
24+
tensorboard: false
25+
training: true
26+
use_safe_reset: ''
27+
v_lr: 0.0002
28+
value_epochs: 100
29+
device: cpu
30+
kv_overrides:
31+
- task_config.init_state=None
32+
- task_config.use_constraint_penalty=False
33+
- sf_config.cost_function=one_step_cost
34+
- sf_config.mpsc_cost_horizon=2
35+
- sf_config.decay_factor=0.85
36+
- sf_config.soften_constraints=True
37+
- algo_config.filter_train_actions=False
38+
- algo_config.use_safe_reset=
39+
- task_config.done_on_violation=
40+
- algo_config.penalize_sf_diff=
41+
- algo_config.pretrained=./models/rl_models/cartpole/stab/cpo_pretrain/
42+
output_dir: ./models/rl_models/cartpole/stab/cpo/none/
43+
overrides:
44+
- ./config_overrides/cartpole/cpo_cartpole.yaml
45+
- ./config_overrides/cartpole/cartpole_stab.yaml
46+
- ./config_overrides/cartpole/nl_mpsc_cartpole_linear.yaml
47+
restore: null
48+
safety_filter: nl_mpsc
49+
seed: 2
50+
sf_config:
51+
cost_function: one_step_cost
52+
decay_factor: 0.85
53+
horizon: 5
54+
integration_algo: LTI
55+
mpsc_cost_horizon: 2
56+
n_samples: 6000
57+
prior_info:
58+
prior_prop: null
59+
prior_prop_rand_info: null
60+
randomize_prior_prop: false
61+
q_lin:
62+
- 0.02
63+
- 0.001
64+
- 10
65+
- 0.5
66+
r_lin:
67+
- 0.1
68+
slack_cost: 200
69+
soften_constraints: true
70+
use_terminal_set: false
71+
warmstart: true
72+
tag: temp
73+
task: cartpole
74+
task_config:
75+
adversary_disturbance: null
76+
adversary_disturbance_offset: 0.0
77+
adversary_disturbance_scale: 0.01
78+
constraint_penalty: -1
79+
constraints:
80+
- constrained_variable: state
81+
constraint_form: default_constraint
82+
lower_bounds:
83+
- -2
84+
- -2
85+
- -0.16
86+
- -1
87+
upper_bounds:
88+
- 2
89+
- 2
90+
- 0.16
91+
- 1
92+
- constrained_variable: input
93+
constraint_form: default_constraint
94+
cost: rl_reward
95+
ctrl_freq: 15
96+
disturbances: null
97+
done_on_out_of_bound: true
98+
done_on_violation: ''
99+
episode_len_sec: 10
100+
gui: false
101+
inertial_prop:
102+
cart_mass: 1
103+
pole_length: 0.5
104+
pole_mass: 0.1
105+
inertial_prop_randomization_info: null
106+
info_in_reset: true
107+
init_state: null
108+
init_state_randomization_info:
109+
init_theta:
110+
distrib: uniform
111+
high: 0.16
112+
low: -0.16
113+
init_theta_dot:
114+
distrib: uniform
115+
high: 1
116+
low: -1
117+
init_x:
118+
distrib: uniform
119+
high: 2
120+
low: -2
121+
init_x_dot:
122+
distrib: uniform
123+
high: 2
124+
low: -2
125+
normalized_rl_action_space: true
126+
obs_goal_horizon: 0
127+
obs_wrap_angle: false
128+
physics: pyb
129+
pyb_freq: 750
130+
randomized_inertial_prop: false
131+
randomized_init: true
132+
rew_act_weight: 0.1
133+
rew_exponential: true
134+
rew_state_weight:
135+
- 1
136+
- 1
137+
- 1
138+
- 1
139+
seed: 4077
140+
task: stabilization
141+
task_info:
142+
stabilization_goal:
143+
- 0.7
144+
- 0
145+
stabilization_goal_tolerance: 0.0
146+
use_constraint_penalty: false
147+
verbose: false
148+
use_gpu: false
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
step,loss/approx_kl
2+
7500,0.0007637462695129216
3+
15000,0.0006288051372393966
4+
22500,0.000866930466145277
5+
30000,0.0008193494286388159
6+
37500,0.0009118398302234709
7+
45000,0.00093095499323681
8+
52500,0.0009214580641128123
9+
60000,0.0009172772406600416
10+
67500,0.0008124987361952662
11+
75000,0.0006498565780930221
12+
82500,0.0008506188751198351
13+
90000,0.0006444334285333753
14+
97500,0.0006475357222370803
15+
105000,0.0006404643645510077
16+
112500,0.0006560455076396465
17+
120000,0.0006578431348316371
18+
127500,0.0006441655568778515
19+
135000,0.0006810820777900517
20+
142500,0.0009484487818554044
21+
150000,0.0006787514430470765
22+
157500,1.9780182700102067e-11
23+
165000,0.0006446137558668852
24+
172500,0.0006754323840141296
25+
180000,0.0006670575239695609
26+
187500,0.0006698388606309891
27+
195000,0.0006462063174694777
28+
202500,0.0006958717713132501
29+
210000,0.0009495722479186952
30+
217500,0.000977855990640819
31+
225000,0.0009298958466388285
32+
232500,0.0006462940364144742
33+
240000,0.0009904190665110946
34+
247500,0.0005805980181321502
35+
255000,0.0009628583211451769
36+
262500,0.0006456329138018191
37+
270000,0.0006483818870037794
38+
277500,0.0009998353198170662
39+
285000,0.0009329646709375083
40+
292500,0.0009268599678762257
41+
300000,2.485004660813389e-13
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
step,loss/cost_surrogate
2+
7500,105.25373077392578
3+
15000,135.73704528808594
4+
22500,110.174072265625
5+
30000,41.73980712890625
6+
37500,7.966276168823242
7+
45000,1.0940775871276855
8+
52500,0.18334102630615234
9+
60000,0.2685624957084656
10+
67500,-0.04398208111524582
11+
75000,1.5512120723724365
12+
82500,0.2660902142524719
13+
90000,-0.021695159375667572
14+
97500,-0.006272792816162109
15+
105000,0.003720715641975403
16+
112500,0.035086289048194885
17+
120000,0.3914239704608917
18+
127500,-0.03873072564601898
19+
135000,-0.028513872995972633
20+
142500,-0.017245113849639893
21+
150000,-0.005148999392986298
22+
157500,0.00627563800662756
23+
165000,0.049609869718551636
24+
172500,0.061136096715927124
25+
180000,0.0738421231508255
26+
187500,-0.0001361016184091568
27+
195000,-0.014164653606712818
28+
202500,0.05187857151031494
29+
210000,0.1658032238483429
30+
217500,0.3076779544353485
31+
225000,0.0920017659664154
32+
232500,4.87026834487915
33+
240000,0.42213770747184753
34+
247500,3.4023852348327637
35+
255000,0.12418961524963379
36+
262500,-0.14210250973701477
37+
270000,0.13648156821727753
38+
277500,0.030916675925254822
39+
285000,-0.0545722097158432
40+
292500,-0.008491892367601395
41+
300000,-1.191438059322536e-07
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
step,loss/cost_value_loss
2+
7500,38.56172180175781
3+
15000,26.3615665435791
4+
22500,12.278301239013672
5+
30000,16.15060806274414
6+
37500,1.9610624313354492
7+
45000,0.07205027341842651
8+
52500,0.2520654797554016
9+
60000,0.0052938312292099
10+
67500,0.007631403394043446
11+
75000,0.007784270215779543
12+
82500,0.0006668045534752309
13+
90000,0.008956549689173698
14+
97500,0.12046217173337936
15+
105000,0.006401794496923685
16+
112500,0.0006805356242693961
17+
120000,0.014772334136068821
18+
127500,0.001862881937995553
19+
135000,0.00014308057143352926
20+
142500,0.00017357783508487046
21+
150000,0.003835005685687065
22+
157500,0.0040668887086212635
23+
165000,0.004117249511182308
24+
172500,0.004651198163628578
25+
180000,0.02515612728893757
26+
187500,0.0033181600738316774
27+
195000,0.0021050202194601297
28+
202500,0.004466979298740625
29+
210000,0.004811915569007397
30+
217500,0.0031447347719222307
31+
225000,0.00039776338962838054
32+
232500,0.234852135181427
33+
240000,0.012947328388690948
34+
247500,0.01838028058409691
35+
255000,0.003159591229632497
36+
262500,0.0009796313242986798
37+
270000,0.0005409414879977703
38+
277500,0.0007045501261018217
39+
285000,0.0032590143382549286
40+
292500,0.00015349579916801304
41+
300000,0.00018221144273411483

0 commit comments

Comments
 (0)