Skip to content

Commit

Permalink
Refactor C++ ONNXParser to pure style that takes in an InputQueryBuil…
Browse files Browse the repository at this point in the history
…der (#757)
  • Loading branch information
MatthewDaggitt authored Feb 17, 2024
1 parent c7f150d commit c27a7c3
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 323 deletions.
5 changes: 3 additions & 2 deletions src/engine/DnCMarabou.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ void DnCMarabou::run()

if ( ( (String)networkFilePath ).endsWith( ".onnx" ) )
{
OnnxParser *_onnxParser = new OnnxParser( networkFilePath );
_onnxParser->generateQuery( _inputQuery );
InputQueryBuilder queryBuilder;
OnnxParser::parse( queryBuilder, networkFilePath, {}, {} );
queryBuilder.generateQuery( _inputQuery );
}
else
{
Expand Down
5 changes: 3 additions & 2 deletions src/engine/Marabou.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ void Marabou::prepareInputQuery()

if ( ( (String)networkFilePath ).endsWith( ".onnx" ) )
{
_onnxParser = new OnnxParser( networkFilePath );
_onnxParser->generateQuery( _inputQuery );
InputQueryBuilder queryBuilder;
OnnxParser::parse( queryBuilder, networkFilePath, {}, {} );
queryBuilder.generateQuery( _inputQuery );
}
else
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/********************* */
/*! \file NetworkParser.cpp
/*! \file InputQueryBuilder.cpp
** \verbatim
** Top contributors (to current version):
** Matthew Daggitt, Luca Arnaboldi
Expand All @@ -16,7 +16,7 @@
** Future parsers for individual network formats should extend this interface.
**/

#include "NetworkParser.h"
#include "InputQueryBuilder.h"

#include "Debug.h"
#include "FloatUtils.h"
Expand All @@ -30,51 +30,61 @@

#include <assert.h>

NetworkParser::NetworkParser()
InputQueryBuilder::InputQueryBuilder()
{
_numVars = 0;
}

Variable NetworkParser::getNewVariable()
Variable InputQueryBuilder::getNewVariable()
{
_numVars += 1;
return _numVars - 1;
}

void NetworkParser::addEquation( Equation &eq )
void InputQueryBuilder::markInputVariable( Variable var )
{
_inputVars.append( var );
}

void InputQueryBuilder::markOutputVariable( Variable var )
{
_outputVars.append( var );
}

void InputQueryBuilder::addEquation( Equation &eq )
{
_equationList.append( eq );
}

void NetworkParser::setLowerBound( Variable var, float value )
void InputQueryBuilder::setLowerBound( Variable var, float value )
{
_lowerBounds[var] = value;
}

void NetworkParser::setUpperBound( Variable var, float value )
void InputQueryBuilder::setUpperBound( Variable var, float value )
{
_upperBounds[var] = value;
}

void NetworkParser::addRelu( Variable inputVar, Variable outputVar )
void InputQueryBuilder::addRelu( Variable inputVar, Variable outputVar )
{
_reluList.append( new ReluConstraint( inputVar, outputVar ) );
setLowerBound( outputVar, 0.0f );
}

void NetworkParser::addLeakyRelu( Variable inputVar, Variable outputVar, float alpha )
void InputQueryBuilder::addLeakyRelu( Variable inputVar, Variable outputVar, float alpha )
{
_leakyReluList.append( new LeakyReluConstraint( inputVar, outputVar, alpha ) );
}

void NetworkParser::addSigmoid( Variable inputVar, Variable outputVar )
void InputQueryBuilder::addSigmoid( Variable inputVar, Variable outputVar )
{
_sigmoidList.append( new SigmoidConstraint( inputVar, outputVar ) );
setLowerBound( outputVar, 0.0 );
setUpperBound( outputVar, 1.0 );
}

void NetworkParser::addTanh( Variable inputVar, Variable outputVar )
void InputQueryBuilder::addTanh( Variable inputVar, Variable outputVar )
{
// Uses the identity `tanh(x) = 2 * sigmoid(2x) - 1` to implement
// it terms of a sigmoid constraint.
Expand All @@ -98,22 +108,22 @@ void NetworkParser::addTanh( Variable inputVar, Variable outputVar )
setUpperBound( outputVar, 1.0 );
}

void NetworkParser::addMaxConstraint( Variable var, Set<Variable> elements )
void InputQueryBuilder::addMaxConstraint( Variable var, Set<Variable> elements )
{
_maxList.append( new MaxConstraint( var, elements ) );
}

void NetworkParser::addSignConstraint( Variable inputVar, Variable outputVar )
void InputQueryBuilder::addSignConstraint( Variable inputVar, Variable outputVar )
{
_signList.append( new SignConstraint( inputVar, outputVar ) );
}

void NetworkParser::addAbsConstraint( Variable inputVar, Variable outputVar )
void InputQueryBuilder::addAbsConstraint( Variable inputVar, Variable outputVar )
{
_absList.append( new AbsoluteValueConstraint( inputVar, outputVar ) );
}

void NetworkParser::getMarabouQuery( InputQuery &query )
void InputQueryBuilder::generateQuery( InputQuery &query )
{
query.setNumberOfVariables( _numVars );

Expand Down Expand Up @@ -208,18 +218,16 @@ void NetworkParser::getMarabouQuery( InputQuery &query )
}
}

int NetworkParser::findEquationWithOutputVariable( Variable variable )
Equation *InputQueryBuilder::findEquationWithOutputVariable( Variable variable )
{
int i = 0;
for ( Equation &equation : _equationList )
{
Equation::Addend outputAddend = equation._addends.back();
if ( variable == outputAddend._variable )
{
ASSERT( outputAddend._coefficient == -1 );
return i;
return &equation;
}
i++;
}
return -1;
return NULL;
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/********************* */
/*! \file NetworkParser.cpp
/*! \file InputQueryBuilder.cpp
** \verbatim
** Top contributors (to current version):
** Matthew Daggitt
Expand All @@ -16,8 +16,8 @@
** Future parsers for individual network formats should extend this interface.
**/

#ifndef __NetworkParser_h__
#define __NetworkParser_h__
#ifndef __InputQueryBuilder_h__
#define __InputQueryBuilder_h__

#include "DisjunctionConstraint.h"
#include "Equation.h"
Expand All @@ -37,12 +37,10 @@

typedef unsigned int Variable;

class NetworkParser
class InputQueryBuilder
{
private:
unsigned int _numVars;

protected:
List<Variable> _inputVars;
List<Variable> _outputVars;

Expand All @@ -56,8 +54,13 @@ class NetworkParser
Map<Variable, float> _lowerBounds;
Map<Variable, float> _upperBounds;

NetworkParser();
public:
InputQueryBuilder();

Variable getNewVariable();

void markInputVariable( Variable var );
void markOutputVariable( Variable var );
void addEquation( Equation &eq );
void setLowerBound( Variable var, float value );
void setUpperBound( Variable var, float value );
Expand All @@ -69,10 +72,9 @@ class NetworkParser
void addMaxConstraint( Variable maxVar, Set<Variable> elements );
void addAbsConstraint( Variable var1, Variable var2 );

Variable getNewVariable();
void getMarabouQuery( InputQuery &query );
void generateQuery( InputQuery &query );

int findEquationWithOutputVariable( Variable variable );
Equation *findEquationWithOutputVariable( Variable variable );
};

#endif // __NetworkParser_h__
#endif // __InputQueryBuilder_h__
Loading

0 comments on commit c27a7c3

Please sign in to comment.