summaryrefslogtreecommitdiffstats
path: root/dnsname.cc
diff options
context:
space:
mode:
Diffstat (limited to 'dnsname.cc')
-rw-r--r--dnsname.cc267
1 files changed, 181 insertions, 86 deletions
diff --git a/dnsname.cc b/dnsname.cc
index 3bfbf30..bbac4ff 100644
--- a/dnsname.cc
+++ b/dnsname.cc
@@ -99,8 +99,7 @@ DNSName::DNSName(const std::string_view sw)
}
}
-
-DNSName::DNSName(const char* pos, int len, int offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, uint16_t minOffset)
+DNSName::DNSName(const char* pos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, uint16_t minOffset)
{
if (offset >= len)
throw std::range_error("Trying to read past the end of the buffer ("+std::to_string(offset)+ " >= "+std::to_string(len)+")");
@@ -114,64 +113,137 @@ DNSName::DNSName(const char* pos, int len, int offset, bool uncompress, uint16_t
packetParser(pos, len, offset, uncompress, qtype, qclass, consumed, 0, minOffset);
}
-// this should be the __only__ dns name parser in PowerDNS.
-void DNSName::packetParser(const char* qpos, int len, int offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset)
+static void checkLabelLength(uint8_t length)
{
- const unsigned char* pos=(const unsigned char*)qpos;
- unsigned char labellen;
- const unsigned char *opos = pos;
+ if (length == 0) {
+ throw std::range_error("no such thing as an empty label to append");
+ }
+ if (length > 63) {
+ throw std::range_error("label too long to append");
+ }
+}
- if (offset >= len)
- throw std::range_error("Trying to read past the end of the buffer ("+std::to_string(offset)+ " >= "+std::to_string(len)+")");
- if (offset < (int) minOffset)
- throw std::range_error("Trying to read before the beginning of the buffer ("+std::to_string(offset)+ " < "+std::to_string(minOffset)+")");
+// this parses a DNS name until a compression pointer is found
+size_t DNSName::parsePacketUncompressed(const pdns::views::UnsignedCharView& view, size_t pos, bool uncompress)
+{
+ const size_t initialPos = pos;
+ size_t totalLength = 0;
+ unsigned char labellen = 0;
- const unsigned char* end = pos + len;
- pos += offset;
- while((labellen=*pos++) && pos < end) { // "scan and copy"
- if(labellen >= 0xc0) {
- if(!uncompress)
- throw std::range_error("Found compressed label, instructed not to follow");
+ do {
+ labellen = view.at(pos);
+ ++pos;
+
+ if (labellen == 0) {
+ --pos;
+ break;
+ }
- labellen &= (~0xc0);
- int newpos = (labellen << 8) + *(const unsigned char*)pos;
-
- if(newpos < offset) {
- if(newpos < (int) minOffset)
- throw std::range_error("Invalid label position during decompression ("+std::to_string(newpos)+ " < "+std::to_string(minOffset)+")");
- if (++depth > 100)
- throw std::range_error("Abort label decompression after 100 redirects");
- packetParser((const char*)opos, len, newpos, true, nullptr, nullptr, nullptr, depth, minOffset);
- } else
- throw std::range_error("Found a forward reference during label decompression");
- pos++;
+ if (labellen >= 0xc0) {
+ if (!uncompress) {
+ throw std::range_error("Found compressed label, instructed not to follow");
+ }
+ --pos;
break;
- } else if(labellen & 0xc0) {
+ }
+
+ if ((labellen & 0xc0) != 0) {
throw std::range_error("Found an invalid label length in qname (only one of the first two bits is set)");
}
- if (pos + labellen < end) {
- appendRawLabel((const char*)pos, labellen);
+ checkLabelLength(labellen);
+ // reserve one byte for the label length
+ if (totalLength + labellen > s_maxDNSNameLength - 1) {
+ throw std::range_error("name too long to append");
}
- else
+ if (pos + labellen >= view.size()) {
throw std::range_error("Found an invalid label length in qname");
- pos+=labellen;
- }
- if(d_storage.empty())
- d_storage.append(1, (char)0); // we just parsed the root
- if(consumed)
- *consumed = pos - opos - offset;
- if(qtype) {
- if (pos + 2 > end) {
- throw std::range_error("Trying to read qtype past the end of the buffer ("+std::to_string((pos - opos) + 2)+ " > "+std::to_string(len)+")");
}
- *qtype=(*(const unsigned char*)pos)*256 + *((const unsigned char*)pos+1);
+ pos += labellen;
+ totalLength += 1 + labellen;
}
- pos+=2;
- if(qclass) {
- if (pos + 2 > end) {
- throw std::range_error("Trying to read qclass past the end of the buffer ("+std::to_string((pos - opos) + 2)+ " > "+std::to_string(len)+")");
+ while (pos < view.size());
+
+ if (totalLength != 0) {
+ auto existingSize = d_storage.size();
+ if (existingSize > 0) {
+ // remove the last label count, we are about to override it */
+ --existingSize;
}
- *qclass=(*(const unsigned char*)pos)*256 + *((const unsigned char*)pos+1);
+ d_storage.reserve(existingSize + totalLength + 1);
+ d_storage.resize(existingSize + totalLength);
+ memcpy(&d_storage.at(existingSize), &view.at(initialPos), totalLength);
+ d_storage.append(1, static_cast<char>(0));
+ }
+ return pos;
+}
+
+// this should be the __only__ dns name parser in PowerDNS.
+void DNSName::packetParser(const char* qpos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset)
+{
+ if (offset >= len) {
+ throw std::range_error("Trying to read past the end of the buffer ("+std::to_string(offset)+ " >= "+std::to_string(len)+")");
+ }
+
+ if (offset < static_cast<size_t>(minOffset)) {
+ throw std::range_error("Trying to read before the beginning of the buffer ("+std::to_string(offset)+ " < "+std::to_string(minOffset)+")");
+ }
+ unsigned char labellen{0};
+
+ pdns::views::UnsignedCharView view(qpos, len);
+ auto pos = parsePacketUncompressed(view, offset, uncompress);
+
+ labellen = view.at(pos);
+ pos++;
+ if (labellen != 0 && pos < view.size()) {
+ if (labellen < 0xc0) {
+ abort();
+ }
+
+ if (!uncompress) {
+ throw std::range_error("Found compressed label, instructed not to follow");
+ }
+
+ labellen &= (~0xc0);
+ size_t newpos = (labellen << 8) + view.at(pos);
+
+ if (newpos >= offset) {
+ throw std::range_error("Found a forward reference during label decompression");
+ }
+
+ if (newpos < minOffset) {
+ throw std::range_error("Invalid label position during decompression ("+std::to_string(newpos)+ " < "+std::to_string(minOffset)+")");
+ }
+
+ if (++depth > 100) {
+ throw std::range_error("Abort label decompression after 100 redirects");
+ }
+
+ packetParser(qpos, len, newpos, true, nullptr, nullptr, nullptr, depth, minOffset);
+
+ pos++;
+ }
+
+ if (d_storage.empty()) {
+ d_storage.append(1, static_cast<char>(0)); // we just parsed the root
+ }
+
+ if (consumed != nullptr) {
+ *consumed = pos - offset;
+ }
+
+ if (qtype != nullptr) {
+ if (pos + 2 > view.size()) {
+ throw std::range_error("Trying to read qtype past the end of the buffer ("+std::to_string(pos + 2)+ " > "+std::to_string(len)+")");
+ }
+ *qtype = view.at(pos)*256 + view.at(pos+1);
+ }
+
+ pos += 2;
+ if (qclass != nullptr) {
+ if (pos + 2 > view.size()) {
+ throw std::range_error("Trying to read qclass past the end of the buffer ("+std::to_string(pos + 2)+ " > "+std::to_string(len)+")");
+ }
+ *qclass = view.at(pos)*256 + view.at(pos+1);
}
}
@@ -225,8 +297,9 @@ std::string DNSName::toLogString() const
std::string DNSName::toDNSString() const
{
- if (empty())
+ if (empty()) {
throw std::out_of_range("Attempt to DNSString an unset dnsname");
+ }
return std::string(d_storage.c_str(), d_storage.length());
}
@@ -250,11 +323,13 @@ size_t DNSName::wirelength() const {
// Are WE part of parent
bool DNSName::isPartOf(const DNSName& parent) const
{
- if(parent.empty() || empty())
+ if(parent.empty() || empty()) {
throw std::out_of_range("empty dnsnames aren't part of anything");
+ }
- if(parent.d_storage.size() > d_storage.size())
+ if(parent.d_storage.size() > d_storage.size()) {
return false;
+ }
// this is slightly complicated since we can't start from the end, since we can't see where a label begins/ends then
for(auto us=d_storage.cbegin(); us<d_storage.cend(); us+=*us+1) {
@@ -284,19 +359,24 @@ DNSName DNSName::makeRelative(const DNSName& zone) const
return ret;
}
-void DNSName::makeUsRelative(const DNSName& zone)
+void DNSName::makeUsRelative(const DNSName& zone)
{
if (isPartOf(zone)) {
d_storage.erase(d_storage.size()-zone.d_storage.size());
- d_storage.append(1, (char)0); // put back the trailing 0
- }
- else
+ d_storage.append(1, static_cast<char>(0)); // put back the trailing 0
+ }
+ else {
clear();
+ }
}
DNSName DNSName::getCommonLabels(const DNSName& other) const
{
- DNSName result;
+ if (empty() || other.empty()) {
+ return DNSName();
+ }
+
+ DNSName result(g_rootdnsname);
const std::vector<std::string> ours = getRawLabels();
const std::vector<std::string> others = other.getRawLabels();
@@ -319,8 +399,9 @@ DNSName DNSName::labelReverse() const
{
DNSName ret;
- if(isRoot())
+ if (isRoot()) {
return *this; // we don't create the root automatically below
+ }
if (!empty()) {
vector<string> l=getRawLabels();
@@ -339,41 +420,48 @@ void DNSName::appendRawLabel(const std::string& label)
void DNSName::appendRawLabel(const char* start, unsigned int length)
{
- if(length==0)
- throw std::range_error("no such thing as an empty label to append");
- if(length > 63)
- throw std::range_error("label too long to append");
- if(d_storage.size() + length > s_maxDNSNameLength - 1) // reserve one byte for the label length
+ checkLabelLength(length);
+
+ // reserve one byte for the label length
+ if (d_storage.size() + length > s_maxDNSNameLength - 1) {
throw std::range_error("name too long to append");
+ }
- if(d_storage.empty()) {
- d_storage.append(1, (char)length);
+ if (d_storage.empty()) {
+ d_storage.reserve(1 + length + 1);
+ d_storage.append(1, static_cast<char>(length));
}
else {
- *d_storage.rbegin()=(char)length;
+ d_storage.reserve(d_storage.size() + length + 1);
+ *d_storage.rbegin() = static_cast<char>(length);
}
d_storage.append(start, length);
- d_storage.append(1, (char)0);
+ d_storage.append(1, static_cast<char>(0));
}
void DNSName::prependRawLabel(const std::string& label)
{
- if(label.empty())
- throw std::range_error("no such thing as an empty label to prepend");
- if(label.size() > 63)
- throw std::range_error("label too long to prepend");
- if(d_storage.size() + label.size() > s_maxDNSNameLength - 1) // reserve one byte for the label length
+ checkLabelLength(label.size());
+
+ // reserve one byte for the label length
+ if (d_storage.size() + label.size() > s_maxDNSNameLength - 1) {
throw std::range_error("name too long to prepend");
+ }
- if(d_storage.empty())
- d_storage.append(1, (char)0);
+ if (d_storage.empty()) {
+ d_storage.reserve(1 + label.size() + 1);
+ d_storage.append(1, static_cast<char>(0));
+ }
+ else {
+ d_storage.reserve(d_storage.size() + 1 + label.size());
+ }
- string_t prep(1, (char)label.size());
+ string_t prep(1, static_cast<char>(label.size()));
prep.append(label.c_str(), label.size());
d_storage = prep+d_storage;
}
-bool DNSName::slowCanonCompare(const DNSName& rhs) const
+bool DNSName::slowCanonCompare(const DNSName& rhs) const
{
auto ours=getRawLabels(), rhsLabels = rhs.getRawLabels();
return std::lexicographical_compare(ours.rbegin(), ours.rend(), rhsLabels.rbegin(), rhsLabels.rend(), CIStringCompare());
@@ -411,16 +499,18 @@ DNSName DNSName::getLastLabel() const
bool DNSName::chopOff()
{
- if(d_storage.empty() || d_storage[0]==0)
+ if (d_storage.empty() || d_storage[0]==0) {
return false;
+ }
d_storage.erase(0, (unsigned int)d_storage[0]+1);
return true;
}
bool DNSName::isWildcard() const
{
- if(d_storage.size() < 2)
+ if (d_storage.size() < 2) {
return false;
+ }
auto p = d_storage.begin();
return (*p == 0x01 && *++p == '*');
}
@@ -450,8 +540,9 @@ unsigned int DNSName::countLabels() const
void DNSName::trimToLabels(unsigned int to)
{
- while(countLabels() > to && chopOff())
+ while(countLabels() > to && chopOff()) {
;
+ }
}
@@ -466,12 +557,15 @@ void DNSName::appendEscapedLabel(std::string& appendTo, const char* orig, size_t
while (pos < len) {
auto p = static_cast<uint8_t>(orig[pos]);
- if(p=='.')
+ if (p=='.') {
appendTo+="\\.";
- else if(p=='\\')
+ }
+ else if (p=='\\') {
appendTo+="\\\\";
- else if(p > 0x20 && p < 0x7f)
- appendTo.append(1, (char)p);
+ }
+ else if (p > 0x20 && p < 0x7f) {
+ appendTo.append(1, static_cast<char>(p));
+ }
else {
char buf[] = "000";
auto got = snprintf(buf, sizeof(buf), "%03" PRIu8, p);
@@ -494,11 +588,12 @@ bool DNSName::has8bitBytes() const
for (size_t idx = 0; idx < length; idx++) {
++pos;
char c = s.at(pos);
- if(!((c >= 'a' && c <= 'z') ||
- (c >= 'A' && c <= 'Z') ||
- (c >= '0' && c <= '9') ||
- c =='-' || c == '_' || c=='*' || c=='.' || c=='/' || c=='@' || c==' ' || c=='\\' || c==':'))
+ if (!((c >= 'a' && c <= 'z') ||
+ (c >= 'A' && c <= 'Z') ||
+ (c >= '0' && c <= '9') ||
+ c =='-' || c == '_' || c=='*' || c=='.' || c=='/' || c=='@' || c==' ' || c=='\\' || c==':')) {
return true;
+ }
}
++pos;
length = s.at(pos);