diff --git a/maraboupy/MarabouNetwork.py b/maraboupy/MarabouNetwork.py index ea4ff8c1b..601f700d2 100644 --- a/maraboupy/MarabouNetwork.py +++ b/maraboupy/MarabouNetwork.py @@ -132,29 +132,17 @@ def evaluateLocalRobustness(self, input, epsilon, originalClass, verbose=True, o - stats (:class:`~maraboupy.MarabouCore.Statistics`): A Statistics object to how Marabou performed - maxClass (int): Output class which value is max within outputs if SAT. """ - inputVars = None - if (type(self.inputVars) is list): - if (len(self.inputVars) != 1): - raise NotImplementedError("Operation for %d inputs is not implemented" % len(self.inputVars)) - inputVars = self.inputVars[0][0] - elif (type(self.inputVars) is np.ndarray): - inputVars = self.inputVars[0] - else: - err_msg = "Unpexpected type of input vars." - raise RuntimeError(err_msg) + assert(type(self.inputVars) is list) + if (len(self.inputVars) != 1): + raise NotImplementedError("Operation for %d inputs is not implemented" % len(self.inputVars)) + inputVars = self.inputVars[0][0] if inputVars.shape != input.shape: raise RuntimeError("Input shape of the model should be same as the input shape\n input shape of the model: {0}, shape of the input: {1}".format(inputVars.shape, input.shape)) - if (type(self.outputVars) is list): - if (len(self.outputVars) != 1): - raise NotImplementedError("Operation for %d outputs is not implemented" % len(self.outputVars)) - elif (type(self.outputVars) is np.ndarray): - if (len(self.outputVars) != 1): - raise NotImplementedError("Operation for %d outputs is not implemented" % len(self.outputVars)) - else: - err_msg = "Unpexpected type of output vars." - raise RuntimeError(err_msg) + assert(type(self.outputVars) is list) + if (len(self.outputVars) != 1): + raise NotImplementedError("Operation for %d outputs is not implemented" % len(self.outputVars)) if options == None: options = MarabouCore.Options() @@ -165,7 +153,7 @@ def evaluateLocalRobustness(self, input, epsilon, originalClass, verbose=True, o for i in range(flattenInput.size): self.setLowerBound(flattenInputVars[i], flattenInput[i] - epsilon) self.setUpperBound(flattenInputVars[i], flattenInput[i] + epsilon) - + maxClass = None outputStartIndex = self.outputVars[0][0][0] diff --git a/maraboupy/test/test_network.py b/maraboupy/test/test_network.py index b57da11ee..530d54b43 100644 --- a/maraboupy/test/test_network.py +++ b/maraboupy/test/test_network.py @@ -805,4 +805,3 @@ def evaluateNetwork(network, testInputs, testOutputs): for testInput, testOutput in zip(testInputs, testOutputs): marabouEval = network.evaluateWithMarabou([testInput], options = OPT, filename = "")[0].flatten() assert max(abs(marabouEval - testOutput)) < TOL -