diff options
Diffstat (limited to 'dnsdist-lua-actions.cc')
-rw-r--r-- | dnsdist-lua-actions.cc | 2001 |
1 files changed, 1156 insertions, 845 deletions
diff --git a/dnsdist-lua-actions.cc b/dnsdist-lua-actions.cc index 5d3271a..e643007 100644 --- a/dnsdist-lua-actions.cc +++ b/dnsdist-lua-actions.cc @@ -23,11 +23,14 @@ #include "threadname.hh" #include "dnsdist.hh" #include "dnsdist-async.hh" +#include "dnsdist-dnsparser.hh" #include "dnsdist-ecs.hh" +#include "dnsdist-edns.hh" #include "dnsdist-lua.hh" #include "dnsdist-lua-ffi.hh" #include "dnsdist-mac-address.hh" #include "dnsdist-protobuf.hh" +#include "dnsdist-proxy-protocol.hh" #include "dnsdist-kvs.hh" #include "dnsdist-svc.hh" @@ -45,11 +48,11 @@ class DropAction : public DNSAction { public: - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { return Action::Drop; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "drop"; } @@ -58,11 +61,11 @@ public: class AllowAction : public DNSAction { public: - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { return Action::Allow; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "allow"; } @@ -72,11 +75,11 @@ class NoneAction : public DNSAction { public: // this action does not stop the processing - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "no op"; } @@ -85,22 +88,22 @@ public: class QPSAction : public DNSAction { public: - QPSAction(int limit) : d_qps(QPSLimiter(limit, limit)) + QPSAction(int limit) : + d_qps(QPSLimiter(limit, limit)) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { if (d_qps.lock()->check()) { return Action::None; } - else { - return Action::Drop; - } + return Action::Drop; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { - return "qps limit to "+std::to_string(d_qps.lock()->getRate()); + return "qps limit to " + std::to_string(d_qps.lock()->getRate()); } + private: mutable LockGuarded<QPSLimiter> d_qps; }; @@ -108,18 +111,20 @@ private: class DelayAction : public DNSAction { public: - DelayAction(int msec) : d_msec(msec) + DelayAction(int msec) : + d_msec(msec) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { *ruleresult = std::to_string(d_msec); return Action::Delay; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { - return "delay by "+std::to_string(d_msec)+ " msec"; + return "delay by " + std::to_string(d_msec) + " ms"; } + private: int d_msec; }; @@ -128,18 +133,22 @@ class TeeAction : public DNSAction { public: // this action does not stop the processing - TeeAction(const ComboAddress& rca, const boost::optional<ComboAddress>& lca, bool addECS=false); + TeeAction(const ComboAddress& rca, const boost::optional<ComboAddress>& lca, bool addECS = false, bool addProxyProtocol = false); + TeeAction(TeeAction& other) = delete; + TeeAction(TeeAction&& other) = delete; + TeeAction& operator=(TeeAction& other) = delete; + TeeAction& operator=(TeeAction&& other) = delete; ~TeeAction() override; - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override; - std::string toString() const override; + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override; + [[nodiscard]] std::string toString() const override; std::map<std::string, double> getStats() const override; private: - ComboAddress d_remote; - std::thread d_worker; void worker(); - int d_fd{-1}; + ComboAddress d_remote; + std::thread d_worker; + Socket d_socket; mutable std::atomic<unsigned long> d_senderrors{0}; unsigned long d_recverrors{0}; mutable std::atomic<unsigned long> d_queries{0}; @@ -154,61 +163,65 @@ private: stat_t d_otherrcode{0}; std::atomic<bool> d_pleaseQuit{false}; bool d_addECS{false}; + bool d_addProxyProtocol{false}; }; -TeeAction::TeeAction(const ComboAddress& rca, const boost::optional<ComboAddress>& lca, bool addECS) - : d_remote(rca), d_addECS(addECS) +TeeAction::TeeAction(const ComboAddress& rca, const boost::optional<ComboAddress>& lca, bool addECS, bool addProxyProtocol) : + d_remote(rca), d_socket(d_remote.sin4.sin_family, SOCK_DGRAM, 0), d_addECS(addECS), d_addProxyProtocol(addProxyProtocol) { - d_fd=SSocket(d_remote.sin4.sin_family, SOCK_DGRAM, 0); - try { - if (lca) { - SBind(d_fd, *lca); - } - SConnect(d_fd, d_remote); - setNonBlocking(d_fd); - d_worker=std::thread([this](){worker();}); - } - catch (...) { - if (d_fd != -1) { - close(d_fd); - } - throw; + if (lca) { + d_socket.bind(*lca, false); } + d_socket.connect(d_remote); + d_socket.setNonBlocking(); + d_worker = std::thread([this]() { + worker(); + }); } TeeAction::~TeeAction() { - d_pleaseQuit=true; - close(d_fd); + d_pleaseQuit = true; + close(d_socket.releaseHandle()); d_worker.join(); } -DNSAction::Action TeeAction::operator()(DNSQuestion* dq, std::string* ruleresult) const +DNSAction::Action TeeAction::operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const { - if (dq->overTCP()) { + if (dnsquestion->overTCP()) { d_tcpdrops++; + return DNSAction::Action::None; } - else { - ssize_t res; - d_queries++; - if(d_addECS) { - PacketBuffer query(dq->getData()); - bool ednsAdded = false; - bool ecsAdded = false; + d_queries++; - std::string newECSOption; - generateECSOption(dq->ecs ? dq->ecs->getNetwork() : dq->ids.origRemote, newECSOption, dq->ecs ? dq->ecs->getBits() : dq->ecsPrefixLength); + PacketBuffer query; + if (d_addECS) { + query = dnsquestion->getData(); + bool ednsAdded = false; + bool ecsAdded = false; - if (!handleEDNSClientSubnet(query, dq->getMaximumSize(), dq->ids.qname.wirelength(), ednsAdded, ecsAdded, dq->ecsOverride, newECSOption)) { - return DNSAction::Action::None; - } + std::string newECSOption; + generateECSOption(dnsquestion->ecs ? dnsquestion->ecs->getNetwork() : dnsquestion->ids.origRemote, newECSOption, dnsquestion->ecs ? dnsquestion->ecs->getBits() : dnsquestion->ecsPrefixLength); - res = send(d_fd, query.data(), query.size(), 0); + if (!handleEDNSClientSubnet(query, dnsquestion->getMaximumSize(), dnsquestion->ids.qname.wirelength(), ednsAdded, ecsAdded, dnsquestion->ecsOverride, newECSOption)) { + return DNSAction::Action::None; } - else { - res = send(d_fd, dq->getData().data(), dq->getData().size(), 0); + } + + if (d_addProxyProtocol) { + auto proxyPayload = getProxyProtocolPayload(*dnsquestion); + if (query.empty()) { + query = dnsquestion->getData(); + } + if (!addProxyProtocol(query, proxyPayload)) { + return DNSAction::Action::None; } + } + + { + const PacketBuffer& payload = query.empty() ? dnsquestion->getData() : query; + auto res = send(d_socket.getHandle(), payload.data(), payload.size(), 0); if (res <= 0) { d_senderrors++; @@ -220,10 +233,10 @@ DNSAction::Action TeeAction::operator()(DNSQuestion* dq, std::string* ruleresult std::string TeeAction::toString() const { - return "tee to "+d_remote.toStringWithPort(); + return "tee to " + d_remote.toStringWithPort(); } -std::map<std::string,double> TeeAction::getStats() const +std::map<std::string, double> TeeAction::getStats() const { return {{"queries", d_queries}, {"responses", d_responses}, @@ -234,81 +247,98 @@ std::map<std::string,double> TeeAction::getStats() const {"refuseds", d_refuseds}, {"servfails", d_servfails}, {"other-rcode", d_otherrcode}, - {"tcp-drops", d_tcpdrops} - }; + {"tcp-drops", d_tcpdrops}}; } void TeeAction::worker() { setThreadName("dnsdist/TeeWork"); - char packet[1500]; - int res=0; - struct dnsheader* dh=(struct dnsheader*)packet; - for(;;) { - res=waitForData(d_fd, 0, 250000); - if(d_pleaseQuit) + std::array<char, s_udpIncomingBufferSize> packet{}; + ssize_t res = 0; + const dnsheader_aligned dnsheader(packet.data()); + for (;;) { + res = waitForData(d_socket.getHandle(), 0, 250000); + if (d_pleaseQuit) { break; - if(res < 0) { + } + + if (res < 0) { usleep(250000); continue; } - if(res==0) + if (res == 0) { continue; - res=recv(d_fd, packet, sizeof(packet), 0); - if(res <= (int)sizeof(struct dnsheader)) + } + res = recv(d_socket.getHandle(), packet.data(), packet.size(), 0); + if (static_cast<size_t>(res) <= sizeof(struct dnsheader)) { d_recverrors++; - else + } + else { d_responses++; + } - if(dh->rcode == RCode::NoError) + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well + if (dnsheader->rcode == RCode::NoError) { d_noerrors++; - else if(dh->rcode == RCode::ServFail) + } + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well + else if (dnsheader->rcode == RCode::ServFail) { d_servfails++; - else if(dh->rcode == RCode::NXDomain) + } + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well + else if (dnsheader->rcode == RCode::NXDomain) { d_nxdomains++; - else if(dh->rcode == RCode::Refused) + } + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well + else if (dnsheader->rcode == RCode::Refused) { d_refuseds++; - else if(dh->rcode == RCode::FormErr) + } + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well + else if (dnsheader->rcode == RCode::FormErr) { d_formerrs++; - else if(dh->rcode == RCode::NotImp) + } + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well + else if (dnsheader->rcode == RCode::NotImp) { d_notimps++; + } } } class PoolAction : public DNSAction { public: - PoolAction(const std::string& pool, bool stopProcessing) : d_pool(pool), d_stopProcessing(stopProcessing) {} + PoolAction(std::string pool, bool stopProcessing) : + d_pool(std::move(pool)), d_stopProcessing(stopProcessing) {} - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { if (d_stopProcessing) { /* we need to do it that way to keep compatiblity with custom Lua actions returning DNSAction.Pool, 'poolname' */ *ruleresult = d_pool; return Action::Pool; } - else { - dq->ids.poolName = d_pool; - return Action::None; - } + dnsquestion->ids.poolName = d_pool; + return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "to pool " + d_pool; } private: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::string d_pool; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const bool d_stopProcessing; }; - class QPSPoolAction : public DNSAction { public: - QPSPoolAction(unsigned int limit, const std::string& pool, bool stopProcessing) : d_qps(QPSLimiter(limit, limit)), d_pool(pool), d_stopProcessing(stopProcessing) {} - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + QPSPoolAction(unsigned int limit, std::string pool, bool stopProcessing) : + d_qps(QPSLimiter(limit, limit)), d_pool(std::move(pool)), d_stopProcessing(stopProcessing) {} + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { if (d_qps.lock()->check()) { if (d_stopProcessing) { @@ -316,66 +346,79 @@ public: *ruleresult = d_pool; return Action::Pool; } - else { - dq->ids.poolName = d_pool; - return Action::None; - } - } - else { - return Action::None; + dnsquestion->ids.poolName = d_pool; } + return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "max " + std::to_string(d_qps.lock()->getRate()) + " to pool " + d_pool; } private: mutable LockGuarded<QPSLimiter> d_qps; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::string d_pool; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const bool d_stopProcessing; }; class RCodeAction : public DNSAction { public: - RCodeAction(uint8_t rcode) : d_rcode(rcode) {} - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + RCodeAction(uint8_t rcode) : + d_rcode(rcode) {} + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - dq->getHeader()->rcode = d_rcode; - dq->getHeader()->qr = true; // for good measure - setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { + header.rcode = d_rcode; + header.qr = true; // for good measure + setResponseHeadersFromConfig(header, d_responseConfig); + return true; + }); return Action::HeaderModify; } - std::string toString() const override + [[nodiscard]] std::string toString() const override + { + return "set rcode " + std::to_string(d_rcode); + } + [[nodiscard]] ResponseConfig& getResponseConfig() { - return "set rcode "+std::to_string(d_rcode); + return d_responseConfig; } - ResponseConfig d_responseConfig; private: + ResponseConfig d_responseConfig; uint8_t d_rcode; }; class ERCodeAction : public DNSAction { public: - ERCodeAction(uint8_t rcode) : d_rcode(rcode) {} - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + ERCodeAction(uint8_t rcode) : + d_rcode(rcode) {} + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - dq->getHeader()->rcode = (d_rcode & 0xF); - dq->ednsRCode = ((d_rcode & 0xFFF0) >> 4); - dq->getHeader()->qr = true; // for good measure - setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { + header.rcode = (d_rcode & 0xF); + header.qr = true; // for good measure + setResponseHeadersFromConfig(header, d_responseConfig); + return true; + }); + dnsquestion->ednsRCode = ((d_rcode & 0xFFF0) >> 4); return Action::HeaderModify; } - std::string toString() const override + [[nodiscard]] std::string toString() const override + { + return "set ercode " + ERCode::to_s(d_rcode); + } + [[nodiscard]] ResponseConfig& getResponseConfig() { - return "set ercode "+ERCode::to_s(d_rcode); + return d_responseConfig; } - ResponseConfig d_responseConfig; private: + ResponseConfig d_responseConfig; uint8_t d_rcode; }; @@ -396,88 +439,106 @@ public: d_payloads.push_back(std::move(payload)); for (const auto& hint : param.second.ipv4hints) { - d_additionals4.insert({ param.second.target, ComboAddress(hint) }); + d_additionals4.insert({param.second.target, ComboAddress(hint)}); } for (const auto& hint : param.second.ipv6hints) { - d_additionals6.insert({ param.second.target, ComboAddress(hint) }); + d_additionals6.insert({param.second.target, ComboAddress(hint)}); } } } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { /* it will likely be a bit bigger than that because of additionals */ - uint16_t numberOfRecords = d_payloads.size(); - const auto qnameWireLength = dq->ids.qname.wirelength(); - if (dq->getMaximumSize() < (sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords*12 /* recordstart */ + d_totalPayloadsSize)) { + auto numberOfRecords = d_payloads.size(); + const auto qnameWireLength = dnsquestion->ids.qname.wirelength(); + if (dnsquestion->getMaximumSize() < (sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords * 12 /* recordstart */ + d_totalPayloadsSize)) { return Action::None; } PacketBuffer newPacket; - newPacket.reserve(sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords*12 /* recordstart */ + d_totalPayloadsSize); - GenericDNSPacketWriter<PacketBuffer> pw(newPacket, dq->ids.qname, dq->ids.qtype); + newPacket.reserve(sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords * 12 /* recordstart */ + d_totalPayloadsSize); + GenericDNSPacketWriter<PacketBuffer> packetWriter(newPacket, dnsquestion->ids.qname, dnsquestion->ids.qtype); for (const auto& payload : d_payloads) { - pw.startRecord(dq->ids.qname, dq->ids.qtype, d_responseConfig.ttl); - pw.xfrBlob(payload); - pw.commit(); + packetWriter.startRecord(dnsquestion->ids.qname, dnsquestion->ids.qtype, d_responseConfig.ttl); + packetWriter.xfrBlob(payload); + packetWriter.commit(); } - if (newPacket.size() < dq->getMaximumSize()) { + if (newPacket.size() < dnsquestion->getMaximumSize()) { for (const auto& additional : d_additionals4) { - pw.startRecord(additional.first.isRoot() ? dq->ids.qname : additional.first, QType::A, d_responseConfig.ttl, QClass::IN, DNSResourceRecord::ADDITIONAL); - pw.xfrCAWithoutPort(4, additional.second); - pw.commit(); + packetWriter.startRecord(additional.first.isRoot() ? dnsquestion->ids.qname : additional.first, QType::A, d_responseConfig.ttl, QClass::IN, DNSResourceRecord::ADDITIONAL); + packetWriter.xfrCAWithoutPort(4, additional.second); + packetWriter.commit(); } } - if (newPacket.size() < dq->getMaximumSize()) { + if (newPacket.size() < dnsquestion->getMaximumSize()) { for (const auto& additional : d_additionals6) { - pw.startRecord(additional.first.isRoot() ? dq->ids.qname : additional.first, QType::AAAA, d_responseConfig.ttl, QClass::IN, DNSResourceRecord::ADDITIONAL); - pw.xfrCAWithoutPort(6, additional.second); - pw.commit(); + packetWriter.startRecord(additional.first.isRoot() ? dnsquestion->ids.qname : additional.first, QType::AAAA, d_responseConfig.ttl, QClass::IN, DNSResourceRecord::ADDITIONAL); + packetWriter.xfrCAWithoutPort(6, additional.second); + packetWriter.commit(); } } - if (g_addEDNSToSelfGeneratedResponses && queryHasEDNS(*dq)) { - bool dnssecOK = getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO; - pw.addOpt(g_PayloadSizeSelfGenAnswers, 0, dnssecOK ? EDNS_HEADER_FLAG_DO : 0); - pw.commit(); + if (g_addEDNSToSelfGeneratedResponses && queryHasEDNS(*dnsquestion)) { + bool dnssecOK = ((getEDNSZ(*dnsquestion) & EDNS_HEADER_FLAG_DO) != 0); + packetWriter.addOpt(g_PayloadSizeSelfGenAnswers, 0, dnssecOK ? EDNS_HEADER_FLAG_DO : 0); + packetWriter.commit(); } - if (newPacket.size() >= dq->getMaximumSize()) { + if (newPacket.size() >= dnsquestion->getMaximumSize()) { /* sorry! */ return Action::None; } - pw.getHeader()->id = dq->getHeader()->id; - pw.getHeader()->qr = true; // for good measure - setResponseHeadersFromConfig(*pw.getHeader(), d_responseConfig); - dq->getMutableData() = std::move(newPacket); + packetWriter.getHeader()->id = dnsquestion->getHeader()->id; + packetWriter.getHeader()->qr = true; // for good measure + setResponseHeadersFromConfig(*packetWriter.getHeader(), d_responseConfig); + dnsquestion->getMutableData() = std::move(newPacket); return Action::HeaderModify; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "spoof SVC record "; } - ResponseConfig d_responseConfig; + [[nodiscard]] ResponseConfig& getResponseConfig() + { + return d_responseConfig; + } + private: - std::vector<std::vector<uint8_t>> d_payloads; - std::set<std::pair<DNSName, ComboAddress>> d_additionals4; - std::set<std::pair<DNSName, ComboAddress>> d_additionals6; + ResponseConfig d_responseConfig; + std::vector<std::vector<uint8_t>> d_payloads{}; + std::set<std::pair<DNSName, ComboAddress>> d_additionals4{}; + std::set<std::pair<DNSName, ComboAddress>> d_additionals6{}; size_t d_totalPayloadsSize{0}; }; class TCAction : public DNSAction { public: - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + return Action::Truncate; + } + [[nodiscard]] std::string toString() const override + { + return "tc=1 answer"; + } +}; + +class TCResponseAction : public DNSResponseAction +{ +public: + DNSResponseAction::Action operator()(DNSResponse* dnsResponse, std::string* ruleresult) const override { return Action::Truncate; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "tc=1 answer"; } @@ -486,18 +547,19 @@ public: class LuaAction : public DNSAction { public: - typedef std::function<std::tuple<int, boost::optional<string> >(DNSQuestion* dq)> func_t; - LuaAction(const LuaAction::func_t& func) : d_func(func) + using func_t = std::function<std::tuple<int, boost::optional<string>>(DNSQuestion* dnsquestion)>; + LuaAction(LuaAction::func_t func) : + d_func(std::move(func)) {} - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { try { - DNSAction::Action result; + DNSAction::Action result{}; { auto lock = g_lua.lock(); - auto ret = d_func(dq); - if (ruleresult) { + auto ret = d_func(dnsquestion); + if (ruleresult != nullptr) { if (boost::optional<std::string> rule = std::get<1>(ret)) { *ruleresult = *rule; } @@ -510,18 +572,21 @@ public: } dnsdist::handleQueuedAsynchronousEvents(); return result; - } catch (const std::exception &e) { + } + catch (const std::exception& e) { warnlog("LuaAction failed inside Lua, returning ServFail: %s", e.what()); - } catch (...) { + } + catch (...) { warnlog("LuaAction failed inside Lua, returning ServFail: [unknown exception]"); } return DNSAction::Action::ServFail; } - string toString() const override + [[nodiscard]] std::string toString() const override { return "Lua script"; } + private: func_t d_func; }; @@ -529,17 +594,18 @@ private: class LuaResponseAction : public DNSResponseAction { public: - typedef std::function<std::tuple<int, boost::optional<string> >(DNSResponse* dr)> func_t; - LuaResponseAction(const LuaResponseAction::func_t& func) : d_func(func) + using func_t = std::function<std::tuple<int, boost::optional<string>>(DNSResponse* response)>; + LuaResponseAction(LuaResponseAction::func_t func) : + d_func(std::move(func)) {} - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { try { - DNSResponseAction::Action result; + DNSResponseAction::Action result{}; { auto lock = g_lua.lock(); - auto ret = d_func(dr); - if (ruleresult) { + auto ret = d_func(response); + if (ruleresult != nullptr) { if (boost::optional<std::string> rule = std::get<1>(ret)) { *ruleresult = *rule; } @@ -552,40 +618,44 @@ public: } dnsdist::handleQueuedAsynchronousEvents(); return result; - } catch (const std::exception &e) { + } + catch (const std::exception& e) { warnlog("LuaResponseAction failed inside Lua, returning ServFail: %s", e.what()); - } catch (...) { + } + catch (...) { warnlog("LuaResponseAction failed inside Lua, returning ServFail: [unknown exception]"); } return DNSResponseAction::Action::ServFail; } - string toString() const override + [[nodiscard]] std::string toString() const override { return "Lua response script"; } + private: func_t d_func; }; -class LuaFFIAction: public DNSAction +class LuaFFIAction : public DNSAction { public: - typedef std::function<int(dnsdist_ffi_dnsquestion_t* dq)> func_t; + using func_t = std::function<int(dnsdist_ffi_dnsquestion_t* dnsquestion)>; - LuaFFIAction(const LuaFFIAction::func_t& func): d_func(func) + LuaFFIAction(LuaFFIAction::func_t func) : + d_func(std::move(func)) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - dnsdist_ffi_dnsquestion_t dqffi(dq); + dnsdist_ffi_dnsquestion_t dqffi(dnsquestion); try { - DNSAction::Action result; + DNSAction::Action result{}; { auto lock = g_lua.lock(); auto ret = d_func(&dqffi); - if (ruleresult) { + if (ruleresult != nullptr) { if (dqffi.result) { *ruleresult = *dqffi.result; } @@ -598,32 +668,36 @@ public: } dnsdist::handleQueuedAsynchronousEvents(); return result; - } catch (const std::exception &e) { + } + catch (const std::exception& e) { warnlog("LuaFFIAction failed inside Lua, returning ServFail: %s", e.what()); - } catch (...) { + } + catch (...) { warnlog("LuaFFIAction failed inside Lua, returning ServFail: [unknown exception]"); } return DNSAction::Action::ServFail; } - string toString() const override + [[nodiscard]] std::string toString() const override { return "Lua FFI script"; } + private: func_t d_func; }; -class LuaFFIPerThreadAction: public DNSAction +class LuaFFIPerThreadAction : public DNSAction { public: - typedef std::function<int(dnsdist_ffi_dnsquestion_t* dq)> func_t; + using func_t = std::function<int(dnsdist_ffi_dnsquestion_t* dnsquestion)>; - LuaFFIPerThreadAction(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++) + LuaFFIPerThreadAction(std::string code) : + d_functionCode(std::move(code)), d_functionID(s_functionsCounter++) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { try { auto& state = t_perThreadStates[d_functionID]; @@ -640,9 +714,9 @@ public: return DNSAction::Action::None; } - dnsdist_ffi_dnsquestion_t dqffi(dq); + dnsdist_ffi_dnsquestion_t dqffi(dnsquestion); auto ret = state.d_func(&dqffi); - if (ruleresult) { + if (ruleresult != nullptr) { if (dqffi.result) { *ruleresult = *dqffi.result; } @@ -654,7 +728,7 @@ public: dnsdist::handleQueuedAsynchronousEvents(); return static_cast<DNSAction::Action>(ret); } - catch (const std::exception &e) { + catch (const std::exception& e) { warnlog("LuaFFIPerThreadAction failed inside Lua, returning ServFail: %s", e.what()); } catch (...) { @@ -663,7 +737,7 @@ public: return DNSAction::Action::ServFail; } - string toString() const override + [[nodiscard]] std::string toString() const override { return "Lua FFI per-thread script"; } @@ -677,33 +751,36 @@ private: }; static std::atomic<uint64_t> s_functionsCounter; static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::string d_functionCode; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const uint64_t d_functionID; }; std::atomic<uint64_t> LuaFFIPerThreadAction::s_functionsCounter = 0; thread_local std::map<uint64_t, LuaFFIPerThreadAction::PerThreadState> LuaFFIPerThreadAction::t_perThreadStates; -class LuaFFIResponseAction: public DNSResponseAction +class LuaFFIResponseAction : public DNSResponseAction { public: - typedef std::function<int(dnsdist_ffi_dnsresponse_t* dq)> func_t; + using func_t = std::function<int(dnsdist_ffi_dnsresponse_t* dnsquestion)>; - LuaFFIResponseAction(const LuaFFIResponseAction::func_t& func): d_func(func) + LuaFFIResponseAction(LuaFFIResponseAction::func_t func) : + d_func(std::move(func)) { } - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { - dnsdist_ffi_dnsresponse_t drffi(dr); + dnsdist_ffi_dnsresponse_t ffiResponse(response); try { - DNSResponseAction::Action result; + DNSResponseAction::Action result{}; { auto lock = g_lua.lock(); - auto ret = d_func(&drffi); - if (ruleresult) { - if (drffi.result) { - *ruleresult = *drffi.result; + auto ret = d_func(&ffiResponse); + if (ruleresult != nullptr) { + if (ffiResponse.result) { + *ruleresult = *ffiResponse.result; } else { // default to empty string @@ -714,32 +791,36 @@ public: } dnsdist::handleQueuedAsynchronousEvents(); return result; - } catch (const std::exception &e) { + } + catch (const std::exception& e) { warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: %s", e.what()); - } catch (...) { + } + catch (...) { warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: [unknown exception]"); } return DNSResponseAction::Action::ServFail; } - string toString() const override + [[nodiscard]] std::string toString() const override { return "Lua FFI script"; } + private: func_t d_func; }; -class LuaFFIPerThreadResponseAction: public DNSResponseAction +class LuaFFIPerThreadResponseAction : public DNSResponseAction { public: - typedef std::function<int(dnsdist_ffi_dnsresponse_t* dr)> func_t; + using func_t = std::function<int(dnsdist_ffi_dnsresponse_t* response)>; - LuaFFIPerThreadResponseAction(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++) + LuaFFIPerThreadResponseAction(std::string code) : + d_functionCode(std::move(code)), d_functionID(s_functionsCounter++) { } - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { try { auto& state = t_perThreadStates[d_functionID]; @@ -756,11 +837,11 @@ public: return DNSResponseAction::Action::None; } - dnsdist_ffi_dnsresponse_t drffi(dr); - auto ret = state.d_func(&drffi); - if (ruleresult) { - if (drffi.result) { - *ruleresult = *drffi.result; + dnsdist_ffi_dnsresponse_t ffiResponse(response); + auto ret = state.d_func(&ffiResponse); + if (ruleresult != nullptr) { + if (ffiResponse.result) { + *ruleresult = *ffiResponse.result; } else { // default to empty string @@ -770,7 +851,7 @@ public: dnsdist::handleQueuedAsynchronousEvents(); return static_cast<DNSResponseAction::Action>(ret); } - catch (const std::exception &e) { + catch (const std::exception& e) { warnlog("LuaFFIPerThreadResponseAction failed inside Lua, returning ServFail: %s", e.what()); } catch (...) { @@ -779,7 +860,7 @@ public: return DNSResponseAction::Action::ServFail; } - string toString() const override + [[nodiscard]] std::string toString() const override { return "Lua FFI per-thread script"; } @@ -794,7 +875,9 @@ private: static std::atomic<uint64_t> s_functionsCounter; static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::string d_functionCode; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const uint64_t d_functionID; }; @@ -803,35 +886,37 @@ thread_local std::map<uint64_t, LuaFFIPerThreadResponseAction::PerThreadState> L thread_local std::default_random_engine SpoofAction::t_randomEngine; -DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresult) const +DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const { - uint16_t qtype = dq->ids.qtype; + uint16_t qtype = dnsquestion->ids.qtype; // do we even have a response? - if (d_cname.empty() && - d_rawResponses.empty() && + if (d_cname.empty() && d_rawResponses.empty() && // make sure pre-forged response is greater than sizeof(dnsheader) - (d_raw.size() < sizeof(dnsheader)) && - d_types.count(qtype) == 0) { + (d_raw.size() < sizeof(dnsheader)) && d_types.count(qtype) == 0) { return Action::None; } if (d_raw.size() >= sizeof(dnsheader)) { - auto id = dq->getHeader()->id; - dq->getMutableData() = d_raw; - dq->getHeader()->id = id; + auto questionId = dnsquestion->getHeader()->id; + dnsquestion->getMutableData() = d_raw; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [questionId](dnsheader& header) { + header.id = questionId; + return true; + }); return Action::HeaderModify; } - vector<ComboAddress> addrs; - vector<std::string> rawResponses; + std::vector<ComboAddress> addrs = {}; + std::vector<std::string> rawResponses = {}; unsigned int totrdatalen = 0; - uint16_t numberOfRecords = 0; + size_t numberOfRecords = 0; if (!d_cname.empty()) { qtype = QType::CNAME; totrdatalen += d_cname.getStorage().size(); numberOfRecords = 1; - } else if (!d_rawResponses.empty()) { + } + else if (!d_rawResponses.empty()) { rawResponses.reserve(d_rawResponses.size()); - for(const auto& rawResponse : d_rawResponses){ + for (const auto& rawResponse : d_rawResponses) { totrdatalen += rawResponse.size(); rawResponses.push_back(rawResponse); ++numberOfRecords; @@ -841,9 +926,8 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu } } else { - for(const auto& addr : d_addrs) { - if(qtype != QType::ANY && ((addr.sin4.sin_family == AF_INET && qtype != QType::A) || - (addr.sin4.sin_family == AF_INET6 && qtype != QType::AAAA))) { + for (const auto& addr : d_addrs) { + if (qtype != QType::ANY && ((addr.sin4.sin_family == AF_INET && qtype != QType::A) || (addr.sin4.sin_family == AF_INET6 && qtype != QType::AAAA))) { continue; } totrdatalen += addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr); @@ -856,37 +940,43 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu shuffle(addrs.begin(), addrs.end(), t_randomEngine); } - unsigned int qnameWireLength=0; - DNSName ignore(reinterpret_cast<const char*>(dq->getData().data()), dq->getData().size(), sizeof(dnsheader), false, 0, 0, &qnameWireLength); + unsigned int qnameWireLength = 0; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + DNSName ignore(reinterpret_cast<const char*>(dnsquestion->getData().data()), dnsquestion->getData().size(), sizeof(dnsheader), false, nullptr, nullptr, &qnameWireLength); - if (dq->getMaximumSize() < (sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords*12 /* recordstart */ + totrdatalen)) { + if (dnsquestion->getMaximumSize() < (sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords * 12 /* recordstart */ + totrdatalen)) { return Action::None; } bool dnssecOK = false; bool hadEDNS = false; - if (g_addEDNSToSelfGeneratedResponses && queryHasEDNS(*dq)) { + if (g_addEDNSToSelfGeneratedResponses && queryHasEDNS(*dnsquestion)) { hadEDNS = true; - dnssecOK = getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO; + dnssecOK = ((getEDNSZ(*dnsquestion) & EDNS_HEADER_FLAG_DO) != 0); } - auto& data = dq->getMutableData(); - data.resize(sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords*12 /* recordstart */ + totrdatalen); // there goes your EDNS + auto& data = dnsquestion->getMutableData(); + data.resize(sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords * 12 /* recordstart */ + totrdatalen); // there goes your EDNS uint8_t* dest = &(data.at(sizeof(dnsheader) + qnameWireLength + 4)); - dq->getHeader()->qr = true; // for good measure - setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); - dq->getHeader()->ancount = 0; - dq->getHeader()->arcount = 0; // for now, forget about your EDNS, we're marching over it + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { + header.qr = true; // for good measure + setResponseHeadersFromConfig(header, d_responseConfig); + header.ancount = 0; + header.arcount = 0; // for now, forget about your EDNS, we're marching over it + return true; + }); uint32_t ttl = htonl(d_responseConfig.ttl); - uint16_t qclass = htons(dq->ids.qclass); - unsigned char recordstart[] = {0xc0, 0x0c, // compressed name - 0, 0, // QTYPE - 0, 0, // QCLASS - 0, 0, 0, 0, // TTL - 0, 0 }; // rdata length - static_assert(sizeof(recordstart) == 12, "sizeof(recordstart) must be equal to 12, otherwise the above check is invalid"); + uint16_t qclass = htons(dnsquestion->ids.qclass); + std::array<unsigned char, 12> recordstart = { + 0xc0, 0x0c, // compressed name + 0, 0, // QTYPE + 0, 0, // QCLASS + 0, 0, 0, 0, // TTL + 0, 0 // rdata length + }; + static_assert(recordstart.size() == 12, "sizeof(recordstart) must be equal to 12, otherwise the above check is invalid"); memcpy(&recordstart[4], &qclass, sizeof(qclass)); memcpy(&recordstart[6], &ttl, sizeof(ttl)); bool raw = false; @@ -898,50 +988,72 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu memcpy(&recordstart[2], &qtype, sizeof(qtype)); memcpy(&recordstart[10], &rdataLen, sizeof(rdataLen)); - memcpy(dest, recordstart, sizeof(recordstart)); - dest += sizeof(recordstart); + memcpy(dest, recordstart.data(), recordstart.size()); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) + dest += recordstart.size(); memcpy(dest, wireData.c_str(), wireData.length()); - dq->getHeader()->ancount++; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) { + header.ancount++; + return true; + }); } else if (!rawResponses.empty()) { + if (qtype == QType::ANY && d_rawTypeForAny) { + qtype = *d_rawTypeForAny; + } qtype = htons(qtype); - for(const auto& rawResponse : rawResponses){ + for (const auto& rawResponse : rawResponses) { uint16_t rdataLen = htons(rawResponse.size()); memcpy(&recordstart[2], &qtype, sizeof(qtype)); memcpy(&recordstart[10], &rdataLen, sizeof(rdataLen)); - memcpy(dest, recordstart, sizeof(recordstart)); - dest += sizeof(recordstart); + memcpy(dest, recordstart.data(), sizeof(recordstart)); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) + dest += recordstart.size(); memcpy(dest, rawResponse.c_str(), rawResponse.size()); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) dest += rawResponse.size(); - dq->getHeader()->ancount++; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) { + header.ancount++; + return true; + }); } raw = true; } else { - for(const auto& addr : addrs) { + for (const auto& addr : addrs) { uint16_t rdataLen = htons(addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr)); qtype = htons(addr.sin4.sin_family == AF_INET ? QType::A : QType::AAAA); memcpy(&recordstart[2], &qtype, sizeof(qtype)); memcpy(&recordstart[10], &rdataLen, sizeof(rdataLen)); - memcpy(dest, recordstart, sizeof(recordstart)); + memcpy(dest, recordstart.data(), recordstart.size()); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) dest += sizeof(recordstart); memcpy(dest, + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) addr.sin4.sin_family == AF_INET ? reinterpret_cast<const void*>(&addr.sin4.sin_addr.s_addr) : reinterpret_cast<const void*>(&addr.sin6.sin6_addr.s6_addr), addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr)); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) dest += (addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr)); - dq->getHeader()->ancount++; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) { + header.ancount++; + return true; + }); } } - dq->getHeader()->ancount = htons(dq->getHeader()->ancount); + auto finalANCount = dnsquestion->getHeader()->ancount; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [finalANCount](dnsheader& header) { + header.ancount = htons(finalANCount); + return true; + }); - if (hadEDNS && raw == false) { - addEDNS(dq->getMutableData(), dq->getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, 0); + if (hadEDNS && !raw) { + addEDNS(dnsquestion->getMutableData(), dnsquestion->getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, 0); } return Action::HeaderModify; @@ -951,56 +1063,62 @@ class SetMacAddrAction : public DNSAction { public: // this action does not stop the processing - SetMacAddrAction(uint16_t code) : d_code(code) + SetMacAddrAction(uint16_t code) : + d_code(code) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - dnsdist::MacAddress mac; - int res = dnsdist::MacAddressesCache::get(dq->ids.origRemote, mac.data(), mac.size()); + dnsdist::MacAddress mac{}; + int res = dnsdist::MacAddressesCache::get(dnsquestion->ids.origRemote, mac.data(), mac.size()); if (res != 0) { return Action::None; } std::string optRData; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) generateEDNSOption(d_code, reinterpret_cast<const char*>(mac.data()), optRData); - if (dq->getHeader()->arcount) { + if (dnsquestion->getHeader()->arcount > 0) { bool ednsAdded = false; bool optionAdded = false; PacketBuffer newContent; - newContent.reserve(dq->getData().size()); + newContent.reserve(dnsquestion->getData().size()); - if (!slowRewriteEDNSOptionInQueryWithRecords(dq->getData(), newContent, ednsAdded, d_code, optionAdded, true, optRData)) { + if (!slowRewriteEDNSOptionInQueryWithRecords(dnsquestion->getData(), newContent, ednsAdded, d_code, optionAdded, true, optRData)) { return Action::None; } - if (newContent.size() > dq->getMaximumSize()) { + if (newContent.size() > dnsquestion->getMaximumSize()) { return Action::None; } - dq->getMutableData() = std::move(newContent); - if (!dq->ids.ednsAdded && ednsAdded) { - dq->ids.ednsAdded = true; + dnsquestion->getMutableData() = std::move(newContent); + if (!dnsquestion->ids.ednsAdded && ednsAdded) { + dnsquestion->ids.ednsAdded = true; } return Action::None; } - auto& data = dq->getMutableData(); - if (generateOptRR(optRData, data, dq->getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) { - dq->getHeader()->arcount = htons(1); + auto& data = dnsquestion->getMutableData(); + if (generateOptRR(optRData, data, dnsquestion->getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) { + header.arcount = htons(1); + return true; + }); // make sure that any EDNS sent by the backend is removed before forwarding the response to the client - dq->ids.ednsAdded = true; + dnsquestion->ids.ednsAdded = true; } return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "add EDNS MAC (code=" + std::to_string(d_code) + ")"; } + private: uint16_t d_code{3}; }; @@ -1009,17 +1127,18 @@ class SetEDNSOptionAction : public DNSAction { public: // this action does not stop the processing - SetEDNSOptionAction(uint16_t code, const std::string& data) : d_code(code), d_data(data) + SetEDNSOptionAction(uint16_t code, std::string data) : + d_code(code), d_data(std::move(data)) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - setEDNSOption(*dq, d_code, d_data); + setEDNSOption(*dnsquestion, d_code, d_data); return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "add EDNS Option (code=" + std::to_string(d_code) + ")"; } @@ -1033,12 +1152,15 @@ class SetNoRecurseAction : public DNSAction { public: // this action does not stop the processing - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - dq->getHeader()->rd = false; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) { + header.rd = false; + return true; + }); return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "set rd=0"; } @@ -1048,69 +1170,68 @@ class LogAction : public DNSAction, public boost::noncopyable { public: // this action does not stop the processing - LogAction() - { - } + LogAction() = default; - LogAction(const std::string& str, bool binary=true, bool append=false, bool buffered=true, bool verboseOnly=true, bool includeTimestamp=false): d_fname(str), d_binary(binary), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp), d_append(append), d_buffered(buffered) + LogAction(const std::string& str, bool binary = true, bool append = false, bool buffered = true, bool verboseOnly = true, bool includeTimestamp = false) : + d_fname(str), d_binary(binary), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp), d_append(append), d_buffered(buffered) { if (str.empty()) { return; } - if (!reopenLogFile()) { + if (!reopenLogFile()) { throw std::runtime_error("Unable to open file '" + str + "' for logging: " + stringerror()); } } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - auto fp = std::atomic_load_explicit(&d_fp, std::memory_order_acquire); - if (!fp) { + auto filepointer = std::atomic_load_explicit(&d_fp, std::memory_order_acquire); + if (!filepointer) { if (!d_verboseOnly || g_verbose) { if (d_includeTimestamp) { - infolog("[%u.%u] Packet from %s for %s %s with id %d", static_cast<unsigned long long>(dq->getQueryRealTime().tv_sec), static_cast<unsigned long>(dq->getQueryRealTime().tv_nsec), dq->ids.origRemote.toStringWithPort(), dq->ids.qname.toString(), QType(dq->ids.qtype).toString(), dq->getHeader()->id); + infolog("[%u.%u] Packet from %s for %s %s with id %d", static_cast<unsigned long long>(dnsquestion->getQueryRealTime().tv_sec), static_cast<unsigned long>(dnsquestion->getQueryRealTime().tv_nsec), dnsquestion->ids.origRemote.toStringWithPort(), dnsquestion->ids.qname.toString(), QType(dnsquestion->ids.qtype).toString(), dnsquestion->getHeader()->id); } else { - infolog("Packet from %s for %s %s with id %d", dq->ids.origRemote.toStringWithPort(), dq->ids.qname.toString(), QType(dq->ids.qtype).toString(), dq->getHeader()->id); + infolog("Packet from %s for %s %s with id %d", dnsquestion->ids.origRemote.toStringWithPort(), dnsquestion->ids.qname.toString(), QType(dnsquestion->ids.qtype).toString(), dnsquestion->getHeader()->id); } } } else { if (d_binary) { - const auto& out = dq->ids.qname.getStorage(); + const auto& out = dnsquestion->ids.qname.getStorage(); if (d_includeTimestamp) { - uint64_t tv_sec = static_cast<uint64_t>(dq->getQueryRealTime().tv_sec); - uint32_t tv_nsec = static_cast<uint32_t>(dq->getQueryRealTime().tv_nsec); - fwrite(&tv_sec, sizeof(tv_sec), 1, fp.get()); - fwrite(&tv_nsec, sizeof(tv_nsec), 1, fp.get()); + auto tv_sec = static_cast<uint64_t>(dnsquestion->getQueryRealTime().tv_sec); + auto tv_nsec = static_cast<uint32_t>(dnsquestion->getQueryRealTime().tv_nsec); + fwrite(&tv_sec, sizeof(tv_sec), 1, filepointer.get()); + fwrite(&tv_nsec, sizeof(tv_nsec), 1, filepointer.get()); } - uint16_t id = dq->getHeader()->id; - fwrite(&id, sizeof(id), 1, fp.get()); - fwrite(out.c_str(), 1, out.size(), fp.get()); - fwrite(&dq->ids.qtype, sizeof(dq->ids.qtype), 1, fp.get()); - fwrite(&dq->ids.origRemote.sin4.sin_family, sizeof(dq->ids.origRemote.sin4.sin_family), 1, fp.get()); - if (dq->ids.origRemote.sin4.sin_family == AF_INET) { - fwrite(&dq->ids.origRemote.sin4.sin_addr.s_addr, sizeof(dq->ids.origRemote.sin4.sin_addr.s_addr), 1, fp.get()); + uint16_t queryId = dnsquestion->getHeader()->id; + fwrite(&queryId, sizeof(queryId), 1, filepointer.get()); + fwrite(out.c_str(), 1, out.size(), filepointer.get()); + fwrite(&dnsquestion->ids.qtype, sizeof(dnsquestion->ids.qtype), 1, filepointer.get()); + fwrite(&dnsquestion->ids.origRemote.sin4.sin_family, sizeof(dnsquestion->ids.origRemote.sin4.sin_family), 1, filepointer.get()); + if (dnsquestion->ids.origRemote.sin4.sin_family == AF_INET) { + fwrite(&dnsquestion->ids.origRemote.sin4.sin_addr.s_addr, sizeof(dnsquestion->ids.origRemote.sin4.sin_addr.s_addr), 1, filepointer.get()); } - else if (dq->ids.origRemote.sin4.sin_family == AF_INET6) { - fwrite(&dq->ids.origRemote.sin6.sin6_addr.s6_addr, sizeof(dq->ids.origRemote.sin6.sin6_addr.s6_addr), 1, fp.get()); + else if (dnsquestion->ids.origRemote.sin4.sin_family == AF_INET6) { + fwrite(&dnsquestion->ids.origRemote.sin6.sin6_addr.s6_addr, sizeof(dnsquestion->ids.origRemote.sin6.sin6_addr.s6_addr), 1, filepointer.get()); } - fwrite(&dq->ids.origRemote.sin4.sin_port, sizeof(dq->ids.origRemote.sin4.sin_port), 1, fp.get()); + fwrite(&dnsquestion->ids.origRemote.sin4.sin_port, sizeof(dnsquestion->ids.origRemote.sin4.sin_port), 1, filepointer.get()); } else { if (d_includeTimestamp) { - fprintf(fp.get(), "[%llu.%lu] Packet from %s for %s %s with id %u\n", static_cast<unsigned long long>(dq->getQueryRealTime().tv_sec), static_cast<unsigned long>(dq->getQueryRealTime().tv_nsec), dq->ids.origRemote.toStringWithPort().c_str(), dq->ids.qname.toString().c_str(), QType(dq->ids.qtype).toString().c_str(), dq->getHeader()->id); + fprintf(filepointer.get(), "[%llu.%lu] Packet from %s for %s %s with id %u\n", static_cast<unsigned long long>(dnsquestion->getQueryRealTime().tv_sec), static_cast<unsigned long>(dnsquestion->getQueryRealTime().tv_nsec), dnsquestion->ids.origRemote.toStringWithPort().c_str(), dnsquestion->ids.qname.toString().c_str(), QType(dnsquestion->ids.qtype).toString().c_str(), dnsquestion->getHeader()->id); } else { - fprintf(fp.get(), "Packet from %s for %s %s with id %u\n", dq->ids.origRemote.toStringWithPort().c_str(), dq->ids.qname.toString().c_str(), QType(dq->ids.qtype).toString().c_str(), dq->getHeader()->id); + fprintf(filepointer.get(), "Packet from %s for %s %s with id %u\n", dnsquestion->ids.origRemote.toStringWithPort().c_str(), dnsquestion->ids.qname.toString().c_str(), QType(dnsquestion->ids.qtype).toString().c_str(), dnsquestion->getHeader()->id); } } } return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { if (!d_fname.empty()) { return "log to " + d_fname; @@ -1131,20 +1252,21 @@ private: // we are using a naked pointer here because we don't want fclose to be called // with a nullptr, which would happen if we constructor a shared_ptr with fclose // as a custom deleter and nullptr as a FILE* - auto nfp = fopen(d_fname.c_str(), d_append ? "a+" : "w"); - if (!nfp) { + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) + auto* nfp = fopen(d_fname.c_str(), d_append ? "a+" : "w"); + if (nfp == nullptr) { /* don't fall on our sword when reopening */ return false; } - auto fp = std::shared_ptr<FILE>(nfp, fclose); + auto filepointer = std::shared_ptr<FILE>(nfp, fclose); nfp = nullptr; if (!d_buffered) { - setbuf(fp.get(), 0); + setbuf(filepointer.get(), nullptr); } - std::atomic_store_explicit(&d_fp, fp, std::memory_order_release); + std::atomic_store_explicit(&d_fp, std::move(filepointer), std::memory_order_release); return true; } @@ -1160,11 +1282,10 @@ private: class LogResponseAction : public DNSResponseAction, public boost::noncopyable { public: - LogResponseAction() - { - } + LogResponseAction() = default; - LogResponseAction(const std::string& str, bool append=false, bool buffered=true, bool verboseOnly=true, bool includeTimestamp=false): d_fname(str), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp), d_append(append), d_buffered(buffered) + LogResponseAction(const std::string& str, bool append = false, bool buffered = true, bool verboseOnly = true, bool includeTimestamp = false) : + d_fname(str), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp), d_append(append), d_buffered(buffered) { if (str.empty()) { return; @@ -1175,31 +1296,31 @@ public: } } - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { - auto fp = std::atomic_load_explicit(&d_fp, std::memory_order_acquire); - if (!fp) { + auto filepointer = std::atomic_load_explicit(&d_fp, std::memory_order_acquire); + if (!filepointer) { if (!d_verboseOnly || g_verbose) { if (d_includeTimestamp) { - infolog("[%u.%u] Answer to %s for %s %s (%s) with id %u", static_cast<unsigned long long>(dr->getQueryRealTime().tv_sec), static_cast<unsigned long>(dr->getQueryRealTime().tv_nsec), dr->ids.origRemote.toStringWithPort(), dr->ids.qname.toString(), QType(dr->ids.qtype).toString(), RCode::to_s(dr->getHeader()->rcode), dr->getHeader()->id); + infolog("[%u.%u] Answer to %s for %s %s (%s) with id %u", static_cast<unsigned long long>(response->getQueryRealTime().tv_sec), static_cast<unsigned long>(response->getQueryRealTime().tv_nsec), response->ids.origRemote.toStringWithPort(), response->ids.qname.toString(), QType(response->ids.qtype).toString(), RCode::to_s(response->getHeader()->rcode), response->getHeader()->id); } else { - infolog("Answer to %s for %s %s (%s) with id %u", dr->ids.origRemote.toStringWithPort(), dr->ids.qname.toString(), QType(dr->ids.qtype).toString(), RCode::to_s(dr->getHeader()->rcode), dr->getHeader()->id); + infolog("Answer to %s for %s %s (%s) with id %u", response->ids.origRemote.toStringWithPort(), response->ids.qname.toString(), QType(response->ids.qtype).toString(), RCode::to_s(response->getHeader()->rcode), response->getHeader()->id); } } } else { if (d_includeTimestamp) { - fprintf(fp.get(), "[%llu.%lu] Answer to %s for %s %s (%s) with id %u\n", static_cast<unsigned long long>(dr->getQueryRealTime().tv_sec), static_cast<unsigned long>(dr->getQueryRealTime().tv_nsec), dr->ids.origRemote.toStringWithPort().c_str(), dr->ids.qname.toString().c_str(), QType(dr->ids.qtype).toString().c_str(), RCode::to_s(dr->getHeader()->rcode).c_str(), dr->getHeader()->id); + fprintf(filepointer.get(), "[%llu.%lu] Answer to %s for %s %s (%s) with id %u\n", static_cast<unsigned long long>(response->getQueryRealTime().tv_sec), static_cast<unsigned long>(response->getQueryRealTime().tv_nsec), response->ids.origRemote.toStringWithPort().c_str(), response->ids.qname.toString().c_str(), QType(response->ids.qtype).toString().c_str(), RCode::to_s(response->getHeader()->rcode).c_str(), response->getHeader()->id); } else { - fprintf(fp.get(), "Answer to %s for %s %s (%s) with id %u\n", dr->ids.origRemote.toStringWithPort().c_str(), dr->ids.qname.toString().c_str(), QType(dr->ids.qtype).toString().c_str(), RCode::to_s(dr->getHeader()->rcode).c_str(), dr->getHeader()->id); + fprintf(filepointer.get(), "Answer to %s for %s %s (%s) with id %u\n", response->ids.origRemote.toStringWithPort().c_str(), response->ids.qname.toString().c_str(), QType(response->ids.qtype).toString().c_str(), RCode::to_s(response->getHeader()->rcode).c_str(), response->getHeader()->id); } } return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { if (!d_fname.empty()) { return "log to " + d_fname; @@ -1220,20 +1341,21 @@ private: // we are using a naked pointer here because we don't want fclose to be called // with a nullptr, which would happen if we constructor a shared_ptr with fclose // as a custom deleter and nullptr as a FILE* - auto nfp = fopen(d_fname.c_str(), d_append ? "a+" : "w"); - if (!nfp) { + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) + auto* nfp = fopen(d_fname.c_str(), d_append ? "a+" : "w"); + if (nfp == nullptr) { /* don't fall on our sword when reopening */ return false; } - auto fp = std::shared_ptr<FILE>(nfp, fclose); + auto filepointer = std::shared_ptr<FILE>(nfp, fclose); nfp = nullptr; if (!d_buffered) { - setbuf(fp.get(), 0); + setbuf(filepointer.get(), nullptr); } - std::atomic_store_explicit(&d_fp, fp, std::memory_order_release); + std::atomic_store_explicit(&d_fp, std::move(filepointer), std::memory_order_release); return true; } @@ -1249,12 +1371,15 @@ class SetDisableValidationAction : public DNSAction { public: // this action does not stop the processing - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - dq->getHeader()->cd = true; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) { + header.cd = true; + return true; + }); return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "set cd=1"; } @@ -1264,12 +1389,12 @@ class SetSkipCacheAction : public DNSAction { public: // this action does not stop the processing - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - dq->ids.skipCache = true; + dnsquestion->ids.skipCache = true; return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "skip cache"; } @@ -1278,12 +1403,12 @@ public: class SetSkipCacheResponseAction : public DNSResponseAction { public: - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { - dr->ids.skipCache = true; + response->ids.skipCache = true; return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "skip cache"; } @@ -1293,18 +1418,20 @@ class SetTempFailureCacheTTLAction : public DNSAction { public: // this action does not stop the processing - SetTempFailureCacheTTLAction(uint32_t ttl) : d_ttl(ttl) + SetTempFailureCacheTTLAction(uint32_t ttl) : + d_ttl(ttl) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - dq->ids.tempFailureTTL = d_ttl; + dnsquestion->ids.tempFailureTTL = d_ttl; return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { - return "set tempfailure cache ttl to "+std::to_string(d_ttl); + return "set tempfailure cache ttl to " + std::to_string(d_ttl); } + private: uint32_t d_ttl; }; @@ -1313,18 +1440,20 @@ class SetECSPrefixLengthAction : public DNSAction { public: // this action does not stop the processing - SetECSPrefixLengthAction(uint16_t v4Length, uint16_t v6Length) : d_v4PrefixLength(v4Length), d_v6PrefixLength(v6Length) + SetECSPrefixLengthAction(uint16_t v4Length, uint16_t v6Length) : + d_v4PrefixLength(v4Length), d_v6PrefixLength(v6Length) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - dq->ecsPrefixLength = dq->ids.origRemote.sin4.sin_family == AF_INET ? d_v4PrefixLength : d_v6PrefixLength; + dnsquestion->ecsPrefixLength = dnsquestion->ids.origRemote.sin4.sin_family == AF_INET ? d_v4PrefixLength : d_v6PrefixLength; return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "set ECS prefix length to " + std::to_string(d_v4PrefixLength) + "/" + std::to_string(d_v6PrefixLength); } + private: uint16_t d_v4PrefixLength; uint16_t d_v6PrefixLength; @@ -1334,33 +1463,34 @@ class SetECSOverrideAction : public DNSAction { public: // this action does not stop the processing - SetECSOverrideAction(bool ecsOverride) : d_ecsOverride(ecsOverride) + SetECSOverrideAction(bool ecsOverride) : + d_ecsOverride(ecsOverride) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - dq->ecsOverride = d_ecsOverride; + dnsquestion->ecsOverride = d_ecsOverride; return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { - return "set ECS override to " + std::to_string(d_ecsOverride); + return "set ECS override to " + std::to_string(static_cast<int>(d_ecsOverride)); } + private: bool d_ecsOverride; }; - class SetDisableECSAction : public DNSAction { public: // this action does not stop the processing - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - dq->useECS = false; + dnsquestion->useECS = false; return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "disable ECS"; } @@ -1370,27 +1500,29 @@ class SetECSAction : public DNSAction { public: // this action does not stop the processing - SetECSAction(const Netmask& v4): d_v4(v4), d_hasV6(false) + SetECSAction(const Netmask& v4Netmask) : + d_v4(v4Netmask), d_hasV6(false) { } - SetECSAction(const Netmask& v4, const Netmask& v6): d_v4(v4), d_v6(v6), d_hasV6(true) + SetECSAction(const Netmask& v4Netmask, const Netmask& v6Netmask) : + d_v4(v4Netmask), d_v6(v6Netmask), d_hasV6(true) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { if (d_hasV6) { - dq->ecs = std::make_unique<Netmask>(dq->ids.origRemote.isIPv4() ? d_v4 : d_v6); + dnsquestion->ecs = std::make_unique<Netmask>(dnsquestion->ids.origRemote.isIPv4() ? d_v4 : d_v6); } else { - dq->ecs = std::make_unique<Netmask>(d_v4); + dnsquestion->ecs = std::make_unique<Netmask>(d_v4); } return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { std::string result = "set ECS to " + d_v4.toString(); if (d_hasV6) { @@ -1411,41 +1543,44 @@ static DnstapMessage::ProtocolType ProtocolToDNSTap(dnsdist::Protocol protocol) if (protocol == dnsdist::Protocol::DoUDP) { return DnstapMessage::ProtocolType::DoUDP; } - else if (protocol == dnsdist::Protocol::DoTCP) { + if (protocol == dnsdist::Protocol::DoTCP) { return DnstapMessage::ProtocolType::DoTCP; } - else if (protocol == dnsdist::Protocol::DoT) { + if (protocol == dnsdist::Protocol::DoT) { return DnstapMessage::ProtocolType::DoT; } - else if (protocol == dnsdist::Protocol::DoH) { + if (protocol == dnsdist::Protocol::DoH || protocol == dnsdist::Protocol::DoH3) { return DnstapMessage::ProtocolType::DoH; } - else if (protocol == dnsdist::Protocol::DNSCryptUDP) { + if (protocol == dnsdist::Protocol::DNSCryptUDP) { return DnstapMessage::ProtocolType::DNSCryptUDP; } - else if (protocol == dnsdist::Protocol::DNSCryptTCP) { + if (protocol == dnsdist::Protocol::DNSCryptTCP) { return DnstapMessage::ProtocolType::DNSCryptTCP; } + if (protocol == dnsdist::Protocol::DoQ) { + return DnstapMessage::ProtocolType::DoQ; + } throw std::runtime_error("Unhandled protocol for dnstap: " + protocol.toPrettyString()); } -static void remoteLoggerQueueData(RemoteLoggerInterface& r, const std::string& data) +static void remoteLoggerQueueData(RemoteLoggerInterface& remoteLogger, const std::string& data) { - auto ret = r.queueData(data); + auto ret = remoteLogger.queueData(data); switch (ret) { case RemoteLoggerInterface::Result::Queued: break; case RemoteLoggerInterface::Result::PipeFull: { - vinfolog("%s: %s", r.name(), RemoteLoggerInterface::toErrorString(ret)); + vinfolog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret)); break; } case RemoteLoggerInterface::Result::TooLarge: { - warnlog("%s: %s", r.name(), RemoteLoggerInterface::toErrorString(ret)); + warnlog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret)); break; } case RemoteLoggerInterface::Result::OtherError: - warnlog("%s: %s", r.name(), RemoteLoggerInterface::toErrorString(ret)); + warnlog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret)); } } @@ -1453,51 +1588,57 @@ class DnstapLogAction : public DNSAction, public boost::noncopyable { public: // this action does not stop the processing - DnstapLogAction(const std::string& identity, std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)> > alterFunc): d_identity(identity), d_logger(logger), d_alterFunc(alterFunc) + DnstapLogAction(std::string identity, std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)>> alterFunc) : + d_identity(std::move(identity)), d_logger(logger), d_alterFunc(std::move(alterFunc)) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { static thread_local std::string data; data.clear(); - DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(dq->getProtocol()); - DnstapMessage message(data, !dq->getHeader()->qr ? DnstapMessage::MessageType::client_query : DnstapMessage::MessageType::client_response, d_identity, &dq->ids.origRemote, &dq->ids.origDest, protocol, reinterpret_cast<const char*>(dq->getData().data()), dq->getData().size(), &dq->getQueryRealTime(), nullptr); + DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(dnsquestion->getProtocol()); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + DnstapMessage message(std::move(data), !dnsquestion->getHeader()->qr ? DnstapMessage::MessageType::client_query : DnstapMessage::MessageType::client_response, d_identity, &dnsquestion->ids.origRemote, &dnsquestion->ids.origDest, protocol, reinterpret_cast<const char*>(dnsquestion->getData().data()), dnsquestion->getData().size(), &dnsquestion->getQueryRealTime(), nullptr); { if (d_alterFunc) { auto lock = g_lua.lock(); - (*d_alterFunc)(dq, &message); + (*d_alterFunc)(dnsquestion, &message); } } + data = message.getBuffer(); remoteLoggerQueueData(*d_logger, data); return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "remote log as dnstap to " + (d_logger ? d_logger->toString() : ""); } + private: std::string d_identity; std::shared_ptr<RemoteLoggerInterface> d_logger; - boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)> > d_alterFunc; + boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)>> d_alterFunc; }; -static void addMetaDataToProtobuf(DNSDistProtoBufMessage& message, const DNSQuestion& dq, const std::vector<std::pair<std::string, ProtoBufMetaKey>>& metas) +namespace +{ +void addMetaDataToProtobuf(DNSDistProtoBufMessage& message, const DNSQuestion& dnsquestion, const std::vector<std::pair<std::string, ProtoBufMetaKey>>& metas) { for (const auto& [name, meta] : metas) { - message.addMeta(name, meta.getValues(dq)); + message.addMeta(name, meta.getValues(dnsquestion), {}); } } -static void addTagsToProtobuf(DNSDistProtoBufMessage& message, const DNSQuestion& dq, const std::unordered_set<std::string>& allowed) +void addTagsToProtobuf(DNSDistProtoBufMessage& message, const DNSQuestion& dnsquestion, const std::unordered_set<std::string>& allowed) { - if (!dq.ids.qTag) { + if (!dnsquestion.ids.qTag) { return; } - for (const auto& [key, value] : *dq.ids.qTag) { + for (const auto& [key, value] : *dnsquestion.ids.qTag) { if (!allowed.empty() && allowed.count(key) == 0) { continue; } @@ -1506,49 +1647,81 @@ static void addTagsToProtobuf(DNSDistProtoBufMessage& message, const DNSQuestion message.addTag(key); } else { - message.addTag(key + ":" + value); + auto tag = key; + tag.append(":"); + tag.append(value); + message.addTag(tag); } } } +void addExtendedDNSErrorToProtobuf(DNSDistProtoBufMessage& message, const DNSResponse& response, const std::string& metaKey) +{ + auto [infoCode, extraText] = dnsdist::edns::getExtendedDNSError(response.getData()); + if (!infoCode) { + return; + } + + if (extraText) { + message.addMeta(metaKey, {*extraText}, {*infoCode}); + } + else { + message.addMeta(metaKey, {}, {*infoCode}); + } +} +} + +struct RemoteLogActionConfiguration +{ + std::vector<std::pair<std::string, ProtoBufMetaKey>> metas; + std::optional<std::unordered_set<std::string>> tagsToExport{std::nullopt}; + boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)>> alterQueryFunc{boost::none}; + boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)>> alterResponseFunc{boost::none}; + std::shared_ptr<RemoteLoggerInterface> logger; + std::string serverID; + std::string ipEncryptKey; + std::optional<std::string> exportExtendedErrorsToMeta{std::nullopt}; + bool includeCNAME{false}; +}; + class RemoteLogAction : public DNSAction, public boost::noncopyable { public: // this action does not stop the processing - RemoteLogAction(std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)> > alterFunc, const std::string& serverID, const std::string& ipEncryptKey, std::vector<std::pair<std::string, ProtoBufMetaKey>>&& metas, std::optional<std::unordered_set<std::string>>&& tagsToExport): d_tagsToExport(std::move(tagsToExport)), d_metas(std::move(metas)), d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_ipEncryptKey(ipEncryptKey) + RemoteLogAction(RemoteLogActionConfiguration& config) : + d_tagsToExport(std::move(config.tagsToExport)), d_metas(std::move(config.metas)), d_logger(config.logger), d_alterFunc(std::move(config.alterQueryFunc)), d_serverID(config.serverID), d_ipEncryptKey(config.ipEncryptKey) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - if (!dq->ids.d_protoBufData) { - dq->ids.d_protoBufData = std::make_unique<InternalQueryState::ProtoBufData>(); + if (!dnsquestion->ids.d_protoBufData) { + dnsquestion->ids.d_protoBufData = std::make_unique<InternalQueryState::ProtoBufData>(); } - if (!dq->ids.d_protoBufData->uniqueId) { - dq->ids.d_protoBufData->uniqueId = getUniqueID(); + if (!dnsquestion->ids.d_protoBufData->uniqueId) { + dnsquestion->ids.d_protoBufData->uniqueId = getUniqueID(); } - DNSDistProtoBufMessage message(*dq); + DNSDistProtoBufMessage message(*dnsquestion); if (!d_serverID.empty()) { message.setServerIdentity(d_serverID); } #if HAVE_IPCIPHER - if (!d_ipEncryptKey.empty()) - { - message.setRequestor(encryptCA(dq->ids.origRemote, d_ipEncryptKey)); + if (!d_ipEncryptKey.empty()) { + message.setRequestor(encryptCA(dnsquestion->ids.origRemote, d_ipEncryptKey)); } #endif /* HAVE_IPCIPHER */ if (d_tagsToExport) { - addTagsToProtobuf(message, *dq, *d_tagsToExport); + addTagsToProtobuf(message, *dnsquestion, *d_tagsToExport); } - addMetaDataToProtobuf(message, *dq, d_metas); + addMetaDataToProtobuf(message, *dnsquestion, d_metas); if (d_alterFunc) { auto lock = g_lua.lock(); - (*d_alterFunc)(dq, &message); + (*d_alterFunc)(dnsquestion, &message); } static thread_local std::string data; @@ -1558,15 +1731,16 @@ public: return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "remote log to " + (d_logger ? d_logger->toString() : ""); } + private: std::optional<std::unordered_set<std::string>> d_tagsToExport; std::vector<std::pair<std::string, ProtoBufMetaKey>> d_metas; std::shared_ptr<RemoteLoggerInterface> d_logger; - boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)> > d_alterFunc; + boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)>> d_alterFunc; std::string d_serverID; std::string d_ipEncryptKey; }; @@ -1577,21 +1751,23 @@ class SNMPTrapAction : public DNSAction { public: // this action does not stop the processing - SNMPTrapAction(const std::string& reason): d_reason(reason) + SNMPTrapAction(std::string reason) : + d_reason(std::move(reason)) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - if (g_snmpAgent && g_snmpTrapsEnabled) { - g_snmpAgent->sendDNSTrap(*dq, d_reason); + if (g_snmpAgent != nullptr && g_snmpTrapsEnabled) { + g_snmpAgent->sendDNSTrap(*dnsquestion, d_reason); } return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "send SNMP trap"; } + private: std::string d_reason; }; @@ -1600,19 +1776,21 @@ class SetTagAction : public DNSAction { public: // this action does not stop the processing - SetTagAction(const std::string& tag, const std::string& value): d_tag(tag), d_value(value) + SetTagAction(std::string tag, std::string value) : + d_tag(std::move(tag)), d_value(std::move(value)) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - dq->setTag(d_tag, d_value); + dnsquestion->setTag(d_tag, d_value); return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "set tag '" + d_tag + "' to value '" + d_value + "'"; } + private: std::string d_tag; std::string d_value; @@ -1623,76 +1801,84 @@ class DnstapLogResponseAction : public DNSResponseAction, public boost::noncopya { public: // this action does not stop the processing - DnstapLogResponseAction(const std::string& identity, std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSResponse*, DnstapMessage*)> > alterFunc): d_identity(identity), d_logger(logger), d_alterFunc(alterFunc) + DnstapLogResponseAction(std::string identity, std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSResponse*, DnstapMessage*)>> alterFunc) : + d_identity(std::move(identity)), d_logger(logger), d_alterFunc(std::move(alterFunc)) { } - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { static thread_local std::string data; - struct timespec now; + struct timespec now = {}; gettime(&now, true); data.clear(); - DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(dr->getProtocol()); - DnstapMessage message(data, DnstapMessage::MessageType::client_response, d_identity, &dr->ids.origRemote, &dr->ids.origDest, protocol, reinterpret_cast<const char*>(dr->getData().data()), dr->getData().size(), &dr->getQueryRealTime(), &now); + DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(response->getProtocol()); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + DnstapMessage message(std::move(data), DnstapMessage::MessageType::client_response, d_identity, &response->ids.origRemote, &response->ids.origDest, protocol, reinterpret_cast<const char*>(response->getData().data()), response->getData().size(), &response->getQueryRealTime(), &now); { if (d_alterFunc) { auto lock = g_lua.lock(); - (*d_alterFunc)(dr, &message); + (*d_alterFunc)(response, &message); } } + data = message.getBuffer(); remoteLoggerQueueData(*d_logger, data); return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "log response as dnstap to " + (d_logger ? d_logger->toString() : ""); } + private: std::string d_identity; std::shared_ptr<RemoteLoggerInterface> d_logger; - boost::optional<std::function<void(DNSResponse*, DnstapMessage*)> > d_alterFunc; + boost::optional<std::function<void(DNSResponse*, DnstapMessage*)>> d_alterFunc; }; class RemoteLogResponseAction : public DNSResponseAction, public boost::noncopyable { public: // this action does not stop the processing - RemoteLogResponseAction(std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)> > alterFunc, const std::string& serverID, const std::string& ipEncryptKey, bool includeCNAME, std::vector<std::pair<std::string, ProtoBufMetaKey>>&& metas, std::optional<std::unordered_set<std::string>>&& tagsToExport): d_tagsToExport(std::move(tagsToExport)), d_metas(std::move(metas)), d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_ipEncryptKey(ipEncryptKey), d_includeCNAME(includeCNAME) + RemoteLogResponseAction(RemoteLogActionConfiguration& config) : + d_tagsToExport(std::move(config.tagsToExport)), d_metas(std::move(config.metas)), d_logger(config.logger), d_alterFunc(std::move(config.alterResponseFunc)), d_serverID(config.serverID), d_ipEncryptKey(config.ipEncryptKey), d_exportExtendedErrorsToMeta(std::move(config.exportExtendedErrorsToMeta)), d_includeCNAME(config.includeCNAME) { } - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { - if (!dr->ids.d_protoBufData) { - dr->ids.d_protoBufData = std::make_unique<InternalQueryState::ProtoBufData>(); + if (!response->ids.d_protoBufData) { + response->ids.d_protoBufData = std::make_unique<InternalQueryState::ProtoBufData>(); } - if (!dr->ids.d_protoBufData->uniqueId) { - dr->ids.d_protoBufData->uniqueId = getUniqueID(); + if (!response->ids.d_protoBufData->uniqueId) { + response->ids.d_protoBufData->uniqueId = getUniqueID(); } - DNSDistProtoBufMessage message(*dr, d_includeCNAME); + DNSDistProtoBufMessage message(*response, d_includeCNAME); if (!d_serverID.empty()) { message.setServerIdentity(d_serverID); } #if HAVE_IPCIPHER - if (!d_ipEncryptKey.empty()) - { - message.setRequestor(encryptCA(dr->ids.origRemote, d_ipEncryptKey)); + if (!d_ipEncryptKey.empty()) { + message.setRequestor(encryptCA(response->ids.origRemote, d_ipEncryptKey)); } #endif /* HAVE_IPCIPHER */ if (d_tagsToExport) { - addTagsToProtobuf(message, *dr, *d_tagsToExport); + addTagsToProtobuf(message, *response, *d_tagsToExport); } - addMetaDataToProtobuf(message, *dr, d_metas); + addMetaDataToProtobuf(message, *response, d_metas); + + if (d_exportExtendedErrorsToMeta) { + addExtendedDNSErrorToProtobuf(message, *response, *d_exportExtendedErrorsToMeta); + } if (d_alterFunc) { auto lock = g_lua.lock(); - (*d_alterFunc)(dr, &message); + (*d_alterFunc)(response, &message); } static thread_local std::string data; @@ -1702,17 +1888,19 @@ public: return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "remote log response to " + (d_logger ? d_logger->toString() : ""); } + private: std::optional<std::unordered_set<std::string>> d_tagsToExport; std::vector<std::pair<std::string, ProtoBufMetaKey>> d_metas; std::shared_ptr<RemoteLoggerInterface> d_logger; - boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)> > d_alterFunc; + boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)>> d_alterFunc; std::string d_serverID; std::string d_ipEncryptKey; + std::optional<std::string> d_exportExtendedErrorsToMeta{std::nullopt}; bool d_includeCNAME; }; @@ -1721,11 +1909,11 @@ private: class DropResponseAction : public DNSResponseAction { public: - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { return Action::Drop; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "drop"; } @@ -1734,11 +1922,11 @@ public: class AllowResponseAction : public DNSResponseAction { public: - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { return Action::Allow; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "allow"; } @@ -1747,18 +1935,20 @@ public: class DelayResponseAction : public DNSResponseAction { public: - DelayResponseAction(int msec) : d_msec(msec) + DelayResponseAction(int msec) : + d_msec(msec) { } - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { *ruleresult = std::to_string(d_msec); return Action::Delay; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { - return "delay by "+std::to_string(d_msec)+ " msec"; + return "delay by " + std::to_string(d_msec) + " ms"; } + private: int d_msec; }; @@ -1768,21 +1958,23 @@ class SNMPTrapResponseAction : public DNSResponseAction { public: // this action does not stop the processing - SNMPTrapResponseAction(const std::string& reason): d_reason(reason) + SNMPTrapResponseAction(std::string reason) : + d_reason(std::move(reason)) { } - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { if (g_snmpAgent && g_snmpTrapsEnabled) { - g_snmpAgent->sendDNSTrap(*dr, d_reason); + g_snmpAgent->sendDNSTrap(*response, d_reason); } return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "send SNMP trap"; } + private: std::string d_reason; }; @@ -1792,19 +1984,21 @@ class SetTagResponseAction : public DNSResponseAction { public: // this action does not stop the processing - SetTagResponseAction(const std::string& tag, const std::string& value): d_tag(tag), d_value(value) + SetTagResponseAction(std::string tag, std::string value) : + d_tag(std::move(tag)), d_value(std::move(value)) { } - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { - dr->setTag(d_tag, d_value); + response->setTag(d_tag, d_value); return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "set tag '" + d_tag + "' to value '" + d_value + "'"; } + private: std::string d_tag; std::string d_value; @@ -1813,19 +2007,20 @@ private: class ClearRecordTypesResponseAction : public DNSResponseAction, public boost::noncopyable { public: - ClearRecordTypesResponseAction(const std::unordered_set<QType>& qtypes) : d_qtypes(qtypes) + ClearRecordTypesResponseAction(std::unordered_set<QType> qtypes) : + d_qtypes(std::move(qtypes)) { } - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { - if (d_qtypes.size() > 0) { - clearDNSPacketRecordTypes(dr->getMutableData(), d_qtypes); + if (!d_qtypes.empty()) { + clearDNSPacketRecordTypes(response->getMutableData(), d_qtypes); } return DNSResponseAction::Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "clear record types"; } @@ -1838,32 +2033,31 @@ class ContinueAction : public DNSAction { public: // this action does not stop the processing - ContinueAction(std::shared_ptr<DNSAction>& action): d_action(action) + ContinueAction(std::shared_ptr<DNSAction>& action) : + d_action(action) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { if (d_action) { /* call the action */ - auto action = (*d_action)(dq, ruleresult); + auto action = (*d_action)(dnsquestion, ruleresult); bool drop = false; /* apply the changes if needed (pool selection, flags, etc */ - processRulesResult(action, *dq, *ruleresult, drop); + processRulesResult(action, *dnsquestion, *ruleresult, drop); } /* but ignore the resulting action no matter what */ return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { if (d_action) { return "continue after: " + (d_action ? d_action->toString() : ""); } - else { - return "no op"; - } + return "no op"; } private: @@ -1871,32 +2065,41 @@ private: }; #ifdef HAVE_DNS_OVER_HTTPS -class HTTPStatusAction: public DNSAction +class HTTPStatusAction : public DNSAction { public: - HTTPStatusAction(int code, const PacketBuffer& body, const std::string& contentType): d_body(body), d_contentType(contentType), d_code(code) + HTTPStatusAction(int code, PacketBuffer body, std::string contentType) : + d_body(std::move(body)), d_contentType(std::move(contentType)), d_code(code) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - if (!dq->ids.du) { + if (!dnsquestion->ids.du) { return Action::None; } - dq->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType); - dq->getHeader()->qr = true; // for good measure - setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); + dnsquestion->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { + header.qr = true; // for good measure + setResponseHeadersFromConfig(header, d_responseConfig); + return true; + }); return Action::HeaderModify; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "return an HTTP status of " + std::to_string(d_code); } - ResponseConfig d_responseConfig; + [[nodiscard]] ResponseConfig& getResponseConfig() + { + return d_responseConfig; + } + private: + ResponseConfig d_responseConfig; PacketBuffer d_body; std::string d_contentType; int d_code; @@ -1908,26 +2111,27 @@ class KeyValueStoreLookupAction : public DNSAction { public: // this action does not stop the processing - KeyValueStoreLookupAction(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, const std::string& destinationTag): d_kvs(kvs), d_key(lookupKey), d_tag(destinationTag) + KeyValueStoreLookupAction(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, std::string destinationTag) : + d_kvs(kvs), d_key(lookupKey), d_tag(std::move(destinationTag)) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - std::vector<std::string> keys = d_key->getKeys(*dq); + std::vector<std::string> keys = d_key->getKeys(*dnsquestion); std::string result; for (const auto& key : keys) { - if (d_kvs->getValue(key, result) == true) { + if (d_kvs->getValue(key, result)) { break; } } - dq->setTag(d_tag, std::move(result)); + dnsquestion->setTag(d_tag, std::move(result)); return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "lookup key-value store based on '" + d_key->toString() + "' and set the result in tag '" + d_tag + "'"; } @@ -1942,26 +2146,27 @@ class KeyValueStoreRangeLookupAction : public DNSAction { public: // this action does not stop the processing - KeyValueStoreRangeLookupAction(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, const std::string& destinationTag): d_kvs(kvs), d_key(lookupKey), d_tag(destinationTag) + KeyValueStoreRangeLookupAction(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, std::string destinationTag) : + d_kvs(kvs), d_key(lookupKey), d_tag(std::move(destinationTag)) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - std::vector<std::string> keys = d_key->getKeys(*dq); + std::vector<std::string> keys = d_key->getKeys(*dnsquestion); std::string result; for (const auto& key : keys) { - if (d_kvs->getRangeValue(key, result) == true) { + if (d_kvs->getRangeValue(key, result)) { break; } } - dq->setTag(d_tag, std::move(result)); + dnsquestion->setTag(d_tag, std::move(result)); return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "do a range-based lookup in key-value store based on '" + d_key->toString() + "' and set the result in tag '" + d_tag + "'"; } @@ -1976,17 +2181,18 @@ private: class MaxReturnedTTLAction : public DNSAction { public: - MaxReturnedTTLAction(uint32_t cap) : d_cap(cap) + MaxReturnedTTLAction(uint32_t cap) : + d_cap(cap) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - dq->ids.ttlCap = d_cap; + dnsquestion->ids.ttlCap = d_cap; return DNSAction::Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "cap the TTL of the returned response to " + std::to_string(d_cap); } @@ -1998,17 +2204,18 @@ private: class MaxReturnedTTLResponseAction : public DNSResponseAction { public: - MaxReturnedTTLResponseAction(uint32_t cap) : d_cap(cap) + MaxReturnedTTLResponseAction(uint32_t cap) : + d_cap(cap) { } - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { - dr->ids.ttlCap = d_cap; + response->ids.ttlCap = d_cap; return DNSResponseAction::Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "cap the TTL of the returned response to " + std::to_string(d_cap); } @@ -2017,41 +2224,54 @@ private: uint32_t d_cap; }; -class NegativeAndSOAAction: public DNSAction +class NegativeAndSOAAction : public DNSAction { public: - NegativeAndSOAAction(bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum, bool soaInAuthoritySection): d_zone(zone), d_mname(mname), d_rname(rname), d_ttl(ttl), d_serial(serial), d_refresh(refresh), d_retry(retry), d_expire(expire), d_minimum(minimum), d_nxd(nxd), d_soaInAuthoritySection(soaInAuthoritySection) + struct SOAParams + { + uint32_t serial; + uint32_t refresh; + uint32_t retry; + uint32_t expire; + uint32_t minimum; + }; + + NegativeAndSOAAction(bool nxd, DNSName zone, uint32_t ttl, DNSName mname, DNSName rname, SOAParams params, bool soaInAuthoritySection) : + d_zone(std::move(zone)), d_mname(std::move(mname)), d_rname(std::move(rname)), d_ttl(ttl), d_params(params), d_nxd(nxd), d_soaInAuthoritySection(soaInAuthoritySection) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - if (!setNegativeAndAdditionalSOA(*dq, d_nxd, d_zone, d_ttl, d_mname, d_rname, d_serial, d_refresh, d_retry, d_expire, d_minimum, d_soaInAuthoritySection)) { + if (!setNegativeAndAdditionalSOA(*dnsquestion, d_nxd, d_zone, d_ttl, d_mname, d_rname, d_params.serial, d_params.refresh, d_params.retry, d_params.expire, d_params.minimum, d_soaInAuthoritySection)) { return Action::None; } - setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { + setResponseHeadersFromConfig(header, d_responseConfig); + return true; + }); return Action::Allow; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return std::string(d_nxd ? "NXD " : "NODATA") + " with SOA"; } + [[nodiscard]] ResponseConfig& getResponseConfig() + { + return d_responseConfig; + } +private: ResponseConfig d_responseConfig; -private: DNSName d_zone; DNSName d_mname; DNSName d_rname; uint32_t d_ttl; - uint32_t d_serial; - uint32_t d_refresh; - uint32_t d_retry; - uint32_t d_expire; - uint32_t d_minimum; + SOAParams d_params; bool d_nxd; bool d_soaInAuthoritySection; }; @@ -2068,18 +2288,18 @@ public: } } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - if (!dq->proxyProtocolValues) { - dq->proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(); + if (!dnsquestion->proxyProtocolValues) { + dnsquestion->proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(); } - *(dq->proxyProtocolValues) = d_values; + *(dnsquestion->proxyProtocolValues) = d_values; return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "set Proxy-Protocol values"; } @@ -2092,22 +2312,23 @@ class SetAdditionalProxyProtocolValueAction : public DNSAction { public: // this action does not stop the processing - SetAdditionalProxyProtocolValueAction(uint8_t type, const std::string& value): d_value(value), d_type(type) + SetAdditionalProxyProtocolValueAction(uint8_t type, std::string value) : + d_value(std::move(value)), d_type(type) { } - DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override { - if (!dq->proxyProtocolValues) { - dq->proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(); + if (!dnsquestion->proxyProtocolValues) { + dnsquestion->proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(); } - dq->proxyProtocolValues->push_back({ d_value, d_type }); + dnsquestion->proxyProtocolValues->push_back({d_value, d_type}); return Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "add a Proxy-Protocol value of type " + std::to_string(d_type); } @@ -2121,20 +2342,23 @@ class SetReducedTTLResponseAction : public DNSResponseAction, public boost::nonc { public: // this action does not stop the processing - SetReducedTTLResponseAction(uint8_t percentage) : d_ratio(percentage / 100.0) + SetReducedTTLResponseAction(uint8_t percentage) : + d_ratio(percentage / 100.0) { } - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override { + // NOLINTNEXTLINE(bugprone-easily-swappable-parameters) auto visitor = [&](uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl) { return ttl * d_ratio; }; - editDNSPacketTTL(reinterpret_cast<char *>(dr->getMutableData().data()), dr->getData().size(), visitor); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + editDNSPacketTTL(reinterpret_cast<char*>(response->getMutableData().data()), response->getData().size(), visitor); return DNSResponseAction::Action::None; } - std::string toString() const override + [[nodiscard]] std::string toString() const override { return "reduce ttl to " + std::to_string(d_ratio * 100) + " percent of its value"; } @@ -2143,23 +2367,76 @@ private: double d_ratio{1.0}; }; -template<typename T, typename ActionT> -static void addAction(GlobalStateHolder<vector<T> > *someRuleActions, const luadnsrule_t& var, const std::shared_ptr<ActionT>& action, boost::optional<luaruleparams_t>& params) { +class SetExtendedDNSErrorAction : public DNSAction +{ +public: + // this action does not stop the processing + SetExtendedDNSErrorAction(uint16_t infoCode, const std::string& extraText) + { + d_ede.infoCode = infoCode; + d_ede.extraText = extraText; + } + + DNSAction::Action operator()(DNSQuestion* dnsQuestion, std::string* ruleresult) const override + { + dnsQuestion->ids.d_extendedError = std::make_unique<EDNSExtendedError>(d_ede); + + return DNSAction::Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "set EDNS Extended DNS Error to " + std::to_string(d_ede.infoCode) + (d_ede.extraText.empty() ? std::string() : std::string(": \"") + d_ede.extraText + std::string("\"")); + } + +private: + EDNSExtendedError d_ede; +}; + +class SetExtendedDNSErrorResponseAction : public DNSResponseAction +{ +public: + // this action does not stop the processing + SetExtendedDNSErrorResponseAction(uint16_t infoCode, const std::string& extraText) + { + d_ede.infoCode = infoCode; + d_ede.extraText = extraText; + } + + DNSResponseAction::Action operator()(DNSResponse* dnsResponse, std::string* ruleresult) const override + { + dnsResponse->ids.d_extendedError = std::make_unique<EDNSExtendedError>(d_ede); + + return DNSResponseAction::Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "set EDNS Extended DNS Error to " + std::to_string(d_ede.infoCode) + (d_ede.extraText.empty() ? std::string() : std::string(": \"") + d_ede.extraText + std::string("\"")); + } + +private: + EDNSExtendedError d_ede; +}; + +template <typename T, typename ActionT> +static void addAction(GlobalStateHolder<vector<T>>* someRuleActions, const luadnsrule_t& var, const std::shared_ptr<ActionT>& action, boost::optional<luaruleparams_t>& params) +{ setLuaSideEffect(); std::string name; - boost::uuids::uuid uuid; - uint64_t creationOrder; + boost::uuids::uuid uuid{}; + uint64_t creationOrder = 0; parseRuleParams(params, uuid, name, creationOrder); checkAllParametersConsumed("addAction", params); - auto rule = makeRule(var); - someRuleActions->modify([&rule, &action, &uuid, creationOrder, &name](vector<T>& ruleactions){ - ruleactions.push_back({std::move(rule), std::move(action), std::move(name), std::move(uuid), creationOrder}); - }); + auto rule = makeRule(var, "addAction"); + someRuleActions->modify([&rule, &action, &uuid, creationOrder, &name](vector<T>& ruleactions) { + ruleactions.push_back({std::move(rule), std::move(action), std::move(name), uuid, creationOrder}); + }); } -typedef std::unordered_map<std::string, boost::variant<bool, uint32_t> > responseParams_t; +using responseParams_t = std::unordered_map<std::string, boost::variant<bool, uint32_t>>; static void parseResponseConfig(boost::optional<responseParams_t>& vars, ResponseConfig& config) { @@ -2169,193 +2446,208 @@ static void parseResponseConfig(boost::optional<responseParams_t>& vars, Respons getOptionalValue<bool>(vars, "ra", config.setRA); } -void setResponseHeadersFromConfig(dnsheader& dh, const ResponseConfig& config) +void setResponseHeadersFromConfig(dnsheader& dnsheader, const ResponseConfig& config) { if (config.setAA) { - dh.aa = *config.setAA; + dnsheader.aa = *config.setAA; } if (config.setAD) { - dh.ad = *config.setAD; + dnsheader.ad = *config.setAD; } else { - dh.ad = false; + dnsheader.ad = false; } if (config.setRA) { - dh.ra = *config.setRA; + dnsheader.ra = *config.setRA; } else { - dh.ra = dh.rd; // for good measure + dnsheader.ra = dnsheader.rd; // for good measure } } +// NOLINTNEXTLINE(readability-function-cognitive-complexity): this function declares Lua bindings, even with a good refactoring it will likely blow up the threshold void setupLuaActions(LuaContext& luaCtx) { - luaCtx.writeFunction("newRuleAction", [](luadnsrule_t dnsrule, std::shared_ptr<DNSAction> action, boost::optional<luaruleparams_t> params) { - boost::uuids::uuid uuid; - uint64_t creationOrder; - std::string name; - parseRuleParams(params, uuid, name, creationOrder); - checkAllParametersConsumed("newRuleAction", params); + luaCtx.writeFunction("newRuleAction", [](const luadnsrule_t& dnsrule, std::shared_ptr<DNSAction> action, boost::optional<luaruleparams_t> params) { + boost::uuids::uuid uuid{}; + uint64_t creationOrder = 0; + std::string name; + parseRuleParams(params, uuid, name, creationOrder); + checkAllParametersConsumed("newRuleAction", params); - auto rule = makeRule(dnsrule); - DNSDistRuleAction ra({std::move(rule), action, std::move(name), uuid, creationOrder}); - return std::make_shared<DNSDistRuleAction>(ra); - }); + auto rule = makeRule(dnsrule, "newRuleAction"); + DNSDistRuleAction ruleaction({std::move(rule), std::move(action), std::move(name), uuid, creationOrder}); + return std::make_shared<DNSDistRuleAction>(ruleaction); + }); - luaCtx.writeFunction("addAction", [](luadnsrule_t var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction> > era, boost::optional<luaruleparams_t> params) { - if (era.type() != typeid(std::shared_ptr<DNSAction>)) { - throw std::runtime_error("addAction() can only be called with query-related actions, not response-related ones. Are you looking for addResponseAction()?"); - } + luaCtx.writeFunction("addAction", [](const luadnsrule_t& var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction>> era, boost::optional<luaruleparams_t> params) { + if (era.type() != typeid(std::shared_ptr<DNSAction>)) { + throw std::runtime_error("addAction() can only be called with query-related actions, not response-related ones. Are you looking for addResponseAction()?"); + } - addAction(&g_ruleactions, var, boost::get<std::shared_ptr<DNSAction> >(era), params); - }); + addAction(&g_ruleactions, var, boost::get<std::shared_ptr<DNSAction>>(era), params); + }); - luaCtx.writeFunction("addResponseAction", [](luadnsrule_t var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction> > era, boost::optional<luaruleparams_t> params) { - if (era.type() != typeid(std::shared_ptr<DNSResponseAction>)) { - throw std::runtime_error("addResponseAction() can only be called with response-related actions, not query-related ones. Are you looking for addAction()?"); - } + luaCtx.writeFunction("addResponseAction", [](const luadnsrule_t& var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction>> era, boost::optional<luaruleparams_t> params) { + if (era.type() != typeid(std::shared_ptr<DNSResponseAction>)) { + throw std::runtime_error("addResponseAction() can only be called with response-related actions, not query-related ones. Are you looking for addAction()?"); + } - addAction(&g_respruleactions, var, boost::get<std::shared_ptr<DNSResponseAction> >(era), params); - }); + addAction(&g_respruleactions, var, boost::get<std::shared_ptr<DNSResponseAction>>(era), params); + }); - luaCtx.writeFunction("addCacheHitResponseAction", [](luadnsrule_t var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction>> era, boost::optional<luaruleparams_t> params) { - if (era.type() != typeid(std::shared_ptr<DNSResponseAction>)) { - throw std::runtime_error("addCacheHitResponseAction() can only be called with response-related actions, not query-related ones. Are you looking for addAction()?"); - } + luaCtx.writeFunction("addCacheHitResponseAction", [](const luadnsrule_t& var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction>> era, boost::optional<luaruleparams_t> params) { + if (era.type() != typeid(std::shared_ptr<DNSResponseAction>)) { + throw std::runtime_error("addCacheHitResponseAction() can only be called with response-related actions, not query-related ones. Are you looking for addAction()?"); + } - addAction(&g_cachehitrespruleactions, var, boost::get<std::shared_ptr<DNSResponseAction> >(era), params); - }); + addAction(&g_cachehitrespruleactions, var, boost::get<std::shared_ptr<DNSResponseAction>>(era), params); + }); - luaCtx.writeFunction("addCacheInsertedResponseAction", [](luadnsrule_t var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction>> era, boost::optional<luaruleparams_t> params) { + luaCtx.writeFunction("addCacheInsertedResponseAction", [](const luadnsrule_t& var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction>> era, boost::optional<luaruleparams_t> params) { if (era.type() != typeid(std::shared_ptr<DNSResponseAction>)) { throw std::runtime_error("addCacheInsertedResponseAction() can only be called with response-related actions, not query-related ones. Are you looking for addAction()?"); } - addAction(&g_cacheInsertedRespRuleActions, var, boost::get<std::shared_ptr<DNSResponseAction> >(era), params); + addAction(&g_cacheInsertedRespRuleActions, var, boost::get<std::shared_ptr<DNSResponseAction>>(era), params); }); - luaCtx.writeFunction("addSelfAnsweredResponseAction", [](luadnsrule_t var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction>> era, boost::optional<luaruleparams_t> params) { - if (era.type() != typeid(std::shared_ptr<DNSResponseAction>)) { - throw std::runtime_error("addSelfAnsweredResponseAction() can only be called with response-related actions, not query-related ones. Are you looking for addAction()?"); - } + luaCtx.writeFunction("addSelfAnsweredResponseAction", [](const luadnsrule_t& var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction>> era, boost::optional<luaruleparams_t> params) { + if (era.type() != typeid(std::shared_ptr<DNSResponseAction>)) { + throw std::runtime_error("addSelfAnsweredResponseAction() can only be called with response-related actions, not query-related ones. Are you looking for addAction()?"); + } - addAction(&g_selfansweredrespruleactions, var, boost::get<std::shared_ptr<DNSResponseAction> >(era), params); - }); + addAction(&g_selfansweredrespruleactions, var, boost::get<std::shared_ptr<DNSResponseAction>>(era), params); + }); - luaCtx.registerFunction<void(DNSAction::*)()const>("printStats", [](const DNSAction& ta) { - setLuaNoSideEffect(); - auto stats = ta.getStats(); - for(const auto& s : stats) { - g_outputBuffer+=s.first+"\t"; - if((uint64_t)s.second == s.second) - g_outputBuffer += std::to_string((uint64_t)s.second)+"\n"; - else - g_outputBuffer += std::to_string(s.second)+"\n"; + luaCtx.registerFunction<void (DNSAction::*)() const>("printStats", [](const DNSAction& action) { + setLuaNoSideEffect(); + auto stats = action.getStats(); + for (const auto& stat : stats) { + g_outputBuffer += stat.first + "\t"; + double integral = 0; + if (std::modf(stat.second, &integral) == 0.0 && stat.second < static_cast<double>(std::numeric_limits<uint64_t>::max())) { + g_outputBuffer += std::to_string(static_cast<uint64_t>(stat.second)) + "\n"; } - }); + else { + g_outputBuffer += std::to_string(stat.second) + "\n"; + } + } + }); luaCtx.writeFunction("getAction", [](unsigned int num) { - setLuaNoSideEffect(); - boost::optional<std::shared_ptr<DNSAction>> ret; - auto ruleactions = g_ruleactions.getCopy(); - if(num < ruleactions.size()) - ret=ruleactions[num].d_action; - return ret; - }); + setLuaNoSideEffect(); + boost::optional<std::shared_ptr<DNSAction>> ret; + auto ruleactions = g_ruleactions.getCopy(); + if (num < ruleactions.size()) { + ret = ruleactions[num].d_action; + } + return ret; + }); luaCtx.registerFunction("getStats", &DNSAction::getStats); luaCtx.registerFunction("reload", &DNSAction::reload); luaCtx.registerFunction("reload", &DNSResponseAction::reload); luaCtx.writeFunction("LuaAction", [](LuaAction::func_t func) { - setLuaSideEffect(); - return std::shared_ptr<DNSAction>(new LuaAction(func)); - }); + setLuaSideEffect(); + return std::shared_ptr<DNSAction>(new LuaAction(std::move(func))); + }); luaCtx.writeFunction("LuaFFIAction", [](LuaFFIAction::func_t func) { - setLuaSideEffect(); - return std::shared_ptr<DNSAction>(new LuaFFIAction(func)); - }); + setLuaSideEffect(); + return std::shared_ptr<DNSAction>(new LuaFFIAction(std::move(func))); + }); luaCtx.writeFunction("LuaFFIPerThreadAction", [](const std::string& code) { - setLuaSideEffect(); - return std::shared_ptr<DNSAction>(new LuaFFIPerThreadAction(code)); - }); + setLuaSideEffect(); + return std::shared_ptr<DNSAction>(new LuaFFIPerThreadAction(code)); + }); luaCtx.writeFunction("SetNoRecurseAction", []() { - return std::shared_ptr<DNSAction>(new SetNoRecurseAction); - }); + return std::shared_ptr<DNSAction>(new SetNoRecurseAction); + }); luaCtx.writeFunction("SetMacAddrAction", [](int code) { - return std::shared_ptr<DNSAction>(new SetMacAddrAction(code)); - }); + return std::shared_ptr<DNSAction>(new SetMacAddrAction(code)); + }); luaCtx.writeFunction("SetEDNSOptionAction", [](int code, const std::string& data) { - return std::shared_ptr<DNSAction>(new SetEDNSOptionAction(code, data)); - }); + return std::shared_ptr<DNSAction>(new SetEDNSOptionAction(code, data)); + }); - luaCtx.writeFunction("PoolAction", [](const std::string& a, boost::optional<bool> stopProcessing) { - return std::shared_ptr<DNSAction>(new PoolAction(a, stopProcessing ? *stopProcessing : true)); - }); + luaCtx.writeFunction("PoolAction", [](const std::string& poolname, boost::optional<bool> stopProcessing) { + return std::shared_ptr<DNSAction>(new PoolAction(poolname, stopProcessing ? *stopProcessing : true)); + }); luaCtx.writeFunction("QPSAction", [](int limit) { - return std::shared_ptr<DNSAction>(new QPSAction(limit)); - }); + return std::shared_ptr<DNSAction>(new QPSAction(limit)); + }); - luaCtx.writeFunction("QPSPoolAction", [](int limit, const std::string& a, boost::optional<bool> stopProcessing) { - return std::shared_ptr<DNSAction>(new QPSPoolAction(limit, a, stopProcessing ? *stopProcessing : true)); - }); + luaCtx.writeFunction("QPSPoolAction", [](int limit, const std::string& poolname, boost::optional<bool> stopProcessing) { + return std::shared_ptr<DNSAction>(new QPSPoolAction(limit, poolname, stopProcessing ? *stopProcessing : true)); + }); luaCtx.writeFunction("SpoofAction", [](LuaTypeOrArrayOf<std::string> inp, boost::optional<responseParams_t> vars) { - vector<ComboAddress> addrs; - if(auto s = boost::get<std::string>(&inp)) { - addrs.push_back(ComboAddress(*s)); - } else { - const auto& v = boost::get<LuaArray<std::string>>(inp); - for(const auto& a: v) { - addrs.push_back(ComboAddress(a.second)); - } + vector<ComboAddress> addrs; + if (auto* ipaddr = boost::get<std::string>(&inp)) { + addrs.emplace_back(*ipaddr); + } + else { + const auto& ipsArray = boost::get<LuaArray<std::string>>(inp); + for (const auto& ipAddr : ipsArray) { + addrs.emplace_back(ipAddr.second); } + } - auto ret = std::shared_ptr<DNSAction>(new SpoofAction(addrs)); - auto sa = std::dynamic_pointer_cast<SpoofAction>(ret); - parseResponseConfig(vars, sa->d_responseConfig); - checkAllParametersConsumed("SpoofAction", vars); - return ret; - }); + auto ret = std::shared_ptr<DNSAction>(new SpoofAction(addrs)); + auto spoofaction = std::dynamic_pointer_cast<SpoofAction>(ret); + parseResponseConfig(vars, spoofaction->getResponseConfig()); + checkAllParametersConsumed("SpoofAction", vars); + return ret; + }); luaCtx.writeFunction("SpoofSVCAction", [](const LuaArray<SVCRecordParameters>& parameters, boost::optional<responseParams_t> vars) { - auto ret = std::shared_ptr<DNSAction>(new SpoofSVCAction(parameters)); - auto sa = std::dynamic_pointer_cast<SpoofSVCAction>(ret); - parseResponseConfig(vars, sa->d_responseConfig); - return ret; - }); + auto ret = std::shared_ptr<DNSAction>(new SpoofSVCAction(parameters)); + auto spoofaction = std::dynamic_pointer_cast<SpoofSVCAction>(ret); + parseResponseConfig(vars, spoofaction->getResponseConfig()); + return ret; + }); - luaCtx.writeFunction("SpoofCNAMEAction", [](const std::string& a, boost::optional<responseParams_t> vars) { - auto ret = std::shared_ptr<DNSAction>(new SpoofAction(DNSName(a))); - auto sa = std::dynamic_pointer_cast<SpoofAction>(ret); - parseResponseConfig(vars, sa->d_responseConfig); - checkAllParametersConsumed("SpoofCNAMEAction", vars); - return ret; - }); + luaCtx.writeFunction("SpoofCNAMEAction", [](const std::string& cname, boost::optional<responseParams_t> vars) { + auto ret = std::shared_ptr<DNSAction>(new SpoofAction(DNSName(cname))); + auto spoofaction = std::dynamic_pointer_cast<SpoofAction>(ret); + parseResponseConfig(vars, spoofaction->getResponseConfig()); + checkAllParametersConsumed("SpoofCNAMEAction", vars); + return ret; + }); luaCtx.writeFunction("SpoofRawAction", [](LuaTypeOrArrayOf<std::string> inp, boost::optional<responseParams_t> vars) { - vector<string> raws; - if(auto s = boost::get<std::string>(&inp)) { - raws.push_back(*s); - } else { - const auto& v = boost::get<LuaArray<std::string>>(inp); - for(const auto& raw: v) { - raws.push_back(raw.second); - } + vector<string> raws; + if (const auto* str = boost::get<std::string>(&inp)) { + raws.push_back(*str); + } + else { + const auto& vect = boost::get<LuaArray<std::string>>(inp); + for (const auto& raw : vect) { + raws.push_back(raw.second); } - - auto ret = std::shared_ptr<DNSAction>(new SpoofAction(raws)); - auto sa = std::dynamic_pointer_cast<SpoofAction>(ret); - parseResponseConfig(vars, sa->d_responseConfig); - checkAllParametersConsumed("SpoofRawAction", vars); - return ret; - }); + } + uint32_t qtypeForAny{0}; + getOptionalValue<uint32_t>(vars, "typeForAny", qtypeForAny); + if (qtypeForAny > std::numeric_limits<uint16_t>::max()) { + qtypeForAny = 0; + } + std::optional<uint16_t> qtypeForAnyParam; + if (qtypeForAny > 0) { + qtypeForAnyParam = static_cast<uint16_t>(qtypeForAny); + } + auto ret = std::shared_ptr<DNSAction>(new SpoofAction(raws, qtypeForAnyParam)); + auto spoofaction = std::dynamic_pointer_cast<SpoofAction>(ret); + parseResponseConfig(vars, spoofaction->getResponseConfig()); + checkAllParametersConsumed("SpoofRawAction", vars); + return ret; + }); luaCtx.writeFunction("SpoofPacketAction", [](const std::string& response, size_t len) { if (len < sizeof(dnsheader)) { @@ -2363,58 +2655,62 @@ void setupLuaActions(LuaContext& luaCtx) } auto ret = std::shared_ptr<DNSAction>(new SpoofAction(response.c_str(), len)); return ret; - }); + }); luaCtx.writeFunction("DropAction", []() { - return std::shared_ptr<DNSAction>(new DropAction); - }); + return std::shared_ptr<DNSAction>(new DropAction); + }); luaCtx.writeFunction("AllowAction", []() { - return std::shared_ptr<DNSAction>(new AllowAction); - }); + return std::shared_ptr<DNSAction>(new AllowAction); + }); luaCtx.writeFunction("NoneAction", []() { - return std::shared_ptr<DNSAction>(new NoneAction); - }); + return std::shared_ptr<DNSAction>(new NoneAction); + }); luaCtx.writeFunction("DelayAction", [](int msec) { - return std::shared_ptr<DNSAction>(new DelayAction(msec)); - }); + return std::shared_ptr<DNSAction>(new DelayAction(msec)); + }); luaCtx.writeFunction("TCAction", []() { - return std::shared_ptr<DNSAction>(new TCAction); - }); + return std::shared_ptr<DNSAction>(new TCAction); + }); + + luaCtx.writeFunction("TCResponseAction", []() { + return std::shared_ptr<DNSResponseAction>(new TCResponseAction); + }); luaCtx.writeFunction("SetDisableValidationAction", []() { - return std::shared_ptr<DNSAction>(new SetDisableValidationAction); - }); + return std::shared_ptr<DNSAction>(new SetDisableValidationAction); + }); luaCtx.writeFunction("LogAction", [](boost::optional<std::string> fname, boost::optional<bool> binary, boost::optional<bool> append, boost::optional<bool> buffered, boost::optional<bool> verboseOnly, boost::optional<bool> includeTimestamp) { - return std::shared_ptr<DNSAction>(new LogAction(fname ? *fname : "", binary ? *binary : true, append ? *append : false, buffered ? *buffered : false, verboseOnly ? *verboseOnly : true, includeTimestamp ? *includeTimestamp : false)); - }); + return std::shared_ptr<DNSAction>(new LogAction(fname ? *fname : "", binary ? *binary : true, append ? *append : false, buffered ? *buffered : false, verboseOnly ? *verboseOnly : true, includeTimestamp ? *includeTimestamp : false)); + }); luaCtx.writeFunction("LogResponseAction", [](boost::optional<std::string> fname, boost::optional<bool> append, boost::optional<bool> buffered, boost::optional<bool> verboseOnly, boost::optional<bool> includeTimestamp) { - return std::shared_ptr<DNSResponseAction>(new LogResponseAction(fname ? *fname : "", append ? *append : false, buffered ? *buffered : false, verboseOnly ? *verboseOnly : true, includeTimestamp ? *includeTimestamp : false)); - }); + return std::shared_ptr<DNSResponseAction>(new LogResponseAction(fname ? *fname : "", append ? *append : false, buffered ? *buffered : false, verboseOnly ? *verboseOnly : true, includeTimestamp ? *includeTimestamp : false)); + }); luaCtx.writeFunction("LimitTTLResponseAction", [](uint32_t min, uint32_t max, boost::optional<LuaArray<uint16_t>> types) { - std::unordered_set<QType> capTypes; - if (types) { - capTypes.reserve(types->size()); - for (const auto& [idx, type] : *types) { - capTypes.insert(QType(type)); - } + std::unordered_set<QType> capTypes; + if (types) { + capTypes.reserve(types->size()); + for (const auto& [idx, type] : *types) { + capTypes.insert(QType(type)); } - return std::shared_ptr<DNSResponseAction>(new LimitTTLResponseAction(min, max, capTypes)); - }); + } + return std::shared_ptr<DNSResponseAction>(new LimitTTLResponseAction(min, max, capTypes)); + }); luaCtx.writeFunction("SetMinTTLResponseAction", [](uint32_t min) { - return std::shared_ptr<DNSResponseAction>(new LimitTTLResponseAction(min)); - }); + return std::shared_ptr<DNSResponseAction>(new LimitTTLResponseAction(min)); + }); luaCtx.writeFunction("SetMaxTTLResponseAction", [](uint32_t max) { - return std::shared_ptr<DNSResponseAction>(new LimitTTLResponseAction(0, max)); - }); + return std::shared_ptr<DNSResponseAction>(new LimitTTLResponseAction(0, max)); + }); luaCtx.writeFunction("SetMaxReturnedTTLAction", [](uint32_t max) { return std::shared_ptr<DNSAction>(new MaxReturnedTTLAction(max)); @@ -2425,257 +2721,272 @@ void setupLuaActions(LuaContext& luaCtx) }); luaCtx.writeFunction("SetReducedTTLResponseAction", [](uint8_t percentage) { - if (percentage > 100) { - throw std::runtime_error(std::string("SetReducedTTLResponseAction takes a percentage between 0 and 100.")); - } - return std::shared_ptr<DNSResponseAction>(new SetReducedTTLResponseAction(percentage)); - }); + if (percentage > 100) { + throw std::runtime_error(std::string("SetReducedTTLResponseAction takes a percentage between 0 and 100.")); + } + return std::shared_ptr<DNSResponseAction>(new SetReducedTTLResponseAction(percentage)); + }); luaCtx.writeFunction("ClearRecordTypesResponseAction", [](LuaTypeOrArrayOf<int> types) { - std::unordered_set<QType> qtypes{}; - if (types.type() == typeid(int)) { - qtypes.insert(boost::get<int>(types)); - } else if (types.type() == typeid(LuaArray<int>)) { - const auto& v = boost::get<LuaArray<int>>(types); - for (const auto& tpair: v) { - qtypes.insert(tpair.second); - } + std::unordered_set<QType> qtypes{}; + if (types.type() == typeid(int)) { + qtypes.insert(boost::get<int>(types)); + } + else if (types.type() == typeid(LuaArray<int>)) { + const auto& typesArray = boost::get<LuaArray<int>>(types); + for (const auto& tpair : typesArray) { + qtypes.insert(tpair.second); } - return std::shared_ptr<DNSResponseAction>(new ClearRecordTypesResponseAction(qtypes)); - }); + } + return std::shared_ptr<DNSResponseAction>(new ClearRecordTypesResponseAction(std::move(qtypes))); + }); luaCtx.writeFunction("RCodeAction", [](uint8_t rcode, boost::optional<responseParams_t> vars) { - auto ret = std::shared_ptr<DNSAction>(new RCodeAction(rcode)); - auto rca = std::dynamic_pointer_cast<RCodeAction>(ret); - parseResponseConfig(vars, rca->d_responseConfig); - checkAllParametersConsumed("RCodeAction", vars); - return ret; - }); + auto ret = std::shared_ptr<DNSAction>(new RCodeAction(rcode)); + auto rca = std::dynamic_pointer_cast<RCodeAction>(ret); + parseResponseConfig(vars, rca->getResponseConfig()); + checkAllParametersConsumed("RCodeAction", vars); + return ret; + }); luaCtx.writeFunction("ERCodeAction", [](uint8_t rcode, boost::optional<responseParams_t> vars) { - auto ret = std::shared_ptr<DNSAction>(new ERCodeAction(rcode)); - auto erca = std::dynamic_pointer_cast<ERCodeAction>(ret); - parseResponseConfig(vars, erca->d_responseConfig); - checkAllParametersConsumed("ERCodeAction", vars); - return ret; - }); + auto ret = std::shared_ptr<DNSAction>(new ERCodeAction(rcode)); + auto erca = std::dynamic_pointer_cast<ERCodeAction>(ret); + parseResponseConfig(vars, erca->getResponseConfig()); + checkAllParametersConsumed("ERCodeAction", vars); + return ret; + }); luaCtx.writeFunction("SetSkipCacheAction", []() { - return std::shared_ptr<DNSAction>(new SetSkipCacheAction); - }); + return std::shared_ptr<DNSAction>(new SetSkipCacheAction); + }); luaCtx.writeFunction("SetSkipCacheResponseAction", []() { - return std::shared_ptr<DNSResponseAction>(new SetSkipCacheResponseAction); - }); + return std::shared_ptr<DNSResponseAction>(new SetSkipCacheResponseAction); + }); luaCtx.writeFunction("SetTempFailureCacheTTLAction", [](int maxTTL) { - return std::shared_ptr<DNSAction>(new SetTempFailureCacheTTLAction(maxTTL)); - }); + return std::shared_ptr<DNSAction>(new SetTempFailureCacheTTLAction(maxTTL)); + }); luaCtx.writeFunction("DropResponseAction", []() { - return std::shared_ptr<DNSResponseAction>(new DropResponseAction); - }); + return std::shared_ptr<DNSResponseAction>(new DropResponseAction); + }); luaCtx.writeFunction("AllowResponseAction", []() { - return std::shared_ptr<DNSResponseAction>(new AllowResponseAction); - }); + return std::shared_ptr<DNSResponseAction>(new AllowResponseAction); + }); luaCtx.writeFunction("DelayResponseAction", [](int msec) { - return std::shared_ptr<DNSResponseAction>(new DelayResponseAction(msec)); - }); + return std::shared_ptr<DNSResponseAction>(new DelayResponseAction(msec)); + }); luaCtx.writeFunction("LuaResponseAction", [](LuaResponseAction::func_t func) { - setLuaSideEffect(); - return std::shared_ptr<DNSResponseAction>(new LuaResponseAction(func)); - }); + setLuaSideEffect(); + return std::shared_ptr<DNSResponseAction>(new LuaResponseAction(std::move(func))); + }); luaCtx.writeFunction("LuaFFIResponseAction", [](LuaFFIResponseAction::func_t func) { - setLuaSideEffect(); - return std::shared_ptr<DNSResponseAction>(new LuaFFIResponseAction(func)); - }); + setLuaSideEffect(); + return std::shared_ptr<DNSResponseAction>(new LuaFFIResponseAction(std::move(func))); + }); luaCtx.writeFunction("LuaFFIPerThreadResponseAction", [](const std::string& code) { - setLuaSideEffect(); - return std::shared_ptr<DNSResponseAction>(new LuaFFIPerThreadResponseAction(code)); - }); + setLuaSideEffect(); + return std::shared_ptr<DNSResponseAction>(new LuaFFIPerThreadResponseAction(code)); + }); #ifndef DISABLE_PROTOBUF - luaCtx.writeFunction("RemoteLogAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<LuaAssociativeTable<std::string>> vars, boost::optional<LuaAssociativeTable<std::string>> metas) { - if (logger) { - // avoids potentially-evaluated-expression warning with clang. - RemoteLoggerInterface& rl = *logger.get(); - if (typeid(rl) != typeid(RemoteLogger)) { - // We could let the user do what he wants, but wrapping PowerDNS Protobuf inside a FrameStream tagged as dnstap is logically wrong. - throw std::runtime_error(std::string("RemoteLogAction only takes RemoteLogger. For other types, please look at DnstapLogAction.")); - } + luaCtx.writeFunction("RemoteLogAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)>> alterFunc, boost::optional<LuaAssociativeTable<std::string>> vars, boost::optional<LuaAssociativeTable<std::string>> metas) { + if (logger) { + // avoids potentially-evaluated-expression warning with clang. + RemoteLoggerInterface& remoteLoggerRef = *logger; + if (typeid(remoteLoggerRef) != typeid(RemoteLogger)) { + // We could let the user do what he wants, but wrapping PowerDNS Protobuf inside a FrameStream tagged as dnstap is logically wrong. + throw std::runtime_error(std::string("RemoteLogAction only takes RemoteLogger. For other types, please look at DnstapLogAction.")); } + } - std::string serverID; - std::string ipEncryptKey; - std::string tags; - getOptionalValue<std::string>(vars, "serverID", serverID); - getOptionalValue<std::string>(vars, "ipEncryptKey", ipEncryptKey); - getOptionalValue<std::string>(vars, "exportTags", tags); - - std::vector<std::pair<std::string, ProtoBufMetaKey>> metaOptions; - if (metas) { - for (const auto& [key, value] : *metas) { - metaOptions.push_back({key, ProtoBufMetaKey(value)}); - } + std::string tags; + RemoteLogActionConfiguration config; + config.logger = std::move(logger); + config.alterQueryFunc = std::move(alterFunc); + getOptionalValue<std::string>(vars, "serverID", config.serverID); + getOptionalValue<std::string>(vars, "ipEncryptKey", config.ipEncryptKey); + getOptionalValue<std::string>(vars, "exportTags", tags); + + if (metas) { + for (const auto& [key, value] : *metas) { + config.metas.emplace_back(key, ProtoBufMetaKey(value)); } + } - std::optional<std::unordered_set<std::string>> tagsToExport{std::nullopt}; - if (!tags.empty()) { - tagsToExport = std::unordered_set<std::string>(); - if (tags != "*") { - std::vector<std::string> tokens; - stringtok(tokens, tags, ","); - for (auto& token : tokens) { - tagsToExport->insert(std::move(token)); - } + if (!tags.empty()) { + config.tagsToExport = std::unordered_set<std::string>(); + if (tags != "*") { + std::vector<std::string> tokens; + stringtok(tokens, tags, ","); + for (auto& token : tokens) { + config.tagsToExport->insert(std::move(token)); } } + } - checkAllParametersConsumed("RemoteLogAction", vars); + checkAllParametersConsumed("RemoteLogAction", vars); - return std::shared_ptr<DNSAction>(new RemoteLogAction(logger, alterFunc, serverID, ipEncryptKey, std::move(metaOptions), std::move(tagsToExport))); - }); + return std::shared_ptr<DNSAction>(new RemoteLogAction(config)); + }); - luaCtx.writeFunction("RemoteLogResponseAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<bool> includeCNAME, boost::optional<LuaAssociativeTable<std::string>> vars, boost::optional<LuaAssociativeTable<std::string>> metas) { - if (logger) { - // avoids potentially-evaluated-expression warning with clang. - RemoteLoggerInterface& rl = *logger.get(); - if (typeid(rl) != typeid(RemoteLogger)) { - // We could let the user do what he wants, but wrapping PowerDNS Protobuf inside a FrameStream tagged as dnstap is logically wrong. - throw std::runtime_error("RemoteLogResponseAction only takes RemoteLogger. For other types, please look at DnstapLogResponseAction."); - } + luaCtx.writeFunction("RemoteLogResponseAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)>> alterFunc, boost::optional<bool> includeCNAME, boost::optional<LuaAssociativeTable<std::string>> vars, boost::optional<LuaAssociativeTable<std::string>> metas) { + if (logger) { + // avoids potentially-evaluated-expression warning with clang. + RemoteLoggerInterface& remoteLoggerRef = *logger; + if (typeid(remoteLoggerRef) != typeid(RemoteLogger)) { + // We could let the user do what he wants, but wrapping PowerDNS Protobuf inside a FrameStream tagged as dnstap is logically wrong. + throw std::runtime_error("RemoteLogResponseAction only takes RemoteLogger. For other types, please look at DnstapLogResponseAction."); } + } - std::string serverID; - std::string ipEncryptKey; - std::string tags; - getOptionalValue<std::string>(vars, "serverID", serverID); - getOptionalValue<std::string>(vars, "ipEncryptKey", ipEncryptKey); - getOptionalValue<std::string>(vars, "exportTags", tags); - - std::vector<std::pair<std::string, ProtoBufMetaKey>> metaOptions; - if (metas) { - for (const auto& [key, value] : *metas) { - metaOptions.push_back({key, ProtoBufMetaKey(value)}); - } + std::string tags; + RemoteLogActionConfiguration config; + config.logger = std::move(logger); + config.alterResponseFunc = std::move(alterFunc); + config.includeCNAME = includeCNAME ? *includeCNAME : false; + getOptionalValue<std::string>(vars, "serverID", config.serverID); + getOptionalValue<std::string>(vars, "ipEncryptKey", config.ipEncryptKey); + getOptionalValue<std::string>(vars, "exportTags", tags); + getOptionalValue<std::string>(vars, "exportExtendedErrorsToMeta", config.exportExtendedErrorsToMeta); + + if (metas) { + for (const auto& [key, value] : *metas) { + config.metas.emplace_back(key, ProtoBufMetaKey(value)); } + } - std::optional<std::unordered_set<std::string>> tagsToExport{std::nullopt}; - if (!tags.empty()) { - tagsToExport = std::unordered_set<std::string>(); - if (tags != "*") { - std::vector<std::string> tokens; - stringtok(tokens, tags, ","); - for (auto& token : tokens) { - tagsToExport->insert(std::move(token)); - } + if (!tags.empty()) { + config.tagsToExport = std::unordered_set<std::string>(); + if (tags != "*") { + std::vector<std::string> tokens; + stringtok(tokens, tags, ","); + for (auto& token : tokens) { + config.tagsToExport->insert(std::move(token)); } } + } - checkAllParametersConsumed("RemoteLogResponseAction", vars); + checkAllParametersConsumed("RemoteLogResponseAction", vars); - return std::shared_ptr<DNSResponseAction>(new RemoteLogResponseAction(logger, alterFunc, serverID, ipEncryptKey, includeCNAME ? *includeCNAME : false, std::move(metaOptions), std::move(tagsToExport))); - }); + return std::shared_ptr<DNSResponseAction>(new RemoteLogResponseAction(config)); + }); - luaCtx.writeFunction("DnstapLogAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)> > alterFunc) { - return std::shared_ptr<DNSAction>(new DnstapLogAction(identity, logger, alterFunc)); - }); + luaCtx.writeFunction("DnstapLogAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)>> alterFunc) { + return std::shared_ptr<DNSAction>(new DnstapLogAction(identity, logger, std::move(alterFunc))); + }); - luaCtx.writeFunction("DnstapLogResponseAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSResponse*, DnstapMessage*)> > alterFunc) { - return std::shared_ptr<DNSResponseAction>(new DnstapLogResponseAction(identity, logger, alterFunc)); - }); + luaCtx.writeFunction("DnstapLogResponseAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSResponse*, DnstapMessage*)>> alterFunc) { + return std::shared_ptr<DNSResponseAction>(new DnstapLogResponseAction(identity, logger, std::move(alterFunc))); + }); #endif /* DISABLE_PROTOBUF */ - luaCtx.writeFunction("TeeAction", [](const std::string& remote, boost::optional<bool> addECS, boost::optional<std::string> local) { - boost::optional<ComboAddress> localAddr{boost::none}; - if (local) { - localAddr = ComboAddress(*local, 0); - } + luaCtx.writeFunction("TeeAction", [](const std::string& remote, boost::optional<bool> addECS, boost::optional<std::string> local, boost::optional<bool> addProxyProtocol) { + boost::optional<ComboAddress> localAddr{boost::none}; + if (local) { + localAddr = ComboAddress(*local, 0); + } - return std::shared_ptr<DNSAction>(new TeeAction(ComboAddress(remote, 53), localAddr, addECS ? *addECS : false)); - }); + return std::shared_ptr<DNSAction>(new TeeAction(ComboAddress(remote, 53), localAddr, addECS ? *addECS : false, addProxyProtocol ? *addProxyProtocol : false)); + }); luaCtx.writeFunction("SetECSPrefixLengthAction", [](uint16_t v4PrefixLength, uint16_t v6PrefixLength) { - return std::shared_ptr<DNSAction>(new SetECSPrefixLengthAction(v4PrefixLength, v6PrefixLength)); - }); + return std::shared_ptr<DNSAction>(new SetECSPrefixLengthAction(v4PrefixLength, v6PrefixLength)); + }); luaCtx.writeFunction("SetECSOverrideAction", [](bool ecsOverride) { - return std::shared_ptr<DNSAction>(new SetECSOverrideAction(ecsOverride)); - }); + return std::shared_ptr<DNSAction>(new SetECSOverrideAction(ecsOverride)); + }); luaCtx.writeFunction("SetDisableECSAction", []() { - return std::shared_ptr<DNSAction>(new SetDisableECSAction()); - }); + return std::shared_ptr<DNSAction>(new SetDisableECSAction()); + }); - luaCtx.writeFunction("SetECSAction", [](const std::string& v4, boost::optional<std::string> v6) { - if (v6) { - return std::shared_ptr<DNSAction>(new SetECSAction(Netmask(v4), Netmask(*v6))); - } - return std::shared_ptr<DNSAction>(new SetECSAction(Netmask(v4))); - }); + luaCtx.writeFunction("SetECSAction", [](const std::string& v4Netmask, boost::optional<std::string> v6Netmask) { + if (v6Netmask) { + return std::shared_ptr<DNSAction>(new SetECSAction(Netmask(v4Netmask), Netmask(*v6Netmask))); + } + return std::shared_ptr<DNSAction>(new SetECSAction(Netmask(v4Netmask))); + }); #ifdef HAVE_NET_SNMP luaCtx.writeFunction("SNMPTrapAction", [](boost::optional<std::string> reason) { - return std::shared_ptr<DNSAction>(new SNMPTrapAction(reason ? *reason : "")); - }); + return std::shared_ptr<DNSAction>(new SNMPTrapAction(reason ? *reason : "")); + }); luaCtx.writeFunction("SNMPTrapResponseAction", [](boost::optional<std::string> reason) { - return std::shared_ptr<DNSResponseAction>(new SNMPTrapResponseAction(reason ? *reason : "")); - }); + return std::shared_ptr<DNSResponseAction>(new SNMPTrapResponseAction(reason ? *reason : "")); + }); #endif /* HAVE_NET_SNMP */ luaCtx.writeFunction("SetTagAction", [](const std::string& tag, const std::string& value) { - return std::shared_ptr<DNSAction>(new SetTagAction(tag, value)); - }); + return std::shared_ptr<DNSAction>(new SetTagAction(tag, value)); + }); luaCtx.writeFunction("SetTagResponseAction", [](const std::string& tag, const std::string& value) { - return std::shared_ptr<DNSResponseAction>(new SetTagResponseAction(tag, value)); - }); + return std::shared_ptr<DNSResponseAction>(new SetTagResponseAction(tag, value)); + }); luaCtx.writeFunction("ContinueAction", [](std::shared_ptr<DNSAction> action) { - return std::shared_ptr<DNSAction>(new ContinueAction(action)); - }); + return std::shared_ptr<DNSAction>(new ContinueAction(action)); + }); #ifdef HAVE_DNS_OVER_HTTPS luaCtx.writeFunction("HTTPStatusAction", [](uint16_t status, std::string body, boost::optional<std::string> contentType, boost::optional<responseParams_t> vars) { - auto ret = std::shared_ptr<DNSAction>(new HTTPStatusAction(status, PacketBuffer(body.begin(), body.end()), contentType ? *contentType : "")); - auto hsa = std::dynamic_pointer_cast<HTTPStatusAction>(ret); - parseResponseConfig(vars, hsa->d_responseConfig); - checkAllParametersConsumed("HTTPStatusAction", vars); - return ret; - }); + auto ret = std::shared_ptr<DNSAction>(new HTTPStatusAction(status, PacketBuffer(body.begin(), body.end()), contentType ? *contentType : "")); + auto hsa = std::dynamic_pointer_cast<HTTPStatusAction>(ret); + parseResponseConfig(vars, hsa->getResponseConfig()); + checkAllParametersConsumed("HTTPStatusAction", vars); + return ret; + }); #endif /* HAVE_DNS_OVER_HTTPS */ #if defined(HAVE_LMDB) || defined(HAVE_CDB) luaCtx.writeFunction("KeyValueStoreLookupAction", [](std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, const std::string& destinationTag) { - return std::shared_ptr<DNSAction>(new KeyValueStoreLookupAction(kvs, lookupKey, destinationTag)); - }); + return std::shared_ptr<DNSAction>(new KeyValueStoreLookupAction(kvs, lookupKey, destinationTag)); + }); luaCtx.writeFunction("KeyValueStoreRangeLookupAction", [](std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, const std::string& destinationTag) { - return std::shared_ptr<DNSAction>(new KeyValueStoreRangeLookupAction(kvs, lookupKey, destinationTag)); - }); + return std::shared_ptr<DNSAction>(new KeyValueStoreRangeLookupAction(kvs, lookupKey, destinationTag)); + }); #endif /* defined(HAVE_LMDB) || defined(HAVE_CDB) */ luaCtx.writeFunction("NegativeAndSOAAction", [](bool nxd, const std::string& zone, uint32_t ttl, const std::string& mname, const std::string& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum, boost::optional<responseParams_t> vars) { - bool soaInAuthoritySection = false; - getOptionalValue<bool>(vars, "soaInAuthoritySection", soaInAuthoritySection); - auto ret = std::shared_ptr<DNSAction>(new NegativeAndSOAAction(nxd, DNSName(zone), ttl, DNSName(mname), DNSName(rname), serial, refresh, retry, expire, minimum, soaInAuthoritySection)); - auto action = std::dynamic_pointer_cast<NegativeAndSOAAction>(ret); - parseResponseConfig(vars, action->d_responseConfig); - checkAllParametersConsumed("NegativeAndSOAAction", vars); - return ret; + bool soaInAuthoritySection = false; + getOptionalValue<bool>(vars, "soaInAuthoritySection", soaInAuthoritySection); + NegativeAndSOAAction::SOAParams params{ + .serial = serial, + .refresh = refresh, + .retry = retry, + .expire = expire, + .minimum = minimum}; + auto ret = std::shared_ptr<DNSAction>(new NegativeAndSOAAction(nxd, DNSName(zone), ttl, DNSName(mname), DNSName(rname), params, soaInAuthoritySection)); + auto action = std::dynamic_pointer_cast<NegativeAndSOAAction>(ret); + parseResponseConfig(vars, action->getResponseConfig()); + checkAllParametersConsumed("NegativeAndSOAAction", vars); + return ret; }); luaCtx.writeFunction("SetProxyProtocolValuesAction", [](const std::vector<std::pair<uint8_t, std::string>>& values) { - return std::shared_ptr<DNSAction>(new SetProxyProtocolValuesAction(values)); - }); + return std::shared_ptr<DNSAction>(new SetProxyProtocolValuesAction(values)); + }); luaCtx.writeFunction("SetAdditionalProxyProtocolValueAction", [](uint8_t type, const std::string& value) { return std::shared_ptr<DNSAction>(new SetAdditionalProxyProtocolValueAction(type, value)); }); + + luaCtx.writeFunction("SetExtendedDNSErrorAction", [](uint16_t infoCode, boost::optional<std::string> extraText) { + return std::shared_ptr<DNSAction>(new SetExtendedDNSErrorAction(infoCode, extraText ? *extraText : "")); + }); + + luaCtx.writeFunction("SetExtendedDNSErrorResponseAction", [](uint16_t infoCode, boost::optional<std::string> extraText) { + return std::shared_ptr<DNSResponseAction>(new SetExtendedDNSErrorResponseAction(infoCode, extraText ? *extraText : "")); + }); } |