Skip to content

Commit 7cf0d2e

Browse files
author
pizarrob
committed
Small clean-up and fixes
1 parent 49c4ad5 commit 7cf0d2e

File tree

3 files changed

+6
-22
lines changed

3 files changed

+6
-22
lines changed

safe_control_gym/controllers/ppo/ppo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,12 @@ def run(self,
240240
physical_action = env.denormalize_action(action)
241241
unextended_obs = np.squeeze(true_obs)[:env.symbolic.nx]
242242
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
243-
if success and self.filter_train_actions is True:
243+
if success:
244244
action = env.normalize_action(certified_action)
245245
else:
246246
self.safety_filter.ocp_solver.reset()
247247
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
248-
if success and self.filter_train_actions is True:
248+
if success:
249249
action = self.env.envs[0].normalize_action(certified_action)
250250

251251
action = np.atleast_2d(np.squeeze([action]))

safe_control_gym/controllers/sac/sac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,12 +251,12 @@ def run(self, env=None, render=False, n_episodes=10, verbose=False, **kwargs):
251251
physical_action = env.denormalize_action(action)
252252
unextended_obs = np.squeeze(true_obs)[:env.symbolic.nx]
253253
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
254-
if success and self.filter_train_actions is True:
254+
if success:
255255
applied_action = env.normalize_action(certified_action)
256256
else:
257257
self.safety_filter.ocp_solver.reset()
258258
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
259-
if success and self.filter_train_actions is True:
259+
if success:
260260
applied_action = self.env.envs[0].normalize_action(certified_action)
261261

262262
action = np.atleast_2d(np.squeeze([applied_action]))

safe_control_gym/safety_filters/mpsc/mpsc_cost_function/precomputed_cost.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def __init__(self,
3030

3131
self.output_dir = output_dir
3232
self.uncertified_controller = None
33-
self.skip_checks = False
3433

3534
def get_cost(self, opti_dict):
3635
'''Returns the cost function for the MPSC optimization in symbolic form.
@@ -106,10 +105,7 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
106105
# Concatenate goal info (goal state(s)) for RL
107106
extended_obs = self.env.extend_obs(obs, next_step + 1)
108107

109-
info = {
110-
'current_step': next_step,
111-
'constraint_values': np.concatenate([self.get_constraint_value(con, obs) for con in self.env.constraints.state_constraints])
112-
}
108+
info = {'current_step': next_step}
113109

114110
action = self.uncertified_controller.select_action(obs=extended_obs, info=info)
115111

@@ -121,7 +117,7 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
121117

122118
action = np.clip(action, self.env.physical_action_bounds[0], self.env.physical_action_bounds[1])
123119

124-
if h == 0 and np.linalg.norm(uncertified_action - action) >= 0.001 and not self.skip_checks:
120+
if h == 0 and np.linalg.norm(uncertified_action - action) >= 0.001:
125121
raise ValueError(f'[ERROR] Mismatch between unsafe controller and MPSC guess. Uncert: {uncertified_action}, Guess: {action}, Diff: {np.linalg.norm(uncertified_action - action)}.')
126122

127123
v_L[:, h:h + 1] = action.reshape((self.model.nu, 1))
@@ -133,15 +129,3 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
133129
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev.npy')
134130

135131
return v_L
136-
137-
def get_constraint_value(self, con, state):
138-
'''Gets the value of a constraint given the state.
139-
140-
Args:
141-
con (Constraint): The constraint.
142-
state (ndarray): The state to be tested.
143-
144-
Returns:
145-
value (float): The value of the constraint at the given state.
146-
'''
147-
return np.round(np.atleast_1d(np.squeeze(con.sym_func(np.array(state, ndmin=1)))), decimals=con.decimals)

0 commit comments

Comments
 (0)