Skip to content

Commit 8356814

Browse files
committed
added unpickling code; retiring this algorithm
1 parent be7cbc3 commit 8356814

File tree

10 files changed

+51
-272
lines changed

10 files changed

+51
-272
lines changed

experiments/mjc_mdgps_antagonist_y1.5/costs.txt

Lines changed: 0 additions & 12 deletions
This file was deleted.

experiments/mjc_mdgps_antagonist_y1.5/hyperparams.py

Lines changed: 0 additions & 176 deletions
This file was deleted.

experiments/mjc_mdgps_antagonist_y1.5/off_policy/hyperparams.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

experiments/mjc_mdgps_antagonist_y1.5/on_policy/hyperparams.py

Lines changed: 0 additions & 35 deletions
This file was deleted.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[333.14990159275749, 105.55982639501148, 270.04224750159727, 490.9910864456142]
2+
[323.13856062991363, 139.47001812421871, 198.60506190905406, 420.8664796818822]
3+
[387.33118832232225, 148.61315571763311, 349.95029068871497, 503.77058506373817]
4+
[200.7535898640904, 98.895497645003914, 180.72339157065679, 433.19929197006229]
5+
[173.00828416462969, 238.81158756623827, 331.51755687009233, 469.91999711398682]
6+
[394.44788274020698, 29.335397251512735, 265.38325900729978, 493.7384856762643]
7+
[241.98367493269262, 78.317647319468051, 232.52723852392091, 553.55410262641976]
8+
[280.90363953187068, 99.389410014524998, 263.64504145422512, 356.87591690251475]
9+
[273.02722779560753, 148.72838490336156, 166.72022888418604, 589.30149040488106]
10+
[201.66028565401888, 86.422899322473071, 257.56239471809255, 500.76789954972594]
11+
[185.64478659031516, 118.41621005433652, 195.84033539927924, 462.28908174901045]
12+
[306.38877164277369, 177.21711302114193, 270.23949188389054, 446.50687476985456]
13+
[306.38877164277369, 177.21711302114193, 270.23949188389054, 446.50687476985456]

experiments/mjc_mdgps_idg_y1.5/hyperparams.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
EXP_DIR = BASE_DIR + '/../experiments/mjc_mdgps_idg_y1.5/'
4141

