Skip to content

Commit 2fffde2

Browse files
committed
fixed update trajectories in c-step
1 parent b9c6cd7 commit 2fffde2

File tree

12 files changed

+2056
-122
lines changed

12 files changed

+2056
-122
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ algorithm['traj_opt'] = {
109109
}
110110
```
111111

112+
* Add the following to `algorithm['policy_opt']` to account for the robust policy
113+
114+
```python
115+
algorithm['policy_opt'] = {
116+
'robust_weights_file_prefix': EXP_DIR + 'robust_policy',
117+
}
118+
```
112119

113120
### Docker Image
114121

experiments/mjc_mdgps_protagonist/hyperparams.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@
152152
'type': PolicyOptCaffe,
153153
'iterations': 4000,
154154
'weights_file_prefix': EXP_DIR + 'policy',
155+
'robust_weights_file_prefix': EXP_DIR + 'robust_policy',
155156
}
156157

157158
algorithm['policy_prior'] = {

python/gps/algorithm/algorithm.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ def __init__(self, hyperparams):
6464
self._hyperparams['init_traj_distr'], self._cond_idx[m] #L84 hyperparams
6565
)
6666
# note that both prot and adv act in turns on adversary
67-
self.cur[m].traj_distr = init_traj_distr['type'](init_traj_distr) #will be init_lqr
68-
self.cur[m].traj_distr_adv = init_traj_distr['type'](init_traj_distr) #adv traj dist
67+
self.cur[m].traj_distr = init_traj_distr['type'](init_traj_distr) #will be init_lqr / init_lqr_robust
68+
self.cur[m].traj_distr_adv = init_traj_distr['type'](init_traj_distr) #adv traj dist
69+
self.cur[m].traj_distr_robust = init_traj_distr['type'](init_traj_distr) #robust traj dist
6970

7071
#init_lqr is defined in algorithm/policy/lin_gauss_init
7172
self.traj_opt = hyperparams['traj_opt']['type'](
@@ -98,7 +99,7 @@ def iteration_idg(self, sample_lists_prot, sample_list):
9899
""" Run iteration of the algorithm. """
99100
raise NotImplementedError("Must be implemented in subclass")
100101

101-
102+
102103
def _update_dynamics(self):
103104
"""
104105
Instantiate dynamics objects and update prior. Fit dynamics to
@@ -176,6 +177,42 @@ def _update_trajectories(self):
176177
self.new_traj_distr[cond], self.cur[cond].eta = \
177178
self.traj_opt.update(cond, self)
178179

180+
def _update_trajectories_robust(self):
181+
"""
182+
Compute new linear Gaussian controllers.
183+
"""
184+
if not hasattr(self, 'new_traj_distr'):
185+
self.new_traj_distr = [
186+
self.cur[cond].traj_distr for cond in range(self.M)
187+
]
188+
189+
if not hasattr(self, 'new_traj_distr_adv'):
190+
self.new_traj_distr_adv = [
191+
self.cur[cond].traj_distr_adv for cond in range(self.M)
192+
]
193+
194+
if not hasattr(self, 'new_traj_distr_robust'):
195+
self.new_traj_distr_robust = [
196+
self.cur[cond].traj_distr_robust for cond in range(self.M)
197+
]
198+
199+
for cond in range(self.M):
200+
LOGGER.debug("updating protagonist trajectory")
201+
self.new_traj_distr[cond], self.cur[cond].eta = \
202+
self.traj_opt.update_protagonist(cond, self)
203+
204+
LOGGER.debug("updating adversary trajectory")
205+
self.new_traj_distr_adv[cond], self.cur[cond].eta_adv = \
206+
self.traj_opt.update_adversary(cond, self)
207+
208+
LOGGER.debug("Computing conditional of protagonist on adversary")
209+
self.new_traj_distr_robust[cond], self.cur[cond].eta = \
210+
self.traj_opt.update_robust(cond, self, \
211+
self.new_traj_distr[cond], \
212+
self.new_traj_distr_adv[cond], \
213+
self.cur[cond].eta, \
214+
self.cur[cond].eta_adv)
215+
179216
def _eval_cost(self, cond):
180217
"""
181218
Evaluate costs for all samples for a condition.

0 commit comments

Comments
 (0)