Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
tagomaru committed Dec 25, 2023
1 parent a9b1752 commit 9155844
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 3 deletions.
16 changes: 16 additions & 0 deletions maraboupy/Marabou.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- Christopher Lazarus
- Kyle Julian
- Andrew Wu
- Teruhiro Tagomori
This file is part of the Marabou project.
Copyright (c) 2017-2019 by the authors listed in the file AUTHORS
Expand All @@ -29,6 +30,7 @@
from maraboupy.MarabouNetworkONNX import *
except ImportError:
warnings.warn("ONNX parser is unavailable because onnx or onnxruntime packages are not installed")
from maraboupy.MarabouNetworkComposition import MarabouNetworkComposition

def read_nnet(filename, normalize=False):
"""Constructs a MarabouNetworkNnet object from a .nnet file
Expand Down Expand Up @@ -74,6 +76,20 @@ def read_onnx(filename, inputNames=None, outputNames=None, reindexOutputVars=Tru
"""
return MarabouNetworkONNX(filename, inputNames, outputNames, reindexOutputVars=reindexOutputVars)

def read_onnx_with_threshould(filename, inputNames=None, outputNames=None, reindexOutputVars=True, threshold=None):
"""Constructs a MarabouNetworkComposition object from an ONNX file
Args:
filename (str): Path to the ONNX file
inputNames (list of str, optional): List of node names corresponding to inputs
outputNames (list of str, optional): List of node names corresponding to outputs
threshold (int, optional): Threshold for the number of linear equations
Returns:
:class:`~maraboupy.MarabouNetworkONNX.MarabouNetworkONNX`
"""
return MarabouNetworkComposition(filename, inputNames, outputNames, reindexOutputVars=reindexOutputVars, threshold=threshold)

