Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tagomaru committed Dec 27, 2023
1 parent 9155844 commit db5ff08
Showing 1 changed file with 9 additions and 67 deletions.
76 changes: 9 additions & 67 deletions maraboupy/MarabouNetworkComposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,17 @@ def __init__(self, filename, inputNames=None, outputNames=None, reindexOutputVar
self.ipqs = []
self.ipqToInVars = {}
self.ipqToOutVars = {}
self.inputVars, self.outputVars = self.getInputOutputVars(filename, inputNames, outputNames, reindexOutputVars=True)
self.inputVars, self.outputVars = self.getInputOutputVars(filename, inputNames, outputNames)

network = MarabouNetworkONNX.MarabouNetworkONNX(filename, reindexOutputVars=reindexOutputVars, threshold=threshold)
print(f'TG: input vars: {network.inputVars}')
print(f'TG: output vars: {network.outputVars}')

network.saveQuery('q1.ipq')
self.ipqs.append('q1.ipq')
self.ipqToInVars['q1.ipq'] = network.inputVars
self.ipqToOutVars['q1.ipq'] = network.outputVars

index = 2
# while if post_split.onnx exists

while os.path.exists('post_split.onnx'):
# delete network
del network
Expand All @@ -64,29 +62,20 @@ def __init__(self, filename, inputNames=None, outputNames=None, reindexOutputVar
self.ipqs.append(f'q{index}.ipq')
self.ipqToInVars[f'q{index}.ipq'] = network.inputVars
self.ipqToOutVars[f'q{index}.ipq'] = network.outputVars
# self.ipq_to_inVars['q{index}.ipq'] = network.inputVars
index += 1
print(network.inputVars)
print(network.outputVars)
print(self.ipqs)
# print(inputQuery.inputVars)
# MarabouNetworkONNX.readONNX(filename, reindexOutputVars=reindexOutputVars, threshold=threshold)
# network, post_network_file = MarabouNetworkONNX.readONNX(filename, reindexOutputVars=reindexOutputVars, threshold=threshold)

def solve(self):
# https://github.com/wu-haoze/Marabou/blob/1a3ca6010b51bba792ef8ddd5e1ccf9119121bd8/resources/runVerify.py#L200-L225
options = Marabou.createOptions(verbosity = 1)
options = Marabou.createOptions(verbosity = 1) # TG: Option を引数でとる
for i, ipqFile in enumerate(self.ipqs):
# ipq = Marabou.load_query(ipqFile)
# load inputquery
# load input query
ipq = Marabou.loadQuery(ipqFile)

if i == 0:
self.encodeInput(ipq)

if i > 0:
self.encodeCalculateInputBounds(ipq, i, bounds)
# print(bounds)

if i == len(self.ipqs) - 1:
self.encodeOutput(ipq, i)
Expand All @@ -101,32 +90,20 @@ def solve(self):
# for i in range(self.outputVars[j].size):
# print("output {} = {}".format(i, vals[self.outputVars[j].item(i)]))
ret, bounds, stats = MarabouCore.calculateBounds(ipq, options)
# print(f'TG: bounds: {bounds}')
else:
ret, bounds, stats = MarabouCore.calculateBounds(ipq, options)
print(f'TG: bounds: {bounds}')
print(f'TG: bounds: {bounds}')

def encodeCalculateInputBounds(self, ipq, i, bounds):
print('TG: encodeCalculateInputBounds')
previousOutputVars = self.ipqToOutVars[f'q{i}.ipq']
currentInputVars = self.ipqToInVars[f'q{i+1}.ipq']
print('TG: ', self.ipqToOutVars[f'q{i+1}.ipq'])
print(f'TG: previous output vars: {previousOutputVars}')
print(f'TG: current input vars: {currentInputVars}')
currentInputVars = self.ipqToInVars[f'q{i+1}.ipq']

for previousOutputVar, currentInputVar in zip(previousOutputVars, currentInputVars):
for previousOutputVarElement, currentInputVarElement in zip(previousOutputVar.flatten(), currentInputVar.flatten()):
print(f'TG: previous output var element: {previousOutputVarElement}')
print(f'TG: current input var element: {currentInputVarElement}')
print(f'TG: bounds: {bounds[previousOutputVarElement]}')
ipq.setLowerBound(currentInputVarElement, bounds[previousOutputVarElement][0])
ipq.setUpperBound(currentInputVarElement, bounds[previousOutputVarElement][1])

def encodeInput(self, ipq):
inputVars = self.ipqToInVars['q1.ipq']
# print('TG: ', self.ipqToOutVars[f'q1.ipq'])
# print('TG: ', self.ipqToInVars[f'q2.ipq'])
for array in inputVars:
for var in array.flatten():
ipq.setLowerBound(var, self.lowerBounds[var])
Expand All @@ -135,8 +112,6 @@ def encodeInput(self, ipq):
def encodeOutput(self, ipq, i):
outputVars = self.ipqToOutVars[f'q{i+1}.ipq']
originalOutputVars = self.outputVars
print(f'TG: original output vars: {originalOutputVars}')
print(f'TG: output vars: {outputVars}')

for originalOutputVar, outputVar in zip(originalOutputVars, outputVars):
for originalOutputVarElement, outputVarElement in zip(originalOutputVar.flatten(), outputVar.flatten()):
Expand All @@ -145,16 +120,13 @@ def encodeOutput(self, ipq, i):
if originalOutputVarElement in self.upperBounds:
ipq.setUpperBound(outputVarElement, self.upperBounds[originalOutputVarElement])


# def getInputOutVars(self, filename, inputNames, outputNames, reindexOutputVars=True, threshold=None):
def getInputOutputVars(self, filename, inputNames, outputNames, reindexOutputVars=True):
"""Read an ONNX file and create a MarabouNetworkONNX object
def getInputOutputVars(self, filename, inputNames, outputNames):
"""Get input and output variables of an original network
Args:
filename: (str): Path to the ONNX file
inputNames: (list of str): List of node names corresponding to inputs
outputNames: (list of str): List of node names corresponding to outputs
reindexOutputVars: (bool): Reindex the variables so that the output variables are immediate after input variables.
:meta private:
"""
Expand Down Expand Up @@ -197,9 +169,7 @@ def getInputOutputVars(self, filename, inputNames, outputNames, reindexOutputVar
self.madeGraphEquations += [node.name]
self.foundnInputFlags += 1
v = self.makeNewVariables(node.name)
# self.inputVars += [np.array(self.varMap[node.name])]
inputVars += [v]
print(f'TG: get input vars: {inputVars}')

# Add shapes for the graph's outputs
outputVars = []
Expand All @@ -212,33 +182,8 @@ def getInputOutputVars(self, filename, inputNames, outputNames, reindexOutputVar
self.foundnInputFlags += 1
v = self.makeNewVariables(node.name)
outputVars += [v]
print(f'TG: get ouput vars: {outputVars}')

return inputVars, outputVars

# # Recursively create remaining shapes and equations as needed
# for outputName in self.outputNames: # TG:
# print(f'TG: outputName: {outputName}')
# self.makeGraphEquations(outputName, True)



# If the given inputNames/outputNames specify only a portion of the network, then we will have
# shape information saved not relevant to the portion of the network. Remove extra shapes.
# self.cleanShapes() # TG: needed?

# TG: needed?

# if reindexOutputVars:
# # Other Marabou input parsers assign output variables immediately after input variables and before any
# # intermediate variables. This function reassigns variable numbering to match other parsers.
# # If this is skipped, the output variables will be the last variables defined.
# self.reassignOutputVariables()
# else:
# self.outputVars = [self.varMap[outputName] for outputName in self.outputNames]



def makeNewVariables(self, nodeName):
"""Assuming the node's shape is known, return a set of new variables in the same shape
Expand All @@ -250,11 +195,9 @@ def makeNewVariables(self, nodeName):
:meta private:
"""
# assert nodeName not in self.varMap
shape = self.shapeMap[nodeName]
size = np.prod(shape)
v = np.array([self.getNewVariable() for _ in range(size)]).reshape(shape)
# self.varMap[nodeName] = v
assert all([np.equal(np.mod(i, 1), 0) for i in v.reshape(-1)]) # check if integers
return v

Expand All @@ -268,9 +211,8 @@ def setLowerBound(self, x, v):
if any(x in arr for arr in self.inputVars) or any(x in arr for arr in self.outputVars):
self.lowerBounds[x] = v
else:
raise RuntimeError("Cannot set bounds on input or output variables")
raise RuntimeError("Can set bounds only on either input or output variables")


def setUpperBound(self, x, v):
"""Function to set upper bound for variable
Expand All @@ -281,4 +223,4 @@ def setUpperBound(self, x, v):
if any(x in arr for arr in self.inputVars) or any(x in arr for arr in self.outputVars):
self.upperBounds[x] = v
else:
raise RuntimeError("Cannot set bounds on input or output variables")
raise RuntimeError("Can set bounds only on either input or output variables")

0 comments on commit db5ff08

Please sign in to comment.