Skip to content

Commit

Permalink
Merging in CPO code
Browse files Browse the repository at this point in the history
  • Loading branch information
pizarrob committed Nov 10, 2023
1 parent c21bf22 commit 49c4ad5
Show file tree
Hide file tree
Showing 473 changed files with 22,305 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
- id: check-yaml
- id: check-toml
- id: check-added-large-files
args: ['--maxkb=10000']
args: ['--maxkb=15000']
- id: check-docstring-first
- id: check-executables-have-shebangs
- id: check-shebang-scripts-are-executable
Expand Down
30 changes: 30 additions & 0 deletions experiments/mpsc/config_overrides/cartpole/cpo_cartpole.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
algo: cpo
algo_config:
# Model args
hidden1: 64
hidden2: 64

# Optim args
discount_factor: 0.98
v_lr: 2.0e-4
cost_v_lr: 2.0e-4
num_conjugate: 10
line_decay: 0.8
max_kl: 0.001
damping_coeff: 0.01
gae_coeff: 0.97
cost_d: 0.0

# Runner args
max_steps: 600
num_epochs: 4000
value_epochs: 100
eval_batch_size: 20

# Misc
log_interval: 40
save_interval: 0
num_checkpoints: 0
eval_interval: 40
eval_save_best: True
tensorboard: False
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
algo: cpo
algo_config:
# Model args
hidden1: 64
hidden2: 64

# Optim args
discount_factor: 0.98
v_lr: 2.0e-4
cost_v_lr: 2.0e-4
num_conjugate: 10
line_decay: 0.8
max_kl: 0.001
damping_coeff: 0.01
gae_coeff: 0.97
cost_d: 0.0

# Runner args
max_steps: 1000
num_epochs: 4000
value_epochs: 100
eval_batch_size: 20

# Misc
log_interval: 40
save_interval: 0
num_checkpoints: 0
eval_interval: 40
eval_save_best: True
tensorboard: False
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
algo: cpo
algo_config:
# Model args
hidden1: 128
hidden2: 128

# Optim args
discount_factor: 0.98
v_lr: 2.0e-4
cost_v_lr: 2.0e-4
num_conjugate: 10
line_decay: 0.8
max_kl: 0.001
damping_coeff: 0.01
gae_coeff: 0.97
cost_d: 0.0

# Runner args
max_steps: 1000
num_epochs: 4000
value_epochs: 150
eval_batch_size: 20

