diff options
Diffstat (limited to 'dnsname.cc')
-rw-r--r-- | dnsname.cc | 267 |
1 files changed, 181 insertions, 86 deletions
@@ -99,8 +99,7 @@ DNSName::DNSName(const std::string_view sw) } } - -DNSName::DNSName(const char* pos, int len, int offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, uint16_t minOffset) +DNSName::DNSName(const char* pos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, uint16_t minOffset) { if (offset >= len) throw std::range_error("Trying to read past the end of the buffer ("+std::to_string(offset)+ " >= "+std::to_string(len)+")"); @@ -114,64 +113,137 @@ DNSName::DNSName(const char* pos, int len, int offset, bool uncompress, uint16_t packetParser(pos, len, offset, uncompress, qtype, qclass, consumed, 0, minOffset); } -// this should be the __only__ dns name parser in PowerDNS. -void DNSName::packetParser(const char* qpos, int len, int offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset) +static void checkLabelLength(uint8_t length) { - const unsigned char* pos=(const unsigned char*)qpos; - unsigned char labellen; - const unsigned char *opos = pos; + if (length == 0) { + throw std::range_error("no such thing as an empty label to append"); + } + if (length > 63) { + throw std::range_error("label too long to append"); + } +} - if (offset >= len) - throw std::range_error("Trying to read past the end of the buffer ("+std::to_string(offset)+ " >= "+std::to_string(len)+")"); - if (offset < (int) minOffset) - throw std::range_error("Trying to read before the beginning of the buffer ("+std::to_string(offset)+ " < "+std::to_string(minOffset)+")"); +// this parses a DNS name until a compression pointer is found +size_t DNSName::parsePacketUncompressed(const pdns::views::UnsignedCharView& view, size_t pos, bool uncompress) +{ + const size_t initialPos = pos; + size_t totalLength = 0; + unsigned char labellen = 0; - const unsigned char* end = pos + len; - pos += offset; - while((labellen=*pos++) && pos < end) { // "scan and copy" - if(labellen >= 0xc0) { - if(!uncompress) - throw std::range_error("Found compressed label, instructed not to follow"); + do { + labellen = view.at(pos); + ++pos; + + if (labellen == 0) { + --pos; + break; + } - labellen &= (~0xc0); - int newpos = (labellen << 8) + *(const unsigned char*)pos; - - if(newpos < offset) { - if(newpos < (int) minOffset) - throw std::range_error("Invalid label position during decompression ("+std::to_string(newpos)+ " < "+std::to_string(minOffset)+")"); - if (++depth > 100) - throw std::range_error("Abort label decompression after 100 redirects"); - packetParser((const char*)opos, len, newpos, true, nullptr, nullptr, nullptr, depth, minOffset); - } else - throw std::range_error("Found a forward reference during label decompression"); - pos++; + if (labellen >= 0xc0) { + if (!uncompress) { + throw std::range_error("Found compressed label, instructed not to follow"); + } + --pos; break; - } else if(labellen & 0xc0) { + } + + if ((labellen & 0xc0) != 0) { throw std::range_error("Found an invalid label length in qname (only one of the first two bits is set)"); } - if (pos + labellen < end) { - appendRawLabel((const char*)pos, labellen); + checkLabelLength(labellen); + // reserve one byte for the label length + if (totalLength + labellen > s_maxDNSNameLength - 1) { + throw std::range_error("name too long to append"); } - else + if (pos + labellen >= view.size()) { throw std::range_error("Found an invalid label length in qname"); - pos+=labellen; - } - if(d_storage.empty()) - d_storage.append(1, (char)0); // we just parsed the root - if(consumed) - *consumed = pos - opos - offset; - if(qtype) { - if (pos + 2 > end) { - throw std::range_error("Trying to read qtype past the end of the buffer ("+std::to_string((pos - opos) + 2)+ " > "+std::to_string(len)+")"); } - *qtype=(*(const unsigned char*)pos)*256 + *((const unsigned char*)pos+1); + pos += labellen; + totalLength += 1 + labellen; } - pos+=2; - if(qclass) { - if (pos + 2 > end) { - throw std::range_error("Trying to read qclass past the end of the buffer ("+std::to_string((pos - opos) + 2)+ " > "+std::to_string(len)+")"); + while (pos < view.size()); + + if (totalLength != 0) { + auto existingSize = d_storage.size(); + if (existingSize > 0) { + // remove the last label count, we are about to override it */ + --existingSize; } - *qclass=(*(const unsigned char*)pos)*256 + *((const unsigned char*)pos+1); + d_storage.reserve(existingSize + totalLength + 1); + d_storage.resize(existingSize + totalLength); + memcpy(&d_storage.at(existingSize), &view.at(initialPos), totalLength); + d_storage.append(1, static_cast<char>(0)); + } + return pos; +} + +// this should be the __only__ dns name parser in PowerDNS. +void DNSName::packetParser(const char* qpos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset) +{ + if (offset >= len) { + throw std::range_error("Trying to read past the end of the buffer ("+std::to_string(offset)+ " >= "+std::to_string(len)+")"); + } + + if (offset < static_cast<size_t>(minOffset)) { + throw std::range_error("Trying to read before the beginning of the buffer ("+std::to_string(offset)+ " < "+std::to_string(minOffset)+")"); + } + unsigned char labellen{0}; + + pdns::views::UnsignedCharView view(qpos, len); + auto pos = parsePacketUncompressed(view, offset, uncompress); + + labellen = view.at(pos); + pos++; + if (labellen != 0 && pos < view.size()) { + if (labellen < 0xc0) { + abort(); + } + + if (!uncompress) { + throw std::range_error("Found compressed label, instructed not to follow"); + } + + labellen &= (~0xc0); + size_t newpos = (labellen << 8) + view.at(pos); + + if (newpos >= offset) { + throw std::range_error("Found a forward reference during label decompression"); + } + + if (newpos < minOffset) { + throw std::range_error("Invalid label position during decompression ("+std::to_string(newpos)+ " < "+std::to_string(minOffset)+")"); + } + + if (++depth > 100) { + throw std::range_error("Abort label decompression after 100 redirects"); + } + + packetParser(qpos, len, newpos, true, nullptr, nullptr, nullptr, depth, minOffset); + + pos++; + } + + if (d_storage.empty()) { + d_storage.append(1, static_cast<char>(0)); // we just parsed the root + } + + if (consumed != nullptr) { + *consumed = pos - offset; + } + + if (qtype != nullptr) { + if (pos + 2 > view.size()) { + throw std::range_error("Trying to read qtype past the end of the buffer ("+std::to_string(pos + 2)+ " > "+std::to_string(len)+")"); + } + *qtype = view.at(pos)*256 + view.at(pos+1); + } + + pos += 2; + if (qclass != nullptr) { + if (pos + 2 > view.size()) { + throw std::range_error("Trying to read qclass past the end of the buffer ("+std::to_string(pos + 2)+ " > "+std::to_string(len)+")"); + } + *qclass = view.at(pos)*256 + view.at(pos+1); } } @@ -225,8 +297,9 @@ std::string DNSName::toLogString() const std::string DNSName::toDNSString() const { - if (empty()) + if (empty()) { throw std::out_of_range("Attempt to DNSString an unset dnsname"); + } return std::string(d_storage.c_str(), d_storage.length()); } @@ -250,11 +323,13 @@ size_t DNSName::wirelength() const { // Are WE part of parent bool DNSName::isPartOf(const DNSName& parent) const { - if(parent.empty() || empty()) + if(parent.empty() || empty()) { throw std::out_of_range("empty dnsnames aren't part of anything"); + } - if(parent.d_storage.size() > d_storage.size()) + if(parent.d_storage.size() > d_storage.size()) { return false; + } // this is slightly complicated since we can't start from the end, since we can't see where a label begins/ends then for(auto us=d_storage.cbegin(); us<d_storage.cend(); us+=*us+1) { @@ -284,19 +359,24 @@ DNSName DNSName::makeRelative(const DNSName& zone) const return ret; } -void DNSName::makeUsRelative(const DNSName& zone) +void DNSName::makeUsRelative(const DNSName& zone) { if (isPartOf(zone)) { d_storage.erase(d_storage.size()-zone.d_storage.size()); - d_storage.append(1, (char)0); // put back the trailing 0 - } - else + d_storage.append(1, static_cast<char>(0)); // put back the trailing 0 + } + else { clear(); + } } DNSName DNSName::getCommonLabels(const DNSName& other) const { - DNSName result; + if (empty() || other.empty()) { + return DNSName(); + } + + DNSName result(g_rootdnsname); const std::vector<std::string> ours = getRawLabels(); const std::vector<std::string> others = other.getRawLabels(); @@ -319,8 +399,9 @@ DNSName DNSName::labelReverse() const { DNSName ret; - if(isRoot()) + if (isRoot()) { return *this; // we don't create the root automatically below + } if (!empty()) { vector<string> l=getRawLabels(); @@ -339,41 +420,48 @@ void DNSName::appendRawLabel(const std::string& label) void DNSName::appendRawLabel(const char* start, unsigned int length) { - if(length==0) - throw std::range_error("no such thing as an empty label to append"); - if(length > 63) - throw std::range_error("label too long to append"); - if(d_storage.size() + length > s_maxDNSNameLength - 1) // reserve one byte for the label length + checkLabelLength(length); + + // reserve one byte for the label length + if (d_storage.size() + length > s_maxDNSNameLength - 1) { throw std::range_error("name too long to append"); + } - if(d_storage.empty()) { - d_storage.append(1, (char)length); + if (d_storage.empty()) { + d_storage.reserve(1 + length + 1); + d_storage.append(1, static_cast<char>(length)); } else { - *d_storage.rbegin()=(char)length; + d_storage.reserve(d_storage.size() + length + 1); + *d_storage.rbegin() = static_cast<char>(length); } d_storage.append(start, length); - d_storage.append(1, (char)0); + d_storage.append(1, static_cast<char>(0)); } void DNSName::prependRawLabel(const std::string& label) { - if(label.empty()) - throw std::range_error("no such thing as an empty label to prepend"); - if(label.size() > 63) - throw std::range_error("label too long to prepend"); - if(d_storage.size() + label.size() > s_maxDNSNameLength - 1) // reserve one byte for the label length + checkLabelLength(label.size()); + + // reserve one byte for the label length + if (d_storage.size() + label.size() > s_maxDNSNameLength - 1) { throw std::range_error("name too long to prepend"); + } - if(d_storage.empty()) - d_storage.append(1, (char)0); + if (d_storage.empty()) { + d_storage.reserve(1 + label.size() + 1); + d_storage.append(1, static_cast<char>(0)); + } + else { + d_storage.reserve(d_storage.size() + 1 + label.size()); + } - string_t prep(1, (char)label.size()); + string_t prep(1, static_cast<char>(label.size())); prep.append(label.c_str(), label.size()); d_storage = prep+d_storage; } -bool DNSName::slowCanonCompare(const DNSName& rhs) const +bool DNSName::slowCanonCompare(const DNSName& rhs) const { auto ours=getRawLabels(), rhsLabels = rhs.getRawLabels(); return std::lexicographical_compare(ours.rbegin(), ours.rend(), rhsLabels.rbegin(), rhsLabels.rend(), CIStringCompare()); @@ -411,16 +499,18 @@ DNSName DNSName::getLastLabel() const bool DNSName::chopOff() { - if(d_storage.empty() || d_storage[0]==0) + if (d_storage.empty() || d_storage[0]==0) { return false; + } d_storage.erase(0, (unsigned int)d_storage[0]+1); return true; } bool DNSName::isWildcard() const { - if(d_storage.size() < 2) + if (d_storage.size() < 2) { return false; + } auto p = d_storage.begin(); return (*p == 0x01 && *++p == '*'); } @@ -450,8 +540,9 @@ unsigned int DNSName::countLabels() const void DNSName::trimToLabels(unsigned int to) { - while(countLabels() > to && chopOff()) + while(countLabels() > to && chopOff()) { ; + } } @@ -466,12 +557,15 @@ void DNSName::appendEscapedLabel(std::string& appendTo, const char* orig, size_t while (pos < len) { auto p = static_cast<uint8_t>(orig[pos]); - if(p=='.') + if (p=='.') { appendTo+="\\."; - else if(p=='\\') + } + else if (p=='\\') { appendTo+="\\\\"; - else if(p > 0x20 && p < 0x7f) - appendTo.append(1, (char)p); + } + else if (p > 0x20 && p < 0x7f) { + appendTo.append(1, static_cast<char>(p)); + } else { char buf[] = "000"; auto got = snprintf(buf, sizeof(buf), "%03" PRIu8, p); @@ -494,11 +588,12 @@ bool DNSName::has8bitBytes() const for (size_t idx = 0; idx < length; idx++) { ++pos; char c = s.at(pos); - if(!((c >= 'a' && c <= 'z') || - (c >= 'A' && c <= 'Z') || - (c >= '0' && c <= '9') || - c =='-' || c == '_' || c=='*' || c=='.' || c=='/' || c=='@' || c==' ' || c=='\\' || c==':')) + if (!((c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || + c =='-' || c == '_' || c=='*' || c=='.' || c=='/' || c=='@' || c==' ' || c=='\\' || c==':')) { return true; + } } ++pos; length = s.at(pos); |