Skip to content

Commit

Permalink
Adjusted BaBSR Branching Strategy to cache all Biases and add candida…
Browse files Browse the repository at this point in the history
…te threshold
  • Loading branch information
liamjdavis committed Oct 31, 2024
1 parent fae9fa1 commit 697cbf0
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/configuration/GlobalConfiguration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ const unsigned GlobalConfiguration::REFACTORIZATION_THRESHOLD = 100;
const GlobalConfiguration::BasisFactorizationType GlobalConfiguration::BASIS_FACTORIZATION_TYPE =
GlobalConfiguration::SPARSE_FORREST_TOMLIN_FACTORIZATION;

const unsigned GlobalConfiguration::BABSR_CANDIDATES_THRESHOLD = 5;
const unsigned GlobalConfiguration::POLARITY_CANDIDATES_THRESHOLD = 5;

const unsigned GlobalConfiguration::DNC_DEPTH_THRESHOLD = 5;
Expand Down
5 changes: 5 additions & 0 deletions src/configuration/GlobalConfiguration.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ class GlobalConfiguration
};
static const BasisFactorizationType BASIS_FACTORIZATION_TYPE;

/* In the BaBSR-based branching heuristics, only this many earliest nodes are considered to
branch on.
*/
static const unsigned BABSR_CANDIDATES_THRESHOLD;

/* In the polarity-based branching heuristics, only this many earliest nodes
are considered to branch on.
*/
Expand Down
2 changes: 2 additions & 0 deletions src/engine/Engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2733,6 +2733,8 @@ PiecewiseLinearConstraint *Engine::pickSplitPLConstraintBasedOnBaBsrHeuristic()
// calculate heuristic score - bias calculation now happens inside computeBaBsr
plConstraint->updateScoreBasedOnBaBsr();
scoreToConstraint[plConstraint->getScore()] = plConstraint;
if ( scoreToConstraint.size() >= GlobalConfiguration::BABSR_CANDIDATES_THRESHOLD )
break;
}
}
}
Expand Down
28 changes: 25 additions & 3 deletions src/engine/ReluConstraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,27 @@ unsigned ReluConstraint::getAux() const
return _aux;
}

std::unordered_map<const ReluConstraint *, double> ReluConstraint::_biasCache;

void ReluConstraint::initializeBiasCache( NLR::NetworkLevelReasoner &nlr )
{
// Loop through all constraints in topological order
for ( const auto *constraint : nlr.getConstraintsInTopologicalOrder() )
{
// Only handle ReluConstraints and cache their biases
if ( const auto *reluConstraint = dynamic_cast<const ReluConstraint *>( constraint ) )
{
if ( _biasCache.find( reluConstraint ) == _biasCache.end() )
{
// Compute the bias and store it in the cache
double bias = nlr.getPrevBiasForReluConstraint( reluConstraint );
_biasCache[reluConstraint] = bias;
}
}
}
}


double ReluConstraint::computeBaBsr() const
{
double biasTerm = calculateBias();
Expand All @@ -1057,10 +1078,11 @@ double ReluConstraint::computeBaBsr() const

double ReluConstraint::calculateBias() const
{
if ( !_networkLevelReasoner )
throw NLRError( NLRError::RELU_NOT_FOUND );
auto it = _biasCache.find( this );
if ( it != _biasCache.end() )
return it->second;

return _networkLevelReasoner->getPrevBiasForReluConstraint( this );
throw NLRError( NLRError::RELU_NOT_FOUND, "Bias not found in cache." );
}

double ReluConstraint::computePolarity() const
Expand Down
10 changes: 10 additions & 0 deletions src/engine/ReluConstraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,14 @@ class ReluConstraint : public PiecewiseLinearConstraint
void setNetworkLevelReasoner( NLR::NetworkLevelReasoner *nlr )
{
_networkLevelReasoner = nlr;
initializeBiasCache( *_networkLevelReasoner );
}

/*
Cache biases for the source layers for ReLU neurons.
*/
static void initializeBiasCache( NLR::NetworkLevelReasoner &nlr );

/*
Restore the state of this constraint from the given one.
*/
Expand Down Expand Up @@ -306,6 +312,10 @@ class ReluConstraint : public PiecewiseLinearConstraint
*/
void addTableauAuxVar( unsigned tableauAuxVar, unsigned constraintAuxVar ) override;

/*
cached biases for the source layers for ReLU neurons.
*/
static std::unordered_map<const ReluConstraint *, double> _biasCache;
double calculateBias() const;
};

Expand Down

0 comments on commit 697cbf0

Please sign in to comment.