diff --git a/src/configuration/GlobalConfiguration.cpp b/src/configuration/GlobalConfiguration.cpp index c70ce013e..b84f9e4a5 100644 --- a/src/configuration/GlobalConfiguration.cpp +++ b/src/configuration/GlobalConfiguration.cpp @@ -106,6 +106,7 @@ const unsigned GlobalConfiguration::REFACTORIZATION_THRESHOLD = 100; const GlobalConfiguration::BasisFactorizationType GlobalConfiguration::BASIS_FACTORIZATION_TYPE = GlobalConfiguration::SPARSE_FORREST_TOMLIN_FACTORIZATION; +const unsigned GlobalConfiguration::BABSR_CANDIDATES_THRESHOLD = 5; const unsigned GlobalConfiguration::POLARITY_CANDIDATES_THRESHOLD = 5; const unsigned GlobalConfiguration::DNC_DEPTH_THRESHOLD = 5; diff --git a/src/configuration/GlobalConfiguration.h b/src/configuration/GlobalConfiguration.h index c89c1d54b..dfa5814bc 100644 --- a/src/configuration/GlobalConfiguration.h +++ b/src/configuration/GlobalConfiguration.h @@ -228,6 +228,11 @@ class GlobalConfiguration }; static const BasisFactorizationType BASIS_FACTORIZATION_TYPE; + /* In the BaBSR-based branching heuristics, only this many earliest nodes are considered to + branch on. + */ + static const unsigned BABSR_CANDIDATES_THRESHOLD; + /* In the polarity-based branching heuristics, only this many earliest nodes are considered to branch on. */ diff --git a/src/engine/Engine.cpp b/src/engine/Engine.cpp index bffc09f78..51a4021aa 100644 --- a/src/engine/Engine.cpp +++ b/src/engine/Engine.cpp @@ -2733,6 +2733,8 @@ PiecewiseLinearConstraint *Engine::pickSplitPLConstraintBasedOnBaBsrHeuristic() // calculate heuristic score - bias calculation now happens inside computeBaBsr plConstraint->updateScoreBasedOnBaBsr(); scoreToConstraint[plConstraint->getScore()] = plConstraint; + if ( scoreToConstraint.size() >= GlobalConfiguration::BABSR_CANDIDATES_THRESHOLD ) + break; } } } diff --git a/src/engine/ReluConstraint.cpp b/src/engine/ReluConstraint.cpp index 512e6dbbd..8438599b6 100644 --- a/src/engine/ReluConstraint.cpp +++ b/src/engine/ReluConstraint.cpp @@ -1034,6 +1034,27 @@ unsigned ReluConstraint::getAux() const return _aux; } +std::unordered_map ReluConstraint::_biasCache; + +void ReluConstraint::initializeBiasCache( NLR::NetworkLevelReasoner &nlr ) +{ + // Loop through all constraints in topological order + for ( const auto *constraint : nlr.getConstraintsInTopologicalOrder() ) + { + // Only handle ReluConstraints and cache their biases + if ( const auto *reluConstraint = dynamic_cast( constraint ) ) + { + if ( _biasCache.find( reluConstraint ) == _biasCache.end() ) + { + // Compute the bias and store it in the cache + double bias = nlr.getPrevBiasForReluConstraint( reluConstraint ); + _biasCache[reluConstraint] = bias; + } + } + } +} + + double ReluConstraint::computeBaBsr() const { double biasTerm = calculateBias(); @@ -1057,10 +1078,11 @@ double ReluConstraint::computeBaBsr() const double ReluConstraint::calculateBias() const { - if ( !_networkLevelReasoner ) - throw NLRError( NLRError::RELU_NOT_FOUND ); + auto it = _biasCache.find( this ); + if ( it != _biasCache.end() ) + return it->second; - return _networkLevelReasoner->getPrevBiasForReluConstraint( this ); + throw NLRError( NLRError::RELU_NOT_FOUND, "Bias not found in cache." ); } double ReluConstraint::computePolarity() const diff --git a/src/engine/ReluConstraint.h b/src/engine/ReluConstraint.h index 367ae63eb..a1ceb7a65 100644 --- a/src/engine/ReluConstraint.h +++ b/src/engine/ReluConstraint.h @@ -73,8 +73,14 @@ class ReluConstraint : public PiecewiseLinearConstraint void setNetworkLevelReasoner( NLR::NetworkLevelReasoner *nlr ) { _networkLevelReasoner = nlr; + initializeBiasCache( *_networkLevelReasoner ); } + /* + Cache biases for the source layers for ReLU neurons. + */ + static void initializeBiasCache( NLR::NetworkLevelReasoner &nlr ); + /* Restore the state of this constraint from the given one. */ @@ -306,6 +312,10 @@ class ReluConstraint : public PiecewiseLinearConstraint */ void addTableauAuxVar( unsigned tableauAuxVar, unsigned constraintAuxVar ) override; + /* + cached biases for the source layers for ReLU neurons. + */ + static std::unordered_map _biasCache; double calculateBias() const; };