diff options
Diffstat (limited to 'netwerk/dns/DNSPacket.cpp')
-rw-r--r-- | netwerk/dns/DNSPacket.cpp | 1625 |
1 files changed, 1625 insertions, 0 deletions
diff --git a/netwerk/dns/DNSPacket.cpp b/netwerk/dns/DNSPacket.cpp new file mode 100644 index 0000000000..4d28889294 --- /dev/null +++ b/netwerk/dns/DNSPacket.cpp @@ -0,0 +1,1625 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "DNSPacket.h" + +#include "DNS.h" +#include "mozilla/EndianUtils.h" +#include "mozilla/ScopeExit.h" +#include "mozilla/StaticPrefs_network.h" +#include "ODoHService.h" +// Put DNSLogging.h at the end to avoid LOG being overwritten by other headers. +#include "DNSLogging.h" + +#include "nsIInputStream.h" + +namespace mozilla { +namespace net { + +static uint16_t get16bit(const unsigned char* aData, unsigned int index) { + return ((aData[index] << 8) | aData[index + 1]); +} + +static bool get16bit(const Span<const uint8_t>& aData, + Span<const uint8_t>::const_iterator& it, + uint16_t& result) { + if (it >= aData.cend() || std::distance(it, aData.cend()) < 2) { + return false; + } + + result = (*it << 8) | *(it + 1); + it += 2; + return true; +} + +static uint32_t get32bit(const unsigned char* aData, unsigned int index) { + return (aData[index] << 24) | (aData[index + 1] << 16) | + (aData[index + 2] << 8) | aData[index + 3]; +} + +// https://datatracker.ietf.org/doc/html/draft-ietf-dnsop-extended-error-16#section-4 +// This is a list of errors for which we should not fallback to Do53. +// These are normally DNSSEC failures or explicit filtering performed by the +// recursive resolver. +bool hardFail(uint16_t code) { + const uint16_t noFallbackErrors[] = { + 4, // Forged answer (malware filtering) + 6, // DNSSEC Boggus + 7, // Signature expired + 8, // Signature not yet valid + 9, // DNSKEY Missing + 10, // RRSIG missing + 11, // No ZONE Key Bit set + 12, // NSEC Missing + 17, // Filtered + }; + + for (const auto& err : noFallbackErrors) { + if (code == err) { + return true; + } + } + return false; +} + +// static +nsresult DNSPacket::ParseSvcParam(unsigned int svcbIndex, uint16_t key, + SvcFieldValue& field, uint16_t length, + const unsigned char* aBuffer) { + switch (key) { + case SvcParamKeyMandatory: { + if (length % 2 != 0) { + // This key should encode a list of uint16_t + return NS_ERROR_UNEXPECTED; + } + while (length > 0) { + uint16_t mandatoryKey = get16bit(aBuffer, svcbIndex); + length -= 2; + svcbIndex += 2; + + if (!IsValidSvcParamKey(mandatoryKey)) { + LOG(("The mandatory field includes a key we don't support %u", + mandatoryKey)); + return NS_ERROR_UNEXPECTED; + } + } + break; + } + case SvcParamKeyAlpn: { + field.mValue = AsVariant(SvcParamAlpn()); + auto& alpnArray = field.mValue.as<SvcParamAlpn>().mValue; + while (length > 0) { + uint8_t alpnIdLength = aBuffer[svcbIndex++]; + length -= 1; + if (alpnIdLength > length) { + return NS_ERROR_UNEXPECTED; + } + + alpnArray.AppendElement( + nsCString((const char*)&aBuffer[svcbIndex], alpnIdLength)); + length -= alpnIdLength; + svcbIndex += alpnIdLength; + } + break; + } + case SvcParamKeyNoDefaultAlpn: { + if (length != 0) { + // This key should not contain a value + return NS_ERROR_UNEXPECTED; + } + field.mValue = AsVariant(SvcParamNoDefaultAlpn{}); + break; + } + case SvcParamKeyPort: { + if (length != 2) { + // This key should only encode a uint16_t + return NS_ERROR_UNEXPECTED; + } + field.mValue = + AsVariant(SvcParamPort{.mValue = get16bit(aBuffer, svcbIndex)}); + break; + } + case SvcParamKeyIpv4Hint: { + if (length % 4 != 0) { + // This key should only encode IPv4 addresses + return NS_ERROR_UNEXPECTED; + } + + field.mValue = AsVariant(SvcParamIpv4Hint()); + auto& ipv4array = field.mValue.as<SvcParamIpv4Hint>().mValue; + while (length > 0) { + NetAddr addr; + addr.inet.family = AF_INET; + addr.inet.port = 0; + addr.inet.ip = ntohl(get32bit(aBuffer, svcbIndex)); + ipv4array.AppendElement(addr); + length -= 4; + svcbIndex += 4; + } + break; + } + case SvcParamKeyEchConfig: { + field.mValue = AsVariant(SvcParamEchConfig{ + .mValue = nsCString((const char*)(&aBuffer[svcbIndex]), length)}); + break; + } + case SvcParamKeyIpv6Hint: { + if (length % 16 != 0) { + // This key should only encode IPv6 addresses + return NS_ERROR_UNEXPECTED; + } + + field.mValue = AsVariant(SvcParamIpv6Hint()); + auto& ipv6array = field.mValue.as<SvcParamIpv6Hint>().mValue; + while (length > 0) { + NetAddr addr; + addr.inet6.family = AF_INET6; + addr.inet6.port = 0; // unknown + addr.inet6.flowinfo = 0; // unknown + addr.inet6.scope_id = 0; // unknown + for (int i = 0; i < 16; i++, svcbIndex++) { + addr.inet6.ip.u8[i] = aBuffer[svcbIndex]; + } + ipv6array.AppendElement(addr); + length -= 16; + // no need to increase svcbIndex - we did it in the for above. + } + break; + } + case SvcParamKeyODoHConfig: { + field.mValue = AsVariant(SvcParamODoHConfig{ + .mValue = nsCString((const char*)(&aBuffer[svcbIndex]), length)}); + break; + } + default: { + // Unespected type. We'll just ignore it. + return NS_OK; + break; + } + } + return NS_OK; +} + +nsresult DNSPacket::PassQName(unsigned int& index, + const unsigned char* aBuffer) { + uint8_t length; + do { + if (mBodySize < (index + 1)) { + LOG(("TRR: PassQName:%d fail at index %d\n", __LINE__, index)); + return NS_ERROR_ILLEGAL_VALUE; + } + length = static_cast<uint8_t>(aBuffer[index]); + if ((length & 0xc0) == 0xc0) { + // name pointer, advance over it and be done + if (mBodySize < (index + 2)) { + return NS_ERROR_ILLEGAL_VALUE; + } + index += 2; + break; + } + if (length & 0xc0) { + LOG(("TRR: illegal label length byte (%x) at index %d\n", length, index)); + return NS_ERROR_ILLEGAL_VALUE; + } + // pass label + if (mBodySize < (index + 1 + length)) { + LOG(("TRR: PassQName:%d fail at index %d\n", __LINE__, index)); + return NS_ERROR_ILLEGAL_VALUE; + } + index += 1 + length; + } while (length); + return NS_OK; +} + +// GetQname: retrieves the qname (stores in 'aQname') and stores the index +// after qname was parsed into the 'aIndex'. +nsresult DNSPacket::GetQname(nsACString& aQname, unsigned int& aIndex, + const unsigned char* aBuffer) { + uint8_t clength = 0; + unsigned int cindex = aIndex; + unsigned int loop = 128; // a valid DNS name can never loop this much + unsigned int endindex = 0; // index position after this data + do { + if (cindex >= mBodySize) { + LOG(("TRR: bad Qname packet\n")); + return NS_ERROR_ILLEGAL_VALUE; + } + clength = static_cast<uint8_t>(aBuffer[cindex]); + if ((clength & 0xc0) == 0xc0) { + // name pointer, get the new offset (14 bits) + if ((cindex + 1) >= mBodySize) { + return NS_ERROR_ILLEGAL_VALUE; + } + // extract the new index position for the next label + uint16_t newpos = (clength & 0x3f) << 8 | aBuffer[cindex + 1]; + if (!endindex) { + // only update on the first "jump" + endindex = cindex + 2; + } + cindex = newpos; + continue; + } + if (clength & 0xc0) { + // any of those bits set individually is an error + LOG(("TRR: bad Qname packet\n")); + return NS_ERROR_ILLEGAL_VALUE; + } + + cindex++; + + if (clength) { + if (!aQname.IsEmpty()) { + aQname.Append("."); + } + if ((cindex + clength) > mBodySize) { + return NS_ERROR_ILLEGAL_VALUE; + } + aQname.Append((const char*)(&aBuffer[cindex]), clength); + cindex += clength; // skip label + } + } while (clength && --loop); + + if (!loop) { + LOG(("DNSPacket::DohDecode pointer loop error\n")); + return NS_ERROR_ILLEGAL_VALUE; + } + if (!endindex) { + // there was no "jump" + endindex = cindex; + } + aIndex = endindex; + return NS_OK; +} + +nsresult DOHresp::Add(uint32_t TTL, unsigned char const* dns, + unsigned int index, uint16_t len, bool aLocalAllowed) { + NetAddr addr; + if (4 == len) { + // IPv4 + addr.inet.family = AF_INET; + addr.inet.port = 0; // unknown + addr.inet.ip = ntohl(get32bit(dns, index)); + } else if (16 == len) { + // IPv6 + addr.inet6.family = AF_INET6; + addr.inet6.port = 0; // unknown + addr.inet6.flowinfo = 0; // unknown + addr.inet6.scope_id = 0; // unknown + for (int i = 0; i < 16; i++, index++) { + addr.inet6.ip.u8[i] = dns[index]; + } + } else { + return NS_ERROR_UNEXPECTED; + } + + if (addr.IsIPAddrLocal() && !aLocalAllowed) { + return NS_ERROR_FAILURE; + } + + // While the DNS packet might return individual TTLs for each address, + // we can only return one value in the AddrInfo class so pick the + // lowest number. + if (mTtl < TTL) { + mTtl = TTL; + } + + if (LOG_ENABLED()) { + char buf[128]; + addr.ToStringBuffer(buf, sizeof(buf)); + LOG(("DOHresp:Add %s\n", buf)); + } + mAddresses.AppendElement(addr); + return NS_OK; +} + +nsresult DNSPacket::OnDataAvailable(nsIRequest* aRequest, + nsIInputStream* aInputStream, + uint64_t aOffset, const uint32_t aCount) { + if (aCount + mBodySize > MAX_SIZE) { + LOG(("DNSPacket::OnDataAvailable:%d fail\n", __LINE__)); + return NS_ERROR_FAILURE; + } + uint32_t count; + nsresult rv = + aInputStream->Read((char*)mResponse + mBodySize, aCount, &count); + if (NS_FAILED(rv)) { + return rv; + } + MOZ_ASSERT(count == aCount); + mBodySize += aCount; + return NS_OK; +} + +const uint8_t kDNS_CLASS_IN = 1; + +nsresult DNSPacket::EncodeRequest(nsCString& aBody, const nsACString& aHost, + uint16_t aType, bool aDisableECS) { + aBody.Truncate(); + // Header + aBody += '\0'; + aBody += '\0'; // 16 bit id + aBody += 0x01; // |QR| Opcode |AA|TC|RD| Set the RD bit + aBody += '\0'; // |RA| Z | RCODE | + aBody += '\0'; + aBody += 1; // QDCOUNT (number of entries in the question section) + aBody += '\0'; + aBody += '\0'; // ANCOUNT + aBody += '\0'; + aBody += '\0'; // NSCOUNT + + char additionalRecords = + (aDisableECS || StaticPrefs::network_trr_padding()) ? 1 : 0; + aBody += '\0'; // ARCOUNT + aBody += additionalRecords; // ARCOUNT low byte for EDNS(0) + + // Question + + // The input host name should be converted to a sequence of labels, where + // each label consists of a length octet followed by that number of + // octets. The domain name terminates with the zero length octet for the + // null label of the root. + // Followed by 16 bit QTYPE and 16 bit QCLASS + + int32_t index = 0; + int32_t offset = 0; + do { + bool dotFound = false; + int32_t labelLength; + index = aHost.FindChar('.', offset); + if (kNotFound != index) { + dotFound = true; + labelLength = index - offset; + } else { + labelLength = aHost.Length() - offset; + } + if (labelLength > 63) { + // too long label! + SetDNSPacketStatus(DNSPacketStatus::EncodeError); + return NS_ERROR_ILLEGAL_VALUE; + } + if (labelLength > 0) { + aBody += static_cast<unsigned char>(labelLength); + nsDependentCSubstring label = Substring(aHost, offset, labelLength); + aBody.Append(label); + } + if (!dotFound) { + aBody += '\0'; // terminate with a final zero + break; + } + offset += labelLength + 1; // move over label and dot + } while (true); + + aBody += static_cast<uint8_t>(aType >> 8); // upper 8 bit TYPE + aBody += static_cast<uint8_t>(aType); + aBody += '\0'; // upper 8 bit CLASS + aBody += kDNS_CLASS_IN; // IN - "the Internet" + + if (additionalRecords) { + // EDNS(0) is RFC 6891, ECS is RFC 7871 + aBody += '\0'; // NAME | domain name | MUST be 0 (root domain) | + aBody += '\0'; + aBody += 41; // TYPE | u_int16_t | OPT (41) | + aBody += 16; // CLASS | u_int16_t | requestor's UDP payload size | + aBody += + '\0'; // advertise 4K (high-byte: 16 | low-byte: 0), ignored by DoH + aBody += '\0'; // TTL | u_int32_t | extended RCODE and flags | + aBody += '\0'; + aBody += '\0'; + aBody += '\0'; + + // calculate padding length + unsigned int paddingLen = 0; + unsigned int rdlen = 0; + bool padding = StaticPrefs::network_trr_padding(); + if (padding) { + // always add padding specified in rfc 7830 when this config is enabled + // to allow the reponse to be padded as well + + // two bytes RDLEN, 4 bytes padding header + unsigned int packetLen = aBody.Length() + 2 + 4; + if (aDisableECS) { + // 8 bytes for disabling ecs + packetLen += 8; + } + + // clamp the padding length, because the padding extension only allows up + // to 2^16 - 1 bytes padding and adding too much padding wastes resources + uint32_t padTo = std::clamp<uint32_t>( + StaticPrefs::network_trr_padding_length(), 0, 1024); + + // Calculate number of padding bytes. The second '%'-operator is necessary + // because we prefer to add 0 bytes padding rather than padTo bytes + if (padTo > 0) { + paddingLen = (padTo - (packetLen % padTo)) % padTo; + } + // padding header + padding length + rdlen += 4 + paddingLen; + } + if (aDisableECS) { + rdlen += 8; + } + + // RDLEN | u_int16_t | length of all RDATA | + aBody += (char)((rdlen >> 8) & 0xff); // upper 8 bit RDLEN + aBody += (char)(rdlen & 0xff); + + // RDATA | octet stream | {attribute,value} pairs | + // The RDATA is just the ECS option setting zero subnet prefix + + if (aDisableECS) { + aBody += '\0'; // upper 8 bit OPTION-CODE ECS + aBody += 8; // OPTION-CODE, 2 octets, for ECS is 8 + + aBody += '\0'; // upper 8 bit OPTION-LENGTH + aBody += 4; // OPTION-LENGTH, 2 octets, contains the length of the + // payload after OPTION-LENGTH + aBody += '\0'; // upper 8 bit FAMILY. IANA Address Family Numbers + // registry, not the AF_* constants! + aBody += 1; // FAMILY (Ipv4), 2 octets + + aBody += '\0'; // SOURCE PREFIX-LENGTH | SCOPE PREFIX-LENGTH | + aBody += '\0'; + + // ADDRESS, minimum number of octets == nothing because zero bits + } + + if (padding) { + aBody += '\0'; // upper 8 bit option OPTION-CODE PADDING + aBody += 12; // OPTION-CODE, 2 octets, for PADDING is 12 + + // OPTION-LENGTH, 2 octets + aBody += (char)((paddingLen >> 8) & 0xff); + aBody += (char)(paddingLen & 0xff); + for (unsigned int i = 0; i < paddingLen; i++) { + aBody += '\0'; + } + } + } + + SetDNSPacketStatus(DNSPacketStatus::Success); + return NS_OK; +} + +Result<uint8_t, nsresult> DNSPacket::GetRCode() const { + if (mBodySize < 12) { + LOG(("DNSPacket::GetRCode - packet too small")); + return Err(NS_ERROR_ILLEGAL_VALUE); + } + + return mResponse[3] & 0x0F; +} + +Result<bool, nsresult> DNSPacket::RecursionAvailable() const { + if (mBodySize < 12) { + LOG(("DNSPacket::GetRCode - packet too small")); + return Err(NS_ERROR_ILLEGAL_VALUE); + } + + return mResponse[3] & 0x80; +} + +nsresult DNSPacket::DecodeInternal( + nsCString& aHost, enum TrrType aType, nsCString& aCname, bool aAllowRFC1918, + DOHresp& aResp, TypeRecordResultType& aTypeResult, + nsClassHashtable<nsCStringHashKey, DOHresp>& aAdditionalRecords, + uint32_t& aTTL, const unsigned char* aBuffer, uint32_t aLen) { + // The response has a 12 byte header followed by the question (returned) + // and then the answer. The answer section itself contains the name, type + // and class again and THEN the record data. + + // www.example.com response: + // header: + // abcd 8180 0001 0001 0000 0000 + // the question: + // 0377 7777 0765 7861 6d70 6c65 0363 6f6d 0000 0100 01 + // the answer: + // 03 7777 7707 6578 616d 706c 6503 636f 6d00 0001 0001 + // 0000 0080 0004 5db8 d822 + + unsigned int index = 12; + uint8_t length; + nsAutoCString host; + nsresult rv; + uint16_t extendedError = UINT16_MAX; + + LOG(("doh decode %s %d bytes\n", aHost.get(), aLen)); + + aCname.Truncate(); + + if (aLen < 12 || aBuffer[0] || aBuffer[1]) { + LOG(("TRR bad incoming DOH, eject!\n")); + return NS_ERROR_ILLEGAL_VALUE; + } + uint8_t rcode = mResponse[3] & 0x0F; + LOG(("TRR Decode %s RCODE %d\n", PromiseFlatCString(aHost).get(), rcode)); + + uint16_t questionRecords = get16bit(aBuffer, 4); // qdcount + // iterate over the single(?) host name in question + while (questionRecords) { + do { + if (aLen < (index + 1)) { + LOG(("TRR Decode 1 index: %u size: %u", index, aLen)); + return NS_ERROR_ILLEGAL_VALUE; + } + length = static_cast<uint8_t>(aBuffer[index]); + if (length) { + if (host.Length()) { + host.Append("."); + } + if (aLen < (index + 1 + length)) { + LOG(("TRR Decode 2 index: %u size: %u len: %u", index, aLen, length)); + return NS_ERROR_ILLEGAL_VALUE; + } + host.Append(((char*)aBuffer) + index + 1, length); + } + index += 1 + length; // skip length byte + label + } while (length); + if (aLen < (index + 4)) { + LOG(("TRR Decode 3 index: %u size: %u", index, aLen)); + return NS_ERROR_ILLEGAL_VALUE; + } + index += 4; // skip question's type, class + questionRecords--; + } + + // Figure out the number of answer records from ANCOUNT + uint16_t answerRecords = get16bit(aBuffer, 6); + + LOG(("TRR Decode: %d answer records (%u bytes body) %s index=%u\n", + answerRecords, aLen, host.get(), index)); + + while (answerRecords) { + nsAutoCString qname; + rv = GetQname(qname, index, aBuffer); + if (NS_FAILED(rv)) { + return rv; + } + // 16 bit TYPE + if (aLen < (index + 2)) { + LOG(("TRR: Dohdecode:%d fail at index %d\n", __LINE__, index + 2)); + return NS_ERROR_ILLEGAL_VALUE; + } + uint16_t TYPE = get16bit(aBuffer, index); + + if ((TYPE != TRRTYPE_CNAME) && (TYPE != TRRTYPE_HTTPSSVC) && + (TYPE != static_cast<uint16_t>(aType))) { + // Not the same type as was asked for nor CNAME + LOG(("TRR: Dohdecode:%d asked for type %d got %d\n", __LINE__, aType, + TYPE)); + return NS_ERROR_UNEXPECTED; + } + index += 2; + + // 16 bit class + if (aLen < (index + 2)) { + LOG(("TRR: Dohdecode:%d fail at index %d\n", __LINE__, index + 2)); + return NS_ERROR_ILLEGAL_VALUE; + } + uint16_t CLASS = get16bit(aBuffer, index); + if (kDNS_CLASS_IN != CLASS) { + LOG(("TRR bad CLASS (%u) at index %d\n", CLASS, index)); + return NS_ERROR_UNEXPECTED; + } + index += 2; + + // 32 bit TTL (seconds) + if (aLen < (index + 4)) { + LOG(("TRR: Dohdecode:%d fail at index %d\n", __LINE__, index)); + return NS_ERROR_ILLEGAL_VALUE; + } + uint32_t TTL = get32bit(aBuffer, index); + index += 4; + + // 16 bit RDLENGTH + if (aLen < (index + 2)) { + LOG(("TRR: Dohdecode:%d fail at index %d\n", __LINE__, index)); + return NS_ERROR_ILLEGAL_VALUE; + } + uint16_t RDLENGTH = get16bit(aBuffer, index); + index += 2; + + if (aLen < (index + RDLENGTH)) { + LOG(("TRR: Dohdecode:%d fail RDLENGTH=%d at index %d\n", __LINE__, + RDLENGTH, index)); + return NS_ERROR_ILLEGAL_VALUE; + } + + // We check if the qname is a case-insensitive match for the host or the + // FQDN version of the host + bool responseMatchesQuestion = + (qname.Length() == aHost.Length() || + (aHost.Length() == qname.Length() + 1 && aHost.Last() == '.')) && + StringBeginsWith(aHost, qname, nsCaseInsensitiveCStringComparator); + + if (responseMatchesQuestion) { + // RDATA + // - A (TYPE 1): 4 bytes + // - AAAA (TYPE 28): 16 bytes + // - NS (TYPE 2): N bytes + + switch (TYPE) { + case TRRTYPE_A: + if (RDLENGTH != 4) { + LOG(("TRR bad length for A (%u)\n", RDLENGTH)); + return NS_ERROR_UNEXPECTED; + } + rv = aResp.Add(TTL, aBuffer, index, RDLENGTH, aAllowRFC1918); + if (NS_FAILED(rv)) { + LOG( + ("TRR:DohDecode failed: local IP addresses or unknown IP " + "family\n")); + return rv; + } + break; + case TRRTYPE_AAAA: + if (RDLENGTH != 16) { + LOG(("TRR bad length for AAAA (%u)\n", RDLENGTH)); + return NS_ERROR_UNEXPECTED; + } + rv = aResp.Add(TTL, aBuffer, index, RDLENGTH, aAllowRFC1918); + if (NS_FAILED(rv)) { + LOG(("TRR got unique/local IPv6 address!\n")); + return rv; + } + break; + + case TRRTYPE_NS: + break; + case TRRTYPE_CNAME: + if (aCname.IsEmpty()) { + nsAutoCString qname; + unsigned int qnameindex = index; + rv = GetQname(qname, qnameindex, aBuffer); + if (NS_FAILED(rv)) { + return rv; + } + if (!qname.IsEmpty()) { + ToLowerCase(qname); + aCname = qname; + LOG(("DNSPacket::DohDecode CNAME host %s => %s\n", host.get(), + aCname.get())); + } else { + LOG(("DNSPacket::DohDecode empty CNAME for host %s!\n", + host.get())); + } + } else { + LOG(("DNSPacket::DohDecode CNAME - ignoring another entry\n")); + } + break; + case TRRTYPE_TXT: { + // TXT record RRDATA sections are a series of character-strings + // each character string is a length byte followed by that many data + // bytes + nsAutoCString txt; + unsigned int txtIndex = index; + uint16_t available = RDLENGTH; + + while (available > 0) { + uint8_t characterStringLen = aBuffer[txtIndex++]; + available--; + if (characterStringLen > available) { + LOG(("DNSPacket::DohDecode MALFORMED TXT RECORD\n")); + break; + } + txt.Append((const char*)(&aBuffer[txtIndex]), characterStringLen); + txtIndex += characterStringLen; + available -= characterStringLen; + } + + if (!aTypeResult.is<TypeRecordTxt>()) { + aTypeResult = AsVariant(CopyableTArray<nsCString>()); + } + + { + auto& results = aTypeResult.as<TypeRecordTxt>(); + results.AppendElement(txt); + } + if (aTTL > TTL) { + aTTL = TTL; + } + LOG(("DNSPacket::DohDecode TXT host %s => %s\n", host.get(), + txt.get())); + + break; + } + case TRRTYPE_HTTPSSVC: { + struct SVCB parsed; + int32_t lastSvcParamKey = -1; + + unsigned int svcbIndex = index; + CheckedInt<uint16_t> available = RDLENGTH; + + // Should have at least 2 bytes for the priority and one for the + // qname length. + if (available.value() < 3) { + return NS_ERROR_UNEXPECTED; + } + + parsed.mSvcFieldPriority = get16bit(aBuffer, svcbIndex); + svcbIndex += 2; + + rv = GetQname(parsed.mSvcDomainName, svcbIndex, aBuffer); + if (NS_FAILED(rv)) { + return rv; + } + + if (parsed.mSvcDomainName.IsEmpty()) { + if (parsed.mSvcFieldPriority == 0) { + // For AliasMode SVCB RRs, a TargetName of "." indicates that the + // service is not available or does not exist. + continue; + } + + // For ServiceMode SVCB RRs, if TargetName has the value ".", + // then the owner name of this record MUST be used as + // the effective TargetName. + // When the qname is port prefix name, we need to use the + // original host name as TargetName. + if (mOriginHost) { + parsed.mSvcDomainName = *mOriginHost; + } else { + parsed.mSvcDomainName = qname; + } + } + + available -= (svcbIndex - index); + if (!available.isValid()) { + return NS_ERROR_UNEXPECTED; + } + aTTL = TTL; + while (available.value() >= 4) { + // Every SvcFieldValues must have at least 4 bytes for the + // SvcParamKey (2 bytes) and length of SvcParamValue (2 bytes) + // If the length ever goes above the available data, meaning if + // available ever underflows, then that is an error. + struct SvcFieldValue value; + uint16_t key = get16bit(aBuffer, svcbIndex); + svcbIndex += 2; + + // 2.2 Clients MUST consider an RR malformed if SvcParamKeys are + // not in strictly increasing numeric order. + if (key <= lastSvcParamKey) { + LOG(("SvcParamKeys not in increasing order")); + return NS_ERROR_UNEXPECTED; + } + lastSvcParamKey = key; + + uint16_t len = get16bit(aBuffer, svcbIndex); + svcbIndex += 2; + + available -= 4 + len; + if (!available.isValid()) { + return NS_ERROR_UNEXPECTED; + } + + rv = ParseSvcParam(svcbIndex, key, value, len, aBuffer); + if (NS_FAILED(rv)) { + return rv; + } + svcbIndex += len; + + // If this is an unknown key, we will simply ignore it. + // We also don't need to record SvcParamKeyMandatory + if (key == SvcParamKeyMandatory || !IsValidSvcParamKey(key)) { + continue; + } + + if (value.mValue.is<SvcParamIpv4Hint>() || + value.mValue.is<SvcParamIpv6Hint>()) { + parsed.mHasIPHints = true; + } + if (value.mValue.is<SvcParamEchConfig>()) { + parsed.mHasEchConfig = true; + parsed.mEchConfig = value.mValue.as<SvcParamEchConfig>().mValue; + } + if (value.mValue.is<SvcParamODoHConfig>()) { + parsed.mODoHConfig = value.mValue.as<SvcParamODoHConfig>().mValue; + } + parsed.mSvcFieldValue.AppendElement(value); + } + + if (aType != TRRTYPE_HTTPSSVC) { + // Ignore the entry that we just parsed if we didn't ask for it. + break; + } + + // Check for AliasForm + if (aCname.IsEmpty() && parsed.mSvcFieldPriority == 0) { + // Alias form SvcDomainName must not have the "." value (empty) + if (parsed.mSvcDomainName.IsEmpty()) { + return NS_ERROR_UNEXPECTED; + } + aCname = parsed.mSvcDomainName; + // If aliasForm is present, Service form must be ignored. + aTypeResult = mozilla::AsVariant(Nothing()); + ToLowerCase(aCname); + LOG(("DNSPacket::DohDecode HTTPSSVC AliasForm host %s => %s\n", + host.get(), aCname.get())); + break; + } + + if (!aTypeResult.is<TypeRecordHTTPSSVC>()) { + aTypeResult = mozilla::AsVariant(CopyableTArray<SVCB>()); + } + { + auto& results = aTypeResult.as<TypeRecordHTTPSSVC>(); + results.AppendElement(parsed); + } + + break; + } + default: + // skip unknown record types + LOG(("TRR unsupported TYPE (%u) RDLENGTH %u\n", TYPE, RDLENGTH)); + break; + } + } else { + LOG(("TRR asked for %s data but got %s\n", aHost.get(), qname.get())); + } + + index += RDLENGTH; + LOG(("done with record type %u len %u index now %u of %u\n", TYPE, RDLENGTH, + index, aLen)); + answerRecords--; + } + + // NSCOUNT + uint16_t nsRecords = get16bit(aBuffer, 8); + LOG(("TRR Decode: %d ns records (%u bytes body)\n", nsRecords, aLen)); + while (nsRecords) { + rv = PassQName(index, aBuffer); + if (NS_FAILED(rv)) { + return rv; + } + + if (aLen < (index + 8)) { + return NS_ERROR_ILLEGAL_VALUE; + } + index += 2; // type + index += 2; // class + index += 4; // ttl + + // 16 bit RDLENGTH + if (aLen < (index + 2)) { + return NS_ERROR_ILLEGAL_VALUE; + } + uint16_t RDLENGTH = get16bit(aBuffer, index); + index += 2; + if (aLen < (index + RDLENGTH)) { + return NS_ERROR_ILLEGAL_VALUE; + } + index += RDLENGTH; + LOG(("done with nsRecord now %u of %u\n", index, aLen)); + nsRecords--; + } + + // additional resource records + uint16_t arRecords = get16bit(aBuffer, 10); + LOG(("TRR Decode: %d additional resource records (%u bytes body)\n", + arRecords, aLen)); + + while (arRecords) { + nsAutoCString qname; + rv = GetQname(qname, index, aBuffer); + if (NS_FAILED(rv)) { + LOG(("Bad qname for additional record")); + return rv; + } + + if (aLen < (index + 8)) { + return NS_ERROR_ILLEGAL_VALUE; + } + uint16_t type = get16bit(aBuffer, index); + index += 2; + // The next two bytes encode class + // (or udpPayloadSize when type is TRRTYPE_OPT) + uint16_t cls = get16bit(aBuffer, index); + index += 2; + // The next 4 bytes encode TTL + // (or extRCode + ednsVersion + flags when type is TRRTYPE_OPT) + uint32_t ttl = get32bit(aBuffer, index); + index += 4; + // cls and ttl are unused when type is TRRTYPE_OPT + + // 16 bit RDLENGTH + if (aLen < (index + 2)) { + LOG(("Record too small")); + return NS_ERROR_ILLEGAL_VALUE; + } + + uint16_t rdlength = get16bit(aBuffer, index); + index += 2; + if (aLen < (index + rdlength)) { + LOG(("rdlength too big")); + return NS_ERROR_ILLEGAL_VALUE; + } + + auto parseRecord = [&]() { + LOG(("Parsing additional record type: %u", type)); + auto* entry = aAdditionalRecords.GetOrInsertNew(qname); + + switch (type) { + case TRRTYPE_A: + if (kDNS_CLASS_IN != cls) { + LOG(("NOT IN - returning")); + return; + } + if (rdlength != 4) { + LOG(("TRR bad length for A (%u)\n", rdlength)); + return; + } + rv = entry->Add(ttl, aBuffer, index, rdlength, aAllowRFC1918); + if (NS_FAILED(rv)) { + LOG( + ("TRR:DohDecode failed: local IP addresses or unknown IP " + "family\n")); + return; + } + break; + case TRRTYPE_AAAA: + if (kDNS_CLASS_IN != cls) { + LOG(("NOT IN - returning")); + return; + } + if (rdlength != 16) { + LOG(("TRR bad length for AAAA (%u)\n", rdlength)); + return; + } + rv = entry->Add(ttl, aBuffer, index, rdlength, aAllowRFC1918); + if (NS_FAILED(rv)) { + LOG(("TRR got unique/local IPv6 address!\n")); + return; + } + break; + case TRRTYPE_OPT: { // OPT + LOG(("Parsing opt rdlen: %u", rdlength)); + unsigned int offset = 0; + while (offset + 2 <= rdlength) { + uint16_t optCode = get16bit(aBuffer, index + offset); + LOG(("optCode: %u", optCode)); + offset += 2; + if (offset + 2 > rdlength) { + break; + } + uint16_t optLen = get16bit(aBuffer, index + offset); + LOG(("optLen: %u", optLen)); + offset += 2; + if (offset + optLen > rdlength) { + LOG(("offset: %u, optLen: %u, rdlen: %u", offset, optLen, + rdlength)); + break; + } + + LOG(("OPT: code: %u len:%u", optCode, optLen)); + + if (optCode != 15) { + offset += optLen; + continue; + } + + // optCode == 15; Extended DNS error + + if (offset + 2 > rdlength || optLen < 2) { + break; + } + extendedError = get16bit(aBuffer, index + offset); + + LOG(("Extended error code: %u message: %s", extendedError, + nsAutoCString((char*)aBuffer + index + offset + 2, optLen - 2) + .get())); + offset += optLen; + } + break; + } + default: + break; + } + }; + + parseRecord(); + + index += rdlength; + LOG(("done with additional rr now %u of %u\n", index, aLen)); + arRecords--; + } + + if (index != aLen) { + LOG(("DohDecode failed to parse entire response body, %u out of %u bytes\n", + index, aLen)); + // failed to parse 100%, do not continue + return NS_ERROR_ILLEGAL_VALUE; + } + + if (aType == TRRTYPE_NS && rcode != 0) { + return NS_ERROR_UNKNOWN_HOST; + } + + if ((aType != TRRTYPE_NS) && aCname.IsEmpty() && aResp.mAddresses.IsEmpty() && + aTypeResult.is<TypeRecordEmpty>()) { + // no entries were stored! + LOG(("TRR: No entries were stored!\n")); + + if (extendedError != UINT16_MAX && hardFail(extendedError)) { + return NS_ERROR_DEFINITIVE_UNKNOWN_HOST; + } + return NS_ERROR_UNKNOWN_HOST; + } + + // https://tools.ietf.org/html/draft-ietf-dnsop-svcb-httpssvc-03#page-14 + // If one or more SVCB records of ServiceForm SvcRecordType are returned for + // HOST, clients should select the highest-priority option with acceptable + // parameters. + if (aTypeResult.is<TypeRecordHTTPSSVC>()) { + auto& results = aTypeResult.as<TypeRecordHTTPSSVC>(); + results.Sort(); + } + + return NS_OK; +} + +// +// DohDecode() collects the TTL and the IP addresses in the response +// +nsresult DNSPacket::Decode( + nsCString& aHost, enum TrrType aType, nsCString& aCname, bool aAllowRFC1918, + DOHresp& aResp, TypeRecordResultType& aTypeResult, + nsClassHashtable<nsCStringHashKey, DOHresp>& aAdditionalRecords, + uint32_t& aTTL) { + nsresult rv = + DecodeInternal(aHost, aType, aCname, aAllowRFC1918, aResp, aTypeResult, + aAdditionalRecords, aTTL, mResponse, mBodySize); + SetDNSPacketStatus(NS_SUCCEEDED(rv) ? DNSPacketStatus::Success + : DNSPacketStatus::DecodeError); + return rv; +} + +static SECItem* CreateRawConfig(const ObliviousDoHConfig& aConfig) { + SECItem* item(::SECITEM_AllocItem(nullptr, nullptr, + 8 + aConfig.mContents.mPublicKey.Length())); + if (!item) { + return nullptr; + } + + uint16_t index = 0; + NetworkEndian::writeUint16(&item->data[index], aConfig.mContents.mKemId); + index += 2; + NetworkEndian::writeUint16(&item->data[index], aConfig.mContents.mKdfId); + index += 2; + NetworkEndian::writeUint16(&item->data[index], aConfig.mContents.mAeadId); + index += 2; + uint16_t keyLength = aConfig.mContents.mPublicKey.Length(); + NetworkEndian::writeUint16(&item->data[index], keyLength); + index += 2; + memcpy(&item->data[index], aConfig.mContents.mPublicKey.Elements(), + aConfig.mContents.mPublicKey.Length()); + return item; +} + +static bool CreateConfigId(ObliviousDoHConfig& aConfig) { + SECStatus rv; + CK_HKDF_PARAMS params = {0}; + SECItem paramsi = {siBuffer, (unsigned char*)¶ms, sizeof(params)}; + + UniquePK11SlotInfo slot(PK11_GetInternalSlot()); + if (!slot) { + return false; + } + + UniqueSECItem rawConfig(CreateRawConfig(aConfig)); + if (!rawConfig) { + return false; + } + + UniquePK11SymKey configKey(PK11_ImportDataKey(slot.get(), CKM_HKDF_DATA, + PK11_OriginUnwrap, CKA_DERIVE, + rawConfig.get(), nullptr)); + if (!configKey) { + return false; + } + + params.bExtract = CK_TRUE; + params.bExpand = CK_TRUE; + params.prfHashMechanism = CKM_SHA256; + params.ulSaltType = CKF_HKDF_SALT_NULL; + params.pInfo = (unsigned char*)&hODoHConfigID[0]; + params.ulInfoLen = strlen(hODoHConfigID); + UniquePK11SymKey derived(PK11_DeriveWithFlags( + configKey.get(), CKM_HKDF_DATA, ¶msi, CKM_HKDF_DERIVE, CKA_DERIVE, + SHA256_LENGTH, CKF_SIGN | CKF_VERIFY)); + + rv = PK11_ExtractKeyValue(derived.get()); + if (rv != SECSuccess) { + return false; + } + + SECItem* derivedItem = PK11_GetKeyData(derived.get()); + if (!derivedItem) { + return false; + } + + if (derivedItem->len != SHA256_LENGTH) { + return false; + } + + aConfig.mConfigId.AppendElements(derivedItem->data, derivedItem->len); + return true; +} + +// static +bool ODoHDNSPacket::ParseODoHConfigs(Span<const uint8_t> aData, + nsTArray<ObliviousDoHConfig>& aOut) { + // struct { + // uint16 kem_id; + // uint16 kdf_id; + // uint16 aead_id; + // opaque public_key<1..2^16-1>; + // } ObliviousDoHConfigContents; + // + // struct { + // uint16 version; + // uint16 length; + // select (ObliviousDoHConfig.version) { + // case 0xff03: ObliviousDoHConfigContents contents; + // } + // } ObliviousDoHConfig; + // + // ObliviousDoHConfig ObliviousDoHConfigs<1..2^16-1>; + + Span<const uint8_t>::const_iterator it = aData.begin(); + uint16_t length = 0; + if (!get16bit(aData, it, length)) { + return false; + } + + if (length != aData.Length() - 2) { + return false; + } + + nsTArray<ObliviousDoHConfig> result; + static const uint32_t kMinimumConfigContentLength = 12; + while (std::distance(it, aData.cend()) > kMinimumConfigContentLength) { + ObliviousDoHConfig config; + if (!get16bit(aData, it, config.mVersion)) { + return false; + } + + if (!get16bit(aData, it, config.mLength)) { + return false; + } + + if (std::distance(it, aData.cend()) < config.mLength) { + return false; + } + + if (!get16bit(aData, it, config.mContents.mKemId)) { + return false; + } + + if (!get16bit(aData, it, config.mContents.mKdfId)) { + return false; + } + + if (!get16bit(aData, it, config.mContents.mAeadId)) { + return false; + } + + uint16_t keyLength = 0; + if (!get16bit(aData, it, keyLength)) { + return false; + } + + if (!keyLength || std::distance(it, aData.cend()) < keyLength) { + return false; + } + + config.mContents.mPublicKey.AppendElements(Span(it, it + keyLength)); + it += keyLength; + + CreateConfigId(config); + + // Check if the version of the config is supported and validate its content. + if (config.mVersion == ODOH_VERSION && + PK11_HPKE_ValidateParameters( + static_cast<HpkeKemId>(config.mContents.mKemId), + static_cast<HpkeKdfId>(config.mContents.mKdfId), + static_cast<HpkeAeadId>(config.mContents.mAeadId)) == SECSuccess) { + result.AppendElement(std::move(config)); + } else { + LOG(("ODoHDNSPacket::ParseODoHConfigs got an invalid config")); + } + } + + aOut = std::move(result); + return true; +} + +ODoHDNSPacket::~ODoHDNSPacket() { PK11_HPKE_DestroyContext(mContext, true); } + +nsresult ODoHDNSPacket::EncodeRequest(nsCString& aBody, const nsACString& aHost, + uint16_t aType, bool aDisableECS) { + nsAutoCString queryBody; + nsresult rv = DNSPacket::EncodeRequest(queryBody, aHost, aType, aDisableECS); + if (NS_FAILED(rv)) { + SetDNSPacketStatus(DNSPacketStatus::EncodeError); + return rv; + } + + if (!gODoHService->ODoHConfigs()) { + SetDNSPacketStatus(DNSPacketStatus::KeyNotAvailable); + return NS_ERROR_FAILURE; + } + + if (gODoHService->ODoHConfigs()->IsEmpty()) { + SetDNSPacketStatus(DNSPacketStatus::KeyNotUsable); + return NS_ERROR_FAILURE; + } + + // We only use the first ODoHConfig. + const ObliviousDoHConfig& config = (*gODoHService->ODoHConfigs())[0]; + + ObliviousDoHMessage message; + // The spec didn't recommand padding length for encryption, let's use 0 here. + if (!EncryptDNSQuery(queryBody, 0, config, message)) { + SetDNSPacketStatus(DNSPacketStatus::EncryptError); + return NS_ERROR_FAILURE; + } + + aBody.Truncate(); + aBody += message.mType; + uint16_t keyIdLength = message.mKeyId.Length(); + aBody += static_cast<uint8_t>(keyIdLength >> 8); + aBody += static_cast<uint8_t>(keyIdLength); + aBody.Append(reinterpret_cast<const char*>(message.mKeyId.Elements()), + keyIdLength); + uint16_t messageLen = message.mEncryptedMessage.Length(); + aBody += static_cast<uint8_t>(messageLen >> 8); + aBody += static_cast<uint8_t>(messageLen); + aBody.Append( + reinterpret_cast<const char*>(message.mEncryptedMessage.Elements()), + messageLen); + + SetDNSPacketStatus(DNSPacketStatus::Success); + return NS_OK; +} + +/* + * def encrypt_query_body(pkR, key_id, Q_plain): + * enc, context = SetupBaseS(pkR, "odoh query") + * aad = 0x01 || len(key_id) || key_id + * ct = context.Seal(aad, Q_plain) + * Q_encrypted = enc || ct + * return Q_encrypted + */ +bool ODoHDNSPacket::EncryptDNSQuery(const nsACString& aQuery, + uint16_t aPaddingLen, + const ObliviousDoHConfig& aConfig, + ObliviousDoHMessage& aOut) { + mContext = PK11_HPKE_NewContext( + static_cast<HpkeKemId>(aConfig.mContents.mKemId), + static_cast<HpkeKdfId>(aConfig.mContents.mKdfId), + static_cast<HpkeAeadId>(aConfig.mContents.mAeadId), nullptr, nullptr); + if (!mContext) { + LOG(("ODoHDNSPacket::EncryptDNSQuery create context failed")); + return false; + } + + SECKEYPublicKey* pkR; + SECStatus rv = + PK11_HPKE_Deserialize(mContext, aConfig.mContents.mPublicKey.Elements(), + aConfig.mContents.mPublicKey.Length(), &pkR); + if (rv != SECSuccess) { + return false; + } + + UniqueSECItem hpkeInfo( + ::SECITEM_AllocItem(nullptr, nullptr, strlen(kODoHQuery))); + if (!hpkeInfo) { + return false; + } + + memcpy(hpkeInfo->data, kODoHQuery, strlen(kODoHQuery)); + + rv = PK11_HPKE_SetupS(mContext, nullptr, nullptr, pkR, hpkeInfo.get()); + if (rv != SECSuccess) { + LOG(("ODoHDNSPacket::EncryptDNSQuery setupS failed")); + return false; + } + + const SECItem* hpkeEnc = PK11_HPKE_GetEncapPubKey(mContext); + if (!hpkeEnc) { + return false; + } + + // aad = 0x01 || len(key_id) || key_id + UniqueSECItem aad(::SECITEM_AllocItem(nullptr, nullptr, + 1 + 2 + aConfig.mConfigId.Length())); + if (!aad) { + return false; + } + + aad->data[0] = ODOH_QUERY; + NetworkEndian::writeUint16(&aad->data[1], aConfig.mConfigId.Length()); + memcpy(&aad->data[3], aConfig.mConfigId.Elements(), + aConfig.mConfigId.Length()); + + // struct { + // opaque dns_message<1..2^16-1>; + // opaque padding<0..2^16-1>; + // } ObliviousDoHMessagePlaintext; + SECItem* odohPlainText(::SECITEM_AllocItem( + nullptr, nullptr, 2 + aQuery.Length() + 2 + aPaddingLen)); + if (!odohPlainText) { + return false; + } + + mPlainQuery.reset(odohPlainText); + memset(mPlainQuery->data, 0, mPlainQuery->len); + NetworkEndian::writeUint16(&mPlainQuery->data[0], aQuery.Length()); + memcpy(&mPlainQuery->data[2], aQuery.BeginReading(), aQuery.Length()); + NetworkEndian::writeUint16(&mPlainQuery->data[2 + aQuery.Length()], + aPaddingLen); + + SECItem* chCt = nullptr; + rv = PK11_HPKE_Seal(mContext, aad.get(), mPlainQuery.get(), &chCt); + if (rv != SECSuccess) { + LOG(("ODoHDNSPacket::EncryptDNSQuery seal failed")); + return false; + } + + UniqueSECItem ct(chCt); + + aOut.mType = ODOH_QUERY; + aOut.mKeyId.AppendElements(aConfig.mConfigId); + aOut.mEncryptedMessage.AppendElements(Span(hpkeEnc->data, hpkeEnc->len)); + aOut.mEncryptedMessage.AppendElements(Span(ct->data, ct->len)); + + return true; +} + +nsresult ODoHDNSPacket::Decode( + nsCString& aHost, enum TrrType aType, nsCString& aCname, bool aAllowRFC1918, + DOHresp& aResp, TypeRecordResultType& aTypeResult, + nsClassHashtable<nsCStringHashKey, DOHresp>& aAdditionalRecords, + uint32_t& aTTL) { + // This function could be called multiple times when we are checking CNAME + // records, but we only need to decrypt the response once. + if (!mDecryptedResponseRange) { + if (!DecryptDNSResponse()) { + SetDNSPacketStatus(DNSPacketStatus::DecryptError); + return NS_ERROR_FAILURE; + } + + uint32_t index = 0; + uint16_t responseLength = get16bit(mResponse, index); + index += 2; + + if (mBodySize < (index + responseLength)) { + SetDNSPacketStatus(DNSPacketStatus::DecryptError); + return NS_ERROR_ILLEGAL_VALUE; + } + + DecryptedResponseRange range; + range.mStart = index; + range.mLength = responseLength; + + index += responseLength; + uint16_t paddingLen = get16bit(mResponse, index); + + if (static_cast<unsigned int>(4 + responseLength + paddingLen) != + mBodySize) { + SetDNSPacketStatus(DNSPacketStatus::DecryptError); + return NS_ERROR_ILLEGAL_VALUE; + } + + mDecryptedResponseRange.emplace(range); + } + + nsresult rv = DecodeInternal(aHost, aType, aCname, aAllowRFC1918, aResp, + aTypeResult, aAdditionalRecords, aTTL, + &mResponse[mDecryptedResponseRange->mStart], + mDecryptedResponseRange->mLength); + SetDNSPacketStatus(NS_SUCCEEDED(rv) ? DNSPacketStatus::Success + : DNSPacketStatus::DecodeError); + return rv; +} + +static bool CreateObliviousDoHMessage(const unsigned char* aData, + unsigned int aLength, + ObliviousDoHMessage& aOut) { + if (aLength < 5) { + return false; + } + + unsigned int index = 0; + aOut.mType = static_cast<ObliviousDoHMessageType>(aData[index++]); + + uint16_t keyIdLength = get16bit(aData, index); + index += 2; + if (aLength < (index + keyIdLength)) { + return false; + } + + aOut.mKeyId.AppendElements(Span(aData + index, keyIdLength)); + index += keyIdLength; + + uint16_t messageLen = get16bit(aData, index); + index += 2; + if (aLength < (index + messageLen)) { + return false; + } + + aOut.mEncryptedMessage.AppendElements(Span(aData + index, messageLen)); + return true; +} + +static SECStatus HKDFExtract(SECItem* aSalt, PK11SymKey* aIkm, + UniquePK11SymKey& aOutKey) { + CK_HKDF_PARAMS params = {0}; + SECItem paramsItem = {siBuffer, (unsigned char*)¶ms, sizeof(params)}; + + params.bExtract = CK_TRUE; + params.bExpand = CK_FALSE; + params.prfHashMechanism = CKM_SHA256; + params.ulSaltType = aSalt ? CKF_HKDF_SALT_DATA : CKF_HKDF_SALT_NULL; + params.pSalt = aSalt ? (CK_BYTE_PTR)aSalt->data : nullptr; + params.ulSaltLen = aSalt ? aSalt->len : 0; + + UniquePK11SymKey prk(PK11_Derive(aIkm, CKM_HKDF_DERIVE, ¶msItem, + CKM_HKDF_DERIVE, CKA_DERIVE, 0)); + if (!prk) { + return SECFailure; + } + + aOutKey.swap(prk); + return SECSuccess; +} + +static SECStatus HKDFExpand(PK11SymKey* aPrk, const SECItem* aInfo, int aLen, + bool aKey, UniquePK11SymKey& aOutKey) { + CK_HKDF_PARAMS params = {0}; + SECItem paramsItem = {siBuffer, (unsigned char*)¶ms, sizeof(params)}; + + params.bExtract = CK_FALSE; + params.bExpand = CK_TRUE; + params.prfHashMechanism = CKM_SHA256; + params.ulSaltType = CKF_HKDF_SALT_NULL; + params.pInfo = (CK_BYTE_PTR)aInfo->data; + params.ulInfoLen = aInfo->len; + CK_MECHANISM_TYPE deriveMech = CKM_HKDF_DERIVE; + CK_MECHANISM_TYPE keyMech = aKey ? CKM_AES_GCM : CKM_HKDF_DERIVE; + + UniquePK11SymKey derivedKey( + PK11_Derive(aPrk, deriveMech, ¶msItem, keyMech, CKA_DERIVE, aLen)); + if (!derivedKey) { + return SECFailure; + } + + aOutKey.swap(derivedKey); + return SECSuccess; +} + +/* + * def decrypt_response_body(context, Q_plain, R_encrypted, response_nonce): + * aead_key, aead_nonce = derive_secrets(context, Q_plain, response_nonce) + * aad = 0x02 || len(response_nonce) || response_nonce + * R_plain, error = Open(key, nonce, aad, R_encrypted) + * return R_plain, error + */ +bool ODoHDNSPacket::DecryptDNSResponse() { + ObliviousDoHMessage message; + if (!CreateObliviousDoHMessage(mResponse, mBodySize, message)) { + LOG(("ODoHDNSPacket::DecryptDNSResponse invalid response")); + return false; + } + + if (message.mType != ODOH_RESPONSE) { + return false; + } + + const unsigned int kResponseNonceLen = 16; + // KeyId is actually response_nonce + if (message.mKeyId.Length() != kResponseNonceLen) { + return false; + } + + // def derive_secrets(context, Q_plain, response_nonce): + // secret = context.Export("odoh response", Nk) + // salt = Q_plain || len(response_nonce) || response_nonce + // prk = Extract(salt, secret) + // key = Expand(odoh_prk, "odoh key", Nk) + // nonce = Expand(odoh_prk, "odoh nonce", Nn) + // return key, nonce + const SECItem kODoHResponsetInfoItem = { + siBuffer, (unsigned char*)kODoHResponse, + static_cast<unsigned int>(strlen(kODoHResponse))}; + const unsigned int kAes128GcmKeyLen = 16; + const unsigned int kAes128GcmNonceLen = 12; + PK11SymKey* tmp = nullptr; + SECStatus rv = PK11_HPKE_ExportSecret(mContext, &kODoHResponsetInfoItem, + kAes128GcmKeyLen, &tmp); + if (rv != SECSuccess) { + LOG(("ODoHDNSPacket::DecryptDNSResponse export secret failed")); + return false; + } + UniquePK11SymKey odohSecret(tmp); + + SECItem* salt(::SECITEM_AllocItem(nullptr, nullptr, + mPlainQuery->len + 2 + kResponseNonceLen)); + memcpy(salt->data, mPlainQuery->data, mPlainQuery->len); + NetworkEndian::writeUint16(&salt->data[mPlainQuery->len], kResponseNonceLen); + memcpy(salt->data + mPlainQuery->len + 2, message.mKeyId.Elements(), + kResponseNonceLen); + UniqueSECItem st(salt); + UniquePK11SymKey odohPrk; + rv = HKDFExtract(salt, odohSecret.get(), odohPrk); + if (rv != SECSuccess) { + LOG(("ODoHDNSPacket::DecryptDNSResponse extract failed")); + return false; + } + + SECItem keyInfoItem = {siBuffer, (unsigned char*)&kODoHKey[0], + static_cast<unsigned int>(strlen(kODoHKey))}; + UniquePK11SymKey key; + rv = HKDFExpand(odohPrk.get(), &keyInfoItem, kAes128GcmKeyLen, true, key); + if (rv != SECSuccess) { + LOG(("ODoHDNSPacket::DecryptDNSResponse expand key failed")); + return false; + } + + SECItem nonceInfoItem = {siBuffer, (unsigned char*)&kODoHNonce[0], + static_cast<unsigned int>(strlen(kODoHNonce))}; + UniquePK11SymKey nonce; + rv = HKDFExpand(odohPrk.get(), &nonceInfoItem, kAes128GcmNonceLen, false, + nonce); + if (rv != SECSuccess) { + LOG(("ODoHDNSPacket::DecryptDNSResponse expand nonce failed")); + return false; + } + + rv = PK11_ExtractKeyValue(nonce.get()); + if (rv != SECSuccess) { + return false; + } + + SECItem* derivedItem = PK11_GetKeyData(nonce.get()); + if (!derivedItem) { + return false; + } + + // aad = 0x02 || len(response_nonce) || response_nonce + SECItem* aadItem( + ::SECITEM_AllocItem(nullptr, nullptr, 1 + 2 + kResponseNonceLen)); + aadItem->data[0] = ODOH_RESPONSE; + NetworkEndian::writeUint16(&aadItem->data[1], kResponseNonceLen); + memcpy(&aadItem->data[3], message.mKeyId.Elements(), kResponseNonceLen); + UniqueSECItem aad(aadItem); + + SECItem paramItem; + CK_GCM_PARAMS param; + param.pIv = derivedItem->data; + param.ulIvLen = derivedItem->len; + param.ulIvBits = param.ulIvLen * 8; + param.ulTagBits = 16 * 8; + param.pAAD = (CK_BYTE_PTR)aad->data; + param.ulAADLen = aad->len; + + paramItem.type = siBuffer; + paramItem.data = (unsigned char*)(¶m); + paramItem.len = sizeof(CK_GCM_PARAMS); + + memset(mResponse, 0, mBodySize); + rv = PK11_Decrypt(key.get(), CKM_AES_GCM, ¶mItem, mResponse, &mBodySize, + MAX_SIZE, message.mEncryptedMessage.Elements(), + message.mEncryptedMessage.Length()); + if (rv != SECSuccess) { + LOG(("ODoHDNSPacket::DecryptDNSResponse decrypt failed %d", + PORT_GetError())); + return false; + } + + return true; +} + +} // namespace net +} // namespace mozilla |