def load_query(filename):
"""Load the serialized inputQuery from the given filename
Expand Down
284 changes: 284 additions & 0 deletions maraboupy/MarabouNetworkComposition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
'''
Top contributors (to current version):
- Teruhiro Tagomori
This file is part of the Marabou project.
Copyright (c) 2017-2019 by the authors listed in the file AUTHORS
in the top-level source directory) and their institutional affiliations.
All rights reserved. See the file COPYING in the top-level source
directory for licensing information.
MarabouNetworkComposition represents neural networks with piecewise linear constraints derived from the ONNX format
'''

import numpy as np
import onnx
import sys
import pathlib
import os
sys.path.insert(0, os.path.join(str(pathlib.Path(__file__).parent.absolute()), "../"))

from maraboupy import Marabou
from maraboupy import MarabouCore
from maraboupy import MarabouNetwork
from maraboupy import MarabouNetworkONNX

class MarabouNetworkComposition(MarabouNetwork.MarabouNetwork):
"""Constructs a MarabouNetworkComposition object from an ONNX file
Args:
filename (str): Path to the ONNX file
inputNames: (list of str, optional): List of node names corresponding to inputs
outputNames: (list of str, optional): List of node names corresponding to outputs
reindexOutputVars: (bool, optional): Reindex the variables so that the output variables are immediate after input variables.
threshold: (float, optional): Threshold for equation functions. If the absolute value of the lower bound of a variable is less than this value, the variable is considered to be 0.
Returns:
:class:`~maraboupy.Marabou.MarabouNetworkComposition.MarabouNetworkComposition`
"""
def __init__(self, filename, inputNames=None, outputNames=None, reindexOutputVars=True, threshold=None):
super().__init__()
self.shapeMap = {}
self.madeGraphEquations = []
self.ipqs = []
self.ipqToInVars = {}
self.ipqToOutVars = {}
self.inputVars, self.outputVars = self.getInputOutputVars(filename, inputNames, outputNames, reindexOutputVars=True)

network = MarabouNetworkONNX.MarabouNetworkONNX(filename, reindexOutputVars=reindexOutputVars, threshold=threshold)
print(f'TG: input vars: {network.inputVars}')
print(f'TG: output vars: {network.outputVars}')

network.saveQuery('q1.ipq')
self.ipqs.append('q1.ipq')
self.ipqToInVars['q1.ipq'] = network.inputVars
self.ipqToOutVars['q1.ipq'] = network.outputVars

index = 2
# while if post_split.onnx exists
while os.path.exists('post_split.onnx'):
# delete network
del network
network = MarabouNetworkONNX.MarabouNetworkONNX('post_split.onnx', reindexOutputVars=reindexOutputVars, threshold=threshold)
network.saveQuery(f'q{index}.ipq')
self.ipqs.append(f'q{index}.ipq')
self.ipqToInVars[f'q{index}.ipq'] = network.inputVars
self.ipqToOutVars[f'q{index}.ipq'] = network.outputVars
# self.ipq_to_inVars['q{index}.ipq'] = network.inputVars
index += 1
print(network.inputVars)
print(network.outputVars)
print(self.ipqs)
# print(inputQuery.inputVars)
# MarabouNetworkONNX.readONNX(filename, reindexOutputVars=reindexOutputVars, threshold=threshold)
# network, post_network_file = MarabouNetworkONNX.readONNX(filename, reindexOutputVars=reindexOutputVars, threshold=threshold)

def solve(self):
# https://github.com/wu-haoze/Marabou/blob/1a3ca6010b51bba792ef8ddd5e1ccf9119121bd8/resources/runVerify.py#L200-L225
options = Marabou.createOptions(verbosity = 1)
for i, ipqFile in enumerate(self.ipqs):
# ipq = Marabou.load_query(ipqFile)
# load inputquery
ipq = Marabou.loadQuery(ipqFile)

if i == 0:
self.encodeInput(ipq)

if i > 0:
self.encodeCalculateInputBounds(ipq, i, bounds)
# print(bounds)

if i == len(self.ipqs) - 1:
self.encodeOutput(ipq, i)
# exitCode, vals, _ = MarabouCore.solve(ipq, options)
# print(f'TG: exit code: {exitCode}')
# if exitCode == "sat":
# for j in range(len(self.inputVars)):
# for i in range(self.inputVars[j].size):
# print("input {} = {}".format(i, vals[self.inputVars[j].item(i)]))

# for j in range(len(self.outputVars)): #TG: Original のものを表示してしまっているので修正
# for i in range(self.outputVars[j].size):
# print("output {} = {}".format(i, vals[self.outputVars[j].item(i)]))
ret, bounds, stats = MarabouCore.calculateBounds(ipq, options)
# print(f'TG: bounds: {bounds}')
else:
ret, bounds, stats = MarabouCore.calculateBounds(ipq, options)
print(f'TG: bounds: {bounds}')
print(f'TG: bounds: {bounds}')

def encodeCalculateInputBounds(self, ipq, i, bounds):
print('TG: encodeCalculateInputBounds')
previousOutputVars = self.ipqToOutVars[f'q{i}.ipq']
currentInputVars = self.ipqToInVars[f'q{i+1}.ipq']
print('TG: ', self.ipqToOutVars[f'q{i+1}.ipq'])
print(f'TG: previous output vars: {previousOutputVars}')
print(f'TG: current input vars: {currentInputVars}')

for previousOutputVar, currentInputVar in zip(previousOutputVars, currentInputVars):
for previousOutputVarElement, currentInputVarElement in zip(previousOutputVar.flatten(), currentInputVar.flatten()):
print(f'TG: previous output var element: {previousOutputVarElement}')
print(f'TG: current input var element: {currentInputVarElement}')
print(f'TG: bounds: {bounds[previousOutputVarElement]}')
ipq.setLowerBound(currentInputVarElement, bounds[previousOutputVarElement][0])
ipq.setUpperBound(currentInputVarElement, bounds[previousOutputVarElement][1])

def encodeInput(self, ipq):
inputVars = self.ipqToInVars['q1.ipq']
# print('TG: ', self.ipqToOutVars[f'q1.ipq'])
# print('TG: ', self.ipqToInVars[f'q2.ipq'])
for array in inputVars:
for var in array.flatten():
ipq.setLowerBound(var, self.lowerBounds[var])
ipq.setUpperBound(var, self.upperBounds[var])

def encodeOutput(self, ipq, i):
outputVars = self.ipqToOutVars[f'q{i+1}.ipq']
originalOutputVars = self.outputVars
print(f'TG: original output vars: {originalOutputVars}')
print(f'TG: output vars: {outputVars}')

for originalOutputVar, outputVar in zip(originalOutputVars, outputVars):
for originalOutputVarElement, outputVarElement in zip(originalOutputVar.flatten(), outputVar.flatten()):
if originalOutputVarElement in self.lowerBounds:
ipq.setLowerBound(outputVarElement, self.lowerBounds[originalOutputVarElement])
if originalOutputVarElement in self.upperBounds:
ipq.setUpperBound(outputVarElement, self.upperBounds[originalOutputVarElement])


# def getInputOutVars(self, filename, inputNames, outputNames, reindexOutputVars=True, threshold=None):
def getInputOutputVars(self, filename, inputNames, outputNames, reindexOutputVars=True):
"""Read an ONNX file and create a MarabouNetworkONNX object
Args:
filename: (str): Path to the ONNX file
inputNames: (list of str): List of node names corresponding to inputs
outputNames: (list of str): List of node names corresponding to outputs
reindexOutputVars: (bool): Reindex the variables so that the output variables are immediate after input variables.
:meta private:
"""
self.filename = filename
self.graph = onnx.load(filename).graph

# Get default inputs/outputs if no names are provided
if not inputNames:
assert len(self.graph.input) >= 1
initNames = [node.name for node in self.graph.initializer]
inputNames = [inp.name for inp in self.graph.input if inp.name not in initNames]
if not outputNames:
assert len(self.graph.output) >= 1
initNames = [node.name for node in self.graph.initializer]
outputNames = [out.name for out in self.graph.output if out.name not in initNames]
elif isinstance(outputNames, str):
outputNames = [outputNames]

# Check that input/outputs are in the graph
for name in inputNames:
if not len([nde for nde in self.graph.node if name in nde.input]):
raise RuntimeError("Input %s not found in graph!" % name)
for name in outputNames:
if not len([nde for nde in self.graph.node if name in nde.output]):
raise RuntimeError("Output %s not found in graph!" % name)

self.inputNames = inputNames
self.outputNames = outputNames

# Process the shapes and values of the graph while making Marabou equations and constraints
self.foundnInputFlags = 0

# Add shapes for the graph's inputs
inputVars = []
for node in self.graph.input:
self.shapeMap[node.name] = list([dim.dim_value if dim.dim_value > 0 else 1 for dim in node.type.tensor_type.shape.dim])

# If we find one of the specified inputs, create new variables
if node.name in self.inputNames:
self.madeGraphEquations += [node.name]
self.foundnInputFlags += 1
v = self.makeNewVariables(node.name)
# self.inputVars += [np.array(self.varMap[node.name])]
inputVars += [v]
print(f'TG: get input vars: {inputVars}')

# Add shapes for the graph's outputs
outputVars = []
for node in self.graph.output:
self.shapeMap[node.name] = list([dim.dim_value if dim.dim_value > 0 else 1 for dim in node.type.tensor_type.shape.dim])

# If we find one of the specified inputs, create new variables
if node.name in self.outputNames:
self.madeGraphEquations += [node.name]
self.foundnInputFlags += 1
v = self.makeNewVariables(node.name)
outputVars += [v]
print(f'TG: get ouput vars: {outputVars}')

return inputVars, outputVars

# # Recursively create remaining shapes and equations as needed
# for outputName in self.outputNames: # TG:
# print(f'TG: outputName: {outputName}')
# self.makeGraphEquations(outputName, True)



# If the given inputNames/outputNames specify only a portion of the network, then we will have
# shape information saved not relevant to the portion of the network. Remove extra shapes.
# self.cleanShapes() # TG: needed?

# TG: needed?

# if reindexOutputVars:
# # Other Marabou input parsers assign output variables immediately after input variables and before any
# # intermediate variables. This function reassigns variable numbering to match other parsers.
# # If this is skipped, the output variables will be the last variables defined.
# self.reassignOutputVariables()
# else:
# self.outputVars = [self.varMap[outputName] for outputName in self.outputNames]



def makeNewVariables(self, nodeName):
"""Assuming the node's shape is known, return a set of new variables in the same shape
Args:
nodeName (str): Name of node
Returns:
(numpy array): Array of variable numbers
:meta private:
"""
# assert nodeName not in self.varMap
shape = self.shapeMap[nodeName]
size = np.prod(shape)
v = np.array([self.getNewVariable() for _ in range(size)]).reshape(shape)
# self.varMap[nodeName] = v
assert all([np.equal(np.mod(i, 1), 0) for i in v.reshape(-1)]) # check if integers
return v

def setLowerBound(self, x, v):
"""Function to set lower bound for variable
Args:
x (int): Variable number to set
v (float): Value representing lower bound
"""
if any(x in arr for arr in self.inputVars) or any(x in arr for arr in self.outputVars):
self.lowerBounds[x] = v
else:
raise RuntimeError("Cannot set bounds on input or output variables")


def setUpperBound(self, x, v):
"""Function to set upper bound for variable
Args:
x (int): Variable number to set
v (float): Value representing upper bound
"""
if any(x in arr for arr in self.inputVars) or any(x in arr for arr in self.outputVars):
self.upperBounds[x] = v
else:
raise RuntimeError("Cannot set bounds on input or output variables")
20 changes: 17 additions & 3 deletions maraboupy/MarabouNetworkONNX.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class MarabouNetworkONNX(MarabouNetwork.MarabouNetwork):
Returns:
:class:`~maraboupy.Marabou.marabouNetworkONNX.marabouNetworkONNX`
"""
def __init__(self, filename, inputNames=None, outputNames=None, reindexOutputVars=True):
def __init__(self, filename, inputNames=None, outputNames=None, reindexOutputVars=True, threshold=None):
super().__init__()
self.readONNX(filename, inputNames, outputNames, reindexOutputVars=reindexOutputVars)
self.readONNX(filename, inputNames, outputNames, reindexOutputVars=reindexOutputVars, threshold=threshold)

