Skip to content

Commit

Permalink
Fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewDaggitt committed Feb 17, 2024
1 parent c64e6d8 commit 9a18b58
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
5 changes: 2 additions & 3 deletions maraboupy/Marabou.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,18 @@ def read_tf(filename, inputNames=None, outputNames=None, modelType="frozen", sav
"""
return MarabouNetworkTF(filename, inputNames, outputNames, modelType, savedModelTags)

def read_onnx(filename, inputNames=None, outputNames=None, reindexOutputVars=False):
def read_onnx(filename, inputNames=None, outputNames=None):
"""Constructs a MarabouNetworkONNX object from an ONNX file
Args:
filename (str): Path to the ONNX file
inputNames (list of str, optional): List of node names corresponding to inputs
outputNames (list of str, optional): List of node names corresponding to outputs
reindexOutputVars (bool): Reindex the variables so that the output variables are immediate after input variables
Returns:
:class:`~maraboupy.MarabouNetworkONNX.MarabouNetworkONNX`
"""
return MarabouNetworkONNX(filename, inputNames, outputNames, reindexOutputVars=reindexOutputVars)
return MarabouNetworkONNX(filename, inputNames, outputNames)

def load_query(filename):
"""Load the serialized inputQuery from the given filename
Expand Down
12 changes: 6 additions & 6 deletions maraboupy/test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,11 @@ def test_multiOutput():
def test_preserve_existing_constraints_clear():
filename = "tanh_test.onnx"
filename = os.path.join(os.path.dirname(__file__), NETWORK_FOLDER, filename)
network = Marabou.read_onnx(filename, reindexOutputVars=False)
network = Marabou.read_onnx(filename)
numVar1 = network.numVars
numEq1 = len(network.equList)
numSigmoid1 = len(network.sigmoidList)
network.readONNX(filename, None, None, reindexOutputVars=False, preserveExistingConstraints=True)
network.readONNX(filename, None, None, preserveExistingConstraints=True)
numVar2 = network.numVars
numEq2 = len(network.equList)
numSigmoid2 = len(network.sigmoidList)
Expand Down Expand Up @@ -308,24 +308,24 @@ def test_errors_do_not_reindex():

# Test that we catch if inputNames or outputNames are not in the model
with pytest.raises(RuntimeError, match=r"Input.*not found"):
Marabou.read_onnx(filename, inputNames = ['BAD_NAME'], outputNames = ['Y'], reindexOutputVars=True)
Marabou.read_onnx(filename, inputNames = ['BAD_NAME'], outputNames = ['Y'])
with pytest.raises(RuntimeError, match=r"Output.*not found"):
Marabou.read_onnx(filename, outputNames = ['BAD_NAME'])

# The layer "12" is in the graph, but refers to a constant, so it should not be used
# as the network input or output.
with pytest.raises(RuntimeError, match=r"input variables could not be found"):
Marabou.read_onnx(filename, inputNames = ['12'], outputNames = ['Y'], reindexOutputVars=True)
Marabou.read_onnx(filename, inputNames = ['12'], outputNames = ['Y'])
with pytest.raises(RuntimeError, match=r"Output variable.*is a constant"):
Marabou.read_onnx(filename, outputNames = ['12'])

# Evaluating with ONNX instead of Marabou gives errors when using a layer that is not already
# defined as part of the model inputs or outputs.
with pytest.raises(NotImplementedError, match=r"ONNX does not allow.*as inputs"):
network = Marabou.read_onnx(filename, inputNames = ['11'], outputNames = ['Y'], reindexOutputVars=True)
network = Marabou.read_onnx(filename, inputNames = ['11'], outputNames = ['Y'])
network.evaluateWithoutMarabou([])
with pytest.raises(NotImplementedError, match=r"ONNX does not allow.*the output"):
network = Marabou.read_onnx(filename, inputNames = ['X'], outputNames = ['11'], reindexOutputVars=True)
network = Marabou.read_onnx(filename, inputNames = ['X'], outputNames = ['11'])
network.evaluateWithoutMarabou([])

def evaluateFile(filename, inputNames = None, outputNames = None, testInputs = None, numPoints = NUM_RAND):
Expand Down

0 comments on commit 9a18b58

Please sign in to comment.