Skip to content

Commit

Permalink
Adding support for vnnlib files encoding properties + adding support …
Browse files Browse the repository at this point in the history
…for "Sub" operator in ONNX
  • Loading branch information
idan0610 committed Dec 28, 2023
1 parent c2bf020 commit 7066456
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 5 deletions.
5 changes: 3 additions & 2 deletions maraboupy/Marabou.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,20 @@ def read_tf(filename, inputNames=None, outputNames=None, modelType="frozen", sav
"""
return MarabouNetworkTF(filename, inputNames, outputNames, modelType, savedModelTags)

def read_onnx(filename, inputNames=None, outputNames=None, reindexOutputVars=True):
def read_onnx(filename, inputNames=None, outputNames=None, reindexOutputVars=True, vnnlibFilename=None):
"""Constructs a MarabouNetworkONNX object from an ONNX file
Args:
filename (str): Path to the ONNX file
inputNames (list of str, optional): List of node names corresponding to inputs
outputNames (list of str, optional): List of node names corresponding to outputs
reindexOutputVars (bool): Reindex the variables so that the output variables are immediate after input variables
vnnlibFilename (str): Optional argument of filename to vnnlib file containing a property
Returns:
:class:`~maraboupy.MarabouNetworkONNX.MarabouNetworkONNX`
"""
return MarabouNetworkONNX(filename, inputNames, outputNames, reindexOutputVars=reindexOutputVars)
return MarabouNetworkONNX(filename, inputNames, outputNames, reindexOutputVars=reindexOutputVars, vnnlibFilename=vnnlibFilename)

def load_query(filename):
"""Load the serialized inputQuery from the given filename
Expand Down
181 changes: 178 additions & 3 deletions maraboupy/MarabouNetworkONNX.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
MarabouNetworkONNX represents neural networks with piecewise linear constraints derived from the ONNX format
'''
import ast

import numpy as np
import onnx
Expand All @@ -25,6 +26,7 @@
import itertools
import torch
import os
import re

class MarabouNetworkONNX(MarabouNetwork.MarabouNetwork):
"""Constructs a MarabouNetworkONNX object from an ONNX file
Expand All @@ -33,13 +35,14 @@ class MarabouNetworkONNX(MarabouNetwork.MarabouNetwork):
filename (str): Path to the ONNX file
inputNames: (list of str, optional): List of node names corresponding to inputs
outputNames: (list of str, optional): List of node names corresponding to outputs
vnnlibFilename (str): Optional argument of filename to vnnlib file containing a property
Returns:
:class:`~maraboupy.Marabou.marabouNetworkONNX.marabouNetworkONNX`
"""
def __init__(self, filename, inputNames=None, outputNames=None, reindexOutputVars=True):
def __init__(self, filename, inputNames=None, outputNames=None, reindexOutputVars=True, vnnlibFilename=None):
super().__init__()
self.readONNX(filename, inputNames, outputNames, reindexOutputVars=reindexOutputVars)
self.readONNX(filename, inputNames, outputNames, reindexOutputVars=reindexOutputVars, vnnlibFilename=vnnlibFilename)

def clear(self):
"""Reset values to represent empty network
Expand All @@ -49,6 +52,7 @@ def clear(self):
self.varMap = dict()
self.constantMap = dict()
self.shapeMap = dict()
self.vnnlibMap = dict()
self.inputNames = None
self.outputNames = None
self.graph = None
Expand All @@ -67,14 +71,15 @@ def shallowClear(self):
self.outputNames = None
self.graph = None

def readONNX(self, filename, inputNames, outputNames, reindexOutputVars=True):
def readONNX(self, filename, inputNames, outputNames, reindexOutputVars=True, vnnlibFilename=None):
"""Read an ONNX file and create a MarabouNetworkONNX object
Args:
filename: (str): Path to the ONNX file
inputNames: (list of str): List of node names corresponding to inputs
outputNames: (list of str): List of node names corresponding to outputs
reindexOutputVars: (bool): Reindex the variables so that the output variables are immediate after input variables.
vnnlibFilename (str): Optional argument of filename to vnnlib file containing a property
:meta private:
"""
Expand Down Expand Up @@ -120,6 +125,9 @@ def readONNX(self, filename, inputNames, outputNames, reindexOutputVars=True):
else:
self.outputVars = [self.varMap[outputName] for outputName in self.outputNames]

if vnnlibFilename:
self.loadPropertyWithVnnlib(vnnlibFilename)

def splitNetworkAtNode(self, nodeName, networkNamePreSplit=None,
networkNamePostSplit=None):
"""
Expand Down Expand Up @@ -272,6 +280,8 @@ def makeMarabouEquations(self, nodeName, makeEquations):
self.resizeEquations(node, makeEquations)
elif node.op_type == 'Tanh':
self.tanhEquations(node, makeEquations)
elif node.op_type == 'Sub':
self.subEquations(node, makeEquations)
else:
raise NotImplementedError("Operation {} not implemented".format(node.op_type))

Expand Down Expand Up @@ -1114,6 +1124,40 @@ def reluEquations(self, node, makeEquations):
for f in outputVars:
self.setLowerBound(f, 0.0)

def subEquations(self, node, makeEquations):
"""Function to generate equations corresponding to subtraction
Args:
node (node): ONNX node representing the Sub operation
makeEquations (bool): True if we need to create new variables and add new Relus
:meta private:
"""
nodeName = node.output[0]
inputName1, inputName2 = node.input[0], node.input[1]
assert inputName1 in self.shapeMap and inputName2 in self.shapeMap
assert self.shapeMap[inputName1] == self.shapeMap[inputName2]
self.shapeMap[nodeName] = self.shapeMap[inputName1]

if not makeEquations:
return

assert inputName1 in self.varMap and inputName2 in self.constantMap

# Get variables
inputVars = self.varMap[inputName1].reshape(-1)
outputVars = self.makeNewVariables(nodeName).reshape(-1)
constants = self.constantMap[inputName2].reshape(-1)
assert len(inputVars) == len(outputVars) == len(constants)

# Generate equations
for i in range(len(inputVars)):
e = MarabouUtils.Equation()
e.addAddend(1, inputVars[i])
e.addAddend(-1, outputVars[i])
e.setScalar(-constants[i])
self.addEquation(e)

def sigmoidEquations(self, node, makeEquations):
"""Function to generate equations corresponding to Sigmoid
Expand Down Expand Up @@ -1310,6 +1354,108 @@ def evaluateWithoutMarabou(self, inputValues):
input_dict[inputName] = inputValues[i].reshape(self.inputVars[i].shape).astype(inputType)
return sess.run(self.outputNames, input_dict)

def loadPropertyWithVnnlib(self, vnnlibFilename):
"""Loads the property from the given vnnlib file
Args:
vnnlibFilename (str): Filename for the vnnlib file
Returns:
None
"""
input_vars = self.inputVars[0].reshape(-1)
output_vars = self.outputVars[0][0].reshape(-1)

input_var_idx = input_vars[0]
output_var_idx = output_vars[0]

with open(vnnlibFilename, 'r') as f:
lines = f.readlines()

vnnlib_content = ""
for i in range(len(lines)):
if lines[i] == "" or lines[i].startswith(";"):
continue

vnnlib_content += lines[i].strip()

vnnlib_content = "(" + vnnlib_content + ")"
vnnlib_items = make_tree(vnnlib_content)

for statement in vnnlib_items:
statement_type = statement[0]

if statement_type == "declare-const":
var_name = statement[1]
if var_name[0] == 'X':
assert input_var_idx <= input_vars[-1]
self.vnnlibMap[statement[1]] = input_var_idx
input_var_idx += 1
elif var_name[0] == 'Y':
assert output_var_idx <= output_vars[-1]
self.vnnlibMap[statement[1]] = output_var_idx
output_var_idx += 1
else:
raise RuntimeError("All variable name should should begin with 'X_' for input variables, "
"and 'Y_' for output variables in variable declarations")

elif statement_type == "assert":
operator = statement[1][0]

if operator == "<=":
var_name = statement[1][1]
second_argument = statement[1][2]

if second_argument in self.vnnlibMap:
self.addInequality([self.vnnlibMap[var_name], self.vnnlibMap[second_argument]], [1, -1], 0)
else:
self.setUpperBound(self.vnnlibMap[var_name], float(second_argument))
elif operator == ">=":
var_name = statement[1][1]
second_argument = statement[1][2]

if second_argument in self.vnnlibMap:
self.addInequality([self.vnnlibMap[var_name], self.vnnlibMap[second_argument]], [-1, 1], 0)
else:
self.setLowerBound(self.vnnlibMap[var_name], float(second_argument))

elif operator == "or":
disjuncts = []
for disjunt_terms in statement[1][1:]:
assert disjunt_terms[0] == "and"

conjuncts = []

for conjunct_terms in disjunt_terms[1:]:
operator = conjunct_terms[0]
var_name = conjunct_terms[1]
second_argument = conjunct_terms[2]

e = MarabouUtils.Equation()
if operator == "<=":
e.addAddend(1, self.vnnlibMap[var_name])
if second_argument in self.vnnlibMap:
e.addAddend(-1, self.vnnlibMap[second_argument])
e.setScalar(0)
else:
e.setScalar(float(second_argument))
elif operator == ">=":
e.addAddend(-1, self.vnnlibMap[var_name])
if second_argument in self.vnnlibMap:
e.addAddend(1, self.vnnlibMap[second_argument])
e.setScalar(0)
else:
e.setScalar(-float(second_argument))
else:
raise NotImplementedError("'or' operator specified in vnnlib file supports only disjuncts "
"with one of the following operators: '>=', '<='")

conjuncts.append(e)

disjuncts.append(conjuncts)

self.addDisjunctionConstraint(disjuncts)

def getBroadcastShape(shape1, shape2):
"""Helper function to get the shape that results from broadcasting these shapes together
Expand All @@ -1323,3 +1469,32 @@ def getBroadcastShape(shape1, shape2):
:meta private:
"""
return [l1 if l1 == l2 else max(l1, l2) for l1, l2 in itertools.zip_longest(shape1[::-1], shape2[::-1], fillvalue=1)][::-1]


def make_tree(content):
"""Helper function to get the statements of given vnnlib content file split into lists
Args:
content (str): Content of vnnlib file (filtered after removing comments)
Returns:
(nested lists of str): list of statements in vnnlib content, possibly with more nested lists
:meta private:
"""
items = re.findall(r"\(|\)|[\w-]+|<=|>=", content)

def req(index):
result = []
item = items[index]
while item != ")":
if item == "(":
subtree, index = req(index + 1)
result.append(subtree)
else:
result.append(item)
index += 1
item = items[index]
return result, index

return req(1)[0]

0 comments on commit 7066456

Please sign in to comment.