4242
common = {
43-
'experiment_name': 'idg_y4' + '_' + \
43+
'experiment_name': 'idg_1.5' + '_' + \
4444
datetime.strftime(datetime.now(), '%m-%d-%y_%H-%M'),
4545
'experiment_dir': EXP_DIR,
4646
'data_files_dir': EXP_DIR + 'data_files/',
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
-59.24
2+
-323.11
3+
-409.02
4+
-438.53
5+
-441.80
6+
-454.83
7+
-463.19
8+
-464.45
9+
-462.74
10+
-450.40
11+
-461.33
12+
-460.38
13+
-461.33
14+
-460.38
15+
-461.33
16+
-460.38
17+
-461.33
18+
-460.38
19+
-461.33
20+
-460.38
21+
-461.33

python/gps/algorithm/policy_opt/policy_opt_caffe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,6 @@ def prob(self, obs):
329329

330330
return output, pol_sigma, pol_prec, pol_det_sigma
331331

332-
333332
def prob_v(self, obs):
334333
"""
335334
Run policy forward.
@@ -378,6 +377,7 @@ def __getstate__(self):
378377
'hyperparams': self._hyperparams,
379378
'dO': self._dO,
380379
'dU': self._dU,
380+
'dV': self._dV,
381381
'scale': self.policy.scale,
382382
'bias': self.policy.bias,
383383
'caffe_iter': self.caffe_iter,
@@ -386,7 +386,7 @@ def __getstate__(self):
386386

387387
# For unpickling.
388388
def __setstate__(self, state):
389-
self.__init__(state['hyperparams'], state['dO'], state['dU'])
389+
self.__init__(state['hyperparams'], state['dO'], state['dU'], state['dV'])
390390
self.policy.scale = state['scale']
391391
self.policy.bias = state['bias']
392392
self.caffe_iter = state['caffe_iter']

python/gps/algorithm/traj_opt/traj_opt_lqr_python.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -665,25 +665,23 @@ def forward_robust(self, traj_distr, traj_info):
665665
sigma_u[t, idx_x, idx_x],
666666
sigma_u[t, idx_x, idx_x].dot(traj_distr.Gu[t, :, :].T),
667667
#pad with v terms
668-
sigma_v[t, idx_x, idx_x].dot(traj_distr.Gv[t, :, :].T),
668+
np.zeros_like(sigma_v[t, idx_x, idx_x].dot(traj_distr.Gv[t, :, :].T))
669669
]),
670670
np.hstack([
671671
traj_distr.Gu[t, :, :].dot(sigma_u[t, idx_x, idx_x]),
672672
traj_distr.Gu[t, :, :].dot(sigma_u[t, idx_x, idx_x]).dot(
673673
traj_distr.Gu[t, :, :].T
674674
) + traj_distr.pol_covar_u[t, :, :],
675675
# pad with adversarial terms
676-
traj_distr.Gv[t, :, :].dot(sigma_v[t, idx_x, idx_x]).dot(
676+
np.zeros_like(traj_distr.Gv[t, :, :].dot(sigma_v[t, idx_x, idx_x]).dot(
677677
traj_distr.Gv[t, :, :].T
678-
) + traj_distr.pol_covar_v[t, :, :]
678+
) + traj_distr.pol_covar_v[t, :, :])
679679
]),
680680
# pad dU terms with zero
681-
# np.zeros([dU, dX+dU+dV])
682681
np.hstack([
683-
traj_distr.Gv[t, :, :].dot(sigma_v[t, idx_x, idx_x]),
684-
traj_distr.Gv[t, :, :].dot(sigma_v[t, idx_x, idx_x]).dot(
685-
traj_distr.Gv[t, :, :].T
686-
) + traj_distr.pol_covar_v[t, :, :],
682+
np.zeros_like(traj_distr.Gv[t, :, :].dot(sigma_v[t, idx_x, idx_x])),
683+
np.zeros_like(traj_distr.Gv[t, :, :].dot(sigma_v[t, idx_x, idx_x]).dot(
684+
traj_distr.Gv[t, :, :].T) + traj_distr.pol_covar_v[t, :, :]),
687685
# pad with control terms
688686
traj_distr.Gu[t, :, :].dot(sigma_u[t, idx_x, idx_x]).dot(
689687
traj_distr.Gu[t, :, :].T
@@ -693,7 +691,8 @@ def forward_robust(self, traj_distr, traj_info):
693691
mu_u[t, :] = np.hstack([
694692
mu_u[t, idx_x],
695693
traj_distr.Gu[t, :, :].dot(mu_u[t, idx_x]) + traj_distr.gu[t, :],
696-
traj_distr.Gv[t, :, :].dot(mu_v[t, idx_x]) + traj_distr.gv[t, :]
694+
np.zeros_like(traj_distr.Gv[t, :, :].dot(mu_v[t, idx_x]) + \
695+
traj_distr.gv[t, :])
697696
])
698697

699698
if t < T - 1:

python/gps/gps_main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,12 @@ def test_policy(self, itr, N):
206206
)
207207

208208
if self.gui:
209-
self.gui.update(itr, self.algorithm, self.agent,
210-
traj_sample_lists, pol_sample_lists, protag_pol_samples=self.protag_pol_samples)
209+
if self.robust:
210+
self.gui.update(itr, self.algorithm, self.agent,
211+
traj_sample_lists, pol_sample_lists, protag_pol_samples=None)
212+
else:
213+
self.gui.update(itr, self.algorithm, self.agent,
214+
traj_sample_lists, pol_sample_lists, protag_pol_samples=self.protag_pol_samples)
211215
self.gui.set_status_text(('Took %d policy sample(s) from ' +
212216
'algorithm state at iteration %d.\n' +
213217
'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') % (N, itr, itr))

0 commit comments

Comments
 (0)