@@ -64,8 +64,9 @@ def __init__(self, hyperparams):
64
64
self ._hyperparams ['init_traj_distr' ], self ._cond_idx [m ] #L84 hyperparams
65
65
)
66
66
# note that both prot and adv act in turns on adversary
67
- self .cur [m ].traj_distr = init_traj_distr ['type' ](init_traj_distr ) #will be init_lqr
68
- self .cur [m ].traj_distr_adv = init_traj_distr ['type' ](init_traj_distr ) #adv traj dist
67
+ self .cur [m ].traj_distr = init_traj_distr ['type' ](init_traj_distr ) #will be init_lqr / init_lqr_robust
68
+ self .cur [m ].traj_distr_adv = init_traj_distr ['type' ](init_traj_distr ) #adv traj dist
69
+ self .cur [m ].traj_distr_robust = init_traj_distr ['type' ](init_traj_distr ) #robust traj dist
69
70
70
71
#init_lqr is defined in algorithm/policy/lin_gauss_init
71
72
self .traj_opt = hyperparams ['traj_opt' ]['type' ](
@@ -98,7 +99,7 @@ def iteration_idg(self, sample_lists_prot, sample_list):
98
99
""" Run iteration of the algorithm. """
99
100
raise NotImplementedError ("Must be implemented in subclass" )
100
101
101
-
102
+
102
103
def _update_dynamics (self ):
103
104
"""
104
105
Instantiate dynamics objects and update prior. Fit dynamics to
@@ -176,6 +177,42 @@ def _update_trajectories(self):
176
177
self .new_traj_distr [cond ], self .cur [cond ].eta = \
177
178
self .traj_opt .update (cond , self )
178
179
180
+ def _update_trajectories_robust (self ):
181
+ """
182
+ Compute new linear Gaussian controllers.
183
+ """
184
+ if not hasattr (self , 'new_traj_distr' ):
185
+ self .new_traj_distr = [
186
+ self .cur [cond ].traj_distr for cond in range (self .M )
187
+ ]
188
+
189
+ if not hasattr (self , 'new_traj_distr_adv' ):
190
+ self .new_traj_distr_adv = [
191
+ self .cur [cond ].traj_distr_adv for cond in range (self .M )
192
+ ]
193
+
194
+ if not hasattr (self , 'new_traj_distr_robust' ):
195
+ self .new_traj_distr_robust = [
196
+ self .cur [cond ].traj_distr_robust for cond in range (self .M )
197
+ ]
198
+
199
+ for cond in range (self .M ):
200
+ LOGGER .debug ("updating protagonist trajectory" )
201
+ self .new_traj_distr [cond ], self .cur [cond ].eta = \
202
+ self .traj_opt .update_protagonist (cond , self )
203
+
204
+ LOGGER .debug ("updating adversary trajectory" )
205
+ self .new_traj_distr_adv [cond ], self .cur [cond ].eta_adv = \
206
+ self .traj_opt .update_adversary (cond , self )
207
+
208
+ LOGGER .debug ("Computing conditional of protagonist on adversary" )
209
+ self .new_traj_distr_robust [cond ], self .cur [cond ].eta = \
210
+ self .traj_opt .update_robust (cond , self , \
211
+ self .new_traj_distr [cond ], \
212
+ self .new_traj_distr_adv [cond ], \
213
+ self .cur [cond ].eta , \
214
+ self .cur [cond ].eta_adv )
215
+
179
216
def _eval_cost (self , cond ):
180
217
"""
181
218
Evaluate costs for all samples for a condition.
0 commit comments