/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "tls_filter.h" #include "sslproto.h" extern "C" { // This is not something that should make you happy. #include "libssl_internals.h" } #include #include #include "gtest_utils.h" #include "tls_agent.h" #include "tls_filter.h" #include "tls_parser.h" #include "tls_protect.h" namespace nss_test { void TlsVersioned::WriteStream(std::ostream& stream) const { stream << (is_dtls() ? "DTLS " : "TLS "); switch (version()) { case 0: stream << "(no version)"; break; case SSL_LIBRARY_VERSION_TLS_1_0: stream << "1.0"; break; case SSL_LIBRARY_VERSION_TLS_1_1: stream << (is_dtls() ? "1.0" : "1.1"); break; case SSL_LIBRARY_VERSION_TLS_1_2: stream << "1.2"; break; case SSL_LIBRARY_VERSION_TLS_1_3: stream << "1.3"; break; default: stream << "Invalid version: " << version(); break; } } TlsRecordFilter::TlsRecordFilter(const std::shared_ptr& a) : agent_(a) { cipher_specs_.emplace_back(a->variant() == ssl_variant_datagram, 0); } void TlsRecordFilter::EnableDecryption() { EXPECT_EQ(SECSuccess, SSL_SecretCallback(agent()->ssl_fd(), SecretCallback, this)); decrypting_ = true; } void TlsRecordFilter::SecretCallback(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir, PK11SymKey* secret, void* arg) { TlsRecordFilter* self = static_cast(arg); if (g_ssl_gtest_verbose) { std::cerr << self->agent()->role_str() << ": " << dir << " secret changed for epoch " << epoch << std::endl; } if (dir == ssl_secret_read) { return; } for (auto& spec : self->cipher_specs_) { ASSERT_NE(spec.epoch(), epoch) << "duplicate spec for epoch " << epoch; } SSLPreliminaryChannelInfo preinfo; EXPECT_EQ(SECSuccess, SSL_GetPreliminaryChannelInfo(self->agent()->ssl_fd(), &preinfo, sizeof(preinfo))); EXPECT_EQ(sizeof(preinfo), preinfo.length); // Check the version. if (preinfo.valuesSet & ssl_preinfo_version) { EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion); } else { EXPECT_EQ(1U, epoch); } uint16_t suite; if (epoch == 1) { // 0-RTT EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_0rtt_cipher_suite); suite = preinfo.zeroRttCipherSuite; } else { EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite); suite = preinfo.cipherSuite; } SSLCipherSuiteInfo cipherinfo; EXPECT_EQ(SECSuccess, SSL_GetCipherSuiteInfo(suite, &cipherinfo, sizeof(cipherinfo))); EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length); self->cipher_specs_.emplace_back(self->is_dtls_agent(), epoch); EXPECT_TRUE(self->cipher_specs_.back().SetKeys(&cipherinfo, secret)); } bool TlsRecordFilter::is_dtls_agent() const { return agent()->variant() == ssl_variant_datagram; } bool TlsRecordFilter::is_dtls13() const { if (!is_dtls_agent()) { return false; } if (agent()->state() == TlsAgent::STATE_CONNECTED) { return agent()->version() >= SSL_LIBRARY_VERSION_TLS_1_3; } SSLPreliminaryChannelInfo info; EXPECT_EQ(SECSuccess, SSL_GetPreliminaryChannelInfo(agent()->ssl_fd(), &info, sizeof(info))); return (info.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) || info.canSendEarlyData; } bool TlsRecordFilter::is_dtls13_ciphertext(uint8_t ct) const { return is_dtls13() && (ct & kCtDtlsCiphertextMask) == kCtDtlsCiphertext; } // Gets the cipher spec that matches the specified epoch. TlsCipherSpec& TlsRecordFilter::spec(uint16_t write_epoch) { for (auto& sp : cipher_specs_) { if (sp.epoch() == write_epoch) { return sp; } } // If we aren't decrypting, provide a cipher spec that does nothing other than // count sequence numbers. EXPECT_FALSE(decrypting_) << "No spec available for epoch " << write_epoch; ; cipher_specs_.emplace_back(is_dtls_agent(), write_epoch); return cipher_specs_.back(); } PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, DataBuffer* output) { // Disable during shutdown. if (!agent()) { return KEEP; } bool changed = false; size_t offset = 0U; output->Allocate(input.len()); TlsParser parser(input); // This uses the current write spec for the purposes of parsing the epoch and // sequence number from the header. This might be wrong because we can // receive records from older specs, but guessing is good enough: // - In DTLS, parsing the sequence number corrects any errors. // - In TLS, we don't use the sequence number unless decrypting, where we use // trial decryption to get the right epoch. uint16_t write_epoch = 0; SECStatus rv = SSL_GetCurrentEpoch(agent()->ssl_fd(), nullptr, &write_epoch); if (rv != SECSuccess) { ADD_FAILURE() << "unable to read epoch"; return KEEP; } uint64_t guess_seqno = static_cast(write_epoch) << 48; while (parser.remaining()) { TlsRecordHeader header; DataBuffer record; if (!header.Parse(is_dtls13(), guess_seqno, &parser, &record)) { ADD_FAILURE() << "not a valid record"; return KEEP; } if (FilterRecord(header, record, &offset, output) != KEEP) { changed = true; } else { offset = header.Write(output, offset, record); } } output->Truncate(offset); // Record how many packets we actually touched. if (changed) { ++count_; return (offset == 0) ? DROP : CHANGE; } return KEEP; } PacketFilter::Action TlsRecordFilter::FilterRecord( const TlsRecordHeader& header, const DataBuffer& record, size_t* offset, DataBuffer* output) { DataBuffer filtered; uint8_t inner_content_type; DataBuffer plaintext; uint16_t protection_epoch = 0; TlsRecordHeader out_header(header); if (!Unprotect(header, record, &protection_epoch, &inner_content_type, &plaintext, &out_header)) { std::cerr << agent()->role_str() << ": unprotect failed: " << header << ":" << record << std::endl; return KEEP; } auto& protection_spec = spec(protection_epoch); TlsRecordHeader real_header(out_header.variant(), out_header.version(), inner_content_type, out_header.sequence_number()); PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered); // In stream mode, even if something doesn't change we need to re-encrypt if // previous packets were dropped. if (action == KEEP) { if (out_header.is_dtls() || !protection_spec.record_dropped()) { // Count every outgoing packet. protection_spec.RecordProtected(); return KEEP; } filtered = plaintext; } if (action == DROP) { std::cerr << "record drop: " << out_header << ":" << record << std::endl; protection_spec.RecordDropped(); return DROP; } EXPECT_GT(0x10000U, filtered.len()); if (action != KEEP) { std::cerr << "record old: " << plaintext << std::endl; std::cerr << "record new: " << filtered << std::endl; } uint64_t seq_num = protection_spec.next_out_seqno(); if (!decrypting_ && out_header.is_dtls()) { // Copy over the epoch, which isn't tracked when not decrypting. seq_num |= out_header.sequence_number() & (0xffffULL << 48); } out_header.sequence_number(seq_num); DataBuffer ciphertext; bool rv = Protect(protection_spec, out_header, inner_content_type, filtered, &ciphertext, &out_header); if (!rv) { return KEEP; } *offset = out_header.Write(output, *offset, ciphertext); return CHANGE; } size_t TlsRecordHeader::header_length() const { // If we have a header, return it's length. if (header_.len()) { return header_.len(); } // Otherwise make a dummy header and return the length. DataBuffer buf; return WriteHeader(&buf, 0, 0); } bool TlsRecordHeader::MaskSequenceNumber() { return MaskSequenceNumber(sn_mask()); } bool TlsRecordHeader::MaskSequenceNumber(const DataBuffer& mask_buf) { if (mask_buf.empty()) { return false; } DataBuffer mask; if (is_dtls13_ciphertext()) { uint64_t seqno = sequence_number(); uint8_t len = content_type() & kCtDtlsCiphertext16bSeqno ? 2 : 1; uint16_t seqno_bitmask = (1 << len * 8) - 1; DataBuffer val; if (val.Write(0, seqno & seqno_bitmask, len) != len) { return false; } #ifdef UNSAFE_FUZZER_MODE // Use a null mask. mask.Allocate(mask_buf.len()); #endif mask.Append(mask_buf); val.data()[0] ^= mask.data()[0]; if (len == 2 && mask.len() > 1) { val.data()[1] ^= mask.data()[1]; } uint32_t tmp; if (!val.Read(0, len, &tmp)) { return false; } seqno = (seqno & ~seqno_bitmask) | tmp; seqno_is_masked_ = !seqno_is_masked_; if (!seqno_is_masked_) { seqno = ParseSequenceNumber(guess_seqno_, seqno, len * 8, 2); } sequence_number_ = seqno; // Now update the header bytes if (header_.len() > 1) { header_.data()[1] ^= mask.data()[0]; if ((content_type() & kCtDtlsCiphertext16bSeqno) && header().len() > 2) { header_.data()[2] ^= mask.data()[1]; } } } sn_mask_ = mask; return true; } uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t guess_seqno, uint32_t partial, size_t partial_bits) { EXPECT_GE(32U, partial_bits); uint64_t mask = (1ULL << partial_bits) - 1; // First we determine the highest possible value. This is half the // expressible range above the expected value (|guess_seqno|), less 1. // // We subtract the extra 1 from the cap so that when given a choice between // the equidistant expected+N and expected-N we want to chose the lower. With // 0-RTT, we sometimes have to recover an epoch of 1 when we expect an epoch // of 3 and with 2 partial bits, the alternative result of 5 is wrong. uint64_t cap = guess_seqno + (1ULL << (partial_bits - 1)) - 1; // Add the partial piece in. e.g., xxxx789a and 1234 becomes xxxx1234. uint64_t seq_no = (cap & ~mask) | partial; // If the partial value is higher than the same partial piece from the cap, // then the real value has to be lower. e.g., xxxx1234 can't become xxxx5678. if (partial > (cap & mask) && (seq_no >= (1ULL << partial_bits))) { seq_no -= 1ULL << partial_bits; } return seq_no; } // Determine the full epoch and sequence number from an expected and raw value. // The expected, raw, and output values are packed as they are in DTLS 1.2 and // earlier: with 16 bits of epoch and 48 bits of sequence number. The raw value // is packed this way (even before recovery) so that we don't need to track a // moving value between two calls (one to recover the epoch, and one after // unmasking to recover the sequence number). uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint64_t raw, size_t seq_no_bits, size_t epoch_bits) { uint64_t epoch_mask = (1ULL << epoch_bits) - 1; uint64_t ep = RecoverSequenceNumber(expected >> 48, (raw >> 48) & epoch_mask, epoch_bits); if (ep > (expected >> 48)) { // If the epoch has changed, reset the expected sequence number. expected = 0; } else { // Otherwise, retain just the sequence number part. expected &= (1ULL << 48) - 1; } uint64_t seq_no_mask = (1ULL << seq_no_bits) - 1; uint64_t seq_no = (raw & seq_no_mask); if (!seqno_is_masked_) { seq_no = RecoverSequenceNumber(expected, seq_no, seq_no_bits); } return (ep << 48) | seq_no; } bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser, DataBuffer* body) { auto mark = parser->consumed(); if (!parser->Read(&content_type_)) { return false; } if (is_dtls13) { variant_ = ssl_variant_datagram; version_ = SSL_LIBRARY_VERSION_TLS_1_3; #ifndef UNSAFE_FUZZER_MODE // Deal with the DTLSCipherText header. if (is_dtls13_ciphertext()) { uint8_t seq_no_bytes = (content_type_ & kCtDtlsCiphertext16bSeqno) ? 2 : 1; uint32_t tmp; if (!parser->Read(&tmp, seq_no_bytes)) { return false; } // Store the guess if masked. If and when seqno_bytesenceNumber is called, // the value will be unmasked and recovered. This assumes we only call // Parse() on headers containing masked values. seqno_is_masked_ = true; guess_seqno_ = seqno; uint64_t ep = content_type_ & 0x03; sequence_number_ = (ep << 48) | tmp; // Recover the full epoch. Note the sequence number portion holds the // masked value until a call to Mask() reveals it (as indicated by // |seqno_is_masked_|). sequence_number_ = ParseSequenceNumber(seqno, sequence_number_, seq_no_bytes * 8, 2); uint32_t len_bytes = (content_type_ & kCtDtlsCiphertextLengthPresent) ? 2 : 0; if (len_bytes) { if (!parser->Read(&tmp, 2)) { return false; } } if (!parser->ReadFromMark(&header_, parser->consumed() - mark, mark)) { return false; } return len_bytes ? parser->Read(body, tmp) : parser->Read(body, parser->remaining()); } // The full DTLSPlainText header can only be used for a few types. EXPECT_TRUE(content_type_ == ssl_ct_alert || content_type_ == ssl_ct_handshake || content_type_ == ssl_ct_ack); #endif } uint32_t ver; if (!parser->Read(&ver, 2)) { return false; } if (!is_dtls13) { variant_ = IsDtls(ver) ? ssl_variant_datagram : ssl_variant_stream; } version_ = NormalizeTlsVersion(ver); if (is_dtls()) { // If this is DTLS, read the sequence number. uint32_t tmp; if (!parser->Read(&tmp, 4)) { return false; } sequence_number_ = static_cast(tmp) << 32; if (!parser->Read(&tmp, 4)) { return false; } sequence_number_ |= static_cast(tmp); } else { sequence_number_ = seqno; } if (!parser->ReadFromMark(&header_, parser->consumed() + 2 - mark, mark)) { return false; } return parser->ReadVariable(body, 2); } size_t TlsRecordHeader::WriteHeader(DataBuffer* buffer, size_t offset, size_t body_len) const { if (is_dtls13_ciphertext()) { uint8_t seq_no_bytes = (content_type_ & kCtDtlsCiphertext16bSeqno) ? 2 : 1; // application_data records in TLS 1.3 have a different header format. uint32_t e = (sequence_number_ >> 48) & 0x3; uint32_t seqno = sequence_number_ & ((1ULL << seq_no_bytes * 8) - 1); uint8_t new_content_type_ = content_type_ | e; offset = buffer->Write(offset, new_content_type_, 1); offset = buffer->Write(offset, seqno, seq_no_bytes); if (content_type_ & kCtDtlsCiphertextLengthPresent) { offset = buffer->Write(offset, body_len, 2); } } else { offset = buffer->Write(offset, content_type_, 1); uint16_t v = is_dtls() ? TlsVersionToDtlsVersion(version_) : version_; offset = buffer->Write(offset, v, 2); if (is_dtls()) { // write epoch (2 octet), and seqnum (6 octet) offset = buffer->Write(offset, sequence_number_ >> 32, 4); offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4); } offset = buffer->Write(offset, body_len, 2); } return offset; } size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const { offset = WriteHeader(buffer, offset, body.len()); offset = buffer->Write(offset, body); return offset; } bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, const DataBuffer& ciphertext, uint16_t* protection_epoch, uint8_t* inner_content_type, DataBuffer* plaintext, TlsRecordHeader* out_header) { if (!decrypting_ || !header.is_protected()) { // Maintain the epoch and sequence number for plaintext records. uint16_t ep = 0; if (is_dtls_agent()) { ep = static_cast(header.sequence_number() >> 48); } spec(ep).RecordUnprotected(header.sequence_number()); *protection_epoch = ep; *inner_content_type = header.content_type(); *plaintext = ciphertext; return true; } uint16_t ep = 0; if (is_dtls_agent()) { ep = static_cast(header.sequence_number() >> 48); if (!spec(ep).Unprotect(header, ciphertext, plaintext, out_header)) { return false; } } else { // In TLS, records aren't clearly labelled with their epoch, and we // can't just use the newest keys because the same flight of messages can // contain multiple epochs. So... trial decrypt! for (size_t i = cipher_specs_.size() - 1; i > 0; --i) { if (cipher_specs_[i].Unprotect(header, ciphertext, plaintext, out_header)) { ep = cipher_specs_[i].epoch(); break; } } if (!ep) { return false; } } size_t len = plaintext->len(); while (len > 0 && !plaintext->data()[len - 1]) { --len; } if (!len) { // Bogus padding. return false; } *protection_epoch = ep; *inner_content_type = plaintext->data()[len - 1]; plaintext->Truncate(len - 1); if (g_ssl_gtest_verbose) { std::cerr << agent()->role_str() << ": unprotect: epoch=" << ep << " seq=" << std::hex << header.sequence_number() << std::dec << " " << *plaintext << std::endl; } return true; } bool TlsRecordFilter::Protect(TlsCipherSpec& protection_spec, const TlsRecordHeader& header, uint8_t inner_content_type, const DataBuffer& plaintext, DataBuffer* ciphertext, TlsRecordHeader* out_header, size_t padding) { if (!protection_spec.is_protected()) { // Not protected, just keep the sequence numbers updated. protection_spec.RecordProtected(); *ciphertext = plaintext; return true; } DataBuffer padded; padded.Allocate(plaintext.len() + 1 + padding); size_t offset = padded.Write(0, plaintext.data(), plaintext.len()); padded.Write(offset, inner_content_type, 1); bool ok = protection_spec.Protect(header, padded, ciphertext, out_header); if (!ok) { ADD_FAILURE() << "protect fail"; } else if (g_ssl_gtest_verbose) { std::cerr << agent()->role_str() << ": protect: epoch=" << protection_spec.epoch() << " seq=" << std::hex << header.sequence_number() << std::dec << " " << *ciphertext << std::endl; } return ok; } bool IsHelloRetry(const DataBuffer& body) { static const uint8_t ssl_hello_retry_random[] = { 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C}; return memcmp(body.data() + 2, ssl_hello_retry_random, sizeof(ssl_hello_retry_random)) == 0; } bool TlsHandshakeFilter::IsFilteredType(const HandshakeHeader& header, const DataBuffer& body) { if (handshake_types_.empty()) { return true; } uint8_t type = header.handshake_type(); if (type == kTlsHandshakeServerHello) { if (IsHelloRetry(body)) { type = kTlsHandshakeHelloRetryRequest; } } return handshake_types_.count(type) > 0U; } PacketFilter::Action TlsHandshakeFilter::FilterRecord( const TlsRecordHeader& record_header, const DataBuffer& input, DataBuffer* output) { // Check that the first byte is as requested. if (record_header.content_type() != ssl_ct_handshake) { return KEEP; } bool changed = false; size_t offset = 0U; output->Allocate(input.len()); // Preallocate a little. TlsParser parser(input); while (parser.remaining()) { HandshakeHeader header; DataBuffer handshake; bool complete = false; if (!header.Parse(&parser, record_header, preceding_fragment_, &handshake, &complete)) { return KEEP; } if (!complete) { EXPECT_TRUE(record_header.is_dtls()); // Save the fragment and drop it from this record. Fragments are // coalesced with the last fragment of the handshake message. changed = true; preceding_fragment_.Assign(handshake); continue; } preceding_fragment_.Truncate(0); DataBuffer filtered; PacketFilter::Action action; if (!IsFilteredType(header, handshake)) { action = KEEP; } else { action = FilterHandshake(header, handshake, &filtered); } if (action == DROP) { changed = true; std::cerr << "handshake drop: " << handshake << std::endl; continue; } const DataBuffer* source = &handshake; if (action == CHANGE) { EXPECT_GT(0x1000000U, filtered.len()); changed = true; std::cerr << "handshake old: " << handshake << std::endl; std::cerr << "handshake new: " << filtered << std::endl; source = &filtered; } else if (preceding_fragment_.len()) { changed = true; } offset = header.Write(output, offset, *source); } output->Truncate(offset); return changed ? (offset ? CHANGE : DROP) : KEEP; } bool TlsHandshakeFilter::HandshakeHeader::ReadLength( TlsParser* parser, const TlsRecordHeader& header, uint32_t expected_offset, uint32_t* length, bool* last_fragment) { uint32_t message_length; if (!parser->Read(&message_length, 3)) { return false; // malformed } if (!header.is_dtls()) { *last_fragment = true; *length = message_length; return true; // nothing left to do } // Read and check DTLS parameters uint32_t message_seq_tmp; if (!parser->Read(&message_seq_tmp, 2)) { // sequence number return false; } message_seq_ = message_seq_tmp; uint32_t offset = 0; if (!parser->Read(&offset, 3)) { return false; } // We only parse if the fragments are all complete and in order. if (offset != expected_offset) { EXPECT_NE(0U, header.epoch()) << "Received out of order handshake fragment for epoch 0"; return false; } // For DTLS, we return the length of just this fragment. if (!parser->Read(length, 3)) { return false; } // It's a fragment if the entire message is longer than what we have. *last_fragment = message_length == (*length + offset); return true; } bool TlsHandshakeFilter::HandshakeHeader::Parse( TlsParser* parser, const TlsRecordHeader& record_header, const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) { *complete = false; variant_ = record_header.variant(); version_ = record_header.version(); if (!parser->Read(&handshake_type_)) { return false; // malformed } uint32_t length; if (!ReadLength(parser, record_header, preceding_fragment.len(), &length, complete)) { return false; } if (!parser->Read(body, length)) { return false; } if (preceding_fragment.len()) { body->Splice(preceding_fragment, 0); } return true; } size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment( DataBuffer* buffer, size_t offset, const DataBuffer& body, size_t fragment_offset, size_t fragment_length) const { EXPECT_TRUE(is_dtls()); EXPECT_GE(body.len(), fragment_offset + fragment_length); offset = buffer->Write(offset, handshake_type(), 1); offset = buffer->Write(offset, body.len(), 3); offset = buffer->Write(offset, message_seq_, 2); offset = buffer->Write(offset, fragment_offset, 3); offset = buffer->Write(offset, fragment_length, 3); offset = buffer->Write(offset, body.data() + fragment_offset, fragment_length); return offset; } size_t TlsHandshakeFilter::HandshakeHeader::Write( DataBuffer* buffer, size_t offset, const DataBuffer& body) const { if (is_dtls()) { return WriteFragment(buffer, offset, body, 0U, body.len()); } offset = buffer->Write(offset, handshake_type(), 1); offset = buffer->Write(offset, body.len(), 3); offset = buffer->Write(offset, body); return offset; } PacketFilter::Action TlsHandshakeRecorder::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { // Only do this once. if (buffer_.len()) { return KEEP; } buffer_ = input; return KEEP; } PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { *output = buffer_; return CHANGE; } PacketFilter::Action TlsRecordRecorder::FilterRecord( const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output) { if (!filter_ || (header.content_type() == ct_)) { records_.push_back({header, input}); } return KEEP; } PacketFilter::Action TlsConversationRecorder::FilterRecord( const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output) { buffer_.Append(input); return KEEP; } PacketFilter::Action TlsHeaderRecorder::FilterRecord(const TlsRecordHeader& hdr, const DataBuffer& input, DataBuffer* output) { headers_.push_back(hdr); return KEEP; } const TlsRecordHeader* TlsHeaderRecorder::header(size_t index) { if (index > headers_.size() + 1) { return nullptr; } return &headers_[index]; } PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input, DataBuffer* output) { DataBuffer in(input); bool changed = false; for (auto it = filters_.begin(); it != filters_.end(); ++it) { PacketFilter::Action action = (*it)->Process(in, output); if (action == DROP) { return DROP; } if (action == CHANGE) { in = *output; changed = true; } } return changed ? CHANGE : KEEP; } bool FindClientHelloExtensions(TlsParser* parser, const TlsVersioned& header) { if (!parser->Skip(2 + 32)) { // version + random return false; } if (!parser->SkipVariable(1)) { // session ID return false; } if (header.is_dtls() && !parser->SkipVariable(1)) { // DTLS cookie return false; } if (!parser->SkipVariable(2)) { // cipher suites return false; } if (!parser->SkipVariable(1)) { // compression methods return false; } return true; } bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) { uint32_t vtmp; if (!parser->Read(&vtmp, 2)) { return false; } uint16_t version = static_cast(vtmp); if (!parser->Skip(32)) { // random return false; } if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) { if (!parser->SkipVariable(1)) { // session ID return false; } } if (!parser->Skip(2)) { // cipher suite return false; } if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) { if (!parser->Skip(1)) { // compression method return false; } } return true; } bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) { return true; } static bool FindCertReqExtensions(TlsParser* parser, const TlsVersioned& header) { if (!parser->SkipVariable(1)) { // request context return false; } return true; } // Only look at the EE cert for this one. static bool FindCertificateExtensions(TlsParser* parser, const TlsVersioned& header) { if (!parser->SkipVariable(1)) { // request context return false; } if (!parser->Skip(3)) { // length of certificate list return false; } if (!parser->SkipVariable(3)) { // ASN1Cert return false; } return true; } static bool FindNewSessionTicketExtensions(TlsParser* parser, const TlsVersioned& header) { if (!parser->Skip(8)) { // lifetime, age add return false; } if (!parser->SkipVariable(1)) { // ticket_nonce return false; } if (!parser->SkipVariable(2)) { // ticket return false; } return true; } static const std::map kExtensionFinders = { {kTlsHandshakeClientHello, FindClientHelloExtensions}, {kTlsHandshakeServerHello, FindServerHelloExtensions}, {kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions}, {kTlsHandshakeCertificateRequest, FindCertReqExtensions}, {kTlsHandshakeCertificate, FindCertificateExtensions}, {kTlsHandshakeNewSessionTicket, FindNewSessionTicketExtensions}}; bool TlsExtensionFilter::FindExtensions(TlsParser* parser, const HandshakeHeader& header) { auto it = kExtensionFinders.find(header.handshake_type()); if (it == kExtensionFinders.end()) { return false; } return (it->second)(parser, header); } PacketFilter::Action TlsExtensionFilter::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { TlsParser parser(input); if (!FindExtensions(&parser, header)) { return KEEP; } return FilterExtensions(&parser, input, output); } PacketFilter::Action TlsExtensionFilter::FilterExtensions( TlsParser* parser, const DataBuffer& input, DataBuffer* output) { size_t length_offset = parser->consumed(); uint32_t all_extensions; if (!parser->Read(&all_extensions, 2)) { return KEEP; // no extensions, odd but OK } if (all_extensions != parser->remaining()) { return KEEP; // malformed } bool changed = false; // Write out the start of the message. output->Allocate(input.len()); size_t offset = output->Write(0, input.data(), parser->consumed()); while (parser->remaining()) { uint32_t extension_type; if (!parser->Read(&extension_type, 2)) { return KEEP; // malformed } DataBuffer extension; if (!parser->ReadVariable(&extension, 2)) { return KEEP; // malformed } DataBuffer filtered; PacketFilter::Action action = FilterExtension(extension_type, extension, &filtered); if (action == DROP) { changed = true; std::cerr << "extension drop: " << extension << std::endl; continue; } const DataBuffer* source = &extension; if (action == CHANGE) { EXPECT_GT(0x10000U, filtered.len()); changed = true; std::cerr << "extension old: " << extension << std::endl; std::cerr << "extension new: " << filtered << std::endl; source = &filtered; } // Write out extension. offset = output->Write(offset, extension_type, 2); offset = output->Write(offset, source->len(), 2); if (source->len() > 0) { offset = output->Write(offset, *source); } } output->Truncate(offset); if (changed) { size_t newlen = output->len() - length_offset - 2; EXPECT_GT(0x10000U, newlen); if (newlen >= 0x10000) { return KEEP; // bad: size increased too much } output->Write(length_offset, newlen, 2); return CHANGE; } return KEEP; } PacketFilter::Action TlsExtensionOrderCapture::FilterExtension( uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { order.push_back(extension_type); return KEEP; } PacketFilter::Action TlsExtensionCapture::FilterExtension( uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { if (extension_type == extension_ && (last_ || !captured_)) { data_.Assign(input); captured_ = true; } return KEEP; } PacketFilter::Action TlsExtensionReplacer::FilterExtension( uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { if (extension_type != extension_) { return KEEP; } *output = data_; return CHANGE; } PacketFilter::Action TlsExtensionResizer::FilterExtension( uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { if (extension_type != extension_) { return KEEP; } if (input.len() <= length_) { DataBuffer buf(length_ - input.len()); output->Append(buf); return CHANGE; } output->Assign(input.data(), length_); return CHANGE; } PacketFilter::Action TlsExtensionAppender::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { TlsParser parser(input); if (!TlsExtensionFilter::FindExtensions(&parser, header)) { return KEEP; } *output = input; // Increase the length of the extensions block. if (!UpdateLength(output, parser.consumed(), 2)) { return KEEP; } // Extensions in Certificate are nested twice. Increase the size of the // certificate list. if (header.handshake_type() == kTlsHandshakeCertificate) { TlsParser p2(input); if (!p2.SkipVariable(1)) { ADD_FAILURE(); return KEEP; } if (!UpdateLength(output, p2.consumed(), 3)) { return KEEP; } } size_t offset = output->len(); offset = output->Write(offset, extension_, 2); WriteVariable(output, offset, data_, 2); return CHANGE; } bool TlsExtensionAppender::UpdateLength(DataBuffer* output, size_t offset, size_t size) { uint32_t len; if (!output->Read(offset, size, &len)) { ADD_FAILURE(); return false; } len += 4 + data_.len(); output->Write(offset, len, size); return true; } PacketFilter::Action TlsExtensionDropper::FilterExtension( uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { if (extension_type == extension_) { return DROP; } return KEEP; } PacketFilter::Action TlsExtensionDamager::FilterExtension( uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { if (extension_type != extension_) { return KEEP; } *output = input; output->data()[index_] += 73; // Increment selected for maximum damage return CHANGE; } PacketFilter::Action TlsExtensionInjector::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { TlsParser parser(input); if (!TlsExtensionFilter::FindExtensions(&parser, header)) { return KEEP; } size_t offset = parser.consumed(); *output = input; // Increase the size of the extensions. uint16_t ext_len; memcpy(&ext_len, output->data() + offset, sizeof(ext_len)); ext_len = htons(ntohs(ext_len) + data_.len() + 4); memcpy(output->data() + offset, &ext_len, sizeof(ext_len)); // Insert the extension type and length. DataBuffer type_length; type_length.Allocate(4); type_length.Write(0, extension_, 2); type_length.Write(2, data_.len(), 2); output->Splice(type_length, offset + 2); // Insert the payload. if (data_.len() > 0) { output->Splice(data_, offset + 6); } return CHANGE; } PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header, const DataBuffer& body, DataBuffer* out) { if (counter_++ == record_) { DataBuffer buf; header.Write(&buf, 0, body); agent()->SendDirect(buf); dest_.lock()->Handshake(); func_(); return DROP; } return KEEP; } PacketFilter::Action TlsClientHelloVersionChanger::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { EXPECT_EQ(SECSuccess, SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd())); return KEEP; } PacketFilter::Action SelectiveDropFilter::Filter(const DataBuffer& input, DataBuffer* output) { if (counter_ >= 32) { return KEEP; } return ((1 << counter_++) & pattern_) ? DROP : KEEP; } PacketFilter::Action SelectiveRecordDropFilter::FilterRecord( const TlsRecordHeader& header, const DataBuffer& data, DataBuffer* changed) { if (counter_ >= 32) { return KEEP; } return ((1 << counter_++) & pattern_) ? DROP : KEEP; } /* static */ uint32_t SelectiveRecordDropFilter::ToPattern( std::initializer_list records) { uint32_t pattern = 0; for (auto it = records.begin(); it != records.end(); ++it) { EXPECT_GT(32U, *it); assert(*it < 32U); pattern |= 1 << *it; } return pattern; } PacketFilter::Action TlsMessageVersionSetter::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { *output = input; output->Write(0, version_, 2); return CHANGE; } PacketFilter::Action SelectedCipherSuiteReplacer::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { *output = input; uint32_t temp = 0; EXPECT_TRUE(input.Read(0, 2, &temp)); EXPECT_EQ(header.version(), NormalizeTlsVersion(temp)); // Cipher suite is after version(2), random(32) // and [legacy_]session_id(<0..32>). size_t pos = 34; EXPECT_TRUE(input.Read(pos, 1, &temp)); pos += 1 + temp; output->Write(pos, static_cast(cipher_suite_), 2); return CHANGE; } PacketFilter::Action ServerHelloRandomChanger::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { *output = input; uint32_t temp = 0; size_t pos = 30; EXPECT_TRUE(input.Read(pos, 2, &temp)); output->Write(pos, (temp ^ 0xffff), 2); return CHANGE; } PacketFilter::Action ClientHelloPreambleCapture::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { EXPECT_TRUE(header.handshake_type() == kTlsHandshakeClientHello); if (captured_) { return KEEP; } captured_ = true; DataBuffer temp; TlsParser parser(input); EXPECT_TRUE(parser.Read(&temp, 2 + 32)); // Version + Random EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Session ID if (is_dtls_agent()) { EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Cookie } EXPECT_TRUE(parser.ReadVariable(&temp, 2)); // Ciphersuites EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Compression // Copy the preamble into a new buffer data_ = input; data_.Truncate(parser.consumed()); return KEEP; } PacketFilter::Action ClientHelloCiphersuiteCapture::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { EXPECT_TRUE(header.handshake_type() == kTlsHandshakeClientHello); if (captured_) { return KEEP; } captured_ = true; TlsParser parser(input); EXPECT_TRUE(parser.Skip(2 + 32)); // Version + Random EXPECT_TRUE(parser.SkipVariable(1)); // Session ID if (is_dtls_agent()) { EXPECT_TRUE(parser.SkipVariable(1)); // Cookie } EXPECT_TRUE(parser.ReadVariable(&data_, 2)); // Ciphersuites return KEEP; } } // namespace nss_test