@@ -30,7 +30,6 @@ def __init__(self,
30
30
31
31
self .output_dir = output_dir
32
32
self .uncertified_controller = None
33
- self .skip_checks = False
34
33
35
34
def get_cost (self , opti_dict ):
36
35
'''Returns the cost function for the MPSC optimization in symbolic form.
@@ -106,10 +105,7 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
106
105
# Concatenate goal info (goal state(s)) for RL
107
106
extended_obs = self .env .extend_obs (obs , next_step + 1 )
108
107
109
- info = {
110
- 'current_step' : next_step ,
111
- 'constraint_values' : np .concatenate ([self .get_constraint_value (con , obs ) for con in self .env .constraints .state_constraints ])
112
- }
108
+ info = {'current_step' : next_step }
113
109
114
110
action = self .uncertified_controller .select_action (obs = extended_obs , info = info )
115
111
@@ -121,7 +117,7 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
121
117
122
118
action = np .clip (action , self .env .physical_action_bounds [0 ], self .env .physical_action_bounds [1 ])
123
119
124
- if h == 0 and np .linalg .norm (uncertified_action - action ) >= 0.001 and not self . skip_checks :
120
+ if h == 0 and np .linalg .norm (uncertified_action - action ) >= 0.001 :
125
121
raise ValueError (f'[ERROR] Mismatch between unsafe controller and MPSC guess. Uncert: { uncertified_action } , Guess: { action } , Diff: { np .linalg .norm (uncertified_action - action )} .' )
126
122
127
123
v_L [:, h :h + 1 ] = action .reshape ((self .model .nu , 1 ))
@@ -133,15 +129,3 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
133
129
self .uncertified_controller .save (f'{ self .output_dir } /temp-data/saved_controller_prev.npy' )
134
130
135
131
return v_L
136
-
137
- def get_constraint_value (self , con , state ):
138
- '''Gets the value of a constraint given the state.
139
-
140
- Args:
141
- con (Constraint): The constraint.
142
- state (ndarray): The state to be tested.
143
-
144
- Returns:
145
- value (float): The value of the constraint at the given state.
146
- '''
147
- return np .round (np .atleast_1d (np .squeeze (con .sym_func (np .array (state , ndmin = 1 )))), decimals = con .decimals )
0 commit comments