diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc | 169 |
1 files changed, 169 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc new file mode 100644 index 0000000000..3752812633 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc @@ -0,0 +1,169 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +// This class cuts every unencrypted handshake record into two parts. +class RecordFragmenter : public PacketFilter { + public: + RecordFragmenter(bool is_dtls13) + : is_dtls13_(is_dtls13), sequence_number_(0), splitting_(true) {} + + private: + class HandshakeSplitter { + public: + HandshakeSplitter(bool is_dtls13, const DataBuffer& input, + DataBuffer* output, uint64_t* sequence_number) + : is_dtls13_(is_dtls13), + input_(input), + output_(output), + cursor_(0), + sequence_number_(sequence_number) {} + + private: + void WriteRecord(TlsRecordHeader& record_header, + DataBuffer& record_fragment) { + TlsRecordHeader fragment_header( + record_header.variant(), record_header.version(), + record_header.content_type(), *sequence_number_); + ++*sequence_number_; + if (::g_ssl_gtest_verbose) { + std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment + << std::endl; + } + cursor_ = fragment_header.Write(output_, cursor_, record_fragment); + } + + bool SplitRecord(TlsRecordHeader& record_header, DataBuffer& record) { + TlsParser parser(record); + while (parser.remaining()) { + TlsHandshakeFilter::HandshakeHeader handshake_header; + DataBuffer handshake_body; + bool complete = false; + if (!handshake_header.Parse(&parser, record_header, DataBuffer(), + &handshake_body, &complete)) { + ADD_FAILURE() << "couldn't parse handshake header"; + return false; + } + if (!complete) { + ADD_FAILURE() << "don't want to deal with fragmented messages"; + return false; + } + + DataBuffer record_fragment; + // We can't fragment handshake records that are too small. + if (handshake_body.len() < 2) { + handshake_header.Write(&record_fragment, 0U, handshake_body); + WriteRecord(record_header, record_fragment); + continue; + } + + size_t cut = handshake_body.len() / 2; + handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, 0U, + cut); + WriteRecord(record_header, record_fragment); + + handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, + cut, handshake_body.len() - cut); + WriteRecord(record_header, record_fragment); + } + return true; + } + + public: + bool Split() { + TlsParser parser(input_); + while (parser.remaining()) { + TlsRecordHeader header; + DataBuffer record; + if (!header.Parse(is_dtls13_, 0, &parser, &record)) { + ADD_FAILURE() << "bad record header"; + return false; + } + + if (::g_ssl_gtest_verbose) { + std::cerr << "Record: " << header << ' ' << record << std::endl; + } + + // Don't touch packets from a non-zero epoch. Leave these unmodified. + if ((header.sequence_number() >> 48) != 0ULL) { + cursor_ = header.Write(output_, cursor_, record); + continue; + } + + // Just rewrite the sequence number (CCS only). + if (header.content_type() != ssl_ct_handshake) { + EXPECT_EQ(ssl_ct_change_cipher_spec, header.content_type()); + WriteRecord(header, record); + continue; + } + + if (!SplitRecord(header, record)) { + return false; + } + } + return true; + } + + private: + bool is_dtls13_; + const DataBuffer& input_; + DataBuffer* output_; + size_t cursor_; + uint64_t* sequence_number_; + }; + + protected: + virtual PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) override { + if (!splitting_) { + return KEEP; + } + + output->Allocate(input.len()); + HandshakeSplitter splitter(is_dtls13_, input, output, &sequence_number_); + if (!splitter.Split()) { + // If splitting fails, we obviously reached encrypted packets. + // Stop splitting from that point onward. + splitting_ = false; + return KEEP; + } + + return CHANGE; + } + + private: + bool is_dtls13_; + uint64_t sequence_number_; + bool splitting_; +}; + +TEST_P(TlsConnectDatagram, FragmentClientPackets) { + bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3; + client_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13)); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectDatagram, FragmentServerPackets) { + bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3; + server_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13)); + Connect(); + SendReceive(); +} + +} // namespace nss_test |