From 26a029d407be480d791972afb5975cf62c9360a6 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 19 Apr 2024 02:47:55 +0200 Subject: Adding upstream version 124.0.1. Signed-off-by: Daniel Baumann --- security/nss/gtests/ssl_gtest/tls_filter.cc | 1293 +++++++++++++++++++++++++++ 1 file changed, 1293 insertions(+) create mode 100644 security/nss/gtests/ssl_gtest/tls_filter.cc (limited to 'security/nss/gtests/ssl_gtest/tls_filter.cc') diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc new file mode 100644 index 0000000000..ab52a07e84 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_filter.cc @@ -0,0 +1,1293 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "tls_filter.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include +#include +#include "gtest_utils.h" +#include "tls_agent.h" +#include "tls_filter.h" +#include "tls_parser.h" +#include "tls_protect.h" + +namespace nss_test { + +void TlsVersioned::WriteStream(std::ostream& stream) const { + stream << (is_dtls() ? "DTLS " : "TLS "); + switch (version()) { + case 0: + stream << "(no version)"; + break; + case SSL_LIBRARY_VERSION_TLS_1_0: + stream << "1.0"; + break; + case SSL_LIBRARY_VERSION_TLS_1_1: + stream << (is_dtls() ? "1.0" : "1.1"); + break; + case SSL_LIBRARY_VERSION_TLS_1_2: + stream << "1.2"; + break; + case SSL_LIBRARY_VERSION_TLS_1_3: + stream << "1.3"; + break; + default: + stream << "Invalid version: " << version(); + break; + } +} + +TlsRecordFilter::TlsRecordFilter(const std::shared_ptr& a) + : agent_(a) { + cipher_specs_.emplace_back(a->variant() == ssl_variant_datagram, 0); +} + +void TlsRecordFilter::EnableDecryption() { + EXPECT_EQ(SECSuccess, + SSL_SecretCallback(agent()->ssl_fd(), SecretCallback, this)); + decrypting_ = true; +} + +void TlsRecordFilter::SecretCallback(PRFileDesc* fd, PRUint16 epoch, + SSLSecretDirection dir, PK11SymKey* secret, + void* arg) { + TlsRecordFilter* self = static_cast(arg); + if (g_ssl_gtest_verbose) { + std::cerr << self->agent()->role_str() << ": " << dir + << " secret changed for epoch " << epoch << std::endl; + } + + if (dir == ssl_secret_read) { + return; + } + + for (auto& spec : self->cipher_specs_) { + ASSERT_NE(spec.epoch(), epoch) << "duplicate spec for epoch " << epoch; + } + + SSLPreliminaryChannelInfo preinfo; + EXPECT_EQ(SECSuccess, + SSL_GetPreliminaryChannelInfo(self->agent()->ssl_fd(), &preinfo, + sizeof(preinfo))); + EXPECT_EQ(sizeof(preinfo), preinfo.length); + + // Check the version. + if (preinfo.valuesSet & ssl_preinfo_version) { + EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion); + } else { + EXPECT_EQ(1U, epoch); + } + + uint16_t suite; + if (epoch == 1) { + // 0-RTT + EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_0rtt_cipher_suite); + suite = preinfo.zeroRttCipherSuite; + } else { + EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite); + suite = preinfo.cipherSuite; + } + + SSLCipherSuiteInfo cipherinfo; + EXPECT_EQ(SECSuccess, + SSL_GetCipherSuiteInfo(suite, &cipherinfo, sizeof(cipherinfo))); + EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length); + + self->cipher_specs_.emplace_back(self->is_dtls_agent(), epoch); + EXPECT_TRUE(self->cipher_specs_.back().SetKeys(&cipherinfo, secret)); +} + +bool TlsRecordFilter::is_dtls_agent() const { + return agent()->variant() == ssl_variant_datagram; +} + +bool TlsRecordFilter::is_dtls13() const { + if (!is_dtls_agent()) { + return false; + } + if (agent()->state() == TlsAgent::STATE_CONNECTED) { + return agent()->version() >= SSL_LIBRARY_VERSION_TLS_1_3; + } + SSLPreliminaryChannelInfo info; + EXPECT_EQ(SECSuccess, SSL_GetPreliminaryChannelInfo(agent()->ssl_fd(), &info, + sizeof(info))); + return (info.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) || + info.canSendEarlyData; +} + +bool TlsRecordFilter::is_dtls13_ciphertext(uint8_t ct) const { + return is_dtls13() && (ct & kCtDtlsCiphertextMask) == kCtDtlsCiphertext; +} + +// Gets the cipher spec that matches the specified epoch. +TlsCipherSpec& TlsRecordFilter::spec(uint16_t write_epoch) { + for (auto& sp : cipher_specs_) { + if (sp.epoch() == write_epoch) { + return sp; + } + } + + // If we aren't decrypting, provide a cipher spec that does nothing other than + // count sequence numbers. + EXPECT_FALSE(decrypting_) << "No spec available for epoch " << write_epoch; + ; + cipher_specs_.emplace_back(is_dtls_agent(), write_epoch); + return cipher_specs_.back(); +} + +PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, + DataBuffer* output) { + // Disable during shutdown. + if (!agent()) { + return KEEP; + } + + bool changed = false; + size_t offset = 0U; + + output->Allocate(input.len()); + TlsParser parser(input); + + // This uses the current write spec for the purposes of parsing the epoch and + // sequence number from the header. This might be wrong because we can + // receive records from older specs, but guessing is good enough: + // - In DTLS, parsing the sequence number corrects any errors. + // - In TLS, we don't use the sequence number unless decrypting, where we use + // trial decryption to get the right epoch. + uint16_t write_epoch = 0; + SECStatus rv = SSL_GetCurrentEpoch(agent()->ssl_fd(), nullptr, &write_epoch); + if (rv != SECSuccess) { + ADD_FAILURE() << "unable to read epoch"; + return KEEP; + } + uint64_t guess_seqno = static_cast(write_epoch) << 48; + + while (parser.remaining()) { + TlsRecordHeader header; + DataBuffer record; + if (!header.Parse(is_dtls13(), guess_seqno, &parser, &record)) { + ADD_FAILURE() << "not a valid record"; + return KEEP; + } + + if (FilterRecord(header, record, &offset, output) != KEEP) { + changed = true; + } else { + offset = header.Write(output, offset, record); + } + } + output->Truncate(offset); + + // Record how many packets we actually touched. + if (changed) { + ++count_; + return (offset == 0) ? DROP : CHANGE; + } + + return KEEP; +} + +PacketFilter::Action TlsRecordFilter::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& record, size_t* offset, + DataBuffer* output) { + DataBuffer filtered; + uint8_t inner_content_type; + DataBuffer plaintext; + uint16_t protection_epoch = 0; + TlsRecordHeader out_header(header); + + if (!Unprotect(header, record, &protection_epoch, &inner_content_type, + &plaintext, &out_header)) { + std::cerr << agent()->role_str() << ": unprotect failed: " << header << ":" + << record << std::endl; + return KEEP; + } + + auto& protection_spec = spec(protection_epoch); + TlsRecordHeader real_header(out_header.variant(), out_header.version(), + inner_content_type, out_header.sequence_number()); + + PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered); + // In stream mode, even if something doesn't change we need to re-encrypt if + // previous packets were dropped. + if (action == KEEP) { + if (out_header.is_dtls() || !protection_spec.record_dropped()) { + // Count every outgoing packet. + protection_spec.RecordProtected(); + return KEEP; + } + filtered = plaintext; + } + + if (action == DROP) { + std::cerr << "record drop: " << out_header << ":" << record << std::endl; + protection_spec.RecordDropped(); + return DROP; + } + + EXPECT_GT(0x10000U, filtered.len()); + if (action != KEEP) { + std::cerr << "record old: " << plaintext << std::endl; + std::cerr << "record new: " << filtered << std::endl; + } + + uint64_t seq_num = protection_spec.next_out_seqno(); + if (!decrypting_ && out_header.is_dtls()) { + // Copy over the epoch, which isn't tracked when not decrypting. + seq_num |= out_header.sequence_number() & (0xffffULL << 48); + } + out_header.sequence_number(seq_num); + + DataBuffer ciphertext; + bool rv = Protect(protection_spec, out_header, inner_content_type, filtered, + &ciphertext, &out_header); + if (!rv) { + return KEEP; + } + *offset = out_header.Write(output, *offset, ciphertext); + return CHANGE; +} + +size_t TlsRecordHeader::header_length() const { + // If we have a header, return it's length. + if (header_.len()) { + return header_.len(); + } + + // Otherwise make a dummy header and return the length. + DataBuffer buf; + return WriteHeader(&buf, 0, 0); +} + +bool TlsRecordHeader::MaskSequenceNumber() { + return MaskSequenceNumber(sn_mask()); +} + +bool TlsRecordHeader::MaskSequenceNumber(const DataBuffer& mask_buf) { + if (mask_buf.empty()) { + return false; + } + + DataBuffer mask; + if (is_dtls13_ciphertext()) { + uint64_t seqno = sequence_number(); + uint8_t len = content_type() & kCtDtlsCiphertext16bSeqno ? 2 : 1; + uint16_t seqno_bitmask = (1 << len * 8) - 1; + DataBuffer val; + if (val.Write(0, seqno & seqno_bitmask, len) != len) { + return false; + } + +#ifdef UNSAFE_FUZZER_MODE + // Use a null mask. + mask.Allocate(mask_buf.len()); +#endif + mask.Append(mask_buf); + val.data()[0] ^= mask.data()[0]; + if (len == 2 && mask.len() > 1) { + val.data()[1] ^= mask.data()[1]; + } + uint32_t tmp; + if (!val.Read(0, len, &tmp)) { + return false; + } + + seqno = (seqno & ~seqno_bitmask) | tmp; + seqno_is_masked_ = !seqno_is_masked_; + if (!seqno_is_masked_) { + seqno = ParseSequenceNumber(guess_seqno_, seqno, len * 8, 2); + } + sequence_number_ = seqno; + + // Now update the header bytes + if (header_.len() > 1) { + header_.data()[1] ^= mask.data()[0]; + if ((content_type() & kCtDtlsCiphertext16bSeqno) && header().len() > 2) { + header_.data()[2] ^= mask.data()[1]; + } + } + } + + sn_mask_ = mask; + return true; +} + +uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t guess_seqno, + uint32_t partial, + size_t partial_bits) { + EXPECT_GE(32U, partial_bits); + uint64_t mask = (1ULL << partial_bits) - 1; + // First we determine the highest possible value. This is half the + // expressible range above the expected value (|guess_seqno|), less 1. + // + // We subtract the extra 1 from the cap so that when given a choice between + // the equidistant expected+N and expected-N we want to chose the lower. With + // 0-RTT, we sometimes have to recover an epoch of 1 when we expect an epoch + // of 3 and with 2 partial bits, the alternative result of 5 is wrong. + uint64_t cap = guess_seqno + (1ULL << (partial_bits - 1)) - 1; + // Add the partial piece in. e.g., xxxx789a and 1234 becomes xxxx1234. + uint64_t seq_no = (cap & ~mask) | partial; + // If the partial value is higher than the same partial piece from the cap, + // then the real value has to be lower. e.g., xxxx1234 can't become xxxx5678. + if (partial > (cap & mask) && (seq_no >= (1ULL << partial_bits))) { + seq_no -= 1ULL << partial_bits; + } + return seq_no; +} + +// Determine the full epoch and sequence number from an expected and raw value. +// The expected, raw, and output values are packed as they are in DTLS 1.2 and +// earlier: with 16 bits of epoch and 48 bits of sequence number. The raw value +// is packed this way (even before recovery) so that we don't need to track a +// moving value between two calls (one to recover the epoch, and one after +// unmasking to recover the sequence number). +uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint64_t raw, + size_t seq_no_bits, + size_t epoch_bits) { + uint64_t epoch_mask = (1ULL << epoch_bits) - 1; + uint64_t ep = RecoverSequenceNumber(expected >> 48, (raw >> 48) & epoch_mask, + epoch_bits); + if (ep > (expected >> 48)) { + // If the epoch has changed, reset the expected sequence number. + expected = 0; + } else { + // Otherwise, retain just the sequence number part. + expected &= (1ULL << 48) - 1; + } + uint64_t seq_no_mask = (1ULL << seq_no_bits) - 1; + uint64_t seq_no = (raw & seq_no_mask); + if (!seqno_is_masked_) { + seq_no = RecoverSequenceNumber(expected, seq_no, seq_no_bits); + } + + return (ep << 48) | seq_no; +} + +bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser, + DataBuffer* body) { + auto mark = parser->consumed(); + + if (!parser->Read(&content_type_)) { + return false; + } + + if (is_dtls13) { + variant_ = ssl_variant_datagram; + version_ = SSL_LIBRARY_VERSION_TLS_1_3; + +#ifndef UNSAFE_FUZZER_MODE + // Deal with the DTLSCipherText header. + if (is_dtls13_ciphertext()) { + uint8_t seq_no_bytes = + (content_type_ & kCtDtlsCiphertext16bSeqno) ? 2 : 1; + uint32_t tmp; + + if (!parser->Read(&tmp, seq_no_bytes)) { + return false; + } + + // Store the guess if masked. If and when seqno_bytesenceNumber is called, + // the value will be unmasked and recovered. This assumes we only call + // Parse() on headers containing masked values. + seqno_is_masked_ = true; + guess_seqno_ = seqno; + uint64_t ep = content_type_ & 0x03; + sequence_number_ = (ep << 48) | tmp; + + // Recover the full epoch. Note the sequence number portion holds the + // masked value until a call to Mask() reveals it (as indicated by + // |seqno_is_masked_|). + sequence_number_ = + ParseSequenceNumber(seqno, sequence_number_, seq_no_bytes * 8, 2); + + uint32_t len_bytes = + (content_type_ & kCtDtlsCiphertextLengthPresent) ? 2 : 0; + if (len_bytes) { + if (!parser->Read(&tmp, 2)) { + return false; + } + } + + if (!parser->ReadFromMark(&header_, parser->consumed() - mark, mark)) { + return false; + } + + return len_bytes ? parser->Read(body, tmp) + : parser->Read(body, parser->remaining()); + } + + // The full DTLSPlainText header can only be used for a few types. + EXPECT_TRUE(content_type_ == ssl_ct_alert || + content_type_ == ssl_ct_handshake || + content_type_ == ssl_ct_ack); +#endif + } + + uint32_t ver; + if (!parser->Read(&ver, 2)) { + return false; + } + if (!is_dtls13) { + variant_ = IsDtls(ver) ? ssl_variant_datagram : ssl_variant_stream; + } + version_ = NormalizeTlsVersion(ver); + + if (is_dtls()) { + // If this is DTLS, read the sequence number. + uint32_t tmp; + if (!parser->Read(&tmp, 4)) { + return false; + } + sequence_number_ = static_cast(tmp) << 32; + if (!parser->Read(&tmp, 4)) { + return false; + } + sequence_number_ |= static_cast(tmp); + } else { + sequence_number_ = seqno; + } + if (!parser->ReadFromMark(&header_, parser->consumed() + 2 - mark, mark)) { + return false; + } + return parser->ReadVariable(body, 2); +} + +size_t TlsRecordHeader::WriteHeader(DataBuffer* buffer, size_t offset, + size_t body_len) const { + if (is_dtls13_ciphertext()) { + uint8_t seq_no_bytes = (content_type_ & kCtDtlsCiphertext16bSeqno) ? 2 : 1; + // application_data records in TLS 1.3 have a different header format. + uint32_t e = (sequence_number_ >> 48) & 0x3; + uint32_t seqno = sequence_number_ & ((1ULL << seq_no_bytes * 8) - 1); + uint8_t new_content_type_ = content_type_ | e; + offset = buffer->Write(offset, new_content_type_, 1); + offset = buffer->Write(offset, seqno, seq_no_bytes); + + if (content_type_ & kCtDtlsCiphertextLengthPresent) { + offset = buffer->Write(offset, body_len, 2); + } + } else { + offset = buffer->Write(offset, content_type_, 1); + uint16_t v = is_dtls() ? TlsVersionToDtlsVersion(version_) : version_; + offset = buffer->Write(offset, v, 2); + if (is_dtls()) { + // write epoch (2 octet), and seqnum (6 octet) + offset = buffer->Write(offset, sequence_number_ >> 32, 4); + offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4); + } + offset = buffer->Write(offset, body_len, 2); + } + + return offset; +} + +size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset, + const DataBuffer& body) const { + offset = WriteHeader(buffer, offset, body.len()); + offset = buffer->Write(offset, body); + return offset; +} + +bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, + const DataBuffer& ciphertext, + uint16_t* protection_epoch, + uint8_t* inner_content_type, + DataBuffer* plaintext, + TlsRecordHeader* out_header) { + if (!decrypting_ || !header.is_protected()) { + // Maintain the epoch and sequence number for plaintext records. + uint16_t ep = 0; + if (is_dtls_agent()) { + ep = static_cast(header.sequence_number() >> 48); + } + spec(ep).RecordUnprotected(header.sequence_number()); + *protection_epoch = ep; + *inner_content_type = header.content_type(); + *plaintext = ciphertext; + return true; + } + + uint16_t ep = 0; + if (is_dtls_agent()) { + ep = static_cast(header.sequence_number() >> 48); + if (!spec(ep).Unprotect(header, ciphertext, plaintext, out_header)) { + return false; + } + } else { + // In TLS, records aren't clearly labelled with their epoch, and we + // can't just use the newest keys because the same flight of messages can + // contain multiple epochs. So... trial decrypt! + for (size_t i = cipher_specs_.size() - 1; i > 0; --i) { + if (cipher_specs_[i].Unprotect(header, ciphertext, plaintext, + out_header)) { + ep = cipher_specs_[i].epoch(); + break; + } + } + if (!ep) { + return false; + } + } + + size_t len = plaintext->len(); + while (len > 0 && !plaintext->data()[len - 1]) { + --len; + } + if (!len) { + // Bogus padding. + return false; + } + + *protection_epoch = ep; + *inner_content_type = plaintext->data()[len - 1]; + plaintext->Truncate(len - 1); + if (g_ssl_gtest_verbose) { + std::cerr << agent()->role_str() << ": unprotect: epoch=" << ep + << " seq=" << std::hex << header.sequence_number() << std::dec + << " " << *plaintext << std::endl; + } + + return true; +} + +bool TlsRecordFilter::Protect(TlsCipherSpec& protection_spec, + const TlsRecordHeader& header, + uint8_t inner_content_type, + const DataBuffer& plaintext, + DataBuffer* ciphertext, + TlsRecordHeader* out_header, size_t padding) { + if (!protection_spec.is_protected()) { + // Not protected, just keep the sequence numbers updated. + protection_spec.RecordProtected(); + *ciphertext = plaintext; + return true; + } + + DataBuffer padded; + padded.Allocate(plaintext.len() + 1 + padding); + size_t offset = padded.Write(0, plaintext.data(), plaintext.len()); + padded.Write(offset, inner_content_type, 1); + + bool ok = protection_spec.Protect(header, padded, ciphertext, out_header); + if (!ok) { + ADD_FAILURE() << "protect fail"; + } else if (g_ssl_gtest_verbose) { + std::cerr << agent()->role_str() + << ": protect: epoch=" << protection_spec.epoch() + << " seq=" << std::hex << header.sequence_number() << std::dec + << " " << *ciphertext << std::endl; + } + return ok; +} + +bool IsHelloRetry(const DataBuffer& body) { + static const uint8_t ssl_hello_retry_random[] = { + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, + 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, + 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C}; + return memcmp(body.data() + 2, ssl_hello_retry_random, + sizeof(ssl_hello_retry_random)) == 0; +} + +bool TlsHandshakeFilter::IsFilteredType(const HandshakeHeader& header, + const DataBuffer& body) { + if (handshake_types_.empty()) { + return true; + } + + uint8_t type = header.handshake_type(); + if (type == kTlsHandshakeServerHello) { + if (IsHelloRetry(body)) { + type = kTlsHandshakeHelloRetryRequest; + } + } + return handshake_types_.count(type) > 0U; +} + +PacketFilter::Action TlsHandshakeFilter::FilterRecord( + const TlsRecordHeader& record_header, const DataBuffer& input, + DataBuffer* output) { + // Check that the first byte is as requested. + if (record_header.content_type() != ssl_ct_handshake) { + return KEEP; + } + + bool changed = false; + size_t offset = 0U; + output->Allocate(input.len()); // Preallocate a little. + + TlsParser parser(input); + while (parser.remaining()) { + HandshakeHeader header; + DataBuffer handshake; + bool complete = false; + if (!header.Parse(&parser, record_header, preceding_fragment_, &handshake, + &complete)) { + return KEEP; + } + + if (!complete) { + EXPECT_TRUE(record_header.is_dtls()); + // Save the fragment and drop it from this record. Fragments are + // coalesced with the last fragment of the handshake message. + changed = true; + preceding_fragment_.Assign(handshake); + continue; + } + preceding_fragment_.Truncate(0); + + DataBuffer filtered; + PacketFilter::Action action; + if (!IsFilteredType(header, handshake)) { + action = KEEP; + } else { + action = FilterHandshake(header, handshake, &filtered); + } + if (action == DROP) { + changed = true; + std::cerr << "handshake drop: " << handshake << std::endl; + continue; + } + + const DataBuffer* source = &handshake; + if (action == CHANGE) { + EXPECT_GT(0x1000000U, filtered.len()); + changed = true; + std::cerr << "handshake old: " << handshake << std::endl; + std::cerr << "handshake new: " << filtered << std::endl; + source = &filtered; + } else if (preceding_fragment_.len()) { + changed = true; + } + + offset = header.Write(output, offset, *source); + } + output->Truncate(offset); + return changed ? (offset ? CHANGE : DROP) : KEEP; +} + +bool TlsHandshakeFilter::HandshakeHeader::ReadLength( + TlsParser* parser, const TlsRecordHeader& header, uint32_t expected_offset, + uint32_t* length, bool* last_fragment) { + uint32_t message_length; + if (!parser->Read(&message_length, 3)) { + return false; // malformed + } + + if (!header.is_dtls()) { + *last_fragment = true; + *length = message_length; + return true; // nothing left to do + } + + // Read and check DTLS parameters + uint32_t message_seq_tmp; + if (!parser->Read(&message_seq_tmp, 2)) { // sequence number + return false; + } + message_seq_ = message_seq_tmp; + + uint32_t offset = 0; + if (!parser->Read(&offset, 3)) { + return false; + } + // We only parse if the fragments are all complete and in order. + if (offset != expected_offset) { + EXPECT_NE(0U, header.epoch()) + << "Received out of order handshake fragment for epoch 0"; + return false; + } + + // For DTLS, we return the length of just this fragment. + if (!parser->Read(length, 3)) { + return false; + } + + // It's a fragment if the entire message is longer than what we have. + *last_fragment = message_length == (*length + offset); + return true; +} + +bool TlsHandshakeFilter::HandshakeHeader::Parse( + TlsParser* parser, const TlsRecordHeader& record_header, + const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) { + *complete = false; + + variant_ = record_header.variant(); + version_ = record_header.version(); + if (!parser->Read(&handshake_type_)) { + return false; // malformed + } + + uint32_t length; + if (!ReadLength(parser, record_header, preceding_fragment.len(), &length, + complete)) { + return false; + } + + if (!parser->Read(body, length)) { + return false; + } + if (preceding_fragment.len()) { + body->Splice(preceding_fragment, 0); + } + return true; +} + +size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment( + DataBuffer* buffer, size_t offset, const DataBuffer& body, + size_t fragment_offset, size_t fragment_length) const { + EXPECT_TRUE(is_dtls()); + EXPECT_GE(body.len(), fragment_offset + fragment_length); + offset = buffer->Write(offset, handshake_type(), 1); + offset = buffer->Write(offset, body.len(), 3); + offset = buffer->Write(offset, message_seq_, 2); + offset = buffer->Write(offset, fragment_offset, 3); + offset = buffer->Write(offset, fragment_length, 3); + offset = + buffer->Write(offset, body.data() + fragment_offset, fragment_length); + return offset; +} + +size_t TlsHandshakeFilter::HandshakeHeader::Write( + DataBuffer* buffer, size_t offset, const DataBuffer& body) const { + if (is_dtls()) { + return WriteFragment(buffer, offset, body, 0U, body.len()); + } + offset = buffer->Write(offset, handshake_type(), 1); + offset = buffer->Write(offset, body.len(), 3); + offset = buffer->Write(offset, body); + return offset; +} + +PacketFilter::Action TlsHandshakeRecorder::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + // Only do this once. + if (buffer_.len()) { + return KEEP; + } + + buffer_ = input; + return KEEP; +} + +PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + *output = buffer_; + return CHANGE; +} + +PacketFilter::Action TlsRecordRecorder::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& input, + DataBuffer* output) { + if (!filter_ || (header.content_type() == ct_)) { + records_.push_back({header, input}); + } + return KEEP; +} + +PacketFilter::Action TlsConversationRecorder::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& input, + DataBuffer* output) { + buffer_.Append(input); + return KEEP; +} + +PacketFilter::Action TlsHeaderRecorder::FilterRecord(const TlsRecordHeader& hdr, + const DataBuffer& input, + DataBuffer* output) { + headers_.push_back(hdr); + return KEEP; +} + +const TlsRecordHeader* TlsHeaderRecorder::header(size_t index) { + if (index > headers_.size() + 1) { + return nullptr; + } + return &headers_[index]; +} + +PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input, + DataBuffer* output) { + DataBuffer in(input); + bool changed = false; + for (auto it = filters_.begin(); it != filters_.end(); ++it) { + PacketFilter::Action action = (*it)->Process(in, output); + if (action == DROP) { + return DROP; + } + + if (action == CHANGE) { + in = *output; + changed = true; + } + } + return changed ? CHANGE : KEEP; +} + +bool FindClientHelloExtensions(TlsParser* parser, const TlsVersioned& header) { + if (!parser->Skip(2 + 32)) { // version + random + return false; + } + if (!parser->SkipVariable(1)) { // session ID + return false; + } + if (header.is_dtls() && !parser->SkipVariable(1)) { // DTLS cookie + return false; + } + if (!parser->SkipVariable(2)) { // cipher suites + return false; + } + if (!parser->SkipVariable(1)) { // compression methods + return false; + } + return true; +} + +bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) { + uint32_t vtmp; + if (!parser->Read(&vtmp, 2)) { + return false; + } + uint16_t version = static_cast(vtmp); + if (!parser->Skip(32)) { // random + return false; + } + if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) { + if (!parser->SkipVariable(1)) { // session ID + return false; + } + } + if (!parser->Skip(2)) { // cipher suite + return false; + } + if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) { + if (!parser->Skip(1)) { // compression method + return false; + } + } + return true; +} + +bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) { + return true; +} + +static bool FindCertReqExtensions(TlsParser* parser, + const TlsVersioned& header) { + if (!parser->SkipVariable(1)) { // request context + return false; + } + return true; +} + +// Only look at the EE cert for this one. +static bool FindCertificateExtensions(TlsParser* parser, + const TlsVersioned& header) { + if (!parser->SkipVariable(1)) { // request context + return false; + } + if (!parser->Skip(3)) { // length of certificate list + return false; + } + if (!parser->SkipVariable(3)) { // ASN1Cert + return false; + } + return true; +} + +static bool FindNewSessionTicketExtensions(TlsParser* parser, + const TlsVersioned& header) { + if (!parser->Skip(8)) { // lifetime, age add + return false; + } + if (!parser->SkipVariable(1)) { // ticket_nonce + return false; + } + if (!parser->SkipVariable(2)) { // ticket + return false; + } + return true; +} + +static const std::map kExtensionFinders = { + {kTlsHandshakeClientHello, FindClientHelloExtensions}, + {kTlsHandshakeServerHello, FindServerHelloExtensions}, + {kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions}, + {kTlsHandshakeCertificateRequest, FindCertReqExtensions}, + {kTlsHandshakeCertificate, FindCertificateExtensions}, + {kTlsHandshakeNewSessionTicket, FindNewSessionTicketExtensions}}; + +bool TlsExtensionFilter::FindExtensions(TlsParser* parser, + const HandshakeHeader& header) { + auto it = kExtensionFinders.find(header.handshake_type()); + if (it == kExtensionFinders.end()) { + return false; + } + return (it->second)(parser, header); +} + +PacketFilter::Action TlsExtensionFilter::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + TlsParser parser(input); + if (!FindExtensions(&parser, header)) { + return KEEP; + } + return FilterExtensions(&parser, input, output); +} + +PacketFilter::Action TlsExtensionFilter::FilterExtensions( + TlsParser* parser, const DataBuffer& input, DataBuffer* output) { + size_t length_offset = parser->consumed(); + uint32_t all_extensions; + if (!parser->Read(&all_extensions, 2)) { + return KEEP; // no extensions, odd but OK + } + if (all_extensions != parser->remaining()) { + return KEEP; // malformed + } + + bool changed = false; + + // Write out the start of the message. + output->Allocate(input.len()); + size_t offset = output->Write(0, input.data(), parser->consumed()); + + while (parser->remaining()) { + uint32_t extension_type; + if (!parser->Read(&extension_type, 2)) { + return KEEP; // malformed + } + + DataBuffer extension; + if (!parser->ReadVariable(&extension, 2)) { + return KEEP; // malformed + } + + DataBuffer filtered; + PacketFilter::Action action = + FilterExtension(extension_type, extension, &filtered); + if (action == DROP) { + changed = true; + std::cerr << "extension drop: " << extension << std::endl; + continue; + } + + const DataBuffer* source = &extension; + if (action == CHANGE) { + EXPECT_GT(0x10000U, filtered.len()); + changed = true; + std::cerr << "extension old: " << extension << std::endl; + std::cerr << "extension new: " << filtered << std::endl; + source = &filtered; + } + + // Write out extension. + offset = output->Write(offset, extension_type, 2); + offset = output->Write(offset, source->len(), 2); + if (source->len() > 0) { + offset = output->Write(offset, *source); + } + } + output->Truncate(offset); + + if (changed) { + size_t newlen = output->len() - length_offset - 2; + EXPECT_GT(0x10000U, newlen); + if (newlen >= 0x10000) { + return KEEP; // bad: size increased too much + } + output->Write(length_offset, newlen, 2); + return CHANGE; + } + return KEEP; +} + +PacketFilter::Action TlsExtensionOrderCapture::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + order.push_back(extension_type); + return KEEP; +} + +PacketFilter::Action TlsExtensionCapture::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + if (extension_type == extension_ && (last_ || !captured_)) { + data_.Assign(input); + captured_ = true; + } + return KEEP; +} + +PacketFilter::Action TlsExtensionReplacer::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + if (extension_type != extension_) { + return KEEP; + } + + *output = data_; + return CHANGE; +} + +PacketFilter::Action TlsExtensionResizer::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + if (extension_type != extension_) { + return KEEP; + } + + if (input.len() <= length_) { + DataBuffer buf(length_ - input.len()); + output->Append(buf); + return CHANGE; + } + + output->Assign(input.data(), length_); + return CHANGE; +} + +PacketFilter::Action TlsExtensionAppender::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + TlsParser parser(input); + if (!TlsExtensionFilter::FindExtensions(&parser, header)) { + return KEEP; + } + *output = input; + + // Increase the length of the extensions block. + if (!UpdateLength(output, parser.consumed(), 2)) { + return KEEP; + } + + // Extensions in Certificate are nested twice. Increase the size of the + // certificate list. + if (header.handshake_type() == kTlsHandshakeCertificate) { + TlsParser p2(input); + if (!p2.SkipVariable(1)) { + ADD_FAILURE(); + return KEEP; + } + if (!UpdateLength(output, p2.consumed(), 3)) { + return KEEP; + } + } + + size_t offset = output->len(); + offset = output->Write(offset, extension_, 2); + WriteVariable(output, offset, data_, 2); + + return CHANGE; +} + +bool TlsExtensionAppender::UpdateLength(DataBuffer* output, size_t offset, + size_t size) { + uint32_t len; + if (!output->Read(offset, size, &len)) { + ADD_FAILURE(); + return false; + } + + len += 4 + data_.len(); + output->Write(offset, len, size); + return true; +} + +PacketFilter::Action TlsExtensionDropper::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + if (extension_type == extension_) { + return DROP; + } + return KEEP; +} + +PacketFilter::Action TlsExtensionDamager::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + if (extension_type != extension_) { + return KEEP; + } + + *output = input; + output->data()[index_] += 73; // Increment selected for maximum damage + return CHANGE; +} + +PacketFilter::Action TlsExtensionInjector::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + TlsParser parser(input); + if (!TlsExtensionFilter::FindExtensions(&parser, header)) { + return KEEP; + } + size_t offset = parser.consumed(); + + *output = input; + + // Increase the size of the extensions. + uint16_t ext_len; + memcpy(&ext_len, output->data() + offset, sizeof(ext_len)); + ext_len = htons(ntohs(ext_len) + data_.len() + 4); + memcpy(output->data() + offset, &ext_len, sizeof(ext_len)); + + // Insert the extension type and length. + DataBuffer type_length; + type_length.Allocate(4); + type_length.Write(0, extension_, 2); + type_length.Write(2, data_.len(), 2); + output->Splice(type_length, offset + 2); + + // Insert the payload. + if (data_.len() > 0) { + output->Splice(data_, offset + 6); + } + + return CHANGE; +} + +PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header, + const DataBuffer& body, + DataBuffer* out) { + if (counter_++ == record_) { + DataBuffer buf; + header.Write(&buf, 0, body); + agent()->SendDirect(buf); + dest_.lock()->Handshake(); + func_(); + return DROP; + } + + return KEEP; +} + +PacketFilter::Action TlsClientHelloVersionChanger::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + EXPECT_EQ(SECSuccess, + SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd())); + return KEEP; +} + +PacketFilter::Action SelectiveDropFilter::Filter(const DataBuffer& input, + DataBuffer* output) { + if (counter_ >= 32) { + return KEEP; + } + return ((1 << counter_++) & pattern_) ? DROP : KEEP; +} + +PacketFilter::Action SelectiveRecordDropFilter::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& data, + DataBuffer* changed) { + if (counter_ >= 32) { + return KEEP; + } + return ((1 << counter_++) & pattern_) ? DROP : KEEP; +} + +/* static */ uint32_t SelectiveRecordDropFilter::ToPattern( + std::initializer_list records) { + uint32_t pattern = 0; + for (auto it = records.begin(); it != records.end(); ++it) { + EXPECT_GT(32U, *it); + assert(*it < 32U); + pattern |= 1 << *it; + } + return pattern; +} + +PacketFilter::Action TlsMessageVersionSetter::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + *output = input; + output->Write(0, version_, 2); + return CHANGE; +} + +PacketFilter::Action SelectedCipherSuiteReplacer::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + *output = input; + uint32_t temp = 0; + EXPECT_TRUE(input.Read(0, 2, &temp)); + EXPECT_EQ(header.version(), NormalizeTlsVersion(temp)); + // Cipher suite is after version(2), random(32) + // and [legacy_]session_id(<0..32>). + size_t pos = 34; + EXPECT_TRUE(input.Read(pos, 1, &temp)); + pos += 1 + temp; + + output->Write(pos, static_cast(cipher_suite_), 2); + return CHANGE; +} + +PacketFilter::Action ServerHelloRandomChanger::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + *output = input; + uint32_t temp = 0; + size_t pos = 30; + EXPECT_TRUE(input.Read(pos, 2, &temp)); + output->Write(pos, (temp ^ 0xffff), 2); + return CHANGE; +} + +PacketFilter::Action ClientHelloPreambleCapture::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + EXPECT_TRUE(header.handshake_type() == kTlsHandshakeClientHello); + + if (captured_) { + return KEEP; + } + captured_ = true; + + DataBuffer temp; + TlsParser parser(input); + EXPECT_TRUE(parser.Read(&temp, 2 + 32)); // Version + Random + EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Session ID + if (is_dtls_agent()) { + EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Cookie + } + EXPECT_TRUE(parser.ReadVariable(&temp, 2)); // Ciphersuites + EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Compression + + // Copy the preamble into a new buffer + data_ = input; + data_.Truncate(parser.consumed()); + + return KEEP; +} + +PacketFilter::Action ClientHelloCiphersuiteCapture::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + EXPECT_TRUE(header.handshake_type() == kTlsHandshakeClientHello); + + if (captured_) { + return KEEP; + } + captured_ = true; + + TlsParser parser(input); + EXPECT_TRUE(parser.Skip(2 + 32)); // Version + Random + EXPECT_TRUE(parser.SkipVariable(1)); // Session ID + if (is_dtls_agent()) { + EXPECT_TRUE(parser.SkipVariable(1)); // Cookie + } + + EXPECT_TRUE(parser.ReadVariable(&data_, 2)); // Ciphersuites + + return KEEP; +} + +} // namespace nss_test -- cgit v1.2.3