diff --git a/pdns/dnsdistdist/dnsdist-async.cc b/pdns/dnsdistdist/dnsdist-async.cc index 9cb96d83a226..0424abe8a0d5 100644 --- a/pdns/dnsdistdist/dnsdist-async.cc +++ b/pdns/dnsdistdist/dnsdist-async.cc @@ -222,8 +222,7 @@ static bool resumeResponse(std::unique_ptr&& response) auto& ids = response->query.d_idstate; DNSResponse dnsResponse = response->getDR(); - LocalHolders holders; - auto result = processResponseAfterRules(response->query.d_buffer, *holders.cacheInsertedRespRuleActions, dnsResponse, ids.cs->muted); + auto result = processResponseAfterRules(response->query.d_buffer, dnsResponse, ids.cs->muted); if (!result) { /* easy */ return true; diff --git a/pdns/dnsdistdist/dnsdist-configuration.hh b/pdns/dnsdistdist/dnsdist-configuration.hh index 9ee53a8483d6..91e84a29c4eb 100644 --- a/pdns/dnsdistdist/dnsdist-configuration.hh +++ b/pdns/dnsdistdist/dnsdist-configuration.hh @@ -26,6 +26,7 @@ #include #include "dnsdist-query-count.hh" +#include "dnsdist-rule-chains.hh" #include "iputils.hh" /* so what could you do: @@ -184,6 +185,16 @@ struct Configuration a RCU-like mechanism */ struct RuntimeConfiguration { + // ca tient pas la route: meilleure option: stocker un type plus opaque dans la configuration (dnsdist::rules::RuleChains) et + // laisser le soin a dnsdist::rules de le gerer + /* std::vector d_cacheMissRuleActions; + std::vector d_respruleactions; + std::vector d_cachehitrespruleactions; + std::vector d_selfansweredrespruleactions; + std::vector d_cacheInsertedRespRuleActions; + std::vector d_XFRRespRuleActions; + */ + rules::RuleChains d_ruleChains; servers_t d_backends; std::map> d_pools; std::shared_ptr d_lbPolicy; diff --git a/pdns/dnsdistdist/dnsdist-lua-actions.cc b/pdns/dnsdistdist/dnsdist-lua-actions.cc index 5a73fc4c0290..9133aa77e2c0 100644 --- a/pdns/dnsdistdist/dnsdist-lua-actions.cc +++ b/pdns/dnsdistdist/dnsdist-lua-actions.cc @@ -2376,8 +2376,8 @@ class SetExtendedDNSErrorResponseAction : public DNSResponseAction EDNSExtendedError d_ede; }; -template -static void addAction(GlobalStateHolder>* someRuleActions, const luadnsrule_t& var, const std::shared_ptr& action, boost::optional& params) +template +static void addAction(IdentifierT identifier, const luadnsrule_t& var, const std::shared_ptr& action, boost::optional& params) { setLuaSideEffect(); @@ -2388,8 +2388,8 @@ static void addAction(GlobalStateHolder>* someRuleActions, const luadn checkAllParametersConsumed("addAction", params); auto rule = makeRule(var, "addAction"); - someRuleActions->modify([&rule, &action, &uuid, creationOrder, &name](vector& ruleactions) { - ruleactions.push_back({std::move(rule), std::move(action), std::move(name), uuid, creationOrder}); + dnsdist::configuration::updateRuntimeConfiguration([identifier, &rule, &action, &name, &uuid, creationOrder](dnsdist::configuration::RuntimeConfiguration& config) { + dnsdist::rules::add(config.d_ruleChains, identifier, std::move(rule), action, std::move(name), uuid, creationOrder); }); } @@ -2418,20 +2418,21 @@ void setupLuaActions(LuaContext& luaCtx) return std::make_shared(ruleaction); }); - for (const auto& chain : dnsdist::rules::getRuleChains()) { + for (const auto& chain : dnsdist::rules::getRuleChainDescriptions()) { auto fullName = std::string("add") + chain.prefix + std::string("Action"); luaCtx.writeFunction(fullName, [&fullName, &chain](const luadnsrule_t& var, boost::variant, std::shared_ptr> era, boost::optional params) { if (era.type() != typeid(std::shared_ptr)) { throw std::runtime_error(fullName + "() can only be called with query-related actions, not response-related ones. Are you looking for addResponseAction()?"); } - addAction(&chain.holder, var, boost::get>(era), params); + addAction(chain.identifier, var, boost::get>(era), params); }); fullName = std::string("get") + chain.prefix + std::string("Action"); luaCtx.writeFunction(fullName, [&chain](unsigned int num) { setLuaNoSideEffect(); boost::optional> ret; - auto ruleactions = chain.holder.getCopy(); + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + const auto& ruleactions = dnsdist::rules::getRuleChain(chains, chain.identifier); if (num < ruleactions.size()) { ret = ruleactions[num].d_action; } @@ -2439,14 +2440,14 @@ void setupLuaActions(LuaContext& luaCtx) }); } - for (const auto& chain : dnsdist::rules::getResponseRuleChains()) { + for (const auto& chain : dnsdist::rules::getResponseRuleChainDescriptions()) { const auto fullName = std::string("add") + chain.prefix + std::string("ResponseAction"); luaCtx.writeFunction(fullName, [&fullName, &chain](const luadnsrule_t& var, boost::variant, std::shared_ptr> era, boost::optional params) { if (era.type() != typeid(std::shared_ptr)) { throw std::runtime_error(fullName + "() can only be called with response-related actions, not query-related ones. Are you looking for addAction()?"); } - addAction(&chain.holder, var, boost::get>(era), params); + addAction(chain.identifier, var, boost::get>(era), params); }); } diff --git a/pdns/dnsdistdist/dnsdist-lua-bindings.cc b/pdns/dnsdistdist/dnsdist-lua-bindings.cc index d485cfad65e8..a4b74ebc1c5f 100644 --- a/pdns/dnsdistdist/dnsdist-lua-bindings.cc +++ b/pdns/dnsdistdist/dnsdist-lua-bindings.cc @@ -113,7 +113,7 @@ void setupLuaBindings(LuaContext& luaCtx, bool client, bool configCheck) /* DownstreamState */ luaCtx.registerFunction("setQPS", [](DownstreamState& state, int lim) { state.qps = lim > 0 ? QPSLimiter(lim, lim) : QPSLimiter(); }); luaCtx.registerFunction::*)(string)>("addPool", [](const std::shared_ptr& state, const string& pool) { - addServerToPool( pool, state); + addServerToPool(pool, state); state->d_config.pools.insert(pool); }); luaCtx.registerFunction::*)(string)>("rmPool", [](const std::shared_ptr& state, const string& pool) { diff --git a/pdns/dnsdistdist/dnsdist-lua-rules.cc b/pdns/dnsdistdist/dnsdist-lua-rules.cc index 2abe6f1359f8..6343dba70630 100644 --- a/pdns/dnsdistdist/dnsdist-lua-rules.cc +++ b/pdns/dnsdistdist/dnsdist-lua-rules.cc @@ -128,91 +128,118 @@ static std::string rulesToString(const std::vector& rules, boost::optional -static void showRules(GlobalStateHolder>* someRuleActions, boost::optional& vars) +template +static void showRules(IdentifierT identifier, boost::optional& vars) { setLuaNoSideEffect(); - auto rules = someRuleActions->getLocal(); - g_outputBuffer += rulesToString(*rules, vars); + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + const auto& rules = dnsdist::rules::getRuleChain(chains, identifier); + g_outputBuffer += rulesToString(rules, vars); } -template -static void rmRule(GlobalStateHolder>* someRuleActions, const boost::variant& ruleID) +template +static bool removeRuleFromChain(ChainTypeT& rules, const std::function& matchFunction) +{ + auto removeIt = std::remove_if(rules.begin(), + rules.end(), + matchFunction); + if (removeIt == rules.end()) { + g_outputBuffer = "Error: no rule matched\n"; + return false; + } + rules.erase(removeIt, + rules.end()); + return true; +} + +template +static void rmRule(ChainIdentifierT chainIdentifier, const boost::variant& ruleID) { - setLuaSideEffect(); - auto rules = someRuleActions->getCopy(); if (const auto* str = boost::get(&ruleID)) { try { const auto uuid = getUniqueID(*str); - auto removeIt = std::remove_if(rules.begin(), - rules.end(), - [&uuid](const T& rule) { return rule.d_id == uuid; }); - if (removeIt == rules.end()) { - g_outputBuffer = "Error: no rule matched\n"; - return; - } - rules.erase(removeIt, - rules.end()); + dnsdist::configuration::updateRuntimeConfiguration([chainIdentifier, &uuid](dnsdist::configuration::RuntimeConfiguration& config) { + constexpr bool isResponseChain = std::is_same_v; + if constexpr (isResponseChain) { + auto& rules = dnsdist::rules::getResponseRuleChain(config.d_ruleChains, chainIdentifier); + std::function matchFunction = [&uuid](const dnsdist::rules::ResponseRuleAction& rule) -> bool { return rule.d_id == uuid; }; + removeRuleFromChain(rules, matchFunction); + } + else { + auto& rules = dnsdist::rules::getRuleChain(config.d_ruleChains, chainIdentifier); + std::function matchFunction = [&uuid](const dnsdist::rules::RuleAction& rule) -> bool { return rule.d_id == uuid; }; + removeRuleFromChain(rules, matchFunction); + } + }); } catch (const std::runtime_error& e) { - /* it was not an UUID, let's see if it was a name instead */ - auto removeIt = std::remove_if(rules.begin(), - rules.end(), - [&str](const T& rule) { return rule.d_name == *str; }); - if (removeIt == rules.end()) { - g_outputBuffer = "Error: no rule matched\n"; - return; - } - rules.erase(removeIt, - rules.end()); + dnsdist::configuration::updateRuntimeConfiguration([chainIdentifier, &str](dnsdist::configuration::RuntimeConfiguration& config) { + constexpr bool isResponseChain = std::is_same_v; + if constexpr (isResponseChain) { + auto& rules = dnsdist::rules::getResponseRuleChain(config.d_ruleChains, chainIdentifier); + std::function matchFunction = [&str](const dnsdist::rules::ResponseRuleAction& rule) -> bool { return rule.d_name == *str; }; + removeRuleFromChain(rules, matchFunction); + } + else { + auto& rules = dnsdist::rules::getRuleChain(config.d_ruleChains, chainIdentifier); + std::function matchFunction = [&str](const dnsdist::rules::RuleAction& rule) -> bool { return rule.d_name == *str; }; + removeRuleFromChain(rules, matchFunction); + } + }); } } else if (const auto* pos = boost::get(&ruleID)) { - if (*pos >= rules.size()) { - g_outputBuffer = "Error: attempt to delete non-existing rule\n"; - return; - } - rules.erase(rules.begin() + *pos); + dnsdist::configuration::updateRuntimeConfiguration([chainIdentifier, pos](dnsdist::configuration::RuntimeConfiguration& config) { + auto& rules = dnsdist::rules::getRuleChain(config.d_ruleChains, chainIdentifier); + if (*pos >= rules.size()) { + g_outputBuffer = "Error: attempt to delete non-existing rule\n"; + return; + } + rules.erase(rules.begin() + *pos); + }); } - someRuleActions->setState(std::move(rules)); + setLuaSideEffect(); } -template -static void moveRuleToTop(GlobalStateHolder>* someRuleActions) +template +static void moveRuleToTop(IdentifierTypeT chainIdentifier) { setLuaSideEffect(); - auto rules = someRuleActions->getCopy(); - if (rules.empty()) { - return; - } - auto subject = *rules.rbegin(); - rules.erase(std::prev(rules.end())); - rules.insert(rules.begin(), subject); - someRuleActions->setState(std::move(rules)); + dnsdist::configuration::updateRuntimeConfiguration([chainIdentifier](dnsdist::configuration::RuntimeConfiguration& config) { + auto& rules = dnsdist::rules::getRuleChain(config.d_ruleChains, chainIdentifier); + if (rules.empty()) { + return; + } + auto subject = *rules.rbegin(); + rules.erase(std::prev(rules.end())); + rules.insert(rules.begin(), subject); + }); + setLuaSideEffect(); } -template -static void mvRule(GlobalStateHolder>* someRespRuleActions, unsigned int from, unsigned int destination) +template +static void mvRule(IdentifierTypeT chainIdentifier, unsigned int from, unsigned int destination) { - setLuaSideEffect(); - auto rules = someRespRuleActions->getCopy(); - if (from >= rules.size() || destination > rules.size()) { - g_outputBuffer = "Error: attempt to move rules from/to invalid index\n"; - return; - } - auto subject = rules[from]; - rules.erase(rules.begin() + from); - if (destination > rules.size()) { - rules.push_back(subject); - } - else { - if (from < destination) { - --destination; + dnsdist::configuration::updateRuntimeConfiguration([chainIdentifier, from, &destination](dnsdist::configuration::RuntimeConfiguration& config) { + auto& rules = dnsdist::rules::getRuleChain(config.d_ruleChains, chainIdentifier); + if (from >= rules.size() || destination > rules.size()) { + g_outputBuffer = "Error: attempt to move rules from/to invalid index\n"; + return; } - rules.insert(rules.begin() + destination, subject); - } - someRespRuleActions->setState(std::move(rules)); + auto subject = rules[from]; + rules.erase(rules.begin() + from); + if (destination > rules.size()) { + rules.push_back(subject); + } + else { + if (from < destination) { + --destination; + } + rules.insert(rules.begin() + destination, subject); + } + }); + setLuaSideEffect(); } template @@ -333,90 +360,100 @@ void setupLuaRules(LuaContext& luaCtx) luaCtx.registerFunction (dnsdist::rules::ResponseRuleAction::*)() const>("getAction", [](const dnsdist::rules::ResponseRuleAction& rule) { return rule.d_action; }); - for (const auto& chain : dnsdist::rules::getResponseRuleChains()) { + for (const auto& chain : dnsdist::rules::getResponseRuleChainDescriptions()) { luaCtx.writeFunction("show" + chain.prefix + "ResponseRules", [&chain](boost::optional vars) { - showRules(&chain.holder, vars); + showRules(chain.identifier, vars); }); luaCtx.writeFunction("rm" + chain.prefix + "ResponseRule", [&chain](const boost::variant& identifier) { - rmRule(&chain.holder, identifier); + rmRule(chain.identifier, identifier); }); luaCtx.writeFunction("mv" + chain.prefix + "ResponseRuleToTop", [&chain]() { - moveRuleToTop(&chain.holder); + moveRuleToTop(chain.identifier); }); luaCtx.writeFunction("mv" + chain.prefix + "ResponseRule", [&chain](unsigned int from, unsigned int dest) { - mvRule(&chain.holder, from, dest); + mvRule(chain.identifier, from, dest); }); luaCtx.writeFunction("get" + chain.prefix + "ResponseRule", [&chain](const boost::variant& selector) -> boost::optional { - auto rules = chain.holder.getLocal(); - return getRuleFromSelector(*rules, selector); + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + const auto& rules = dnsdist::rules::getResponseRuleChain(chains, chain.identifier); + return getRuleFromSelector(rules, selector); }); luaCtx.writeFunction("getTop" + chain.prefix + "ResponseRules", [&chain](boost::optional top) { setLuaNoSideEffect(); - auto rules = chain.holder.getLocal(); - return toLuaArray(getTopRules(*rules, (top ? *top : 10))); + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + const auto& rules = dnsdist::rules::getResponseRuleChain(chains, chain.identifier); + return toLuaArray(getTopRules(rules, (top ? *top : 10))); }); luaCtx.writeFunction("top" + chain.prefix + "ResponseRules", [&chain](boost::optional top, boost::optional vars) { setLuaNoSideEffect(); - auto rules = chain.holder.getLocal(); - return rulesToString(getTopRules(*rules, (top ? *top : 10)), vars); + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + const auto& rules = dnsdist::rules::getResponseRuleChain(chains, chain.identifier); + return rulesToString(getTopRules(rules, (top ? *top : 10)), vars); }); luaCtx.writeFunction("clear" + chain.prefix + "ResponseRules", [&chain]() { setLuaSideEffect(); - chain.holder.modify([](std::remove_reference_t::value_type& ruleactions) { - ruleactions.clear(); + dnsdist::configuration::updateRuntimeConfiguration([&chain](dnsdist::configuration::RuntimeConfiguration& config) { + auto& rules = dnsdist::rules::getRuleChain(config.d_ruleChains, chain.identifier); + rules.clear(); }); }); } - for (const auto& chain : dnsdist::rules::getRuleChains()) { + for (const auto& chain : dnsdist::rules::getRuleChainDescriptions()) { luaCtx.writeFunction("show" + chain.prefix + "Rules", [&chain](boost::optional vars) { - showRules(&chain.holder, vars); + showRules(chain.identifier, vars); }); luaCtx.writeFunction("rm" + chain.prefix + "Rule", [&chain](const boost::variant& identifier) { - rmRule(&chain.holder, identifier); + rmRule(chain.identifier, identifier); }); luaCtx.writeFunction("mv" + chain.prefix + "RuleToTop", [&chain]() { - moveRuleToTop(&chain.holder); + moveRuleToTop(chain.identifier); }); luaCtx.writeFunction("mv" + chain.prefix + "Rule", [&chain](unsigned int from, unsigned int dest) { - mvRule(&chain.holder, from, dest); + mvRule(chain.identifier, from, dest); }); luaCtx.writeFunction("get" + chain.prefix + "Rule", [&chain](const boost::variant& selector) -> boost::optional { - auto rules = chain.holder.getLocal(); - return getRuleFromSelector(*rules, selector); + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + const auto& rules = dnsdist::rules::getRuleChain(chains, chain.identifier); + return getRuleFromSelector(rules, selector); }); luaCtx.writeFunction("getTop" + chain.prefix + "Rules", [&chain](boost::optional top) { setLuaNoSideEffect(); - auto rules = chain.holder.getLocal(); - return toLuaArray(getTopRules(*rules, (top ? *top : 10))); + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + const auto& rules = dnsdist::rules::getRuleChain(chains, chain.identifier); + return toLuaArray(getTopRules(rules, (top ? *top : 10))); }); luaCtx.writeFunction("top" + chain.prefix + "Rules", [&chain](boost::optional top, boost::optional vars) { setLuaNoSideEffect(); - auto rules = chain.holder.getLocal(); - return rulesToString(getTopRules(*rules, (top ? *top : 10)), vars); + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + const auto& rules = dnsdist::rules::getRuleChain(chains, chain.identifier); + + return rulesToString(getTopRules(rules, (top ? *top : 10)), vars); }); luaCtx.writeFunction("clear" + chain.prefix + "Rules", [&chain]() { setLuaSideEffect(); - chain.holder.modify([](std::remove_reference_t::value_type& ruleactions) { - ruleactions.clear(); + dnsdist::configuration::updateRuntimeConfiguration([&chain](dnsdist::configuration::RuntimeConfiguration& config) { + auto& rules = dnsdist::rules::getRuleChain(config.d_ruleChains, chain.identifier); + rules.clear(); }); }); luaCtx.writeFunction("set" + chain.prefix + "Rules", [&chain](const LuaArray>& newruleactions) { setLuaSideEffect(); - chain.holder.modify([newruleactions](std::remove_reference_t::value_type& gruleactions) { - gruleactions.clear(); + dnsdist::configuration::updateRuntimeConfiguration([&chain, &newruleactions](dnsdist::configuration::RuntimeConfiguration& config) { + auto& rules = dnsdist::rules::getRuleChain(config.d_ruleChains, chain.identifier); + rules.clear(); for (const auto& pair : newruleactions) { const auto& newruleaction = pair.second; if (newruleaction->d_action) { auto rule = newruleaction->d_rule; - gruleactions.push_back({std::move(rule), newruleaction->d_action, newruleaction->d_name, newruleaction->d_id, newruleaction->d_creationOrder}); + rules.push_back({std::move(rule), newruleaction->d_action, newruleaction->d_name, newruleaction->d_id, newruleaction->d_creationOrder}); } } }); diff --git a/pdns/dnsdistdist/dnsdist-rule-chains.cc b/pdns/dnsdistdist/dnsdist-rule-chains.cc index 1c79fd0a24e9..340e4006fb4c 100644 --- a/pdns/dnsdistdist/dnsdist-rule-chains.cc +++ b/pdns/dnsdistdist/dnsdist-rule-chains.cc @@ -24,44 +24,101 @@ namespace dnsdist::rules { -GlobalStateHolder> s_ruleActions; -GlobalStateHolder> s_cacheMissRuleActions; -GlobalStateHolder> s_respruleactions; -GlobalStateHolder> s_cachehitrespruleactions; -GlobalStateHolder> s_selfansweredrespruleactions; -GlobalStateHolder> s_cacheInsertedRespRuleActions; -GlobalStateHolder> s_XFRRespRuleActions; - static const std::vector s_responseRuleChains{ - {"", "response-rules", s_respruleactions}, - {"CacheHit", "cache-hit-response-rules", s_cachehitrespruleactions}, - {"CacheInserted", "cache-inserted-response-rules", s_selfansweredrespruleactions}, - {"SelfAnswered", "self-answered-response-rules", s_cacheInsertedRespRuleActions}, - {"XFR", "xfr-response-rules", s_XFRRespRuleActions}, + {"", "response-rules", ResponseRuleChain::ResponseRules}, + {"CacheHit", "cache-hit-response-rules", ResponseRuleChain::CacheHitResponseRules}, + {"CacheInserted", "cache-inserted-response-rules", ResponseRuleChain::CacheInsertedResponseRules}, + {"SelfAnswered", "self-answered-response-rules", ResponseRuleChain::SelfAnsweredResponseRules}, + {"XFR", "xfr-response-rules", ResponseRuleChain::XFRResponseRules}, }; -const std::vector& getResponseRuleChains() +const std::vector& getResponseRuleChainDescriptions() { return s_responseRuleChains; } -GlobalStateHolder>& getResponseRuleChainHolder(ResponseRuleChain chain) -{ - return s_responseRuleChains.at(static_cast(chain)).holder; -} - static const std::vector s_ruleChains{ - {"", "rules", s_ruleActions}, - {"CacheMiss", "cache-miss-rules", s_cacheMissRuleActions}, + {"", "rules", RuleChain::Rules}, + {"CacheMiss", "cache-miss-rules", RuleChain::CacheMissRules}, }; -const std::vector& getRuleChains() +const std::vector& getRuleChainDescriptions() { return s_ruleChains; } -GlobalStateHolder>& getRuleChainHolder(RuleChain chain) +std::vector& getRuleChain(RuleChains& chains, RuleChain chain) +{ + switch (chain) { + case RuleChain::Rules: + return chains.d_ruleActions; + case RuleChain::CacheMissRules: + return chains.d_cacheMissRuleActions; + } +} + +const std::vector& getRuleChain(const RuleChains& chains, RuleChain chain) +{ + switch (chain) { + case RuleChain::Rules: + return chains.d_ruleActions; + case RuleChain::CacheMissRules: + return chains.d_cacheMissRuleActions; + } +} + +std::vector& getRuleChain(RuleChains& chains, ResponseRuleChain chain) +{ + return getResponseRuleChain(chains, chain); +} + +const std::vector& getRuleChain(const RuleChains& chains, ResponseRuleChain chain) { - return s_ruleChains.at(static_cast(chain)).holder; + return getResponseRuleChain(chains, chain); } + +std::vector& getResponseRuleChain(RuleChains& chains, ResponseRuleChain chain) +{ + switch (chain) { + case ResponseRuleChain::ResponseRules: + return chains.d_respruleactions; + case ResponseRuleChain::CacheHitResponseRules: + return chains.d_cachehitrespruleactions; + case ResponseRuleChain::CacheInsertedResponseRules: + return chains.d_cacheInsertedRespRuleActions; + case ResponseRuleChain::SelfAnsweredResponseRules: + return chains.d_selfansweredrespruleactions; + case ResponseRuleChain::XFRResponseRules: + return chains.d_XFRRespRuleActions; + } +} + +const std::vector& getResponseRuleChain(const RuleChains& chains, ResponseRuleChain chain) +{ + switch (chain) { + case ResponseRuleChain::ResponseRules: + return chains.d_respruleactions; + case ResponseRuleChain::CacheHitResponseRules: + return chains.d_cachehitrespruleactions; + case ResponseRuleChain::CacheInsertedResponseRules: + return chains.d_cacheInsertedRespRuleActions; + case ResponseRuleChain::SelfAnsweredResponseRules: + return chains.d_selfansweredrespruleactions; + case ResponseRuleChain::XFRResponseRules: + return chains.d_XFRRespRuleActions; + } +} + +void add(RuleChains& chains, RuleChain identifier, const std::shared_ptr& selector, const std::shared_ptr& action, std::string&& name, const boost::uuids::uuid& uuid, uint64_t creationOrder) +{ + auto& chain = getRuleChain(chains, identifier); + chain.push_back({selector, action, std::move(name), uuid, creationOrder}); +} + +void add(RuleChains& chains, ResponseRuleChain identifier, const std::shared_ptr& selector, const std::shared_ptr& action, std::string&& name, const boost::uuids::uuid& uuid, uint64_t creationOrder) +{ + auto& chain = getResponseRuleChain(chains, identifier); + chain.push_back({selector, action, std::move(name), uuid, creationOrder}); +} + } diff --git a/pdns/dnsdistdist/dnsdist-rule-chains.hh b/pdns/dnsdistdist/dnsdist-rule-chains.hh index 5d2220cdd565..47657635ca00 100644 --- a/pdns/dnsdistdist/dnsdist-rule-chains.hh +++ b/pdns/dnsdistdist/dnsdist-rule-chains.hh @@ -25,7 +25,6 @@ #include #include -#include "sholder.hh" #include "uuid-utils.hh" class DNSRule; @@ -43,21 +42,18 @@ struct RuleAction uint64_t d_creationOrder; }; -struct RuleChainDescription -{ - std::string prefix; - std::string metricName; - GlobalStateHolder>& holder; -}; - enum class RuleChain : uint8_t { Rules = 0, CacheMissRules = 1, }; -const std::vector& getRuleChains(); -GlobalStateHolder>& getRuleChainHolder(RuleChain chain); +struct RuleChainDescription +{ + const std::string prefix; + const std::string metricName; + const RuleChain identifier; +}; struct ResponseRuleAction { @@ -79,12 +75,30 @@ enum class ResponseRuleChain : uint8_t struct ResponseRuleChainDescription { - std::string prefix; - std::string metricName; - GlobalStateHolder>& holder; + const std::string prefix; + const std::string metricName; + const ResponseRuleChain identifier; }; -const std::vector& getResponseRuleChains(); -GlobalStateHolder>& getResponseRuleChainHolder(ResponseRuleChain chain); +struct RuleChains +{ + std::vector d_ruleActions; + std::vector d_cacheMissRuleActions; + std::vector d_respruleactions; + std::vector d_cachehitrespruleactions; + std::vector d_selfansweredrespruleactions; + std::vector d_cacheInsertedRespRuleActions; + std::vector d_XFRRespRuleActions; +}; +const std::vector& getRuleChainDescriptions(); +std::vector& getRuleChain(RuleChains& chains, RuleChain chain); +const std::vector& getRuleChain(const RuleChains& chains, RuleChain chain); +const std::vector& getResponseRuleChainDescriptions(); +std::vector& getRuleChain(RuleChains& chains, ResponseRuleChain chain); +const std::vector& getRuleChain(const RuleChains& chains, ResponseRuleChain chain); +std::vector& getResponseRuleChain(RuleChains& chains, ResponseRuleChain chain); +const std::vector& getResponseRuleChain(const RuleChains& chains, ResponseRuleChain chain); +void add(RuleChains& chains, RuleChain identifier, const std::shared_ptr& selector, const std::shared_ptr& action, std::string&& name, const boost::uuids::uuid& uuid, uint64_t creationOrder); +void add(RuleChains& chains, ResponseRuleChain identifier, const std::shared_ptr& selector, const std::shared_ptr& action, std::string&& name, const boost::uuids::uuid& uuid, uint64_t creationOrder); } diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh index e03646d4560f..44992284d019 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh @@ -10,14 +10,11 @@ class TCPClientThreadData { public: TCPClientThreadData(): - localRespRuleActions(dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::ResponseRules).getLocal()), localCacheInsertedRespRuleActions(dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::CacheInsertedResponseRules).getLocal()), localXFRRespRuleActions(dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::XFRResponseRules).getLocal()), mplexer(std::unique_ptr(FDMultiplexer::getMultiplexerSilent())) + mplexer(std::unique_ptr(FDMultiplexer::getMultiplexerSilent())) { } LocalHolders holders; - LocalStateHolder> localRespRuleActions; - LocalStateHolder> localCacheInsertedRespRuleActions; - LocalStateHolder> localXFRRespRuleActions; std::unique_ptr mplexer{nullptr}; pdns::channel::Receiver queryReceiver; pdns::channel::Receiver crossProtocolQueryReceiver; diff --git a/pdns/dnsdistdist/dnsdist-tcp.cc b/pdns/dnsdistdist/dnsdist-tcp.cc index 955ad29734b5..ef695d211214 100644 --- a/pdns/dnsdistdist/dnsdist-tcp.cc +++ b/pdns/dnsdistdist/dnsdist-tcp.cc @@ -503,7 +503,7 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe memcpy(&response.d_cleartextDH, dnsResponse.getHeader().get(), sizeof(response.d_cleartextDH)); - if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dnsResponse, false)) { + if (!processResponse(response.d_buffer, dnsResponse, false)) { state->terminateClientConnection(); return; } @@ -1205,8 +1205,11 @@ void IncomingTCPConnectionState::notifyIOError(const struct timeval& now, TCPRes } } -static bool processXFRResponse(PacketBuffer& response, const std::vector& xfrRespRuleActions, DNSResponse& dnsResponse) +static bool processXFRResponse(PacketBuffer& response, DNSResponse& dnsResponse) { + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + const auto& xfrRespRuleActions = dnsdist::rules::getResponseRuleChain(chains, dnsdist::rules::ResponseRuleChain::XFRResponseRules); + if (!applyRulesToResponse(xfrRespRuleActions, dnsResponse)) { return false; } @@ -1236,7 +1239,7 @@ void IncomingTCPConnectionState::handleXFRResponse(const struct timeval& now, TC dnsResponse.d_incomingTCPState = state; memcpy(&response.d_cleartextDH, dnsResponse.getHeader().get(), sizeof(response.d_cleartextDH)); - if (!processXFRResponse(response.d_buffer, *state->d_threadData.localXFRRespRuleActions, dnsResponse)) { + if (!processXFRResponse(response.d_buffer, dnsResponse)) { state->terminateClientConnection(); return; } diff --git a/pdns/dnsdistdist/dnsdist-web.cc b/pdns/dnsdistdist/dnsdist-web.cc index fd72f383064a..7fcaad5f468f 100644 --- a/pdns/dnsdistdist/dnsdist-web.cc +++ b/pdns/dnsdistdist/dnsdist-web.cc @@ -446,14 +446,13 @@ static void addCustomHeaders(YaHTTP::Response& resp, const boost::optional -static json11::Json::array someResponseRulesToJson(GlobalStateHolder>* someResponseRules) +static json11::Json::array someResponseRulesToJson(const std::vector& someResponseRules) { using namespace json11; Json::array responseRules; int num = 0; - auto localResponseRules = someResponseRules->getLocal(); - responseRules.reserve(localResponseRules->size()); - for (const auto& rule : *localResponseRules) { + responseRules.reserve(someResponseRules.size()); + for (const auto& rule : someResponseRules) { responseRules.emplace_back(Json::object{ {"id", num++}, {"creationOrder", static_cast(rule.d_creationOrder)}, @@ -469,10 +468,9 @@ static json11::Json::array someResponseRulesToJson(GlobalStateHolder>* #ifndef DISABLE_PROMETHEUS template -static void addRulesToPrometheusOutput(std::ostringstream& output, GlobalStateHolder>& rules) +static void addRulesToPrometheusOutput(std::ostringstream& output, const std::vector& rules) { - auto localRules = rules.getLocal(); - for (const auto& entry : *localRules) { + for (const auto& entry : rules) { std::string identifier = !entry.d_name.empty() ? entry.d_name : boost::uuids::to_string(entry.d_id); output << "dnsdist_rule_hits{id=\"" << identifier << "\"} " << entry.d_rule->d_matches << "\n"; } @@ -897,11 +895,14 @@ static void handlePrometheus(const YaHTTP::Request& req, YaHTTP::Response& resp) output << "# HELP dnsdist_rule_hits " << "Number of hits of that rule" << "\n"; output << "# TYPE dnsdist_rule_hits " << "counter" << "\n"; - for (const auto& chain : dnsdist::rules::getRuleChains()) { - addRulesToPrometheusOutput(output, chain.holder); + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + for (const auto& chainDescription : dnsdist::rules::getRuleChainDescriptions()) { + const auto& chain = dnsdist::rules::getRuleChain(chains, chainDescription.identifier); + addRulesToPrometheusOutput(output, chain); } - for (const auto& chain : dnsdist::rules::getResponseRuleChains()) { - addRulesToPrometheusOutput(output, chain.holder); + for (const auto& chainDescription : dnsdist::rules::getResponseRuleChainDescriptions()) { + const auto& chain = dnsdist::rules::getResponseRuleChain(chains, chainDescription.identifier); + addRulesToPrometheusOutput(output, chain); } #ifndef DISABLE_DYNBLOCKS @@ -1314,12 +1315,13 @@ static void handleStats(const YaHTTP::Request& req, YaHTTP::Response& resp) /* unfortunately DNSActions have getStats(), and DNSResponseActions do not. */ - for (const auto& chain : dnsdist::rules::getRuleChains()) { + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + for (const auto& chainDescription : dnsdist::rules::getRuleChainDescriptions()) { Json::array rules; - auto localRules = chain.holder.getLocal(); + const auto& chain = dnsdist::rules::getRuleChain(chains, chainDescription.identifier); num = 0; - rules.reserve(localRules->size()); - for (const auto& lrule : *localRules) { + rules.reserve(chain.size()); + for (const auto& lrule : chain) { Json::object rule{ {"id", num++}, {"creationOrder", (double)lrule.d_creationOrder}, @@ -1331,12 +1333,13 @@ static void handleStats(const YaHTTP::Request& req, YaHTTP::Response& resp) {"action-stats", lrule.d_action->getStats()}}; rules.emplace_back(std::move(rule)); } - responseObject[chain.metricName] = std::move(rules); + responseObject[chainDescription.metricName] = std::move(rules); } - for (const auto& chain : dnsdist::rules::getResponseRuleChains()) { - auto responseRules = someResponseRulesToJson(&chain.holder); - responseObject[chain.metricName] = std::move(responseRules); + for (const auto& chainDescription : dnsdist::rules::getResponseRuleChainDescriptions()) { + const auto& chain = dnsdist::rules::getResponseRuleChain(chains, chainDescription.identifier); + auto responseRules = someResponseRulesToJson(chain); + responseObject[chainDescription.metricName] = std::move(responseRules); } resp.headers["Content-Type"] = "application/json"; diff --git a/pdns/dnsdistdist/dnsdist-xsk.cc b/pdns/dnsdistdist/dnsdist-xsk.cc index 4e09bac81da1..2deafb1887cc 100644 --- a/pdns/dnsdistdist/dnsdist-xsk.cc +++ b/pdns/dnsdistdist/dnsdist-xsk.cc @@ -39,8 +39,6 @@ void XskResponderThread(std::shared_ptr dss, std::shared_ptrisStopped()) { poll(pollfds.data(), pollfds.size(), -1); @@ -75,7 +73,7 @@ void XskResponderThread(std::shared_ptr dss, std::shared_ptrxskPacketHeader.clear(); } - if (!processResponderPacket(dss, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, std::move(*ids))) { + if (!processResponderPacket(dss, response, std::move(*ids))) { xskInfo->markAsFree(packet); infolog("XSK packet pushed to queue because processResponderPacket failed"); return; @@ -114,7 +112,7 @@ void XskResponderThread(std::shared_ptr dss, std::shared_ptr dss, std::shared_ptr xskInfo); -bool XskIsQueryAcceptable(const XskPacket& packet, ClientState& clientState, LocalHolders& holders, bool& expectProxyProtocol); +bool XskIsQueryAcceptable(const XskPacket& packet, ClientState& clientState, bool& expectProxyProtocol); bool XskProcessQuery(ClientState& clientState, LocalHolders& holders, XskPacket& packet); void XskRouter(std::shared_ptr xsk); void XskClientThread(ClientState* clientState); diff --git a/pdns/dnsdistdist/dnsdist.cc b/pdns/dnsdistdist/dnsdist.cc index 953ac3854526..492196183fc4 100644 --- a/pdns/dnsdistdist/dnsdist.cc +++ b/pdns/dnsdistdist/dnsdist.cc @@ -522,7 +522,7 @@ bool applyRulesToResponse(const std::vector& return true; } -bool processResponseAfterRules(PacketBuffer& response, const std::vector& cacheInsertedRespRuleActions, DNSResponse& dnsResponse, bool muted) +bool processResponseAfterRules(PacketBuffer& response, DNSResponse& dnsResponse, bool muted) { bool zeroScope = false; if (!fixUpResponse(response, dnsResponse.ids.qname, dnsResponse.ids.origFlags, dnsResponse.ids.ednsAdded, dnsResponse.ids.ecsAdded, dnsResponse.ids.useZeroScope ? &zeroScope : nullptr)) { @@ -552,6 +552,8 @@ bool processResponseAfterRules(PacketBuffer& response, const std::vectorinsert(cacheKey, zeroScope ? boost::none : dnsResponse.ids.subnet, dnsResponse.ids.cacheFlags, dnsResponse.ids.dnssecOK, dnsResponse.ids.qname, dnsResponse.ids.qtype, dnsResponse.ids.qclass, response, dnsResponse.ids.forwardedOverUDP, dnsResponse.getHeader()->rcode, dnsResponse.ids.tempFailureTTL); + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + const auto& cacheInsertedRespRuleActions = dnsdist::rules::getResponseRuleChain(chains, dnsdist::rules::ResponseRuleChain::CacheInsertedResponseRules); if (!applyRulesToResponse(cacheInsertedRespRuleActions, dnsResponse)) { return false; } @@ -578,8 +580,11 @@ bool processResponseAfterRules(PacketBuffer& response, const std::vector& respRuleActions, const std::vector& cacheInsertedRespRuleActions, DNSResponse& dnsResponse, bool muted) +bool processResponse(PacketBuffer& response, DNSResponse& dnsResponse, bool muted) { + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + const auto& respRuleActions = dnsdist::rules::getResponseRuleChain(chains, dnsdist::rules::ResponseRuleChain::ResponseRules); + if (!applyRulesToResponse(respRuleActions, dnsResponse)) { return false; } @@ -588,7 +593,7 @@ bool processResponse(PacketBuffer& response, const std::vector& respRuleActions, const std::vector& cacheInsertedRespRuleActions, const std::shared_ptr& backend, bool isAsync, bool selfGenerated) +static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& response, const std::shared_ptr& backend, bool isAsync, bool selfGenerated) { DNSResponse dnsResponse(ids, response, backend); @@ -684,7 +689,7 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH)); if (!isAsync) { - if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dnsResponse, ids.cs != nullptr && ids.cs->muted)) { + if (!processResponse(response, dnsResponse, ids.cs != nullptr && ids.cs->muted)) { return; } @@ -725,7 +730,7 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re } } -bool processResponderPacket(std::shared_ptr& dss, PacketBuffer& response, const std::vector& localRespRuleActions, const std::vector& cacheInsertedRespRuleActions, InternalQueryState&& ids) +bool processResponderPacket(std::shared_ptr& dss, PacketBuffer& response, InternalQueryState&& ids) { const dnsheader_aligned dnsHeader(response.data()); @@ -757,7 +762,7 @@ bool processResponderPacket(std::shared_ptr& dss, PacketBuffer& return false; } - handleResponseForUDPClient(ids, response, localRespRuleActions, cacheInsertedRespRuleActions, dss, false, false); + handleResponseForUDPClient(ids, response, dss, false, false); return true; } @@ -766,8 +771,6 @@ void responderThread(std::shared_ptr dss) { try { setThreadName("dnsdist/respond"); - auto localRespRuleActions = dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::ResponseRules).getLocal(); - auto localCacheInsertedRespRuleActions = dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::CacheInsertedResponseRules).getLocal(); const size_t initialBufferSize = getInitialUDPPacketBufferSize(false); /* allocate one more byte so we can detect truncation */ PacketBuffer response(initialBufferSize + 1); @@ -826,7 +829,7 @@ void responderThread(std::shared_ptr dss) continue; } - if (processResponderPacket(dss, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, std::move(*ids)) && ids->isXSK() && ids->cs->xskInfo) { + if (processResponderPacket(dss, response, std::move(*ids)) && ids->isXSK() && ids->cs->xskInfo) { #ifdef HAVE_XSK auto& xskInfo = ids->cs->xskInfo; auto xskPacket = xskInfo->getEmptyFrame(); @@ -1216,7 +1219,9 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dnsQuestion, c } #endif /* DISABLE_DYNBLOCKS */ - return applyRulesChainToQuery(*holders.ruleactions, dnsQuestion); + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + const auto& queryRules = dnsdist::rules::getRuleChain(chains, dnsdist::rules::RuleChain::Rules); + return applyRulesChainToQuery(queryRules, dnsQuestion); } ssize_t udpClientSendRequestToBackend(const std::shared_ptr& backend, const int socketDesc, const PacketBuffer& request, bool healthCheck) @@ -1368,14 +1373,17 @@ struct mmsghdr #endif /* self-generated responses or cache hits */ -static bool prepareOutgoingResponse(LocalHolders& holders, const ClientState& clientState, DNSQuestion& dnsQuestion, bool cacheHit) +static bool prepareOutgoingResponse(const ClientState& clientState, DNSQuestion& dnsQuestion, bool cacheHit) { std::shared_ptr backend{nullptr}; DNSResponse dnsResponse(dnsQuestion.ids, dnsQuestion.getMutableData(), backend); dnsResponse.d_incomingTCPState = dnsQuestion.d_incomingTCPState; dnsResponse.ids.selfGenerated = true; - if (!applyRulesToResponse(cacheHit ? *holders.cacheHitRespRuleactions : *holders.selfAnsweredRespRuleactions, dnsResponse)) { + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + const auto& cacheHitRespRules = dnsdist::rules::getResponseRuleChain(chains, dnsdist::rules::ResponseRuleChain::CacheHitResponseRules); + const auto& selfAnsweredRespRules = dnsdist::rules::getResponseRuleChain(chains, dnsdist::rules::ResponseRuleChain::SelfAnsweredResponseRules); + if (!applyRulesToResponse(cacheHit ? cacheHitRespRules : selfAnsweredRespRules, dnsResponse)) { return false; } @@ -1408,11 +1416,11 @@ static bool prepareOutgoingResponse(LocalHolders& holders, const ClientState& cl return true; } -static ProcessQueryResult handleQueryTurnedIntoSelfAnsweredResponse(DNSQuestion& dnsQuestion, LocalHolders& holders) +static ProcessQueryResult handleQueryTurnedIntoSelfAnsweredResponse(DNSQuestion& dnsQuestion) { fixUpQueryTurnedResponse(dnsQuestion, dnsQuestion.ids.origFlags); - if (!prepareOutgoingResponse(holders, *dnsQuestion.ids.cs, dnsQuestion, false)) { + if (!prepareOutgoingResponse(*dnsQuestion.ids.cs, dnsQuestion, false)) { return ProcessQueryResult::Drop; } @@ -1446,7 +1454,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, LocalHolders try { if (dnsQuestion.getHeader()->qr) { // something turned it into a response - return handleQueryTurnedIntoSelfAnsweredResponse(dnsQuestion, holders); + return handleQueryTurnedIntoSelfAnsweredResponse(dnsQuestion); } std::shared_ptr serverPool = getPool(dnsQuestion.ids.poolName); dnsQuestion.ids.packetCache = serverPool->packetCache; @@ -1467,7 +1475,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, LocalHolders vinfolog("Packet cache hit for query for %s|%s from %s (%s, %d bytes)", dnsQuestion.ids.qname.toLogString(), QType(dnsQuestion.ids.qtype).toString(), dnsQuestion.ids.origRemote.toStringWithPort(), dnsQuestion.ids.protocol.toString(), dnsQuestion.getData().size()); - if (!prepareOutgoingResponse(holders, *dnsQuestion.ids.cs, dnsQuestion, true)) { + if (!prepareOutgoingResponse(*dnsQuestion.ids.cs, dnsQuestion, true)) { return ProcessQueryResult::Drop; } @@ -1505,7 +1513,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, LocalHolders vinfolog("Packet cache hit for query for %s|%s from %s (%s, %d bytes)", dnsQuestion.ids.qname.toLogString(), QType(dnsQuestion.ids.qtype).toString(), dnsQuestion.ids.origRemote.toStringWithPort(), dnsQuestion.ids.protocol.toString(), dnsQuestion.getData().size()); - if (!prepareOutgoingResponse(holders, *dnsQuestion.ids.cs, dnsQuestion, true)) { + if (!prepareOutgoingResponse(*dnsQuestion.ids.cs, dnsQuestion, true)) { return ProcessQueryResult::Drop; } @@ -1516,7 +1524,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, LocalHolders if (dnsQuestion.ids.protocol == dnsdist::Protocol::DoH && !forwardedOverUDP) { /* do a second-lookup for UDP responses, but we do not want TC=1 answers */ if (dnsQuestion.ids.packetCache->get(dnsQuestion, dnsQuestion.getHeader()->id, &dnsQuestion.ids.cacheKeyUDP, dnsQuestion.ids.subnet, dnsQuestion.ids.dnssecOK, true, allowExpired, false, false, true)) { - if (!prepareOutgoingResponse(holders, *dnsQuestion.ids.cs, dnsQuestion, true)) { + if (!prepareOutgoingResponse(*dnsQuestion.ids.cs, dnsQuestion, true)) { return ProcessQueryResult::Drop; } @@ -1531,11 +1539,14 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, LocalHolders ++dnsdist::metrics::g_stats.cacheMisses; const auto existingPool = dnsQuestion.ids.poolName; - if (!applyRulesChainToQuery(*holders.cacheMissRuleActions, dnsQuestion)) { + const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains; + const auto& cacheMissRuleActions = dnsdist::rules::getRuleChain(chains, dnsdist::rules::RuleChain::CacheMissRules); + + if (!applyRulesChainToQuery(cacheMissRuleActions, dnsQuestion)) { return ProcessQueryResult::Drop; } if (dnsQuestion.getHeader()->qr) { // something turned it into a response - return handleQueryTurnedIntoSelfAnsweredResponse(dnsQuestion, holders); + return handleQueryTurnedIntoSelfAnsweredResponse(dnsQuestion); } /* let's be nice and allow the selection of a different pool, but no second cache-lookup for you */ @@ -1560,7 +1571,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, LocalHolders fixUpQueryTurnedResponse(dnsQuestion, dnsQuestion.ids.origFlags); - if (!prepareOutgoingResponse(holders, *dnsQuestion.ids.cs, dnsQuestion, false)) { + if (!prepareOutgoingResponse(*dnsQuestion.ids.cs, dnsQuestion, false)) { return ProcessQueryResult::Drop; } ++dnsdist::metrics::g_stats.responses; @@ -1614,10 +1625,7 @@ class UDPTCPCrossQuerySender : public TCPQuerySender auto& ids = response.d_idstate; - static thread_local LocalStateHolder> localRespRuleActions = dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::ResponseRules).getLocal(); - static thread_local LocalStateHolder> localCacheInsertedRespRuleActions = dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::CacheInsertedResponseRules).getLocal(); - - handleResponseForUDPClient(ids, response.d_buffer, *localRespRuleActions, *localCacheInsertedRespRuleActions, response.d_ds, response.isAsync(), response.d_idstate.selfGenerated); + handleResponseForUDPClient(ids, response.d_buffer, response.d_ds, response.isAsync(), response.d_idstate.selfGenerated); } void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override @@ -1925,7 +1933,7 @@ bool XskProcessQuery(ClientState& clientState, LocalHolders& holders, XskPacket& try { bool expectProxyProtocol = false; - if (!XskIsQueryAcceptable(packet, clientState, holders, expectProxyProtocol)) { + if (!XskIsQueryAcceptable(packet, clientState, expectProxyProtocol)) { return false; } @@ -2782,13 +2790,8 @@ static void cleanupLuaObjects() { /* when our coverage mode is enabled, we need to make sure that the Lua objects are destroyed before the Lua contexts. */ - for (const auto& chain : dnsdist::rules::getRuleChains()) { - chain.holder.setState({}); - } - for (const auto& chain : dnsdist::rules::getResponseRuleChains()) { - chain.holder.setState({}); - } dnsdist::configuration::updateRuntimeConfiguration([](dnsdist::configuration::RuntimeConfiguration& config) { + config.d_ruleChains = dnsdist::rules::RuleChains(); config.d_lbPolicy = std::make_shared(); config.d_pools.clear(); config.d_backends.clear(); diff --git a/pdns/dnsdistdist/dnsdist.hh b/pdns/dnsdistdist/dnsdist.hh index 4a24060c5b69..8ed4136a7266 100644 --- a/pdns/dnsdistdist/dnsdist.hh +++ b/pdns/dnsdistdist/dnsdist.hh @@ -1061,25 +1061,20 @@ enum class ProcessQueryResult : uint8_t struct LocalHolders { LocalHolders() : - ruleactions(dnsdist::rules::getRuleChainHolder(dnsdist::rules::RuleChain::Rules).getLocal()), cacheMissRuleActions(dnsdist::rules::getRuleChainHolder(dnsdist::rules::RuleChain::CacheMissRules).getLocal()), cacheHitRespRuleactions(dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::CacheHitResponseRules).getLocal()), cacheInsertedRespRuleActions(dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::CacheInsertedResponseRules).getLocal()), selfAnsweredRespRuleactions(dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::SelfAnsweredResponseRules).getLocal()), dynNMGBlock(g_dynblockNMG.getLocal()), dynSMTBlock(g_dynblockSMT.getLocal()) + dynNMGBlock(g_dynblockNMG.getLocal()), dynSMTBlock(g_dynblockSMT.getLocal()) { } - LocalStateHolder> ruleactions; - LocalStateHolder> cacheMissRuleActions; - LocalStateHolder> cacheHitRespRuleactions; - LocalStateHolder> cacheInsertedRespRuleActions; - LocalStateHolder> selfAnsweredRespRuleactions; LocalStateHolder> dynNMGBlock; LocalStateHolder> dynSMTBlock; }; ProcessQueryResult processQuery(DNSQuestion& dnsQuestion, LocalHolders& holders, std::shared_ptr& selectedBackend); ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, LocalHolders& holders, std::shared_ptr& selectedBackend); -bool processResponse(PacketBuffer& response, const std::vector& respRuleActions, const std::vector& cacheInsertedRespRuleActions, DNSResponse& dnsResponse, bool muted); +bool processResponse(PacketBuffer& response, DNSResponse& dnsResponse, bool muted); bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dnsQuestion, std::string& ruleresult, bool& drop); -bool processResponseAfterRules(PacketBuffer& response, const std::vector& cacheInsertedRespRuleActions, DNSResponse& dnsResponse, bool muted); -bool processResponderPacket(std::shared_ptr& dss, PacketBuffer& response, const std::vector& localRespRuleActions, const std::vector& cacheInsertedRespRuleActions, InternalQueryState&& ids); +bool processResponseAfterRules(PacketBuffer& response, DNSResponse& dnsResponse, bool muted); +bool processResponderPacket(std::shared_ptr& dss, PacketBuffer& response, InternalQueryState&& ids); bool applyRulesToResponse(const std::vector& respRuleActions, DNSResponse& dnsResponse); bool assignOutgoingUDPQueryToBackend(std::shared_ptr& downstream, uint16_t queryID, DNSQuestion& dnsQuestion, PacketBuffer& query, bool actuallySend = true); diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index fa6604b079ff..ef96639a708f 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -503,12 +503,9 @@ class DoHTCPCrossQuerySender final : public TCPQuerySender memcpy(&cleartextDH, dr.getHeader().get(), sizeof(cleartextDH)); if (!response.isAsync()) { - static thread_local LocalStateHolder> localRespRuleActions = dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::ResponseRules).getLocal(); - static thread_local LocalStateHolder> localCacheInsertedRespRuleActions = dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::CacheInsertedResponseRules).getLocal(); - dr.ids.du = std::move(dohUnit); - if (!processResponse(dynamic_cast(dr.ids.du.get())->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) { + if (!processResponse(dynamic_cast(dr.ids.du.get())->response, dr, false)) { if (dr.ids.du) { dohUnit = getDUFromIDS(dr.ids); dohUnit->status_code = 503; @@ -1649,15 +1646,12 @@ void DOHUnit::handleUDPResponse(PacketBuffer&& udpResponse, InternalQueryState&& } } if (!dohUnit->truncated) { - static thread_local LocalStateHolder> localRespRuleActions = dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::ResponseRules).getLocal(); - static thread_local LocalStateHolder> localCacheInsertedRespRuleActions = dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::CacheInsertedResponseRules).getLocal(); - DNSResponse dnsResponse(dohUnit->ids, udpResponse, dohUnit->downstream); dnsheader cleartextDH{}; memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH)); dnsResponse.ids.du = std::move(dohUnit); - if (!processResponse(udpResponse, *localRespRuleActions, *localCacheInsertedRespRuleActions, dnsResponse, false)) { + if (!processResponse(udpResponse, dnsResponse, false)) { if (dnsResponse.ids.du) { dohUnit = getDUFromIDS(dnsResponse.ids); dohUnit->status_code = 503; diff --git a/pdns/dnsdistdist/doh3.cc b/pdns/dnsdistdist/doh3.cc index bbf2d5e3afe8..39df3bfaae15 100644 --- a/pdns/dnsdistdist/doh3.cc +++ b/pdns/dnsdistdist/doh3.cc @@ -142,12 +142,9 @@ class DOH3TCPCrossQuerySender final : public TCPQuerySender if (!response.isAsync()) { - static thread_local LocalStateHolder> localRespRuleActions = dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::ResponseRules).getLocal(); - static thread_local LocalStateHolder> localCacheInsertedRespRuleActions = dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::CacheInsertedResponseRules).getLocal(); - dnsResponse.ids.doh3u = std::move(unit); - if (!processResponse(dnsResponse.ids.doh3u->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dnsResponse, false)) { + if (!processResponse(dnsResponse.ids.doh3u->response, dnsResponse, false)) { if (dnsResponse.ids.doh3u) { sendBackDOH3Unit(std::move(dnsResponse.ids.doh3u), "Response dropped by rules"); diff --git a/pdns/dnsdistdist/doq.cc b/pdns/dnsdistdist/doq.cc index 54da845b7ec7..f535b1c5c063 100644 --- a/pdns/dnsdistdist/doq.cc +++ b/pdns/dnsdistdist/doq.cc @@ -135,13 +135,9 @@ class DOQTCPCrossQuerySender final : public TCPQuerySender memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH)); if (!response.isAsync()) { - - static thread_local LocalStateHolder> localRespRuleActions = dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::ResponseRules).getLocal(); - static thread_local LocalStateHolder> localCacheInsertedRespRuleActions = dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::CacheInsertedResponseRules).getLocal(); - dnsResponse.ids.doqu = std::move(unit); - if (!processResponse(dnsResponse.ids.doqu->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dnsResponse, false)) { + if (!processResponse(dnsResponse.ids.doqu->response, dnsResponse, false)) { if (dnsResponse.ids.doqu) { sendBackDOQUnit(std::move(dnsResponse.ids.doqu), "Response dropped by rules"); diff --git a/pdns/dnsdistdist/test-dnsdist-lua-ffi.cc b/pdns/dnsdistdist/test-dnsdist-lua-ffi.cc index a412bd2e5d69..03c6a75759cc 100644 --- a/pdns/dnsdistdist/test-dnsdist-lua-ffi.cc +++ b/pdns/dnsdistdist/test-dnsdist-lua-ffi.cc @@ -474,7 +474,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCache) testPool->packetCache = packetCache; std::string poolWithNoCacheName("test-pool-without-cache"); auto testPoolWithNoCache = std::make_shared(); - dnsdist::configuration::updateRuntimeConfiguration([&poolName,&testPool,&poolWithNoCacheName,&testPoolWithNoCache](dnsdist::configuration::RuntimeConfiguration& config) { + dnsdist::configuration::updateRuntimeConfiguration([&poolName, &testPool, &poolWithNoCacheName, &testPoolWithNoCache](dnsdist::configuration::RuntimeConfiguration& config) { config.d_pools.emplace(poolName, testPool); config.d_pools.emplace(poolWithNoCacheName, testPoolWithNoCache); }); diff --git a/pdns/dnsdistdist/test-dnsdist_cc.cc b/pdns/dnsdistdist/test-dnsdist_cc.cc index 9fbdcbfd1af0..233afb94499a 100644 --- a/pdns/dnsdistdist/test-dnsdist_cc.cc +++ b/pdns/dnsdistdist/test-dnsdist_cc.cc @@ -48,7 +48,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, LocalHolders return ProcessQueryResult::Drop; } -bool processResponseAfterRules(PacketBuffer& response, const std::vector& cacheInsertedRespRuleActions, DNSResponse& dnsResponse, bool muted) +bool processResponseAfterRules(PacketBuffer& response, DNSResponse& dnsResponse, bool muted) { return false; } @@ -92,7 +92,7 @@ bool XskProcessQuery(ClientState& clientState, LocalHolders& holders, XskPacket& } #endif /* HAVE_XSK */ -bool processResponderPacket(std::shared_ptr& dss, PacketBuffer& response, const std::vector& localRespRuleActions, const std::vector& cacheInsertedRespRuleActions, InternalQueryState&& ids) +bool processResponderPacket(std::shared_ptr& dss, PacketBuffer& response, InternalQueryState&& ids) { return false; } diff --git a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc index 694460f98da5..9475310f96b4 100644 --- a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc +++ b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc @@ -73,7 +73,7 @@ bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, static std::function s_processResponse; -bool processResponse(PacketBuffer& response, const std::vector& respRuleActions, const std::vector& cacheInsertedRespRuleActions, DNSResponse& dnsResponse, bool muted) +bool processResponse(PacketBuffer& response, DNSResponse& dnsResponse, bool muted) { if (s_processResponse) { return s_processResponse(response, dnsResponse, muted);