summaryrefslogtreecommitdiffstats
path: root/dnsdist-ecs.cc
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--dnsdist-ecs.cc114
1 files changed, 69 insertions, 45 deletions
diff --git a/dnsdist-ecs.cc b/dnsdist-ecs.cc
index 9e9d9c3..2cad194 100644
--- a/dnsdist-ecs.cc
+++ b/dnsdist-ecs.cc
@@ -21,6 +21,7 @@
*/
#include "dolog.hh"
#include "dnsdist.hh"
+#include "dnsdist-dnsparser.hh"
#include "dnsdist-ecs.hh"
#include "dnsparser.hh"
#include "dnswriter.hh"
@@ -44,13 +45,15 @@ bool g_addEDNSToSelfGeneratedResponses{true};
int rewriteResponseWithoutEDNS(const PacketBuffer& initialPacket, PacketBuffer& newContent)
{
assert(initialPacket.size() >= sizeof(dnsheader));
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+ const dnsheader_aligned dh(initialPacket.data());
- if (ntohs(dh->arcount) == 0)
+ if (ntohs(dh->arcount) == 0) {
return ENOENT;
+ }
- if (ntohs(dh->qdcount) == 0)
+ if (ntohs(dh->qdcount) == 0) {
return ENOENT;
+ }
PacketReader pr(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
@@ -152,7 +155,7 @@ static bool addOrReplaceEDNSOption(std::vector<std::pair<uint16_t, std::string>>
bool slowRewriteEDNSOptionInQueryWithRecords(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, uint16_t optionToReplace, bool& optionAdded, bool overrideExisting, const string& newOptionContent)
{
assert(initialPacket.size() >= sizeof(dnsheader));
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+ const dnsheader_aligned dh(initialPacket.data());
if (ntohs(dh->qdcount) == 0) {
return false;
@@ -269,7 +272,7 @@ static bool slowParseEDNSOptions(const PacketBuffer& packet, EDNSOptionViewMap&
return false;
}
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
+ const dnsheader_aligned dh(packet.data());
if (ntohs(dh->qdcount) == 0) {
return false;
@@ -324,10 +327,11 @@ int locateEDNSOptRR(const PacketBuffer& packet, uint16_t * optStart, size_t * op
assert(optStart != NULL);
assert(optLen != NULL);
assert(last != NULL);
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
+ const dnsheader_aligned dh(packet.data());
- if (ntohs(dh->arcount) == 0)
+ if (ntohs(dh->arcount) == 0) {
return ENOENT;
+ }
PacketReader pr(std::string_view(reinterpret_cast<const char*>(packet.data()), packet.size()));
@@ -390,14 +394,15 @@ int getEDNSOptionsStart(const PacketBuffer& packet, const size_t offset, uint16_
{
assert(optRDPosition != nullptr);
assert(remaining != nullptr);
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
+ const dnsheader_aligned dh(packet.data());
if (offset >= packet.size()) {
return ENOENT;
}
- if (ntohs(dh->qdcount) != 1 || ntohs(dh->ancount) != 0 || ntohs(dh->arcount) != 1 || ntohs(dh->nscount) != 0)
+ if (ntohs(dh->qdcount) != 1 || ntohs(dh->ancount) != 0 || ntohs(dh->arcount) != 1 || ntohs(dh->nscount) != 0) {
return ENOENT;
+ }
size_t pos = sizeof(dnsheader) + offset;
pos += DNS_TYPE_SIZE + DNS_CLASS_SIZE;
@@ -571,10 +576,12 @@ static bool addEDNSWithECS(PacketBuffer& packet, size_t maximumSize, const strin
return false;
}
- struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet.data());
- uint16_t arcount = ntohs(dh->arcount);
- arcount++;
- dh->arcount = htons(arcount);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
+ uint16_t arcount = ntohs(header.arcount);
+ arcount++;
+ header.arcount = htons(arcount);
+ return true;
+ });
ednsAdded = true;
ecsAdded = true;
@@ -585,7 +592,7 @@ bool handleEDNSClientSubnet(PacketBuffer& packet, const size_t maximumSize, cons
{
assert(qnameWireLength <= packet.size());
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
+ const dnsheader_aligned dh(packet.data());
if (ntohs(dh->ancount) != 0 || ntohs(dh->nscount) != 0 || (ntohs(dh->arcount) != 0 && ntohs(dh->arcount) != 1)) {
PacketBuffer newContent;
@@ -752,7 +759,7 @@ bool isEDNSOptionInOpt(const PacketBuffer& packet, const size_t optStart, const
int rewriteResponseWithoutEDNSOption(const PacketBuffer& initialPacket, const uint16_t optionCodeToSkip, PacketBuffer& newContent)
{
assert(initialPacket.size() >= sizeof(dnsheader));
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+ const dnsheader_aligned dh(initialPacket.data());
if (ntohs(dh->arcount) == 0)
return ENOENT;
@@ -852,8 +859,10 @@ bool addEDNS(PacketBuffer& packet, size_t maximumSize, bool dnssecOK, uint16_t p
return false;
}
- auto dh = reinterpret_cast<dnsheader*>(packet.data());
- dh->arcount = htons(ntohs(dh->arcount) + 1);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
+ header.arcount = htons(ntohs(header.arcount) + 1);
+ return true;
+ });
return true;
}
@@ -894,17 +903,19 @@ bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone,
/* chop off everything after the question */
packet.resize(queryPartSize);
- dh = dq.getHeader();
- if (nxd) {
- dh->rcode = RCode::NXDomain;
- }
- else {
- dh->rcode = RCode::NoError;
- }
- dh->qr = true;
- dh->ancount = 0;
- dh->nscount = 0;
- dh->arcount = 0;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [nxd](dnsheader& header) {
+ if (nxd) {
+ header.rcode = RCode::NXDomain;
+ }
+ else {
+ header.rcode = RCode::NoError;
+ }
+ header.qr = true;
+ header.ancount = 0;
+ header.nscount = 0;
+ header.arcount = 0;
+ return true;
+ });
rdLength = htons(rdLength);
ttl = htonl(ttl);
@@ -934,16 +945,18 @@ bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone,
}
packet.insert(packet.end(), soa.begin(), soa.end());
- dh = dq.getHeader();
/* We are populating a response with only the query in place, order of sections is QD,AN,NS,AR
NS (authority) is before AR (additional) so we can just decide which section the SOA record is in here
and have EDNS added to AR afterwards */
- if (soaInAuthoritySection) {
- dh->nscount = htons(1);
- } else {
- dh->arcount = htons(1);
- }
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [soaInAuthoritySection](dnsheader& header) {
+ if (soaInAuthoritySection) {
+ header.nscount = htons(1);
+ } else {
+ header.arcount = htons(1);
+ }
+ return true;
+ });
if (hadEDNS) {
/* now we need to add a new OPT record */
@@ -982,7 +995,10 @@ bool addEDNSToQueryTurnedResponse(DNSQuestion& dq)
/* remove the existing OPT record, and everything else that follows (any SIG or TSIG would be useless anyway) */
packet.resize(packet.size() - existingOptLen);
- dq.getHeader()->arcount = 0;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
+ header.arcount = 0;
+ return true;
+ });
if (g_addEDNSToSelfGeneratedResponses) {
/* now we need to add a new OPT record */
@@ -1107,7 +1123,10 @@ bool setEDNSOption(DNSQuestion& dq, uint16_t ednsCode, const std::string& ednsDa
auto& data = dq.getMutableData();
if (generateOptRR(optRData, data, dq.getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) {
- dq.getHeader()->arcount = htons(1);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+ header.arcount = htons(1);
+ return true;
+ });
// make sure that any EDNS sent by the backend is removed before forwarding the response to the client
dq.ids.ednsAdded = true;
}
@@ -1129,17 +1148,22 @@ bool setInternalQueryRCode(InternalQueryState& state, PacketBuffer& buffer, uin
hadEDNS = getEDNS0Record(buffer, edns0);
}
- auto dh = reinterpret_cast<dnsheader*>(buffer.data());
- dh->rcode = rcode;
- dh->ad = false;
- dh->aa = false;
- dh->ra = dh->rd;
- dh->qr = true;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer, [rcode,clearAnswers](dnsheader& header) {
+ header.rcode = rcode;
+ header.ad = false;
+ header.aa = false;
+ header.ra = header.rd;
+ header.qr = true;
+
+ if (clearAnswers) {
+ header.ancount = 0;
+ header.nscount = 0;
+ header.arcount = 0;
+ }
+ return true;
+ });
if (clearAnswers) {
- dh->ancount = 0;
- dh->nscount = 0;
- dh->arcount = 0;
buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t));
if (hadEDNS) {
DNSQuestion dq(state, buffer);