/* * This file is part of PowerDNS or dnsdist. * Copyright -- PowerDNS.COM B.V. and its contributors * * This program is free software; you can redistribute it and/or modify * it under the terms of version 2 of the GNU General Public License as * published by the Free Software Foundation. * * In addition, for the avoidance of any doubt, permission is granted to * link this program with OpenSSL and to (re)distribute the binaries * produced as the result of such linking. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ #pragma once #include "cachecleaner.hh" #include "dnsdist.hh" #include "dnsdist-ecs.hh" #include "dnsdist-kvs.hh" #include "dnsdist-lua-ffi.hh" #include "dolog.hh" #include "dnsparser.hh" class MaxQPSIPRule : public DNSRule { public: MaxQPSIPRule(unsigned int qps, unsigned int burst, unsigned int ipv4trunc=32, unsigned int ipv6trunc=64, unsigned int expiration=300, unsigned int cleanupDelay=60, unsigned int scanFraction=10): d_qps(qps), d_burst(burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc), d_cleanupDelay(cleanupDelay), d_expiration(expiration), d_scanFraction(scanFraction) { gettime(&d_lastCleanup, true); } void clear() { d_limits.lock()->clear(); } size_t cleanup(const struct timespec& cutOff, size_t* scannedCount=nullptr) const { auto limits = d_limits.lock(); size_t toLook = limits->size() / d_scanFraction + 1; size_t lookedAt = 0; size_t removed = 0; auto& sequence = limits->get(); for (auto entry = sequence.begin(); entry != sequence.end() && lookedAt < toLook; lookedAt++) { if (entry->d_limiter.seenSince(cutOff)) { /* entries are ordered from least recently seen to more recently seen, as soon as we see one that has not expired yet, we are done */ lookedAt++; break; } entry = sequence.erase(entry); removed++; } if (scannedCount != nullptr) { *scannedCount = lookedAt; } return removed; } void cleanupIfNeeded(const struct timespec& now) const { if (d_cleanupDelay > 0) { struct timespec cutOff = d_lastCleanup; cutOff.tv_sec += d_cleanupDelay; if (cutOff < now) { /* the QPS Limiter doesn't use realtime, be careful! */ gettime(&cutOff, false); cutOff.tv_sec -= d_expiration; cleanup(cutOff); d_lastCleanup = now; } } } bool matches(const DNSQuestion* dq) const override { cleanupIfNeeded(*dq->queryTime); ComboAddress zeroport(*dq->remote); zeroport.sin4.sin_port=0; zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc); { auto limits = d_limits.lock(); auto iter = limits->find(zeroport); if (iter == limits->end()) { Entry e(zeroport, QPSLimiter(d_qps, d_burst)); iter = limits->insert(e).first; } moveCacheItemToBack(*limits, iter); return !iter->d_limiter.check(d_qps, d_burst); } } string toString() const override { return "IP (/"+std::to_string(d_ipv4trunc)+", /"+std::to_string(d_ipv6trunc)+") match for QPS over " + std::to_string(d_qps) + " burst "+ std::to_string(d_burst); } size_t getEntriesCount() const { return d_limits.lock()->size(); } private: struct OrderedTag {}; struct SequencedTag {}; struct Entry { Entry(const ComboAddress& addr, BasicQPSLimiter&& limiter): d_limiter(limiter), d_addr(addr) { } mutable BasicQPSLimiter d_limiter; ComboAddress d_addr; }; typedef multi_index_container< Entry, indexed_by < ordered_unique, member, ComboAddress::addressOnlyLessThan >, sequenced > > > qpsContainer_t; mutable LockGuarded d_limits; mutable struct timespec d_lastCleanup; unsigned int d_qps, d_burst, d_ipv4trunc, d_ipv6trunc, d_cleanupDelay, d_expiration; unsigned int d_scanFraction{10}; }; class MaxQPSRule : public DNSRule { public: MaxQPSRule(unsigned int qps) : d_qps(qps, qps) {} MaxQPSRule(unsigned int qps, unsigned int burst) : d_qps(qps, burst) {} bool matches(const DNSQuestion* qd) const override { return d_qps.check(); } string toString() const override { return "Max " + std::to_string(d_qps.getRate()) + " qps"; } private: mutable QPSLimiter d_qps; }; class NMGRule : public DNSRule { public: NMGRule(const NetmaskGroup& nmg) : d_nmg(nmg) {} protected: NetmaskGroup d_nmg; }; class NetmaskGroupRule : public NMGRule { public: NetmaskGroupRule(const NetmaskGroup& nmg, bool src, bool quiet = false) : NMGRule(nmg) { d_src = src; d_quiet = quiet; } bool matches(const DNSQuestion* dq) const override { if(!d_src) { return d_nmg.match(*dq->local); } return d_nmg.match(*dq->remote); } string toString() const override { string ret = "Src: "; if(!d_src) { ret = "Dst: "; } if (d_quiet) { return ret + "in-group"; } return ret + d_nmg.toString(); } private: bool d_src; bool d_quiet; }; class TimedIPSetRule : public DNSRule, boost::noncopyable { private: struct IPv6 { IPv6(const ComboAddress& ca) { static_assert(sizeof(*this)==16, "IPv6 struct has wrong size"); memcpy((char*)this, ca.sin6.sin6_addr.s6_addr, 16); } bool operator==(const IPv6& rhs) const { return a==rhs.a && b==rhs.b; } uint64_t a, b; }; public: TimedIPSetRule() { } ~TimedIPSetRule() { } bool matches(const DNSQuestion* dq) const override { if (dq->remote->sin4.sin_family == AF_INET) { auto ip4s = d_ip4s.read_lock(); auto fnd = ip4s->find(dq->remote->sin4.sin_addr.s_addr); if (fnd == ip4s->end()) { return false; } return time(nullptr) < fnd->second; } else { auto ip6s = d_ip6s.read_lock(); auto fnd = ip6s->find({*dq->remote}); if (fnd == ip6s->end()) { return false; } return time(nullptr) < fnd->second; } } void add(const ComboAddress& ca, time_t ttd) { // think twice before adding templates here if (ca.sin4.sin_family == AF_INET) { auto res = d_ip4s.write_lock()->insert({ca.sin4.sin_addr.s_addr, ttd}); if (!res.second && (time_t)res.first->second < ttd) { res.first->second = (uint32_t)ttd; } } else { auto res = d_ip6s.write_lock()->insert({{ca}, ttd}); if (!res.second && (time_t)res.first->second < ttd) { res.first->second = (uint32_t)ttd; } } } void remove(const ComboAddress& ca) { if (ca.sin4.sin_family == AF_INET) { d_ip4s.write_lock()->erase(ca.sin4.sin_addr.s_addr); } else { d_ip6s.write_lock()->erase({ca}); } } void clear() { d_ip4s.write_lock()->clear(); d_ip6s.write_lock()->clear(); } void cleanup() { time_t now = time(nullptr); { auto ip4s = d_ip4s.write_lock(); for (auto iter = ip4s->begin(); iter != ip4s->end(); ) { if (iter->second < now) { iter = ip4s->erase(iter); } else { ++iter; } } } { auto ip6s = d_ip6s.write_lock(); for (auto iter = ip6s->begin(); iter != ip6s->end(); ) { if (iter->second < now) { iter = ip6s->erase(iter); } else { ++iter; } } } } string toString() const override { time_t now = time(nullptr); uint64_t count = 0; for (const auto& ip : *(d_ip4s.read_lock())) { if (now < ip.second) { ++count; } } for (const auto& ip : *(d_ip6s.read_lock())) { if (now < ip.second) { ++count; } } return "Src: "+std::to_string(count)+" ips"; } private: struct IPv6Hash { std::size_t operator()(const IPv6& ip) const { auto ah=std::hash{}(ip.a); auto bh=std::hash{}(ip.b); return ah & (bh<<1); } }; mutable SharedLockGuarded> d_ip6s; mutable SharedLockGuarded> d_ip4s; }; class AllRule : public DNSRule { public: AllRule() {} bool matches(const DNSQuestion* dq) const override { return true; } string toString() const override { return "All"; } }; class DNSSECRule : public DNSRule { public: DNSSECRule() { } bool matches(const DNSQuestion* dq) const override { return dq->getHeader()->cd || (getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default.. } string toString() const override { return "DNSSEC"; } }; class AndRule : public DNSRule { public: AndRule(const vector > >& rules) { for(const auto& r : rules) d_rules.push_back(r.second); } bool matches(const DNSQuestion* dq) const override { auto iter = d_rules.begin(); for(; iter != d_rules.end(); ++iter) if(!(*iter)->matches(dq)) break; return iter == d_rules.end(); } string toString() const override { string ret; for(const auto& rule : d_rules) { if(!ret.empty()) ret+= " && "; ret += "("+ rule->toString()+")"; } return ret; } private: vector > d_rules; }; class OrRule : public DNSRule { public: OrRule(const vector > >& rules) { for(const auto& r : rules) d_rules.push_back(r.second); } bool matches(const DNSQuestion* dq) const override { auto iter = d_rules.begin(); for(; iter != d_rules.end(); ++iter) if((*iter)->matches(dq)) return true; return false; } string toString() const override { string ret; for(const auto& rule : d_rules) { if(!ret.empty()) ret+= " || "; ret += "("+ rule->toString()+")"; } return ret; } private: vector > d_rules; }; class RegexRule : public DNSRule { public: RegexRule(const std::string& regex) : d_regex(regex), d_visual(regex) { } bool matches(const DNSQuestion* dq) const override { return d_regex.match(dq->qname->toStringNoDot()); } string toString() const override { return "Regex: "+d_visual; } private: Regex d_regex; string d_visual; }; #ifdef HAVE_RE2 #include class RE2Rule : public DNSRule { public: RE2Rule(const std::string& re2) : d_re2(re2, RE2::Latin1), d_visual(re2) { } bool matches(const DNSQuestion* dq) const override { return RE2::FullMatch(dq->qname->toStringNoDot(), d_re2); } string toString() const override { return "RE2 match: "+d_visual; } private: RE2 d_re2; string d_visual; }; #endif #ifdef HAVE_DNS_OVER_HTTPS class HTTPHeaderRule : public DNSRule { public: HTTPHeaderRule(const std::string& header, const std::string& regex); bool matches(const DNSQuestion* dq) const override; string toString() const override; private: string d_header; Regex d_regex; string d_visual; }; class HTTPPathRule : public DNSRule { public: HTTPPathRule(const std::string& path); bool matches(const DNSQuestion* dq) const override; string toString() const override; private: string d_path; }; class HTTPPathRegexRule : public DNSRule { public: HTTPPathRegexRule(const std::string& regex); bool matches(const DNSQuestion* dq) const override; string toString() const override; private: Regex d_regex; std::string d_visual; }; #endif class SNIRule : public DNSRule { public: SNIRule(const std::string& name) : d_sni(name) { } bool matches(const DNSQuestion* dq) const override { return dq->sni == d_sni; } string toString() const override { return "SNI == " + d_sni; } private: std::string d_sni; }; class SuffixMatchNodeRule : public DNSRule { public: SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet=false) : d_smn(smn), d_quiet(quiet) { } bool matches(const DNSQuestion* dq) const override { return d_smn.check(*dq->qname); } string toString() const override { if(d_quiet) return "qname==in-set"; else return "qname in "+d_smn.toString(); } private: SuffixMatchNode d_smn; bool d_quiet; }; class QNameRule : public DNSRule { public: QNameRule(const DNSName& qname) : d_qname(qname) { } bool matches(const DNSQuestion* dq) const override { return d_qname==*dq->qname; } string toString() const override { return "qname=="+d_qname.toString(); } private: DNSName d_qname; }; class QNameSetRule : public DNSRule { public: QNameSetRule(const DNSNameSet& names) : qname_idx(names) {} bool matches(const DNSQuestion* dq) const override { return qname_idx.find(*dq->qname) != qname_idx.end(); } string toString() const override { std::stringstream ss; ss << "qname in DNSNameSet(" << qname_idx.size() << " FQDNs)"; return ss.str(); } private: DNSNameSet qname_idx; }; class QTypeRule : public DNSRule { public: QTypeRule(uint16_t qtype) : d_qtype(qtype) { } bool matches(const DNSQuestion* dq) const override { return d_qtype == dq->qtype; } string toString() const override { QType qt(d_qtype); return "qtype=="+qt.toString(); } private: uint16_t d_qtype; }; class QClassRule : public DNSRule { public: QClassRule(uint16_t qclass) : d_qclass(qclass) { } bool matches(const DNSQuestion* dq) const override { return d_qclass == dq->qclass; } string toString() const override { return "qclass=="+std::to_string(d_qclass); } private: uint16_t d_qclass; }; class OpcodeRule : public DNSRule { public: OpcodeRule(uint8_t opcode) : d_opcode(opcode) { } bool matches(const DNSQuestion* dq) const override { return d_opcode == dq->getHeader()->opcode; } string toString() const override { return "opcode=="+std::to_string(d_opcode); } private: uint8_t d_opcode; }; class DSTPortRule : public DNSRule { public: DSTPortRule(uint16_t port) : d_port(port) { } bool matches(const DNSQuestion* dq) const override { return htons(d_port) == dq->local->sin4.sin_port; } string toString() const override { return "dst port=="+std::to_string(d_port); } private: uint16_t d_port; }; class TCPRule : public DNSRule { public: TCPRule(bool tcp): d_tcp(tcp) { } bool matches(const DNSQuestion* dq) const override { return dq->overTCP() == d_tcp; } string toString() const override { return (d_tcp ? "TCP" : "UDP"); } private: bool d_tcp; }; class NotRule : public DNSRule { public: NotRule(shared_ptr& rule): d_rule(rule) { } bool matches(const DNSQuestion* dq) const override { return !d_rule->matches(dq); } string toString() const override { return "!("+ d_rule->toString()+")"; } private: shared_ptr d_rule; }; class RecordsCountRule : public DNSRule { public: RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount): d_minCount(minCount), d_maxCount(maxCount), d_section(section) { } bool matches(const DNSQuestion* dq) const override { uint16_t count = 0; switch(d_section) { case 0: count = ntohs(dq->getHeader()->qdcount); break; case 1: count = ntohs(dq->getHeader()->ancount); break; case 2: count = ntohs(dq->getHeader()->nscount); break; case 3: count = ntohs(dq->getHeader()->arcount); break; } return count >= d_minCount && count <= d_maxCount; } string toString() const override { string section; switch(d_section) { case 0: section = "QD"; break; case 1: section = "AN"; break; case 2: section = "NS"; break; case 3: section = "AR"; break; } return std::to_string(d_minCount) + " <= records in " + section + " <= "+ std::to_string(d_maxCount); } private: uint16_t d_minCount; uint16_t d_maxCount; uint8_t d_section; }; class RecordsTypeCountRule : public DNSRule { public: RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount): d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section) { } bool matches(const DNSQuestion* dq) const override { uint16_t count = 0; switch(d_section) { case 0: count = ntohs(dq->getHeader()->qdcount); break; case 1: count = ntohs(dq->getHeader()->ancount); break; case 2: count = ntohs(dq->getHeader()->nscount); break; case 3: count = ntohs(dq->getHeader()->arcount); break; } if (count < d_minCount) { return false; } count = getRecordsOfTypeCount(reinterpret_cast(dq->getData().data()), dq->getData().size(), d_section, d_type); return count >= d_minCount && count <= d_maxCount; } string toString() const override { string section; switch(d_section) { case 0: section = "QD"; break; case 1: section = "AN"; break; case 2: section = "NS"; break; case 3: section = "AR"; break; } return std::to_string(d_minCount) + " <= " + QType(d_type).toString() + " records in " + section + " <= "+ std::to_string(d_maxCount); } private: uint16_t d_type; uint16_t d_minCount; uint16_t d_maxCount; uint8_t d_section; }; class TrailingDataRule : public DNSRule { public: TrailingDataRule() { } bool matches(const DNSQuestion* dq) const override { uint16_t length = getDNSPacketLength(reinterpret_cast(dq->getData().data()), dq->getData().size()); return length < dq->getData().size(); } string toString() const override { return "trailing data"; } }; class QNameLabelsCountRule : public DNSRule { public: QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount) { } bool matches(const DNSQuestion* dq) const override { unsigned int count = dq->qname->countLabels(); return count < d_min || count > d_max; } string toString() const override { return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max); } private: unsigned int d_min; unsigned int d_max; }; class QNameWireLengthRule : public DNSRule { public: QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max) { } bool matches(const DNSQuestion* dq) const override { size_t const wirelength = dq->qname->wirelength(); return wirelength < d_min || wirelength > d_max; } string toString() const override { return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max); } private: size_t d_min; size_t d_max; }; class RCodeRule : public DNSRule { public: RCodeRule(uint8_t rcode) : d_rcode(rcode) { } bool matches(const DNSQuestion* dq) const override { return d_rcode == dq->getHeader()->rcode; } string toString() const override { return "rcode=="+RCode::to_s(d_rcode); } private: uint8_t d_rcode; }; class ERCodeRule : public DNSRule { public: ERCodeRule(uint8_t rcode) : d_rcode(rcode & 0xF), d_extrcode(rcode >> 4) { } bool matches(const DNSQuestion* dq) const override { // avoid parsing EDNS OPT RR when not needed. if (d_rcode != dq->getHeader()->rcode) { return false; } EDNS0Record edns0; if (!getEDNS0Record(*dq, edns0)) { return false; } return d_extrcode == edns0.extRCode; } string toString() const override { return "ercode=="+ERCode::to_s(d_rcode | (d_extrcode << 4)); } private: uint8_t d_rcode; // plain DNS Rcode uint8_t d_extrcode; // upper bits in EDNS0 record }; class EDNSVersionRule : public DNSRule { public: EDNSVersionRule(uint8_t version) : d_version(version) { } bool matches(const DNSQuestion* dq) const override { EDNS0Record edns0; if (!getEDNS0Record(*dq, edns0)) { return false; } return d_version < edns0.version; } string toString() const override { return "ednsversion>"+std::to_string(d_version); } private: uint8_t d_version; }; class EDNSOptionRule : public DNSRule { public: EDNSOptionRule(uint16_t optcode) : d_optcode(optcode) { } bool matches(const DNSQuestion* dq) const override { uint16_t optStart; size_t optLen = 0; bool last = false; int res = locateEDNSOptRR(dq->getData(), &optStart, &optLen, &last); if (res != 0) { // no EDNS OPT RR return false; } if (optLen < optRecordMinimumSize) { return false; } if (optStart < dq->getData().size() && dq->getData().at(optStart) != 0) { // OPT RR Name != '.' return false; } return isEDNSOptionInOpt(dq->getData(), optStart, optLen, d_optcode); } string toString() const override { return "ednsoptcode=="+std::to_string(d_optcode); } private: uint16_t d_optcode; }; class RDRule : public DNSRule { public: RDRule() { } bool matches(const DNSQuestion* dq) const override { return dq->getHeader()->rd == 1; } string toString() const override { return "rd==1"; } }; class ProbaRule : public DNSRule { public: ProbaRule(double proba) : d_proba(proba) { } bool matches(const DNSQuestion* dq) const override { if(d_proba == 1.0) return true; double rnd = 1.0*random() / RAND_MAX; return rnd > (1.0 - d_proba); } string toString() const override { return "match with prob. " + (boost::format("%0.2f") % d_proba).str(); } private: double d_proba; }; class TagRule : public DNSRule { public: TagRule(const std::string& tag, boost::optional value) : d_value(value), d_tag(tag) { } bool matches(const DNSQuestion* dq) const override { if (!dq->qTag) { return false; } const auto it = dq->qTag->find(d_tag); if (it == dq->qTag->cend()) { return false; } if (!d_value) { return true; } return it->second == *d_value; } string toString() const override { return "tag '" + d_tag + "' is set" + (d_value ? (" to '" + *d_value + "'") : ""); } private: boost::optional d_value; std::string d_tag; }; class PoolAvailableRule : public DNSRule { public: PoolAvailableRule(const std::string& poolname) : d_pools(&g_pools), d_poolname(poolname) { } bool matches(const DNSQuestion* dq) const override { return (getPool(*d_pools, d_poolname)->countServers(true) > 0); } string toString() const override { return "pool '" + d_poolname + "' is available"; } private: mutable LocalStateHolder d_pools; std::string d_poolname; }; class PoolOutstandingRule : public DNSRule { public: PoolOutstandingRule(const std::string& poolname, const size_t limit) : d_pools(&g_pools), d_poolname(poolname), d_limit(limit) { } bool matches(const DNSQuestion* dq) const override { return (getPool(*d_pools, d_poolname)->poolLoad()) > d_limit; } string toString() const override { return "pool '" + d_poolname + "' outstanding > " + std::to_string(d_limit); } private: mutable LocalStateHolder d_pools; std::string d_poolname; size_t d_limit; }; class KeyValueStoreLookupRule: public DNSRule { public: KeyValueStoreLookupRule(std::shared_ptr& kvs, std::shared_ptr& lookupKey): d_kvs(kvs), d_key(lookupKey) { } bool matches(const DNSQuestion* dq) const override { std::vector keys = d_key->getKeys(*dq); for (const auto& key : keys) { if (d_kvs->keyExists(key) == true) { return true; } } return false; } string toString() const override { return "lookup key-value store based on '" + d_key->toString() + "'"; } private: std::shared_ptr d_kvs; std::shared_ptr d_key; }; class KeyValueStoreRangeLookupRule: public DNSRule { public: KeyValueStoreRangeLookupRule(std::shared_ptr& kvs, std::shared_ptr& lookupKey): d_kvs(kvs), d_key(lookupKey) { } bool matches(const DNSQuestion* dq) const override { std::vector keys = d_key->getKeys(*dq); for (const auto& key : keys) { std::string value; if (d_kvs->getRangeValue(key, value) == true) { return true; } } return false; } string toString() const override { return "range-based lookup key-value store based on '" + d_key->toString() + "'"; } private: std::shared_ptr d_kvs; std::shared_ptr d_key; }; class LuaRule : public DNSRule { public: typedef std::function func_t; LuaRule(const func_t& func): d_func(func) {} bool matches(const DNSQuestion* dq) const override { try { auto lock = g_lua.lock(); return d_func(dq); } catch (const std::exception &e) { warnlog("LuaRule failed inside Lua: %s", e.what()); } catch (...) { warnlog("LuaRule failed inside Lua: [unknown exception]"); } return false; } string toString() const override { return "Lua script"; } private: func_t d_func; }; class LuaFFIRule : public DNSRule { public: typedef std::function func_t; LuaFFIRule(const func_t& func): d_func(func) {} bool matches(const DNSQuestion* dq) const override { dnsdist_ffi_dnsquestion_t dqffi(const_cast(dq)); try { auto lock = g_lua.lock(); return d_func(&dqffi); } catch (const std::exception &e) { warnlog("LuaFFIRule failed inside Lua: %s", e.what()); } catch (...) { warnlog("LuaFFIRule failed inside Lua: [unknown exception]"); } return false; } string toString() const override { return "Lua FFI script"; } private: func_t d_func; }; class LuaFFIPerThreadRule : public DNSRule { public: typedef std::function func_t; LuaFFIPerThreadRule(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++) { } bool matches(const DNSQuestion* dq) const override { try { auto& state = t_perThreadStates[d_functionID]; if (!state.d_initialized) { setupLuaFFIPerThreadContext(state.d_luaContext); /* mark the state as initialized first so if there is a syntax error we only try to execute the code once */ state.d_initialized = true; state.d_func = state.d_luaContext.executeCode(d_functionCode); } if (!state.d_func) { /* the function was not properly initialized */ return false; } dnsdist_ffi_dnsquestion_t dqffi(const_cast(dq)); return state.d_func(&dqffi); } catch (const std::exception &e) { warnlog("LuaFFIPerthreadRule failed inside Lua: %s", e.what()); } catch (...) { warnlog("LuaFFIPerThreadRule failed inside Lua: [unknown exception]"); } return false; } string toString() const override { return "Lua FFI per-thread script"; } private: struct PerThreadState { LuaContext d_luaContext; func_t d_func; bool d_initialized{false}; }; static std::atomic s_functionsCounter; static thread_local std::map t_perThreadStates; const std::string d_functionCode; const uint64_t d_functionID; }; class ProxyProtocolValueRule : public DNSRule { public: ProxyProtocolValueRule(uint8_t type, boost::optional value): d_value(value), d_type(type) { } bool matches(const DNSQuestion* dq) const override { if (!dq->proxyProtocolValues) { return false; } for (const auto& entry : *dq->proxyProtocolValues) { if (entry.type == d_type && (!d_value || entry.content == *d_value)) { return true; } } return false; } string toString() const override { if (d_value) { return "proxy protocol value of type " + std::to_string(d_type) + " matches"; } return "proxy protocol value of type " + std::to_string(d_type) + " is present"; } private: boost::optional d_value; uint8_t d_type; };