diff options
Diffstat (limited to 'dnsparser.cc')
-rw-r--r-- | dnsparser.cc | 1235 |
1 files changed, 1235 insertions, 0 deletions
diff --git a/dnsparser.cc b/dnsparser.cc new file mode 100644 index 0000000..b7c6a9b --- /dev/null +++ b/dnsparser.cc @@ -0,0 +1,1235 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#include "dnsparser.hh" +#include "dnswriter.hh" +#include <boost/algorithm/string.hpp> +#include <boost/format.hpp> + +#include "namespaces.hh" +#include "noinitvector.hh" + +UnknownRecordContent::UnknownRecordContent(const string& zone) +{ + // parse the input + vector<string> parts; + stringtok(parts, zone); + // we need exactly 3 parts, except if the length field is set to 0 then we only need 2 + if (parts.size() != 3 && !(parts.size() == 2 && boost::equals(parts.at(1), "0"))) { + throw MOADNSException("Unknown record was stored incorrectly, need 3 fields, got " + std::to_string(parts.size()) + ": " + zone); + } + + if (parts.at(0) != "\\#") { + throw MOADNSException("Unknown record was stored incorrectly, first part should be '\\#', got '" + parts.at(0) + "'"); + } + + const string& relevant = (parts.size() > 2) ? parts.at(2) : ""; + auto total = pdns::checked_stoi<unsigned int>(parts.at(1)); + if (relevant.size() % 2 || (relevant.size() / 2) != total) { + throw MOADNSException((boost::format("invalid unknown record length: size not equal to length field (%d != 2 * %d)") % relevant.size() % total).str()); + } + + string out; + out.reserve(total + 1); + + for (unsigned int n = 0; n < total; ++n) { + int c; + if (sscanf(&relevant.at(2*n), "%02x", &c) != 1) { + throw MOADNSException("unable to read data at position " + std::to_string(2 * n) + " from unknown record of size " + std::to_string(relevant.size())); + } + out.append(1, (char)c); + } + + d_record.insert(d_record.end(), out.begin(), out.end()); +} + +string UnknownRecordContent::getZoneRepresentation(bool /* noDot */) const +{ + ostringstream str; + str<<"\\# "<<(unsigned int)d_record.size()<<" "; + char hex[4]; + for (unsigned char n : d_record) { + snprintf(hex, sizeof(hex), "%02x", n); + str << hex; + } + return str.str(); +} + +void UnknownRecordContent::toPacket(DNSPacketWriter& pw) const +{ + pw.xfrBlob(string(d_record.begin(),d_record.end())); +} + +shared_ptr<DNSRecordContent> DNSRecordContent::deserialize(const DNSName& qname, uint16_t qtype, const string& serialized) +{ + dnsheader dnsheader; + memset(&dnsheader, 0, sizeof(dnsheader)); + dnsheader.qdcount=htons(1); + dnsheader.ancount=htons(1); + + PacketBuffer packet; // build pseudo packet + /* will look like: dnsheader, 5 bytes, encoded qname, dns record header, serialized data */ + const auto& encoded = qname.getStorage(); + packet.resize(sizeof(dnsheader) + 5 + encoded.size() + sizeof(struct dnsrecordheader) + serialized.size()); + + uint16_t pos=0; + memcpy(&packet[0], &dnsheader, sizeof(dnsheader)); pos+=sizeof(dnsheader); + + constexpr std::array<uint8_t, 5> tmp= {'\x0', '\x0', '\x1', '\x0', '\x1' }; // root question for ns_t_a + memcpy(&packet[pos], tmp.data(), tmp.size()); pos += tmp.size(); + + memcpy(&packet[pos], encoded.c_str(), encoded.size()); pos+=(uint16_t)encoded.size(); + + struct dnsrecordheader drh; + drh.d_type=htons(qtype); + drh.d_class=htons(QClass::IN); + drh.d_ttl=0; + drh.d_clen=htons(serialized.size()); + + memcpy(&packet[pos], &drh, sizeof(drh)); pos+=sizeof(drh); + if (!serialized.empty()) { + memcpy(&packet[pos], serialized.c_str(), serialized.size()); + pos += (uint16_t) serialized.size(); + (void) pos; + } + + DNSRecord dr; + dr.d_class = QClass::IN; + dr.d_type = qtype; + dr.d_name = qname; + dr.d_clen = serialized.size(); + PacketReader pr(std::string_view(reinterpret_cast<const char*>(packet.data()), packet.size()), packet.size() - serialized.size() - sizeof(dnsrecordheader)); + /* needed to get the record boundaries right */ + pr.getDnsrecordheader(drh); + auto content = DNSRecordContent::mastermake(dr, pr, Opcode::Query); + return content; +} + +std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(const DNSRecord &dr, + PacketReader& pr) +{ + uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT + + auto i = getTypemap().find(pair(searchclass, dr.d_type)); + if(i==getTypemap().end() || !i->second) { + return std::make_shared<UnknownRecordContent>(dr, pr); + } + + return i->second(dr, pr); +} + +std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(uint16_t qtype, uint16_t qclass, + const string& content) +{ + auto i = getZmakermap().find(pair(qclass, qtype)); + if(i==getZmakermap().end()) { + return std::make_shared<UnknownRecordContent>(content); + } + + return i->second(content); +} + +std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(const DNSRecord &dr, PacketReader& pr, uint16_t oc) { + // For opcode UPDATE and where the DNSRecord is an answer record, we don't care about content, because this is + // not used within the prerequisite section of RFC2136, so - we can simply use unknownrecordcontent. + // For section 3.2.3, we do need content so we need to get it properly. But only for the correct QClasses. + if (oc == Opcode::Update && dr.d_place == DNSResourceRecord::ANSWER && dr.d_class != 1) + return std::make_shared<UnknownRecordContent>(dr, pr); + + uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT + + auto i = getTypemap().find(pair(searchclass, dr.d_type)); + if(i==getTypemap().end() || !i->second) { + return std::make_shared<UnknownRecordContent>(dr, pr); + } + + return i->second(dr, pr); +} + +string DNSRecordContent::upgradeContent(const DNSName& qname, const QType& qtype, const string& content) { + // seamless upgrade for previously unsupported but now implemented types. + UnknownRecordContent unknown_content(content); + shared_ptr<DNSRecordContent> rc = DNSRecordContent::deserialize(qname, qtype.getCode(), unknown_content.serialize(qname)); + return rc->getZoneRepresentation(); +} + +DNSRecordContent::typemap_t& DNSRecordContent::getTypemap() +{ + static DNSRecordContent::typemap_t typemap; + return typemap; +} + +DNSRecordContent::n2typemap_t& DNSRecordContent::getN2Typemap() +{ + static DNSRecordContent::n2typemap_t n2typemap; + return n2typemap; +} + +DNSRecordContent::t2namemap_t& DNSRecordContent::getT2Namemap() +{ + static DNSRecordContent::t2namemap_t t2namemap; + return t2namemap; +} + +DNSRecordContent::zmakermap_t& DNSRecordContent::getZmakermap() +{ + static DNSRecordContent::zmakermap_t zmakermap; + return zmakermap; +} + +bool DNSRecordContent::isRegisteredType(uint16_t rtype, uint16_t rclass) +{ + return getTypemap().count(pair(rclass, rtype)) != 0; +} + +DNSRecord::DNSRecord(const DNSResourceRecord& rr): d_name(rr.qname) +{ + d_type = rr.qtype.getCode(); + d_ttl = rr.ttl; + d_class = rr.qclass; + d_place = DNSResourceRecord::ANSWER; + d_clen = 0; + d_content = DNSRecordContent::mastermake(d_type, rr.qclass, rr.content); +} + +// If you call this and you are not parsing a packet coming from a socket, you are doing it wrong. +DNSResourceRecord DNSResourceRecord::fromWire(const DNSRecord& d) { + DNSResourceRecord rr; + rr.qname = d.d_name; + rr.qtype = QType(d.d_type); + rr.ttl = d.d_ttl; + rr.content = d.getContent()->getZoneRepresentation(true); + rr.auth = false; + rr.qclass = d.d_class; + return rr; +} + +void MOADNSParser::init(bool query, const std::string_view& packet) +{ + if (packet.size() < sizeof(dnsheader)) + throw MOADNSException("Packet shorter than minimal header"); + + memcpy(&d_header, packet.data(), sizeof(dnsheader)); + + if(d_header.opcode != Opcode::Query && d_header.opcode != Opcode::Notify && d_header.opcode != Opcode::Update) + throw MOADNSException("Can't parse non-query packet with opcode="+ std::to_string(d_header.opcode)); + + d_header.qdcount=ntohs(d_header.qdcount); + d_header.ancount=ntohs(d_header.ancount); + d_header.nscount=ntohs(d_header.nscount); + d_header.arcount=ntohs(d_header.arcount); + + if (query && (d_header.qdcount > 1)) + throw MOADNSException("Query with QD > 1 ("+std::to_string(d_header.qdcount)+")"); + + unsigned int n=0; + + PacketReader pr(packet); + bool validPacket=false; + try { + d_qtype = d_qclass = 0; // sometimes replies come in with no question, don't present garbage then + + for(n=0;n < d_header.qdcount; ++n) { + d_qname=pr.getName(); + d_qtype=pr.get16BitInt(); + d_qclass=pr.get16BitInt(); + } + + struct dnsrecordheader ah; + vector<unsigned char> record; + bool seenTSIG = false; + validPacket=true; + d_answers.reserve((unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount)); + for(n=0;n < (unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount); ++n) { + DNSRecord dr; + + if(n < d_header.ancount) + dr.d_place=DNSResourceRecord::ANSWER; + else if(n < d_header.ancount + d_header.nscount) + dr.d_place=DNSResourceRecord::AUTHORITY; + else + dr.d_place=DNSResourceRecord::ADDITIONAL; + + unsigned int recordStartPos=pr.getPosition(); + + DNSName name=pr.getName(); + + pr.getDnsrecordheader(ah); + dr.d_ttl=ah.d_ttl; + dr.d_type=ah.d_type; + dr.d_class=ah.d_class; + + dr.d_name=name; + dr.d_clen=ah.d_clen; + + if (query && + !(d_qtype == QType::IXFR && dr.d_place == DNSResourceRecord::AUTHORITY && dr.d_type == QType::SOA) && // IXFR queries have a SOA in their AUTHORITY section + (dr.d_place == DNSResourceRecord::ANSWER || dr.d_place == DNSResourceRecord::AUTHORITY || (dr.d_type != QType::OPT && dr.d_type != QType::TSIG && dr.d_type != QType::SIG && dr.d_type != QType::TKEY) || ((dr.d_type == QType::TSIG || dr.d_type == QType::SIG || dr.d_type == QType::TKEY) && dr.d_class != QClass::ANY))) { +// cerr<<"discarding RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl; + dr.setContent(std::make_shared<UnknownRecordContent>(dr, pr)); + } + else { +// cerr<<"parsing RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl; + dr.setContent(DNSRecordContent::mastermake(dr, pr, d_header.opcode)); + } + + /* XXX: XPF records should be allowed after TSIG as soon as the actual XPF option code has been assigned: + if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG && dr.d_type != QType::XPF) + */ + if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG) { + /* only XPF records are allowed after a TSIG */ + throw MOADNSException("Packet ("+d_qname.toString()+"|#"+std::to_string(d_qtype)+") has an unexpected record ("+std::to_string(dr.d_type)+") after a TSIG one."); + } + + if(dr.d_type == QType::TSIG && dr.d_class == QClass::ANY) { + if(seenTSIG || dr.d_place != DNSResourceRecord::ADDITIONAL) { + throw MOADNSException("Packet ("+d_qname.toLogString()+"|#"+std::to_string(d_qtype)+") has a TSIG record in an invalid position."); + } + seenTSIG = true; + d_tsigPos = recordStartPos; + } + + d_answers.emplace_back(std::move(dr), pr.getPosition() - sizeof(dnsheader)); + } + +#if 0 + if(pr.getPosition()!=packet.size()) { + throw MOADNSException("Packet ("+d_qname+"|#"+std::to_string(d_qtype)+") has trailing garbage ("+ std::to_string(pr.getPosition()) + " < " + + std::to_string(packet.size()) + ")"); + } +#endif + } + catch(const std::out_of_range &re) { + if(validPacket && d_header.tc) { // don't sweat it over truncated packets, but do adjust an, ns and arcount + if(n < d_header.ancount) { + d_header.ancount=n; d_header.nscount = d_header.arcount = 0; + } + else if(n < d_header.ancount + d_header.nscount) { + d_header.nscount = n - d_header.ancount; d_header.arcount=0; + } + else { + d_header.arcount = n - d_header.ancount - d_header.nscount; + } + } + else { + throw MOADNSException("Error parsing packet of "+std::to_string(packet.size())+" bytes (rd="+ + std::to_string(d_header.rd)+ + "), out of bounds: "+string(re.what())); + } + } +} + +bool MOADNSParser::hasEDNS() const +{ + if (d_header.arcount == 0 || d_answers.empty()) { + return false; + } + + for (const auto& record : d_answers) { + if (record.first.d_place == DNSResourceRecord::ADDITIONAL && record.first.d_type == QType::OPT) { + return true; + } + } + + return false; +} + +void PacketReader::getDnsrecordheader(struct dnsrecordheader &ah) +{ + unsigned char *p = reinterpret_cast<unsigned char*>(&ah); + + for(unsigned int n = 0; n < sizeof(dnsrecordheader); ++n) { + p[n] = d_content.at(d_pos++); + } + + ah.d_type = ntohs(ah.d_type); + ah.d_class = ntohs(ah.d_class); + ah.d_clen = ntohs(ah.d_clen); + ah.d_ttl = ntohl(ah.d_ttl); + + d_startrecordpos = d_pos; // needed for getBlob later on + d_recordlen = ah.d_clen; +} + + +void PacketReader::copyRecord(vector<unsigned char>& dest, uint16_t len) +{ + if (len == 0) { + return; + } + if ((d_pos + len) > d_content.size()) { + throw std::out_of_range("Attempt to copy outside of packet"); + } + + dest.resize(len); + + for (uint16_t n = 0; n < len; ++n) { + dest.at(n) = d_content.at(d_pos++); + } +} + +void PacketReader::copyRecord(unsigned char* dest, uint16_t len) +{ + if (d_pos + len > d_content.size()) { + throw std::out_of_range("Attempt to copy outside of packet"); + } + + memcpy(dest, &d_content.at(d_pos), len); + d_pos += len; +} + +void PacketReader::xfrNodeOrLocatorID(NodeOrLocatorID& ret) +{ + if (d_pos + sizeof(ret) > d_content.size()) { + throw std::out_of_range("Attempt to read 64 bit value outside of packet"); + } + memcpy(&ret.content, &d_content.at(d_pos), sizeof(ret.content)); + d_pos += sizeof(ret); +} + +void PacketReader::xfr48BitInt(uint64_t& ret) +{ + ret=0; + ret+=static_cast<uint8_t>(d_content.at(d_pos++)); + ret<<=8; + ret+=static_cast<uint8_t>(d_content.at(d_pos++)); + ret<<=8; + ret+=static_cast<uint8_t>(d_content.at(d_pos++)); + ret<<=8; + ret+=static_cast<uint8_t>(d_content.at(d_pos++)); + ret<<=8; + ret+=static_cast<uint8_t>(d_content.at(d_pos++)); + ret<<=8; + ret+=static_cast<uint8_t>(d_content.at(d_pos++)); +} + +uint32_t PacketReader::get32BitInt() +{ + uint32_t ret=0; + ret+=static_cast<uint8_t>(d_content.at(d_pos++)); + ret<<=8; + ret+=static_cast<uint8_t>(d_content.at(d_pos++)); + ret<<=8; + ret+=static_cast<uint8_t>(d_content.at(d_pos++)); + ret<<=8; + ret+=static_cast<uint8_t>(d_content.at(d_pos++)); + + return ret; +} + + +uint16_t PacketReader::get16BitInt() +{ + uint16_t ret=0; + ret+=static_cast<uint8_t>(d_content.at(d_pos++)); + ret<<=8; + ret+=static_cast<uint8_t>(d_content.at(d_pos++)); + + return ret; +} + +uint8_t PacketReader::get8BitInt() +{ + return d_content.at(d_pos++); +} + +DNSName PacketReader::getName() +{ + unsigned int consumed; + try { + DNSName dn((const char*) d_content.data(), d_content.size(), d_pos, true /* uncompress */, nullptr /* qtype */, nullptr /* qclass */, &consumed, sizeof(dnsheader)); + + d_pos+=consumed; + return dn; + } + catch(const std::range_error& re) { + throw std::out_of_range(string("dnsname issue: ")+re.what()); + } + catch(...) { + throw std::out_of_range("dnsname issue"); + } + throw PDNSException("PacketReader::getName(): name is empty"); +} + +static string txtEscape(const string &name) +{ + string ret; + char ebuf[5]; + + for(char i : name) { + if((unsigned char) i >= 127 || (unsigned char) i < 32) { + snprintf(ebuf, sizeof(ebuf), "\\%03u", (unsigned char)i); + ret += ebuf; + } + else if(i=='"' || i=='\\'){ + ret += '\\'; + ret += i; + } + else + ret += i; + } + return ret; +} + +// exceptions thrown here do not result in logging in the main pdns auth server - just so you know! +string PacketReader::getText(bool multi, bool lenField) +{ + string ret; + ret.reserve(40); + while(d_pos < d_startrecordpos + d_recordlen ) { + if(!ret.empty()) { + ret.append(1,' '); + } + uint16_t labellen; + if(lenField) + labellen=static_cast<uint8_t>(d_content.at(d_pos++)); + else + labellen=d_recordlen - (d_pos - d_startrecordpos); + + ret.append(1,'"'); + if(labellen) { // no need to do anything for an empty string + string val(&d_content.at(d_pos), &d_content.at(d_pos+labellen-1)+1); + ret.append(txtEscape(val)); // the end is one beyond the packet + } + ret.append(1,'"'); + d_pos+=labellen; + if(!multi) + break; + } + + return ret; +} + +string PacketReader::getUnquotedText(bool lenField) +{ + uint16_t stop_at; + if(lenField) + stop_at = static_cast<uint8_t>(d_content.at(d_pos)) + d_pos + 1; + else + stop_at = d_recordlen; + + /* think unsigned overflow */ + if (stop_at < d_pos) { + throw std::out_of_range("getUnquotedText out of record range"); + } + + if(stop_at == d_pos) + return ""; + + d_pos++; + string ret(&d_content.at(d_pos), &d_content.at(stop_at)); + d_pos = stop_at; + return ret; +} + +void PacketReader::xfrBlob(string& blob) +{ + try { + if(d_recordlen && !(d_pos == (d_startrecordpos + d_recordlen))) { + if (d_pos > (d_startrecordpos + d_recordlen)) { + throw std::out_of_range("xfrBlob out of record range"); + } + blob.assign(&d_content.at(d_pos), &d_content.at(d_startrecordpos + d_recordlen - 1 ) + 1); + } + else { + blob.clear(); + } + + d_pos = d_startrecordpos + d_recordlen; + } + catch(...) + { + throw std::out_of_range("xfrBlob out of range"); + } +} + +void PacketReader::xfrBlobNoSpaces(string& blob, int length) { + xfrBlob(blob, length); +} + +void PacketReader::xfrBlob(string& blob, int length) +{ + if(length) { + if (length < 0) { + throw std::out_of_range("xfrBlob out of range (negative length)"); + } + + blob.assign(&d_content.at(d_pos), &d_content.at(d_pos + length - 1 ) + 1 ); + + d_pos += length; + } + else { + blob.clear(); + } +} + +void PacketReader::xfrSvcParamKeyVals(set<SvcParam> &kvs) { + while (d_pos < (d_startrecordpos + d_recordlen)) { + if (d_pos + 2 > (d_startrecordpos + d_recordlen)) { + throw std::out_of_range("incomplete key"); + } + uint16_t keyInt; + xfr16BitInt(keyInt); + auto key = static_cast<SvcParam::SvcParamKey>(keyInt); + uint16_t len; + xfr16BitInt(len); + + if (d_pos + len > (d_startrecordpos + d_recordlen)) { + throw std::out_of_range("record is shorter than SVCB lengthfield implies"); + } + + switch (key) + { + case SvcParam::mandatory: { + if (len % 2 != 0) { + throw std::out_of_range("mandatory SvcParam has invalid length"); + } + if (len == 0) { + throw std::out_of_range("empty 'mandatory' values"); + } + std::set<SvcParam::SvcParamKey> paramKeys; + size_t stop = d_pos + len; + while (d_pos < stop) { + uint16_t keyval; + xfr16BitInt(keyval); + paramKeys.insert(static_cast<SvcParam::SvcParamKey>(keyval)); + } + kvs.insert(SvcParam(key, std::move(paramKeys))); + break; + } + case SvcParam::alpn: { + size_t stop = d_pos + len; + std::vector<string> alpns; + while (d_pos < stop) { + string alpn; + uint8_t alpnLen = 0; + xfr8BitInt(alpnLen); + if (alpnLen == 0) { + throw std::out_of_range("alpn length of 0"); + } + xfrBlob(alpn, alpnLen); + alpns.push_back(alpn); + } + kvs.insert(SvcParam(key, std::move(alpns))); + break; + } + case SvcParam::no_default_alpn: { + if (len != 0) { + throw std::out_of_range("invalid length for no-default-alpn"); + } + kvs.insert(SvcParam(key)); + break; + } + case SvcParam::port: { + if (len != 2) { + throw std::out_of_range("invalid length for port"); + } + uint16_t port; + xfr16BitInt(port); + kvs.insert(SvcParam(key, port)); + break; + } + case SvcParam::ipv4hint: /* fall-through */ + case SvcParam::ipv6hint: { + size_t addrLen = (key == SvcParam::ipv4hint ? 4 : 16); + if (len % addrLen != 0) { + throw std::out_of_range("invalid length for " + SvcParam::keyToString(key)); + } + vector<ComboAddress> addresses; + auto stop = d_pos + len; + while (d_pos < stop) + { + ComboAddress addr; + xfrCAWithoutPort(key, addr); + addresses.push_back(addr); + } + kvs.insert(SvcParam(key, std::move(addresses))); + break; + } + case SvcParam::ech: { + std::string blob; + blob.reserve(len); + xfrBlobNoSpaces(blob, len); + kvs.insert(SvcParam(key, blob)); + break; + } + default: { + std::string blob; + blob.reserve(len); + xfrBlob(blob, len); + kvs.insert(SvcParam(key, blob)); + break; + } + } + } +} + + +void PacketReader::xfrHexBlob(string& blob, bool /* keepReading */) +{ + xfrBlob(blob); +} + +//FIXME400 remove this method completely +string simpleCompress(const string& elabel, const string& root) +{ + string label=elabel; + // FIXME400: this relies on the semi-canonical escaped output from getName + if(strchr(label.c_str(), '\\')) { + boost::replace_all(label, "\\.", "."); + boost::replace_all(label, "\\032", " "); + boost::replace_all(label, "\\\\", "\\"); + } + typedef vector<pair<unsigned int, unsigned int> > parts_t; + parts_t parts; + vstringtok(parts, label, "."); + string ret; + ret.reserve(label.size()+4); + for(const auto & part : parts) { + if(!root.empty() && !strncasecmp(root.c_str(), label.c_str() + part.first, 1 + label.length() - part.first)) { // also match trailing 0, hence '1 +' + const unsigned char rootptr[2]={0xc0,0x11}; + ret.append((const char *) rootptr, 2); + return ret; + } + ret.append(1, (char)(part.second - part.first)); + ret.append(label.c_str() + part.first, part.second - part.first); + } + ret.append(1, (char)0); + return ret; +} + +// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it +void editDNSPacketTTL(char* packet, size_t length, const std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)>& visitor) +{ + if(length < sizeof(dnsheader)) + return; + try + { + dnsheader dh; + memcpy((void*)&dh, (const dnsheader*)packet, sizeof(dh)); + uint64_t numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount); + DNSPacketMangler dpm(packet, length); + + uint64_t n; + for(n=0; n < ntohs(dh.qdcount) ; ++n) { + dpm.skipDomainName(); + /* type and class */ + dpm.skipBytes(4); + } + + for(n=0; n < numrecords; ++n) { + dpm.skipDomainName(); + + uint8_t section = n < ntohs(dh.ancount) ? 1 : (n < (ntohs(dh.ancount) + ntohs(dh.nscount)) ? 2 : 3); + uint16_t dnstype = dpm.get16BitInt(); + uint16_t dnsclass = dpm.get16BitInt(); + + if(dnstype == QType::OPT) // not getting near that one with a stick + break; + + uint32_t dnsttl = dpm.get32BitInt(); + uint32_t newttl = visitor(section, dnsclass, dnstype, dnsttl); + if (newttl) { + dpm.rewindBytes(sizeof(newttl)); + dpm.setAndSkip32BitInt(newttl); + } + dpm.skipRData(); + } + } + catch(...) + { + return; + } +} + +static bool checkIfPacketContainsRecords(const PacketBuffer& packet, const std::unordered_set<QType>& qtypes) +{ + auto length = packet.size(); + if (length < sizeof(dnsheader)) { + return false; + } + + try { + auto dh = reinterpret_cast<const dnsheader*>(packet.data()); + DNSPacketMangler dpm(const_cast<char*>(reinterpret_cast<const char*>(packet.data())), length); + + const uint16_t qdcount = ntohs(dh->qdcount); + for (size_t n = 0; n < qdcount; ++n) { + dpm.skipDomainName(); + /* type and class */ + dpm.skipBytes(4); + } + const size_t recordsCount = static_cast<size_t>(ntohs(dh->ancount)) + ntohs(dh->nscount) + ntohs(dh->arcount); + for (size_t n = 0; n < recordsCount; ++n) { + dpm.skipDomainName(); + uint16_t dnstype = dpm.get16BitInt(); + uint16_t dnsclass = dpm.get16BitInt(); + if (dnsclass == QClass::IN && qtypes.count(dnstype) > 0) { + return true; + } + /* ttl */ + dpm.skipBytes(4); + dpm.skipRData(); + } + } + catch (...) { + } + + return false; +} + +static int rewritePacketWithoutRecordTypes(const PacketBuffer& initialPacket, PacketBuffer& newContent, const std::unordered_set<QType>& qtypes) +{ + static const std::unordered_set<QType>& safeTypes{QType::A, QType::AAAA, QType::DHCID, QType::TXT, QType::OPT, QType::HINFO, QType::DNSKEY, QType::CDNSKEY, QType::DS, QType::CDS, QType::DLV, QType::SSHFP, QType::KEY, QType::CERT, QType::TLSA, QType::SMIMEA, QType::OPENPGPKEY, QType::SVCB, QType::HTTPS, QType::NSEC3, QType::CSYNC, QType::NSEC3PARAM, QType::LOC, QType::NID, QType::L32, QType::L64, QType::EUI48, QType::EUI64, QType::URI, QType::CAA}; + + if (initialPacket.size() < sizeof(dnsheader)) { + return EINVAL; + } + try { + const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data()); + + if (ntohs(dh->qdcount) == 0) + return ENOENT; + auto packetView = std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()); + + PacketReader pr(packetView); + + size_t idx = 0; + DNSName rrname; + uint16_t qdcount = ntohs(dh->qdcount); + uint16_t ancount = ntohs(dh->ancount); + uint16_t nscount = ntohs(dh->nscount); + uint16_t arcount = ntohs(dh->arcount); + uint16_t rrtype; + uint16_t rrclass; + string blob; + struct dnsrecordheader ah; + + rrname = pr.getName(); + rrtype = pr.get16BitInt(); + rrclass = pr.get16BitInt(); + + GenericDNSPacketWriter<PacketBuffer> pw(newContent, rrname, rrtype, rrclass, dh->opcode); + pw.getHeader()->id=dh->id; + pw.getHeader()->qr=dh->qr; + pw.getHeader()->aa=dh->aa; + pw.getHeader()->tc=dh->tc; + pw.getHeader()->rd=dh->rd; + pw.getHeader()->ra=dh->ra; + pw.getHeader()->ad=dh->ad; + pw.getHeader()->cd=dh->cd; + pw.getHeader()->rcode=dh->rcode; + + /* consume remaining qd if any */ + if (qdcount > 1) { + for(idx = 1; idx < qdcount; idx++) { + rrname = pr.getName(); + rrtype = pr.get16BitInt(); + rrclass = pr.get16BitInt(); + (void) rrtype; + (void) rrclass; + } + } + + /* copy AN */ + for (idx = 0; idx < ancount; idx++) { + rrname = pr.getName(); + pr.getDnsrecordheader(ah); + pr.xfrBlob(blob); + + if (qtypes.find(ah.d_type) == qtypes.end()) { + // if this is not a safe type + if (safeTypes.find(ah.d_type) == safeTypes.end()) { + // "unsafe" types might countain compressed data, so cancel rewrite + newContent.clear(); + return EIO; + } + pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true); + pw.xfrBlob(blob); + } + } + + /* copy NS */ + for (idx = 0; idx < nscount; idx++) { + rrname = pr.getName(); + pr.getDnsrecordheader(ah); + pr.xfrBlob(blob); + + if (qtypes.find(ah.d_type) == qtypes.end()) { + if (safeTypes.find(ah.d_type) == safeTypes.end()) { + // "unsafe" types might countain compressed data, so cancel rewrite + newContent.clear(); + return EIO; + } + pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true); + pw.xfrBlob(blob); + } + } + /* copy AR */ + for (idx = 0; idx < arcount; idx++) { + rrname = pr.getName(); + pr.getDnsrecordheader(ah); + pr.xfrBlob(blob); + + if (qtypes.find(ah.d_type) == qtypes.end()) { + if (safeTypes.find(ah.d_type) == safeTypes.end()) { + // "unsafe" types might countain compressed data, so cancel rewrite + newContent.clear(); + return EIO; + } + pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true); + pw.xfrBlob(blob); + } + } + pw.commit(); + + } + catch (...) + { + newContent.clear(); + return EIO; + } + return 0; +} + +void clearDNSPacketRecordTypes(vector<uint8_t>& packet, const std::unordered_set<QType>& qtypes) +{ + return clearDNSPacketRecordTypes(reinterpret_cast<PacketBuffer&>(packet), qtypes); +} + +void clearDNSPacketRecordTypes(PacketBuffer& packet, const std::unordered_set<QType>& qtypes) +{ + if (!checkIfPacketContainsRecords(packet, qtypes)) { + return; + } + + PacketBuffer newContent; + + auto result = rewritePacketWithoutRecordTypes(packet, newContent, qtypes); + if (!result) { + packet = std::move(newContent); + } +} + +// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it +void ageDNSPacket(char* packet, size_t length, uint32_t seconds, const dnsheader_aligned& aligned_dh) +{ + if (length < sizeof(dnsheader)) { + return; + } + try { + const dnsheader* dhp = aligned_dh.get(); + const uint64_t dqcount = ntohs(dhp->qdcount); + const uint64_t numrecords = ntohs(dhp->ancount) + ntohs(dhp->nscount) + ntohs(dhp->arcount); + DNSPacketMangler dpm(packet, length); + + for (uint64_t rec = 0; rec < dqcount; ++rec) { + dpm.skipDomainName(); + /* type and class */ + dpm.skipBytes(4); + } + + for(uint64_t rec = 0; rec < numrecords; ++rec) { + dpm.skipDomainName(); + + uint16_t dnstype = dpm.get16BitInt(); + /* class */ + dpm.skipBytes(2); + + if (dnstype != QType::OPT) { // not aging that one with a stick + dpm.decreaseAndSkip32BitInt(seconds); + } else { + dpm.skipBytes(4); + } + dpm.skipRData(); + } + } + catch(...) { + } +} + +void ageDNSPacket(std::string& packet, uint32_t seconds, const dnsheader_aligned& aligned_dh) +{ + ageDNSPacket(packet.data(), packet.length(), seconds, aligned_dh); +} + +uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA) +{ + uint32_t result = std::numeric_limits<uint32_t>::max(); + if(length < sizeof(dnsheader)) { + return result; + } + try + { + const dnsheader* dh = (const dnsheader*) packet; + DNSPacketMangler dpm(const_cast<char*>(packet), length); + + const uint16_t qdcount = ntohs(dh->qdcount); + for(size_t n = 0; n < qdcount; ++n) { + dpm.skipDomainName(); + /* type and class */ + dpm.skipBytes(4); + } + const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount); + for(size_t n = 0; n < numrecords; ++n) { + dpm.skipDomainName(); + const uint16_t dnstype = dpm.get16BitInt(); + /* class */ + const uint16_t dnsclass = dpm.get16BitInt(); + + if(dnstype == QType::OPT) { + break; + } + + /* report it if we see a SOA record in the AUTHORITY section */ + if(dnstype == QType::SOA && dnsclass == QClass::IN && seenAuthSOA != nullptr && n >= ntohs(dh->ancount) && n < (ntohs(dh->ancount) + ntohs(dh->nscount))) { + *seenAuthSOA = true; + } + + const uint32_t ttl = dpm.get32BitInt(); + if (result > ttl) { + result = ttl; + } + + dpm.skipRData(); + } + } + catch(...) + { + } + return result; +} + +uint32_t getDNSPacketLength(const char* packet, size_t length) +{ + uint32_t result = length; + if(length < sizeof(dnsheader)) { + return result; + } + try + { + const dnsheader* dh = reinterpret_cast<const dnsheader*>(packet); + DNSPacketMangler dpm(const_cast<char*>(packet), length); + + const uint16_t qdcount = ntohs(dh->qdcount); + for(size_t n = 0; n < qdcount; ++n) { + dpm.skipDomainName(); + /* type and class */ + dpm.skipBytes(4); + } + const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount); + for(size_t n = 0; n < numrecords; ++n) { + dpm.skipDomainName(); + /* type (2), class (2) and ttl (4) */ + dpm.skipBytes(8); + dpm.skipRData(); + } + result = dpm.getOffset(); + } + catch(...) + { + } + return result; +} + +uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type) +{ + uint16_t result = 0; + if(length < sizeof(dnsheader)) { + return result; + } + try + { + const dnsheader* dh = (const dnsheader*) packet; + DNSPacketMangler dpm(const_cast<char*>(packet), length); + + const uint16_t qdcount = ntohs(dh->qdcount); + for(size_t n = 0; n < qdcount; ++n) { + dpm.skipDomainName(); + if (section == 0) { + uint16_t dnstype = dpm.get16BitInt(); + if (dnstype == type) { + result++; + } + /* class */ + dpm.skipBytes(2); + } else { + /* type and class */ + dpm.skipBytes(4); + } + } + const uint16_t ancount = ntohs(dh->ancount); + for(size_t n = 0; n < ancount; ++n) { + dpm.skipDomainName(); + if (section == 1) { + uint16_t dnstype = dpm.get16BitInt(); + if (dnstype == type) { + result++; + } + /* class */ + dpm.skipBytes(2); + } else { + /* type and class */ + dpm.skipBytes(4); + } + /* ttl */ + dpm.skipBytes(4); + dpm.skipRData(); + } + const uint16_t nscount = ntohs(dh->nscount); + for(size_t n = 0; n < nscount; ++n) { + dpm.skipDomainName(); + if (section == 2) { + uint16_t dnstype = dpm.get16BitInt(); + if (dnstype == type) { + result++; + } + /* class */ + dpm.skipBytes(2); + } else { + /* type and class */ + dpm.skipBytes(4); + } + /* ttl */ + dpm.skipBytes(4); + dpm.skipRData(); + } + const uint16_t arcount = ntohs(dh->arcount); + for(size_t n = 0; n < arcount; ++n) { + dpm.skipDomainName(); + if (section == 3) { + uint16_t dnstype = dpm.get16BitInt(); + if (dnstype == type) { + result++; + } + /* class */ + dpm.skipBytes(2); + } else { + /* type and class */ + dpm.skipBytes(4); + } + /* ttl */ + dpm.skipBytes(4); + dpm.skipRData(); + } + } + catch(...) + { + } + return result; +} + +bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payloadSize, uint16_t* z) +{ + if (length < sizeof(dnsheader)) { + return false; + } + + *payloadSize = 0; + *z = 0; + + try + { + const dnsheader* dh = (const dnsheader*) packet; + DNSPacketMangler dpm(const_cast<char*>(packet), length); + + const uint16_t qdcount = ntohs(dh->qdcount); + for(size_t n = 0; n < qdcount; ++n) { + dpm.skipDomainName(); + /* type and class */ + dpm.skipBytes(4); + } + const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount); + for(size_t n = 0; n < numrecords; ++n) { + dpm.skipDomainName(); + const uint16_t dnstype = dpm.get16BitInt(); + const uint16_t dnsclass = dpm.get16BitInt(); + + if(dnstype == QType::OPT) { + /* skip extended rcode and version */ + dpm.skipBytes(2); + *z = dpm.get16BitInt(); + *payloadSize = dnsclass; + return true; + } + + /* TTL */ + dpm.skipBytes(4); + dpm.skipRData(); + } + } + catch(...) + { + } + + return false; +} + +bool visitDNSPacket(const std::string_view& packet, const std::function<bool(uint8_t, uint16_t, uint16_t, uint32_t, uint16_t, const char*)>& visitor) +{ + if (packet.size() < sizeof(dnsheader)) { + return false; + } + + try + { + dnsheader dh; + memcpy(&dh, reinterpret_cast<const dnsheader*>(packet.data()), sizeof(dh)); + uint64_t numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount); + PacketReader reader(packet); + + uint64_t n; + for (n = 0; n < ntohs(dh.qdcount) ; ++n) { + (void) reader.getName(); + /* type and class */ + reader.skip(4); + } + + for (n = 0; n < numrecords; ++n) { + (void) reader.getName(); + + uint8_t section = n < ntohs(dh.ancount) ? 1 : (n < (ntohs(dh.ancount) + ntohs(dh.nscount)) ? 2 : 3); + uint16_t dnstype = reader.get16BitInt(); + uint16_t dnsclass = reader.get16BitInt(); + + if (dnstype == QType::OPT) { + // not getting near that one with a stick + break; + } + + uint32_t dnsttl = reader.get32BitInt(); + uint16_t contentLength = reader.get16BitInt(); + uint16_t pos = reader.getPosition(); + + bool done = visitor(section, dnsclass, dnstype, dnsttl, contentLength, &packet.at(pos)); + if (done) { + return true; + } + + reader.skip(contentLength); + } + } + catch (...) { + return false; + } + + return true; +} |