summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_filter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.cc1293
1 files changed, 1293 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc
new file mode 100644
index 0000000000..ab52a07e84
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_filter.cc
@@ -0,0 +1,1293 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_filter.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include <cassert>
+#include <iostream>
+#include "gtest_utils.h"
+#include "tls_agent.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+#include "tls_protect.h"
+
+namespace nss_test {
+
+void TlsVersioned::WriteStream(std::ostream& stream) const {
+ stream << (is_dtls() ? "DTLS " : "TLS ");
+ switch (version()) {
+ case 0:
+ stream << "(no version)";
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_0:
+ stream << "1.0";
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ stream << (is_dtls() ? "1.0" : "1.1");
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ stream << "1.2";
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_3:
+ stream << "1.3";
+ break;
+ default:
+ stream << "Invalid version: " << version();
+ break;
+ }
+}
+
+TlsRecordFilter::TlsRecordFilter(const std::shared_ptr<TlsAgent>& a)
+ : agent_(a) {
+ cipher_specs_.emplace_back(a->variant() == ssl_variant_datagram, 0);
+}
+
+void TlsRecordFilter::EnableDecryption() {
+ EXPECT_EQ(SECSuccess,
+ SSL_SecretCallback(agent()->ssl_fd(), SecretCallback, this));
+ decrypting_ = true;
+}
+
+void TlsRecordFilter::SecretCallback(PRFileDesc* fd, PRUint16 epoch,
+ SSLSecretDirection dir, PK11SymKey* secret,
+ void* arg) {
+ TlsRecordFilter* self = static_cast<TlsRecordFilter*>(arg);
+ if (g_ssl_gtest_verbose) {
+ std::cerr << self->agent()->role_str() << ": " << dir
+ << " secret changed for epoch " << epoch << std::endl;
+ }
+
+ if (dir == ssl_secret_read) {
+ return;
+ }
+
+ for (auto& spec : self->cipher_specs_) {
+ ASSERT_NE(spec.epoch(), epoch) << "duplicate spec for epoch " << epoch;
+ }
+
+ SSLPreliminaryChannelInfo preinfo;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetPreliminaryChannelInfo(self->agent()->ssl_fd(), &preinfo,
+ sizeof(preinfo)));
+ EXPECT_EQ(sizeof(preinfo), preinfo.length);
+
+ // Check the version.
+ if (preinfo.valuesSet & ssl_preinfo_version) {
+ EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion);
+ } else {
+ EXPECT_EQ(1U, epoch);
+ }
+
+ uint16_t suite;
+ if (epoch == 1) {
+ // 0-RTT
+ EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_0rtt_cipher_suite);
+ suite = preinfo.zeroRttCipherSuite;
+ } else {
+ EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite);
+ suite = preinfo.cipherSuite;
+ }
+
+ SSLCipherSuiteInfo cipherinfo;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetCipherSuiteInfo(suite, &cipherinfo, sizeof(cipherinfo)));
+ EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length);
+
+ self->cipher_specs_.emplace_back(self->is_dtls_agent(), epoch);
+ EXPECT_TRUE(self->cipher_specs_.back().SetKeys(&cipherinfo, secret));
+}
+
+bool TlsRecordFilter::is_dtls_agent() const {
+ return agent()->variant() == ssl_variant_datagram;
+}
+
+bool TlsRecordFilter::is_dtls13() const {
+ if (!is_dtls_agent()) {
+ return false;
+ }
+ if (agent()->state() == TlsAgent::STATE_CONNECTED) {
+ return agent()->version() >= SSL_LIBRARY_VERSION_TLS_1_3;
+ }
+ SSLPreliminaryChannelInfo info;
+ EXPECT_EQ(SECSuccess, SSL_GetPreliminaryChannelInfo(agent()->ssl_fd(), &info,
+ sizeof(info)));
+ return (info.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) ||
+ info.canSendEarlyData;
+}
+
+bool TlsRecordFilter::is_dtls13_ciphertext(uint8_t ct) const {
+ return is_dtls13() && (ct & kCtDtlsCiphertextMask) == kCtDtlsCiphertext;
+}
+
+// Gets the cipher spec that matches the specified epoch.
+TlsCipherSpec& TlsRecordFilter::spec(uint16_t write_epoch) {
+ for (auto& sp : cipher_specs_) {
+ if (sp.epoch() == write_epoch) {
+ return sp;
+ }
+ }
+
+ // If we aren't decrypting, provide a cipher spec that does nothing other than
+ // count sequence numbers.
+ EXPECT_FALSE(decrypting_) << "No spec available for epoch " << write_epoch;
+ ;
+ cipher_specs_.emplace_back(is_dtls_agent(), write_epoch);
+ return cipher_specs_.back();
+}
+
+PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
+ DataBuffer* output) {
+ // Disable during shutdown.
+ if (!agent()) {
+ return KEEP;
+ }
+
+ bool changed = false;
+ size_t offset = 0U;
+
+ output->Allocate(input.len());
+ TlsParser parser(input);
+
+ // This uses the current write spec for the purposes of parsing the epoch and
+ // sequence number from the header. This might be wrong because we can
+ // receive records from older specs, but guessing is good enough:
+ // - In DTLS, parsing the sequence number corrects any errors.
+ // - In TLS, we don't use the sequence number unless decrypting, where we use
+ // trial decryption to get the right epoch.
+ uint16_t write_epoch = 0;
+ SECStatus rv = SSL_GetCurrentEpoch(agent()->ssl_fd(), nullptr, &write_epoch);
+ if (rv != SECSuccess) {
+ ADD_FAILURE() << "unable to read epoch";
+ return KEEP;
+ }
+ uint64_t guess_seqno = static_cast<uint64_t>(write_epoch) << 48;
+
+ while (parser.remaining()) {
+ TlsRecordHeader header;
+ DataBuffer record;
+ if (!header.Parse(is_dtls13(), guess_seqno, &parser, &record)) {
+ ADD_FAILURE() << "not a valid record";
+ return KEEP;
+ }
+
+ if (FilterRecord(header, record, &offset, output) != KEEP) {
+ changed = true;
+ } else {
+ offset = header.Write(output, offset, record);
+ }
+ }
+ output->Truncate(offset);
+
+ // Record how many packets we actually touched.
+ if (changed) {
+ ++count_;
+ return (offset == 0) ? DROP : CHANGE;
+ }
+
+ return KEEP;
+}
+
+PacketFilter::Action TlsRecordFilter::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& record, size_t* offset,
+ DataBuffer* output) {
+ DataBuffer filtered;
+ uint8_t inner_content_type;
+ DataBuffer plaintext;
+ uint16_t protection_epoch = 0;
+ TlsRecordHeader out_header(header);
+
+ if (!Unprotect(header, record, &protection_epoch, &inner_content_type,
+ &plaintext, &out_header)) {
+ std::cerr << agent()->role_str() << ": unprotect failed: " << header << ":"
+ << record << std::endl;
+ return KEEP;
+ }
+
+ auto& protection_spec = spec(protection_epoch);
+ TlsRecordHeader real_header(out_header.variant(), out_header.version(),
+ inner_content_type, out_header.sequence_number());
+
+ PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered);
+ // In stream mode, even if something doesn't change we need to re-encrypt if
+ // previous packets were dropped.
+ if (action == KEEP) {
+ if (out_header.is_dtls() || !protection_spec.record_dropped()) {
+ // Count every outgoing packet.
+ protection_spec.RecordProtected();
+ return KEEP;
+ }
+ filtered = plaintext;
+ }
+
+ if (action == DROP) {
+ std::cerr << "record drop: " << out_header << ":" << record << std::endl;
+ protection_spec.RecordDropped();
+ return DROP;
+ }
+
+ EXPECT_GT(0x10000U, filtered.len());
+ if (action != KEEP) {
+ std::cerr << "record old: " << plaintext << std::endl;
+ std::cerr << "record new: " << filtered << std::endl;
+ }
+
+ uint64_t seq_num = protection_spec.next_out_seqno();
+ if (!decrypting_ && out_header.is_dtls()) {
+ // Copy over the epoch, which isn't tracked when not decrypting.
+ seq_num |= out_header.sequence_number() & (0xffffULL << 48);
+ }
+ out_header.sequence_number(seq_num);
+
+ DataBuffer ciphertext;
+ bool rv = Protect(protection_spec, out_header, inner_content_type, filtered,
+ &ciphertext, &out_header);
+ if (!rv) {
+ return KEEP;
+ }
+ *offset = out_header.Write(output, *offset, ciphertext);
+ return CHANGE;
+}
+
+size_t TlsRecordHeader::header_length() const {
+ // If we have a header, return it's length.
+ if (header_.len()) {
+ return header_.len();
+ }
+
+ // Otherwise make a dummy header and return the length.
+ DataBuffer buf;
+ return WriteHeader(&buf, 0, 0);
+}
+
+bool TlsRecordHeader::MaskSequenceNumber() {
+ return MaskSequenceNumber(sn_mask());
+}
+
+bool TlsRecordHeader::MaskSequenceNumber(const DataBuffer& mask_buf) {
+ if (mask_buf.empty()) {
+ return false;
+ }
+
+ DataBuffer mask;
+ if (is_dtls13_ciphertext()) {
+ uint64_t seqno = sequence_number();
+ uint8_t len = content_type() & kCtDtlsCiphertext16bSeqno ? 2 : 1;
+ uint16_t seqno_bitmask = (1 << len * 8) - 1;
+ DataBuffer val;
+ if (val.Write(0, seqno & seqno_bitmask, len) != len) {
+ return false;
+ }
+
+#ifdef UNSAFE_FUZZER_MODE
+ // Use a null mask.
+ mask.Allocate(mask_buf.len());
+#endif
+ mask.Append(mask_buf);
+ val.data()[0] ^= mask.data()[0];
+ if (len == 2 && mask.len() > 1) {
+ val.data()[1] ^= mask.data()[1];
+ }
+ uint32_t tmp;
+ if (!val.Read(0, len, &tmp)) {
+ return false;
+ }
+
+ seqno = (seqno & ~seqno_bitmask) | tmp;
+ seqno_is_masked_ = !seqno_is_masked_;
+ if (!seqno_is_masked_) {
+ seqno = ParseSequenceNumber(guess_seqno_, seqno, len * 8, 2);
+ }
+ sequence_number_ = seqno;
+
+ // Now update the header bytes
+ if (header_.len() > 1) {
+ header_.data()[1] ^= mask.data()[0];
+ if ((content_type() & kCtDtlsCiphertext16bSeqno) && header().len() > 2) {
+ header_.data()[2] ^= mask.data()[1];
+ }
+ }
+ }
+
+ sn_mask_ = mask;
+ return true;
+}
+
+uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t guess_seqno,
+ uint32_t partial,
+ size_t partial_bits) {
+ EXPECT_GE(32U, partial_bits);
+ uint64_t mask = (1ULL << partial_bits) - 1;
+ // First we determine the highest possible value. This is half the
+ // expressible range above the expected value (|guess_seqno|), less 1.
+ //
+ // We subtract the extra 1 from the cap so that when given a choice between
+ // the equidistant expected+N and expected-N we want to chose the lower. With
+ // 0-RTT, we sometimes have to recover an epoch of 1 when we expect an epoch
+ // of 3 and with 2 partial bits, the alternative result of 5 is wrong.
+ uint64_t cap = guess_seqno + (1ULL << (partial_bits - 1)) - 1;
+ // Add the partial piece in. e.g., xxxx789a and 1234 becomes xxxx1234.
+ uint64_t seq_no = (cap & ~mask) | partial;
+ // If the partial value is higher than the same partial piece from the cap,
+ // then the real value has to be lower. e.g., xxxx1234 can't become xxxx5678.
+ if (partial > (cap & mask) && (seq_no >= (1ULL << partial_bits))) {
+ seq_no -= 1ULL << partial_bits;
+ }
+ return seq_no;
+}
+
+// Determine the full epoch and sequence number from an expected and raw value.
+// The expected, raw, and output values are packed as they are in DTLS 1.2 and
+// earlier: with 16 bits of epoch and 48 bits of sequence number. The raw value
+// is packed this way (even before recovery) so that we don't need to track a
+// moving value between two calls (one to recover the epoch, and one after
+// unmasking to recover the sequence number).
+uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint64_t raw,
+ size_t seq_no_bits,
+ size_t epoch_bits) {
+ uint64_t epoch_mask = (1ULL << epoch_bits) - 1;
+ uint64_t ep = RecoverSequenceNumber(expected >> 48, (raw >> 48) & epoch_mask,
+ epoch_bits);
+ if (ep > (expected >> 48)) {
+ // If the epoch has changed, reset the expected sequence number.
+ expected = 0;
+ } else {
+ // Otherwise, retain just the sequence number part.
+ expected &= (1ULL << 48) - 1;
+ }
+ uint64_t seq_no_mask = (1ULL << seq_no_bits) - 1;
+ uint64_t seq_no = (raw & seq_no_mask);
+ if (!seqno_is_masked_) {
+ seq_no = RecoverSequenceNumber(expected, seq_no, seq_no_bits);
+ }
+
+ return (ep << 48) | seq_no;
+}
+
+bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser,
+ DataBuffer* body) {
+ auto mark = parser->consumed();
+
+ if (!parser->Read(&content_type_)) {
+ return false;
+ }
+
+ if (is_dtls13) {
+ variant_ = ssl_variant_datagram;
+ version_ = SSL_LIBRARY_VERSION_TLS_1_3;
+
+#ifndef UNSAFE_FUZZER_MODE
+ // Deal with the DTLSCipherText header.
+ if (is_dtls13_ciphertext()) {
+ uint8_t seq_no_bytes =
+ (content_type_ & kCtDtlsCiphertext16bSeqno) ? 2 : 1;
+ uint32_t tmp;
+
+ if (!parser->Read(&tmp, seq_no_bytes)) {
+ return false;
+ }
+
+ // Store the guess if masked. If and when seqno_bytesenceNumber is called,
+ // the value will be unmasked and recovered. This assumes we only call
+ // Parse() on headers containing masked values.
+ seqno_is_masked_ = true;
+ guess_seqno_ = seqno;
+ uint64_t ep = content_type_ & 0x03;
+ sequence_number_ = (ep << 48) | tmp;
+
+ // Recover the full epoch. Note the sequence number portion holds the
+ // masked value until a call to Mask() reveals it (as indicated by
+ // |seqno_is_masked_|).
+ sequence_number_ =
+ ParseSequenceNumber(seqno, sequence_number_, seq_no_bytes * 8, 2);
+
+ uint32_t len_bytes =
+ (content_type_ & kCtDtlsCiphertextLengthPresent) ? 2 : 0;
+ if (len_bytes) {
+ if (!parser->Read(&tmp, 2)) {
+ return false;
+ }
+ }
+
+ if (!parser->ReadFromMark(&header_, parser->consumed() - mark, mark)) {
+ return false;
+ }
+
+ return len_bytes ? parser->Read(body, tmp)
+ : parser->Read(body, parser->remaining());
+ }
+
+ // The full DTLSPlainText header can only be used for a few types.
+ EXPECT_TRUE(content_type_ == ssl_ct_alert ||
+ content_type_ == ssl_ct_handshake ||
+ content_type_ == ssl_ct_ack);
+#endif
+ }
+
+ uint32_t ver;
+ if (!parser->Read(&ver, 2)) {
+ return false;
+ }
+ if (!is_dtls13) {
+ variant_ = IsDtls(ver) ? ssl_variant_datagram : ssl_variant_stream;
+ }
+ version_ = NormalizeTlsVersion(ver);
+
+ if (is_dtls()) {
+ // If this is DTLS, read the sequence number.
+ uint32_t tmp;
+ if (!parser->Read(&tmp, 4)) {
+ return false;
+ }
+ sequence_number_ = static_cast<uint64_t>(tmp) << 32;
+ if (!parser->Read(&tmp, 4)) {
+ return false;
+ }
+ sequence_number_ |= static_cast<uint64_t>(tmp);
+ } else {
+ sequence_number_ = seqno;
+ }
+ if (!parser->ReadFromMark(&header_, parser->consumed() + 2 - mark, mark)) {
+ return false;
+ }
+ return parser->ReadVariable(body, 2);
+}
+
+size_t TlsRecordHeader::WriteHeader(DataBuffer* buffer, size_t offset,
+ size_t body_len) const {
+ if (is_dtls13_ciphertext()) {
+ uint8_t seq_no_bytes = (content_type_ & kCtDtlsCiphertext16bSeqno) ? 2 : 1;
+ // application_data records in TLS 1.3 have a different header format.
+ uint32_t e = (sequence_number_ >> 48) & 0x3;
+ uint32_t seqno = sequence_number_ & ((1ULL << seq_no_bytes * 8) - 1);
+ uint8_t new_content_type_ = content_type_ | e;
+ offset = buffer->Write(offset, new_content_type_, 1);
+ offset = buffer->Write(offset, seqno, seq_no_bytes);
+
+ if (content_type_ & kCtDtlsCiphertextLengthPresent) {
+ offset = buffer->Write(offset, body_len, 2);
+ }
+ } else {
+ offset = buffer->Write(offset, content_type_, 1);
+ uint16_t v = is_dtls() ? TlsVersionToDtlsVersion(version_) : version_;
+ offset = buffer->Write(offset, v, 2);
+ if (is_dtls()) {
+ // write epoch (2 octet), and seqnum (6 octet)
+ offset = buffer->Write(offset, sequence_number_ >> 32, 4);
+ offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4);
+ }
+ offset = buffer->Write(offset, body_len, 2);
+ }
+
+ return offset;
+}
+
+size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body) const {
+ offset = WriteHeader(buffer, offset, body.len());
+ offset = buffer->Write(offset, body);
+ return offset;
+}
+
+bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
+ const DataBuffer& ciphertext,
+ uint16_t* protection_epoch,
+ uint8_t* inner_content_type,
+ DataBuffer* plaintext,
+ TlsRecordHeader* out_header) {
+ if (!decrypting_ || !header.is_protected()) {
+ // Maintain the epoch and sequence number for plaintext records.
+ uint16_t ep = 0;
+ if (is_dtls_agent()) {
+ ep = static_cast<uint16_t>(header.sequence_number() >> 48);
+ }
+ spec(ep).RecordUnprotected(header.sequence_number());
+ *protection_epoch = ep;
+ *inner_content_type = header.content_type();
+ *plaintext = ciphertext;
+ return true;
+ }
+
+ uint16_t ep = 0;
+ if (is_dtls_agent()) {
+ ep = static_cast<uint16_t>(header.sequence_number() >> 48);
+ if (!spec(ep).Unprotect(header, ciphertext, plaintext, out_header)) {
+ return false;
+ }
+ } else {
+ // In TLS, records aren't clearly labelled with their epoch, and we
+ // can't just use the newest keys because the same flight of messages can
+ // contain multiple epochs. So... trial decrypt!
+ for (size_t i = cipher_specs_.size() - 1; i > 0; --i) {
+ if (cipher_specs_[i].Unprotect(header, ciphertext, plaintext,
+ out_header)) {
+ ep = cipher_specs_[i].epoch();
+ break;
+ }
+ }
+ if (!ep) {
+ return false;
+ }
+ }
+
+ size_t len = plaintext->len();
+ while (len > 0 && !plaintext->data()[len - 1]) {
+ --len;
+ }
+ if (!len) {
+ // Bogus padding.
+ return false;
+ }
+
+ *protection_epoch = ep;
+ *inner_content_type = plaintext->data()[len - 1];
+ plaintext->Truncate(len - 1);
+ if (g_ssl_gtest_verbose) {
+ std::cerr << agent()->role_str() << ": unprotect: epoch=" << ep
+ << " seq=" << std::hex << header.sequence_number() << std::dec
+ << " " << *plaintext << std::endl;
+ }
+
+ return true;
+}
+
+bool TlsRecordFilter::Protect(TlsCipherSpec& protection_spec,
+ const TlsRecordHeader& header,
+ uint8_t inner_content_type,
+ const DataBuffer& plaintext,
+ DataBuffer* ciphertext,
+ TlsRecordHeader* out_header, size_t padding) {
+ if (!protection_spec.is_protected()) {
+ // Not protected, just keep the sequence numbers updated.
+ protection_spec.RecordProtected();
+ *ciphertext = plaintext;
+ return true;
+ }
+
+ DataBuffer padded;
+ padded.Allocate(plaintext.len() + 1 + padding);
+ size_t offset = padded.Write(0, plaintext.data(), plaintext.len());
+ padded.Write(offset, inner_content_type, 1);
+
+ bool ok = protection_spec.Protect(header, padded, ciphertext, out_header);
+ if (!ok) {
+ ADD_FAILURE() << "protect fail";
+ } else if (g_ssl_gtest_verbose) {
+ std::cerr << agent()->role_str()
+ << ": protect: epoch=" << protection_spec.epoch()
+ << " seq=" << std::hex << header.sequence_number() << std::dec
+ << " " << *ciphertext << std::endl;
+ }
+ return ok;
+}
+
+bool IsHelloRetry(const DataBuffer& body) {
+ static const uint8_t ssl_hello_retry_random[] = {
+ 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C,
+ 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB,
+ 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C};
+ return memcmp(body.data() + 2, ssl_hello_retry_random,
+ sizeof(ssl_hello_retry_random)) == 0;
+}
+
+bool TlsHandshakeFilter::IsFilteredType(const HandshakeHeader& header,
+ const DataBuffer& body) {
+ if (handshake_types_.empty()) {
+ return true;
+ }
+
+ uint8_t type = header.handshake_type();
+ if (type == kTlsHandshakeServerHello) {
+ if (IsHelloRetry(body)) {
+ type = kTlsHandshakeHelloRetryRequest;
+ }
+ }
+ return handshake_types_.count(type) > 0U;
+}
+
+PacketFilter::Action TlsHandshakeFilter::FilterRecord(
+ const TlsRecordHeader& record_header, const DataBuffer& input,
+ DataBuffer* output) {
+ // Check that the first byte is as requested.
+ if (record_header.content_type() != ssl_ct_handshake) {
+ return KEEP;
+ }
+
+ bool changed = false;
+ size_t offset = 0U;
+ output->Allocate(input.len()); // Preallocate a little.
+
+ TlsParser parser(input);
+ while (parser.remaining()) {
+ HandshakeHeader header;
+ DataBuffer handshake;
+ bool complete = false;
+ if (!header.Parse(&parser, record_header, preceding_fragment_, &handshake,
+ &complete)) {
+ return KEEP;
+ }
+
+ if (!complete) {
+ EXPECT_TRUE(record_header.is_dtls());
+ // Save the fragment and drop it from this record. Fragments are
+ // coalesced with the last fragment of the handshake message.
+ changed = true;
+ preceding_fragment_.Assign(handshake);
+ continue;
+ }
+ preceding_fragment_.Truncate(0);
+
+ DataBuffer filtered;
+ PacketFilter::Action action;
+ if (!IsFilteredType(header, handshake)) {
+ action = KEEP;
+ } else {
+ action = FilterHandshake(header, handshake, &filtered);
+ }
+ if (action == DROP) {
+ changed = true;
+ std::cerr << "handshake drop: " << handshake << std::endl;
+ continue;
+ }
+
+ const DataBuffer* source = &handshake;
+ if (action == CHANGE) {
+ EXPECT_GT(0x1000000U, filtered.len());
+ changed = true;
+ std::cerr << "handshake old: " << handshake << std::endl;
+ std::cerr << "handshake new: " << filtered << std::endl;
+ source = &filtered;
+ } else if (preceding_fragment_.len()) {
+ changed = true;
+ }
+
+ offset = header.Write(output, offset, *source);
+ }
+ output->Truncate(offset);
+ return changed ? (offset ? CHANGE : DROP) : KEEP;
+}
+
+bool TlsHandshakeFilter::HandshakeHeader::ReadLength(
+ TlsParser* parser, const TlsRecordHeader& header, uint32_t expected_offset,
+ uint32_t* length, bool* last_fragment) {
+ uint32_t message_length;
+ if (!parser->Read(&message_length, 3)) {
+ return false; // malformed
+ }
+
+ if (!header.is_dtls()) {
+ *last_fragment = true;
+ *length = message_length;
+ return true; // nothing left to do
+ }
+
+ // Read and check DTLS parameters
+ uint32_t message_seq_tmp;
+ if (!parser->Read(&message_seq_tmp, 2)) { // sequence number
+ return false;
+ }
+ message_seq_ = message_seq_tmp;
+
+ uint32_t offset = 0;
+ if (!parser->Read(&offset, 3)) {
+ return false;
+ }
+ // We only parse if the fragments are all complete and in order.
+ if (offset != expected_offset) {
+ EXPECT_NE(0U, header.epoch())
+ << "Received out of order handshake fragment for epoch 0";
+ return false;
+ }
+
+ // For DTLS, we return the length of just this fragment.
+ if (!parser->Read(length, 3)) {
+ return false;
+ }
+
+ // It's a fragment if the entire message is longer than what we have.
+ *last_fragment = message_length == (*length + offset);
+ return true;
+}
+
+bool TlsHandshakeFilter::HandshakeHeader::Parse(
+ TlsParser* parser, const TlsRecordHeader& record_header,
+ const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) {
+ *complete = false;
+
+ variant_ = record_header.variant();
+ version_ = record_header.version();
+ if (!parser->Read(&handshake_type_)) {
+ return false; // malformed
+ }
+
+ uint32_t length;
+ if (!ReadLength(parser, record_header, preceding_fragment.len(), &length,
+ complete)) {
+ return false;
+ }
+
+ if (!parser->Read(body, length)) {
+ return false;
+ }
+ if (preceding_fragment.len()) {
+ body->Splice(preceding_fragment, 0);
+ }
+ return true;
+}
+
+size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment(
+ DataBuffer* buffer, size_t offset, const DataBuffer& body,
+ size_t fragment_offset, size_t fragment_length) const {
+ EXPECT_TRUE(is_dtls());
+ EXPECT_GE(body.len(), fragment_offset + fragment_length);
+ offset = buffer->Write(offset, handshake_type(), 1);
+ offset = buffer->Write(offset, body.len(), 3);
+ offset = buffer->Write(offset, message_seq_, 2);
+ offset = buffer->Write(offset, fragment_offset, 3);
+ offset = buffer->Write(offset, fragment_length, 3);
+ offset =
+ buffer->Write(offset, body.data() + fragment_offset, fragment_length);
+ return offset;
+}
+
+size_t TlsHandshakeFilter::HandshakeHeader::Write(
+ DataBuffer* buffer, size_t offset, const DataBuffer& body) const {
+ if (is_dtls()) {
+ return WriteFragment(buffer, offset, body, 0U, body.len());
+ }
+ offset = buffer->Write(offset, handshake_type(), 1);
+ offset = buffer->Write(offset, body.len(), 3);
+ offset = buffer->Write(offset, body);
+ return offset;
+}
+
+PacketFilter::Action TlsHandshakeRecorder::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ // Only do this once.
+ if (buffer_.len()) {
+ return KEEP;
+ }
+
+ buffer_ = input;
+ return KEEP;
+}
+
+PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ *output = buffer_;
+ return CHANGE;
+}
+
+PacketFilter::Action TlsRecordRecorder::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ if (!filter_ || (header.content_type() == ct_)) {
+ records_.push_back({header, input});
+ }
+ return KEEP;
+}
+
+PacketFilter::Action TlsConversationRecorder::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ buffer_.Append(input);
+ return KEEP;
+}
+
+PacketFilter::Action TlsHeaderRecorder::FilterRecord(const TlsRecordHeader& hdr,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ headers_.push_back(hdr);
+ return KEEP;
+}
+
+const TlsRecordHeader* TlsHeaderRecorder::header(size_t index) {
+ if (index > headers_.size() + 1) {
+ return nullptr;
+ }
+ return &headers_[index];
+}
+
+PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
+ DataBuffer* output) {
+ DataBuffer in(input);
+ bool changed = false;
+ for (auto it = filters_.begin(); it != filters_.end(); ++it) {
+ PacketFilter::Action action = (*it)->Process(in, output);
+ if (action == DROP) {
+ return DROP;
+ }
+
+ if (action == CHANGE) {
+ in = *output;
+ changed = true;
+ }
+ }
+ return changed ? CHANGE : KEEP;
+}
+
+bool FindClientHelloExtensions(TlsParser* parser, const TlsVersioned& header) {
+ if (!parser->Skip(2 + 32)) { // version + random
+ return false;
+ }
+ if (!parser->SkipVariable(1)) { // session ID
+ return false;
+ }
+ if (header.is_dtls() && !parser->SkipVariable(1)) { // DTLS cookie
+ return false;
+ }
+ if (!parser->SkipVariable(2)) { // cipher suites
+ return false;
+ }
+ if (!parser->SkipVariable(1)) { // compression methods
+ return false;
+ }
+ return true;
+}
+
+bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) {
+ uint32_t vtmp;
+ if (!parser->Read(&vtmp, 2)) {
+ return false;
+ }
+ uint16_t version = static_cast<uint16_t>(vtmp);
+ if (!parser->Skip(32)) { // random
+ return false;
+ }
+ if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) {
+ if (!parser->SkipVariable(1)) { // session ID
+ return false;
+ }
+ }
+ if (!parser->Skip(2)) { // cipher suite
+ return false;
+ }
+ if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) {
+ if (!parser->Skip(1)) { // compression method
+ return false;
+ }
+ }
+ return true;
+}
+
+bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) {
+ return true;
+}
+
+static bool FindCertReqExtensions(TlsParser* parser,
+ const TlsVersioned& header) {
+ if (!parser->SkipVariable(1)) { // request context
+ return false;
+ }
+ return true;
+}
+
+// Only look at the EE cert for this one.
+static bool FindCertificateExtensions(TlsParser* parser,
+ const TlsVersioned& header) {
+ if (!parser->SkipVariable(1)) { // request context
+ return false;
+ }
+ if (!parser->Skip(3)) { // length of certificate list
+ return false;
+ }
+ if (!parser->SkipVariable(3)) { // ASN1Cert
+ return false;
+ }
+ return true;
+}
+
+static bool FindNewSessionTicketExtensions(TlsParser* parser,
+ const TlsVersioned& header) {
+ if (!parser->Skip(8)) { // lifetime, age add
+ return false;
+ }
+ if (!parser->SkipVariable(1)) { // ticket_nonce
+ return false;
+ }
+ if (!parser->SkipVariable(2)) { // ticket
+ return false;
+ }
+ return true;
+}
+
+static const std::map<uint16_t, TlsExtensionFinder> kExtensionFinders = {
+ {kTlsHandshakeClientHello, FindClientHelloExtensions},
+ {kTlsHandshakeServerHello, FindServerHelloExtensions},
+ {kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions},
+ {kTlsHandshakeCertificateRequest, FindCertReqExtensions},
+ {kTlsHandshakeCertificate, FindCertificateExtensions},
+ {kTlsHandshakeNewSessionTicket, FindNewSessionTicketExtensions}};
+
+bool TlsExtensionFilter::FindExtensions(TlsParser* parser,
+ const HandshakeHeader& header) {
+ auto it = kExtensionFinders.find(header.handshake_type());
+ if (it == kExtensionFinders.end()) {
+ return false;
+ }
+ return (it->second)(parser, header);
+}
+
+PacketFilter::Action TlsExtensionFilter::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ TlsParser parser(input);
+ if (!FindExtensions(&parser, header)) {
+ return KEEP;
+ }
+ return FilterExtensions(&parser, input, output);
+}
+
+PacketFilter::Action TlsExtensionFilter::FilterExtensions(
+ TlsParser* parser, const DataBuffer& input, DataBuffer* output) {
+ size_t length_offset = parser->consumed();
+ uint32_t all_extensions;
+ if (!parser->Read(&all_extensions, 2)) {
+ return KEEP; // no extensions, odd but OK
+ }
+ if (all_extensions != parser->remaining()) {
+ return KEEP; // malformed
+ }
+
+ bool changed = false;
+
+ // Write out the start of the message.
+ output->Allocate(input.len());
+ size_t offset = output->Write(0, input.data(), parser->consumed());
+
+ while (parser->remaining()) {
+ uint32_t extension_type;
+ if (!parser->Read(&extension_type, 2)) {
+ return KEEP; // malformed
+ }
+
+ DataBuffer extension;
+ if (!parser->ReadVariable(&extension, 2)) {
+ return KEEP; // malformed
+ }
+
+ DataBuffer filtered;
+ PacketFilter::Action action =
+ FilterExtension(extension_type, extension, &filtered);
+ if (action == DROP) {
+ changed = true;
+ std::cerr << "extension drop: " << extension << std::endl;
+ continue;
+ }
+
+ const DataBuffer* source = &extension;
+ if (action == CHANGE) {
+ EXPECT_GT(0x10000U, filtered.len());
+ changed = true;
+ std::cerr << "extension old: " << extension << std::endl;
+ std::cerr << "extension new: " << filtered << std::endl;
+ source = &filtered;
+ }
+
+ // Write out extension.
+ offset = output->Write(offset, extension_type, 2);
+ offset = output->Write(offset, source->len(), 2);
+ if (source->len() > 0) {
+ offset = output->Write(offset, *source);
+ }
+ }
+ output->Truncate(offset);
+
+ if (changed) {
+ size_t newlen = output->len() - length_offset - 2;
+ EXPECT_GT(0x10000U, newlen);
+ if (newlen >= 0x10000) {
+ return KEEP; // bad: size increased too much
+ }
+ output->Write(length_offset, newlen, 2);
+ return CHANGE;
+ }
+ return KEEP;
+}
+
+PacketFilter::Action TlsExtensionOrderCapture::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ order.push_back(extension_type);
+ return KEEP;
+}
+
+PacketFilter::Action TlsExtensionCapture::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ if (extension_type == extension_ && (last_ || !captured_)) {
+ data_.Assign(input);
+ captured_ = true;
+ }
+ return KEEP;
+}
+
+PacketFilter::Action TlsExtensionReplacer::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ if (extension_type != extension_) {
+ return KEEP;
+ }
+
+ *output = data_;
+ return CHANGE;
+}
+
+PacketFilter::Action TlsExtensionResizer::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ if (extension_type != extension_) {
+ return KEEP;
+ }
+
+ if (input.len() <= length_) {
+ DataBuffer buf(length_ - input.len());
+ output->Append(buf);
+ return CHANGE;
+ }
+
+ output->Assign(input.data(), length_);
+ return CHANGE;
+}
+
+PacketFilter::Action TlsExtensionAppender::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ TlsParser parser(input);
+ if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
+ return KEEP;
+ }
+ *output = input;
+
+ // Increase the length of the extensions block.
+ if (!UpdateLength(output, parser.consumed(), 2)) {
+ return KEEP;
+ }
+
+ // Extensions in Certificate are nested twice. Increase the size of the
+ // certificate list.
+ if (header.handshake_type() == kTlsHandshakeCertificate) {
+ TlsParser p2(input);
+ if (!p2.SkipVariable(1)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+ if (!UpdateLength(output, p2.consumed(), 3)) {
+ return KEEP;
+ }
+ }
+
+ size_t offset = output->len();
+ offset = output->Write(offset, extension_, 2);
+ WriteVariable(output, offset, data_, 2);
+
+ return CHANGE;
+}
+
+bool TlsExtensionAppender::UpdateLength(DataBuffer* output, size_t offset,
+ size_t size) {
+ uint32_t len;
+ if (!output->Read(offset, size, &len)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ len += 4 + data_.len();
+ output->Write(offset, len, size);
+ return true;
+}
+
+PacketFilter::Action TlsExtensionDropper::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ if (extension_type == extension_) {
+ return DROP;
+ }
+ return KEEP;
+}
+
+PacketFilter::Action TlsExtensionDamager::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ if (extension_type != extension_) {
+ return KEEP;
+ }
+
+ *output = input;
+ output->data()[index_] += 73; // Increment selected for maximum damage
+ return CHANGE;
+}
+
+PacketFilter::Action TlsExtensionInjector::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ TlsParser parser(input);
+ if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
+ return KEEP;
+ }
+ size_t offset = parser.consumed();
+
+ *output = input;
+
+ // Increase the size of the extensions.
+ uint16_t ext_len;
+ memcpy(&ext_len, output->data() + offset, sizeof(ext_len));
+ ext_len = htons(ntohs(ext_len) + data_.len() + 4);
+ memcpy(output->data() + offset, &ext_len, sizeof(ext_len));
+
+ // Insert the extension type and length.
+ DataBuffer type_length;
+ type_length.Allocate(4);
+ type_length.Write(0, extension_, 2);
+ type_length.Write(2, data_.len(), 2);
+ output->Splice(type_length, offset + 2);
+
+ // Insert the payload.
+ if (data_.len() > 0) {
+ output->Splice(data_, offset + 6);
+ }
+
+ return CHANGE;
+}
+
+PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& body,
+ DataBuffer* out) {
+ if (counter_++ == record_) {
+ DataBuffer buf;
+ header.Write(&buf, 0, body);
+ agent()->SendDirect(buf);
+ dest_.lock()->Handshake();
+ func_();
+ return DROP;
+ }
+
+ return KEEP;
+}
+
+PacketFilter::Action TlsClientHelloVersionChanger::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ EXPECT_EQ(SECSuccess,
+ SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd()));
+ return KEEP;
+}
+
+PacketFilter::Action SelectiveDropFilter::Filter(const DataBuffer& input,
+ DataBuffer* output) {
+ if (counter_ >= 32) {
+ return KEEP;
+ }
+ return ((1 << counter_++) & pattern_) ? DROP : KEEP;
+}
+
+PacketFilter::Action SelectiveRecordDropFilter::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& data,
+ DataBuffer* changed) {
+ if (counter_ >= 32) {
+ return KEEP;
+ }
+ return ((1 << counter_++) & pattern_) ? DROP : KEEP;
+}
+
+/* static */ uint32_t SelectiveRecordDropFilter::ToPattern(
+ std::initializer_list<size_t> records) {
+ uint32_t pattern = 0;
+ for (auto it = records.begin(); it != records.end(); ++it) {
+ EXPECT_GT(32U, *it);
+ assert(*it < 32U);
+ pattern |= 1 << *it;
+ }
+ return pattern;
+}
+
+PacketFilter::Action TlsMessageVersionSetter::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+ output->Write(0, version_, 2);
+ return CHANGE;
+}
+
+PacketFilter::Action SelectedCipherSuiteReplacer::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+ uint32_t temp = 0;
+ EXPECT_TRUE(input.Read(0, 2, &temp));
+ EXPECT_EQ(header.version(), NormalizeTlsVersion(temp));
+ // Cipher suite is after version(2), random(32)
+ // and [legacy_]session_id(<0..32>).
+ size_t pos = 34;
+ EXPECT_TRUE(input.Read(pos, 1, &temp));
+ pos += 1 + temp;
+
+ output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2);
+ return CHANGE;
+}
+
+PacketFilter::Action ServerHelloRandomChanger::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+ uint32_t temp = 0;
+ size_t pos = 30;
+ EXPECT_TRUE(input.Read(pos, 2, &temp));
+ output->Write(pos, (temp ^ 0xffff), 2);
+ return CHANGE;
+}
+
+PacketFilter::Action ClientHelloPreambleCapture::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ EXPECT_TRUE(header.handshake_type() == kTlsHandshakeClientHello);
+
+ if (captured_) {
+ return KEEP;
+ }
+ captured_ = true;
+
+ DataBuffer temp;
+ TlsParser parser(input);
+ EXPECT_TRUE(parser.Read(&temp, 2 + 32)); // Version + Random
+ EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Session ID
+ if (is_dtls_agent()) {
+ EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Cookie
+ }
+ EXPECT_TRUE(parser.ReadVariable(&temp, 2)); // Ciphersuites
+ EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Compression
+
+ // Copy the preamble into a new buffer
+ data_ = input;
+ data_.Truncate(parser.consumed());
+
+ return KEEP;
+}
+
+PacketFilter::Action ClientHelloCiphersuiteCapture::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ EXPECT_TRUE(header.handshake_type() == kTlsHandshakeClientHello);
+
+ if (captured_) {
+ return KEEP;
+ }
+ captured_ = true;
+
+ TlsParser parser(input);
+ EXPECT_TRUE(parser.Skip(2 + 32)); // Version + Random
+ EXPECT_TRUE(parser.SkipVariable(1)); // Session ID
+ if (is_dtls_agent()) {
+ EXPECT_TRUE(parser.SkipVariable(1)); // Cookie
+ }
+
+ EXPECT_TRUE(parser.ReadVariable(&data_, 2)); // Ciphersuites
+
+ return KEEP;
+}
+
+} // namespace nss_test