def clear(self):
"""Reset values to represent empty network
Expand Down Expand Up @@ -67,7 +67,7 @@ def shallowClear(self):
self.outputNames = None
self.graph = None

def readONNX(self, filename, inputNames, outputNames, reindexOutputVars=True):
def readONNX(self, filename, inputNames, outputNames, reindexOutputVars=True, threshold=None):
"""Read an ONNX file and create a MarabouNetworkONNX object
Args:
Expand All @@ -80,6 +80,11 @@ def readONNX(self, filename, inputNames, outputNames, reindexOutputVars=True):
"""
self.filename = filename
self.graph = onnx.load(filename).graph
self.threshold = threshold
self.thresholdReached = False

if os.path.exists('post_split.onnx'):
os.remove('post_split.onnx')

# Get default inputs/outputs if no names are provided
if not inputNames:
Expand Down Expand Up @@ -185,6 +190,8 @@ def processGraph(self):
# Recursively create remaining shapes and equations as needed
for outputName in self.outputNames:
self.makeGraphEquations(outputName, True)
# if self.thresholdReached:
# return

def makeGraphEquations(self, nodeName, makeEquations):
"""Recursively populates self.shapeMap, self.varMap, and self.constantMap while adding equations and constraints
Expand Down Expand Up @@ -217,8 +224,15 @@ def makeGraphEquations(self, nodeName, makeEquations):
raise RuntimeError(err_msg)

# Compute node's shape and create Marabou equations as needed
if self.thresholdReached:
return
self.makeMarabouEquations(nodeName, makeEquations)

if self.threshold is not None:
if not self.thresholdReached and len(self.equList) > self.threshold:
if self.splitNetworkAtNode(nodeName, networkNamePostSplit='post_split.onnx'):
self.thresholdReached = True

# Create new variables when we find one of the inputs
if nodeName in self.inputNames:
self.makeNewVariables(nodeName)
Expand Down

0 comments on commit 9155844

Please sign in to comment.