# Misc
log_interval: 40
save_interval: 0
num_checkpoints: 0
eval_interval: 40
eval_save_best: True
tensorboard: False
148 changes: 148 additions & 0 deletions experiments/mpsc/models/rl_models/cartpole/stab/cpo/none/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
algo: cpo
algo_config:
cost_d: 0.0
cost_v_lr: 0.0002
damping_coeff: 0.01
discount_factor: 0.98
eval_batch_size: 20
eval_interval: 50
eval_save_best: true
filter_train_actions: false
gae_coeff: 0.97
hidden1: 32
hidden2: 32
line_decay: 0.8
log_interval: 50
max_kl: 0.001
max_steps: 150
num_checkpoints: 0
num_conjugate: 10
num_epochs: 2000
penalize_sf_diff: ''
pretrained: ./models/rl_models/cartpole/stab/cpo_pretrain/
save_interval: 0
tensorboard: false
training: true
use_safe_reset: ''
v_lr: 0.0002
value_epochs: 100
device: cpu
kv_overrides:
- task_config.init_state=None
- task_config.use_constraint_penalty=False
- sf_config.cost_function=one_step_cost
- sf_config.mpsc_cost_horizon=2
- sf_config.decay_factor=0.85
- sf_config.soften_constraints=True
- algo_config.filter_train_actions=False
- algo_config.use_safe_reset=
- task_config.done_on_violation=
- algo_config.penalize_sf_diff=
- algo_config.pretrained=./models/rl_models/cartpole/stab/cpo_pretrain/
output_dir: ./models/rl_models/cartpole/stab/cpo/none/
overrides:
- ./config_overrides/cartpole/cpo_cartpole.yaml
- ./config_overrides/cartpole/cartpole_stab.yaml
- ./config_overrides/cartpole/nl_mpsc_cartpole_linear.yaml
restore: null
safety_filter: nl_mpsc
seed: 2
sf_config:
cost_function: one_step_cost
decay_factor: 0.85
horizon: 5
integration_algo: LTI
mpsc_cost_horizon: 2
n_samples: 6000
prior_info:
prior_prop: null
prior_prop_rand_info: null
randomize_prior_prop: false
q_lin:
- 0.02
- 0.001
- 10
- 0.5
r_lin:
- 0.1
slack_cost: 200
soften_constraints: true
use_terminal_set: false
warmstart: true
tag: temp
task: cartpole
task_config:
adversary_disturbance: null
adversary_disturbance_offset: 0.0
adversary_disturbance_scale: 0.01
constraint_penalty: -1
constraints:
- constrained_variable: state
constraint_form: default_constraint
lower_bounds:
- -2
- -2
- -0.16
- -1
upper_bounds:
- 2
- 2
- 0.16
- 1
- constrained_variable: input
constraint_form: default_constraint
cost: rl_reward
ctrl_freq: 15
disturbances: null
done_on_out_of_bound: true
done_on_violation: ''
episode_len_sec: 10
gui: false
inertial_prop:
cart_mass: 1
pole_length: 0.5
pole_mass: 0.1
inertial_prop_randomization_info: null
info_in_reset: true
init_state: null
init_state_randomization_info:
init_theta:
distrib: uniform
high: 0.16
low: -0.16
init_theta_dot:
distrib: uniform
high: 1
low: -1
init_x:
distrib: uniform
high: 2
low: -2
init_x_dot:
distrib: uniform
high: 2
low: -2
normalized_rl_action_space: true
obs_goal_horizon: 0
obs_wrap_angle: false
physics: pyb
pyb_freq: 750
randomized_inertial_prop: false
randomized_init: true
rew_act_weight: 0.1
rew_exponential: true
rew_state_weight:
- 1
- 1
- 1
- 1
seed: 4077
task: stabilization
task_info:
stabilization_goal:
- 0.7
- 0
stabilization_goal_tolerance: 0.0
use_constraint_penalty: false
verbose: false
use_gpu: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
step,loss/approx_kl
7500,0.0007637462695129216
15000,0.0006288051372393966
22500,0.000866930466145277
30000,0.0008193494286388159
37500,0.0009118398302234709
45000,0.00093095499323681
52500,0.0009214580641128123
60000,0.0009172772406600416
67500,0.0008124987361952662
75000,0.0006498565780930221
82500,0.0008506188751198351
90000,0.0006444334285333753
97500,0.0006475357222370803
105000,0.0006404643645510077
112500,0.0006560455076396465
120000,0.0006578431348316371
127500,0.0006441655568778515
135000,0.0006810820777900517
142500,0.0009484487818554044
150000,0.0006787514430470765
157500,1.9780182700102067e-11
165000,0.0006446137558668852
172500,0.0006754323840141296
180000,0.0006670575239695609
187500,0.0006698388606309891
195000,0.0006462063174694777
202500,0.0006958717713132501
210000,0.0009495722479186952
217500,0.000977855990640819
225000,0.0009298958466388285
232500,0.0006462940364144742
240000,0.0009904190665110946
247500,0.0005805980181321502
255000,0.0009628583211451769
262500,0.0006456329138018191
270000,0.0006483818870037794
277500,0.0009998353198170662
285000,0.0009329646709375083
292500,0.0009268599678762257
300000,2.485004660813389e-13
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
step,loss/cost_surrogate
7500,105.25373077392578
15000,135.73704528808594
22500,110.174072265625
30000,41.73980712890625
37500,7.966276168823242
45000,1.0940775871276855
52500,0.18334102630615234
60000,0.2685624957084656
67500,-0.04398208111524582
75000,1.5512120723724365
82500,0.2660902142524719
90000,-0.021695159375667572
97500,-0.006272792816162109
105000,0.003720715641975403
112500,0.035086289048194885
120000,0.3914239704608917
127500,-0.03873072564601898
135000,-0.028513872995972633
142500,-0.017245113849639893
150000,-0.005148999392986298
157500,0.00627563800662756
165000,0.049609869718551636
172500,0.061136096715927124
180000,0.0738421231508255
187500,-0.0001361016184091568
195000,-0.014164653606712818
202500,0.05187857151031494
210000,0.1658032238483429
217500,0.3076779544353485
225000,0.0920017659664154
232500,4.87026834487915
240000,0.42213770747184753
247500,3.4023852348327637
255000,0.12418961524963379
262500,-0.14210250973701477
270000,0.13648156821727753
277500,0.030916675925254822
285000,-0.0545722097158432
292500,-0.008491892367601395
300000,-1.191438059322536e-07
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
step,loss/cost_value_loss
7500,38.56172180175781
15000,26.3615665435791
22500,12.278301239013672
30000,16.15060806274414
37500,1.9610624313354492
45000,0.07205027341842651
52500,0.2520654797554016
60000,0.0052938312292099
67500,0.007631403394043446
75000,0.007784270215779543
82500,0.0006668045534752309
90000,0.008956549689173698
97500,0.12046217173337936
105000,0.006401794496923685
112500,0.0006805356242693961
120000,0.014772334136068821
127500,0.001862881937995553
135000,0.00014308057143352926
142500,0.00017357783508487046
150000,0.003835005685687065
157500,0.0040668887086212635
165000,0.004117249511182308
172500,0.004651198163628578
180000,0.02515612728893757
187500,0.0033181600738316774
195000,0.0021050202194601297
202500,0.004466979298740625
210000,0.004811915569007397
217500,0.0031447347719222307
225000,0.00039776338962838054
232500,0.234852135181427
240000,0.012947328388690948
247500,0.01838028058409691
255000,0.003159591229632497
262500,0.0009796313242986798
270000,0.0005409414879977703
277500,0.0007045501261018217
285000,0.0032590143382549286
292500,0.00015349579916801304
300000,0.00018221144273411483
Loading

0 comments on commit 49c4ad5

Please sign in to comment.