Skip to content

Commit

Permalink
fix redundant code in evaluateRobustness
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-haoze committed Aug 22, 2024
1 parent f88d30e commit 85c963c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 21 deletions.
28 changes: 8 additions & 20 deletions maraboupy/MarabouNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,29 +132,17 @@ def evaluateLocalRobustness(self, input, epsilon, originalClass, verbose=True, o
- stats (:class:`~maraboupy.MarabouCore.Statistics`): A Statistics object to how Marabou performed
- maxClass (int): Output class which value is max within outputs if SAT.
"""
inputVars = None
if (type(self.inputVars) is list):
if (len(self.inputVars) != 1):
raise NotImplementedError("Operation for %d inputs is not implemented" % len(self.inputVars))
inputVars = self.inputVars[0][0]
elif (type(self.inputVars) is np.ndarray):
inputVars = self.inputVars[0]
else:
err_msg = "Unpexpected type of input vars."
raise RuntimeError(err_msg)
assert(type(self.inputVars) is list)
if (len(self.inputVars) != 1):
raise NotImplementedError("Operation for %d inputs is not implemented" % len(self.inputVars))
inputVars = self.inputVars[0][0]

if inputVars.shape != input.shape:
raise RuntimeError("Input shape of the model should be same as the input shape\n input shape of the model: {0}, shape of the input: {1}".format(inputVars.shape, input.shape))

if (type(self.outputVars) is list):
if (len(self.outputVars) != 1):
raise NotImplementedError("Operation for %d outputs is not implemented" % len(self.outputVars))
elif (type(self.outputVars) is np.ndarray):
if (len(self.outputVars) != 1):
raise NotImplementedError("Operation for %d outputs is not implemented" % len(self.outputVars))
else:
err_msg = "Unpexpected type of output vars."
raise RuntimeError(err_msg)
assert(type(self.outputVars) is list)
if (len(self.outputVars) != 1):
raise NotImplementedError("Operation for %d outputs is not implemented" % len(self.outputVars))

if options == None:
options = MarabouCore.Options()
Expand All @@ -165,7 +153,7 @@ def evaluateLocalRobustness(self, input, epsilon, originalClass, verbose=True, o
for i in range(flattenInput.size):
self.setLowerBound(flattenInputVars[i], flattenInput[i] - epsilon)
self.setUpperBound(flattenInputVars[i], flattenInput[i] + epsilon)

maxClass = None
outputStartIndex = self.outputVars[0][0][0]

Expand Down
1 change: 0 additions & 1 deletion maraboupy/test/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,4 +805,3 @@ def evaluateNetwork(network, testInputs, testOutputs):
for testInput, testOutput in zip(testInputs, testOutputs):
marabouEval = network.evaluateWithMarabou([testInput], options = OPT, filename = "")[0].flatten()
assert max(abs(marabouEval - testOutput)) < TOL

0 comments on commit 85c963c

Please sign in to comment.