Skip to content

Commit

Permalink
fix unit test and CI
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-haoze committed Aug 22, 2024
1 parent 1905555 commit 85322dd
Show file tree
Hide file tree
Showing 33 changed files with 96 additions and 1,258 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Changelog

## Next Release
- Dropped support for parsing Tensorflow network format. Newest Marabou version that supports Tensorflow is at commit 190555573e4702.

## Version 2.0.0

Expand Down
22 changes: 0 additions & 22 deletions maraboupy/Marabou.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@
from maraboupy.MarabouNetworkNNet import *
except ImportError:
warnings.warn("NNet parser is unavailable because the numpy package is not installed")
try:
from maraboupy.MarabouNetworkTF import *
except ImportError:
warnings.warn("Tensorflow parser is unavailable because tensorflow package is not installed")
try:
from maraboupy.MarabouNetworkONNX import *
except ImportError:
Expand All @@ -43,24 +39,6 @@ def read_nnet(filename, normalize=False):
"""
return MarabouNetworkNNet(filename, normalize=normalize)


def read_tf(filename, inputNames=None, outputNames=None, modelType="frozen", savedModelTags=[]):
"""Constructs a MarabouNetworkTF object from a frozen Tensorflow protobuf
Args:
filename (str): Path to tensorflow network
inputNames (list of str, optional): List of operation names corresponding to inputs
outputNames (list of str, optional): List of operation names corresponding to outputs
modelType (str, optional): Type of model to read. The default is "frozen" for a frozen graph.
Can also use "savedModel_v1" or "savedModel_v2" for the SavedModel format
created from either tensorflow versions 1.X or 2.X respectively.
savedModelTags (list of str, optional): If loading a SavedModel, the user must specify tags used, default is []
Returns:
:class:`~maraboupy.MarabouNetworkTF.MarabouNetworkTF`
"""
return MarabouNetworkTF(filename, inputNames, outputNames, modelType, savedModelTags)

def read_onnx(filename, inputNames=None, outputNames=None):
"""Constructs a MarabouNetworkONNX object from an ONNX file
Expand Down
28 changes: 8 additions & 20 deletions maraboupy/MarabouNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,29 +132,17 @@ def evaluateLocalRobustness(self, input, epsilon, originalClass, verbose=True, o
- stats (:class:`~maraboupy.MarabouCore.Statistics`): A Statistics object to how Marabou performed
- maxClass (int): Output class which value is max within outputs if SAT.
"""
inputVars = None
if (type(self.inputVars) is list):
if (len(self.inputVars) != 1):
raise NotImplementedError("Operation for %d inputs is not implemented" % len(self.inputVars))
inputVars = self.inputVars[0][0]
elif (type(self.inputVars) is np.ndarray):
inputVars = self.inputVars[0]
else:
err_msg = "Unpexpected type of input vars."
raise RuntimeError(err_msg)
assert(type(self.inputVars) is list)
if (len(self.inputVars) != 1):
raise NotImplementedError("Operation for %d inputs is not implemented" % len(self.inputVars))
inputVars = self.inputVars[0][0]

if inputVars.shape != input.shape:
raise RuntimeError("Input shape of the model should be same as the input shape\n input shape of the model: {0}, shape of the input: {1}".format(inputVars.shape, input.shape))

if (type(self.outputVars) is list):
if (len(self.outputVars) != 1):
raise NotImplementedError("Operation for %d outputs is not implemented" % len(self.outputVars))
elif (type(self.outputVars) is np.ndarray):
if (len(self.outputVars) != 1):
raise NotImplementedError("Operation for %d outputs is not implemented" % len(self.outputVars))
else:
err_msg = "Unpexpected type of output vars."
raise RuntimeError(err_msg)
assert(type(self.outputVars) is list)
if (len(self.outputVars) != 1):
raise NotImplementedError("Operation for %d outputs is not implemented" % len(self.outputVars))

if options == None:
options = MarabouCore.Options()
Expand All @@ -165,7 +153,7 @@ def evaluateLocalRobustness(self, input, epsilon, originalClass, verbose=True, o
for i in range(flattenInput.size):
self.setLowerBound(flattenInputVars[i], flattenInput[i] - epsilon)
self.setUpperBound(flattenInputVars[i], flattenInput[i] + epsilon)

maxClass = None
outputStartIndex = self.outputVars[0][0][0]

Expand Down
Loading

0 comments on commit 85322dd

Please sign in to comment.