/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ #ifndef tls_filter_h_ #define tls_filter_h_ #include #include #include #include #include "pk11pub.h" #include "sslt.h" #include "sslproto.h" #include "test_io.h" #include "tls_agent.h" #include "tls_parser.h" #include "tls_protect.h" extern "C" { #include "libssl_internals.h" } namespace nss_test { class TlsCipherSpec; class TlsSendCipherSpecCapturer { public: TlsSendCipherSpecCapturer(const std::shared_ptr& agent) : agent_(agent), send_cipher_specs_() { EXPECT_EQ(SECSuccess, SSL_SecretCallback(agent_->ssl_fd(), SecretCallback, this)); } std::shared_ptr spec(size_t i) { if (i >= send_cipher_specs_.size()) { return nullptr; } return send_cipher_specs_[i]; } private: static void SecretCallback(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir, PK11SymKey* secret, void* arg) { auto self = static_cast(arg); std::cerr << self->agent_->role_str() << ": capture " << dir << " secret for epoch " << epoch << std::endl; if (dir == ssl_secret_read) { return; } SSLPreliminaryChannelInfo preinfo; EXPECT_EQ(SECSuccess, SSL_GetPreliminaryChannelInfo(self->agent_->ssl_fd(), &preinfo, sizeof(preinfo))); EXPECT_EQ(sizeof(preinfo), preinfo.length); EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite); // Check the version: EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version); ASSERT_GE(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion); SSLCipherSuiteInfo cipherinfo; EXPECT_EQ(SECSuccess, SSL_GetCipherSuiteInfo(preinfo.cipherSuite, &cipherinfo, sizeof(cipherinfo))); EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length); auto spec = std::make_shared(true, epoch); EXPECT_TRUE(spec->SetKeys(&cipherinfo, secret)); self->send_cipher_specs_.push_back(spec); } std::shared_ptr agent_; std::vector> send_cipher_specs_; }; class TlsVersioned { public: TlsVersioned() : variant_(ssl_variant_stream), version_(0) {} TlsVersioned(SSLProtocolVariant var, uint16_t ver) : variant_(var), version_(ver) {} bool is_dtls() const { return variant_ == ssl_variant_datagram; } SSLProtocolVariant variant() const { return variant_; } uint16_t version() const { return version_; } void WriteStream(std::ostream& stream) const; protected: SSLProtocolVariant variant_; uint16_t version_; }; class TlsRecordHeader : public TlsVersioned { public: TlsRecordHeader() : TlsVersioned(), content_type_(0), guess_seqno_(0), seqno_is_masked_(false), sequence_number_(0), header_() {} TlsRecordHeader(SSLProtocolVariant var, uint16_t ver, uint8_t ct, uint64_t seqno) : TlsVersioned(var, ver), content_type_(ct), guess_seqno_(0), seqno_is_masked_(false), sequence_number_(seqno), header_(), sn_mask_() {} bool is_protected() const { // *TLS < 1.3 if (version() < SSL_LIBRARY_VERSION_TLS_1_3 && content_type() == ssl_ct_application_data) { return true; } // TLS 1.3 if (!is_dtls() && version() >= SSL_LIBRARY_VERSION_TLS_1_3 && content_type() == ssl_ct_application_data) { return true; } // DTLS 1.3 return is_dtls13_ciphertext(); } uint8_t content_type() const { return content_type_; } uint16_t epoch() const { return static_cast(sequence_number_ >> 48); } uint64_t sequence_number() const { return sequence_number_; } void sequence_number(uint64_t seqno) { sequence_number_ = seqno; } const DataBuffer& sn_mask() const { return sn_mask_; } bool is_dtls13_ciphertext() const { return is_dtls() && (version() >= SSL_LIBRARY_VERSION_TLS_1_3) && (content_type() & kCtDtlsCiphertextMask) == kCtDtlsCiphertext; } size_t header_length() const; const DataBuffer& header() const { return header_; } bool MaskSequenceNumber(); bool MaskSequenceNumber(const DataBuffer& mask_buf); // Parse the header; return true if successful; body in an outparam if OK. bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser, DataBuffer* body); // Write the header and body to a buffer at the given offset. // Return the offset of the end of the write. size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; size_t WriteHeader(DataBuffer* buffer, size_t offset, size_t body_len) const; private: static uint64_t RecoverSequenceNumber(uint64_t guess_seqno, uint32_t partial, size_t partial_bits); uint64_t ParseSequenceNumber(uint64_t expected, uint64_t raw, size_t seq_no_bits, size_t epoch_bits); uint8_t content_type_; uint64_t guess_seqno_; bool seqno_is_masked_; uint64_t sequence_number_; DataBuffer header_; DataBuffer sn_mask_; }; struct TlsRecord { const TlsRecordHeader header; const DataBuffer buffer; }; // Make a filter and install it on a TlsAgent. template inline std::shared_ptr MakeTlsFilter(const std::shared_ptr& agent, Args&&... args) { auto filter = std::make_shared(agent, std::forward(args)...); agent->SetFilter(filter); return filter; } // Abstract filter that operates on entire (D)TLS records. class TlsRecordFilter : public PacketFilter { public: TlsRecordFilter(const std::shared_ptr& a); std::shared_ptr agent() const { return agent_.lock(); } // External interface. Overrides PacketFilter. PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output); // Report how many packets were altered by the filter. size_t filtered_packets() const { return count_; } // Enable decryption. This only works properly for TLS 1.3 and above. // Enabling it for lower version tests will cause undefined // behavior. void EnableDecryption(); bool decrypting() const { return decrypting_; }; bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText, uint16_t* protection_epoch, uint8_t* inner_content_type, DataBuffer* plaintext, TlsRecordHeader* out_header); bool Protect(TlsCipherSpec& protection_spec, const TlsRecordHeader& header, uint8_t inner_content_type, const DataBuffer& plaintext, DataBuffer* ciphertext, TlsRecordHeader* out_header, size_t padding = 0); protected: // There are two filter functions which can be overriden. Both are // called with the header and the record but the outer one is called // with a raw pointer to let you write into the buffer and lets you // do anything with this section of the stream. The inner one // just lets you change the record contents. By default, the // outer one calls the inner one, so if you override the outer // one, the inner one is never called unless you call it yourself. virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& record, size_t* offset, DataBuffer* output); // The record filter receives the record contentType, version and DTLS // sequence number (which is zero for TLS), plus the existing record payload. // It returns an action (KEEP, CHANGE, DROP). It writes to the `changed` // outparam with the new record contents if it chooses to CHANGE the record. virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& data, DataBuffer* changed) { return KEEP; } bool is_dtls_agent() const; bool is_dtls13() const; bool is_dtls13_ciphertext(uint8_t ct) const; TlsCipherSpec& spec(uint16_t epoch); private: static void SecretCallback(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir, PK11SymKey* secret, void* arg); std::weak_ptr agent_; size_t count_ = 0; std::vector cipher_specs_; bool decrypting_ = false; }; inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) { v.WriteStream(stream); return stream; } inline std::ostream& operator<<(std::ostream& stream, const TlsRecordHeader& hdr) { hdr.WriteStream(stream); stream << ' '; switch (hdr.content_type()) { case ssl_ct_change_cipher_spec: stream << "CCS"; break; case ssl_ct_alert: stream << "Alert"; break; case ssl_ct_handshake: stream << "Handshake"; break; case ssl_ct_application_data: stream << "Data"; break; case ssl_ct_ack: stream << "ACK"; break; default: stream << '<' << static_cast(hdr.content_type()) << '>'; break; } return stream << ' ' << std::hex << hdr.sequence_number() << std::dec; } // Abstract filter that operates on handshake messages rather than records. // This assumes that the handshake messages are written in a block as entire // records and that they don't span records or anything crazy like that. class TlsHandshakeFilter : public TlsRecordFilter { public: TlsHandshakeFilter(const std::shared_ptr& a) : TlsRecordFilter(a), handshake_types_(), preceding_fragment_() {} TlsHandshakeFilter(const std::shared_ptr& a, const std::set& types) : TlsRecordFilter(a), handshake_types_(types), preceding_fragment_() {} // This filter can be set to be selective based on handshake message type. If // this function isn't used (or the set is empty), then all handshake messages // will be filtered. void SetHandshakeTypes(const std::set& types) { handshake_types_ = types; } class HandshakeHeader : public TlsVersioned { public: HandshakeHeader() : TlsVersioned(), handshake_type_(0), message_seq_(0) {} uint8_t handshake_type() const { return handshake_type_; } bool Parse(TlsParser* parser, const TlsRecordHeader& record_header, const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete); size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; size_t WriteFragment(DataBuffer* buffer, size_t offset, const DataBuffer& body, size_t fragment_offset, size_t fragment_length) const; private: // Reads the length from the record header. // This also reads the DTLS fragment information and checks it. bool ReadLength(TlsParser* parser, const TlsRecordHeader& header, uint32_t expected_offset, uint32_t* length, bool* last_fragment); uint8_t handshake_type_; uint16_t message_seq_; // fragment_offset is always zero in these tests. }; protected: virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) = 0; private: bool IsFilteredType(const HandshakeHeader& header, const DataBuffer& handshake); std::set handshake_types_; DataBuffer preceding_fragment_; }; // Make a copy of the first instance of a handshake message. class TlsHandshakeRecorder : public TlsHandshakeFilter { public: TlsHandshakeRecorder(const std::shared_ptr& a, uint8_t handshake_type) : TlsHandshakeFilter(a, {handshake_type}), buffer_() {} TlsHandshakeRecorder(const std::shared_ptr& a, const std::set& handshake_types) : TlsHandshakeFilter(a, handshake_types), buffer_() {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); void Reset() { buffer_.Truncate(0); } const DataBuffer& buffer() const { return buffer_; } private: DataBuffer buffer_; }; // Replace all instances of a handshake message. class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { public: TlsInspectorReplaceHandshakeMessage(const std::shared_ptr& a, uint8_t handshake_type, const DataBuffer& replacement) : TlsHandshakeFilter(a, {handshake_type}), buffer_(replacement) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); private: DataBuffer buffer_; }; // Make a copy of each record of a given type. class TlsRecordRecorder : public TlsRecordFilter { public: TlsRecordRecorder(const std::shared_ptr& a, uint8_t ct) : TlsRecordFilter(a), filter_(true), ct_(ct), records_() {} TlsRecordRecorder(const std::shared_ptr& a) : TlsRecordFilter(a), filter_(false), ct_(ssl_ct_handshake), // dummy ( is C++14) records_() {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); size_t count() const { return records_.size(); } void Clear() { records_.clear(); } const TlsRecord& record(size_t i) const { return records_[i]; } private: bool filter_; uint8_t ct_; std::vector records_; }; // Make a copy of the complete conversation. class TlsConversationRecorder : public TlsRecordFilter { public: TlsConversationRecorder(const std::shared_ptr& a, DataBuffer& buffer) : TlsRecordFilter(a), buffer_(buffer) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); private: DataBuffer buffer_; }; // Make a copy of the records class TlsHeaderRecorder : public TlsRecordFilter { public: TlsHeaderRecorder(const std::shared_ptr& a) : TlsRecordFilter(a) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); const TlsRecordHeader* header(size_t index); private: std::vector headers_; }; typedef std::initializer_list> ChainedPacketFilterInit; // Runs multiple packet filters in series. class ChainedPacketFilter : public PacketFilter { public: ChainedPacketFilter() {} ChainedPacketFilter(const std::vector> filters) : filters_(filters.begin(), filters.end()) {} ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {} virtual ~ChainedPacketFilter() {} virtual PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output); // Takes ownership of the filter. void Add(std::shared_ptr filter) { filters_.push_back(filter); } private: std::vector> filters_; }; typedef std::function TlsExtensionFinder; class TlsExtensionFilter : public TlsHandshakeFilter { public: TlsExtensionFilter(const std::shared_ptr& a) : TlsHandshakeFilter(a, {kTlsHandshakeClientHello, kTlsHandshakeServerHello, kTlsHandshakeHelloRetryRequest, kTlsHandshakeEncryptedExtensions}) {} TlsExtensionFilter(const std::shared_ptr& a, const std::set& types) : TlsHandshakeFilter(a, types) {} static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header); protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override; virtual PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) = 0; private: PacketFilter::Action FilterExtensions(TlsParser* parser, const DataBuffer& input, DataBuffer* output); }; class TlsExtensionOrderCapture : public TlsExtensionFilter { public: TlsExtensionOrderCapture(const std::shared_ptr& a, uint8_t message) : TlsExtensionFilter(a, {message}){}; std::vector order; protected: PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) override; }; class TlsExtensionCapture : public TlsExtensionFilter { public: TlsExtensionCapture(const std::shared_ptr& a, uint16_t ext, bool last = false) : TlsExtensionFilter(a), extension_(ext), captured_(false), last_(last), data_() {} const DataBuffer& extension() const { return data_; } bool captured() const { return captured_; } protected: PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) override; private: const uint16_t extension_; bool captured_; bool last_; DataBuffer data_; }; class TlsExtensionReplacer : public TlsExtensionFilter { public: TlsExtensionReplacer(const std::shared_ptr& a, uint16_t extension, const DataBuffer& data) : TlsExtensionFilter(a), extension_(extension), data_(data) {} PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) override; private: const uint16_t extension_; const DataBuffer data_; }; class TlsExtensionResizer : public TlsExtensionFilter { public: TlsExtensionResizer(const std::shared_ptr& a, uint16_t extension, size_t length) : TlsExtensionFilter(a), extension_(extension), length_(length) {} PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) override; private: uint16_t extension_; size_t length_; }; class TlsExtensionAppender : public TlsHandshakeFilter { public: TlsExtensionAppender(const std::shared_ptr& a, uint8_t handshake_type, uint16_t ext, DataBuffer& data) : TlsHandshakeFilter(a, {handshake_type}), extension_(ext), data_(data) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); private: bool UpdateLength(DataBuffer* output, size_t offset, size_t size); const uint16_t extension_; const DataBuffer data_; }; class TlsExtensionDropper : public TlsExtensionFilter { public: TlsExtensionDropper(const std::shared_ptr& a, uint16_t extension) : TlsExtensionFilter(a), extension_(extension) {} PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer&, DataBuffer*) override; private: uint16_t extension_; }; class TlsHandshakeDropper : public TlsHandshakeFilter { public: TlsHandshakeDropper(const std::shared_ptr& a) : TlsHandshakeFilter(a) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { return DROP; } }; class TlsEncryptedHandshakeMessageReplacer : public TlsRecordFilter { public: TlsEncryptedHandshakeMessageReplacer(const std::shared_ptr& a, uint8_t old_ct, uint8_t new_ct) : TlsRecordFilter(a), old_ct_(old_ct), new_ct_(new_ct) {} protected: PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& record, size_t* offset, DataBuffer* output) override { if (header.content_type() != ssl_ct_application_data) { return KEEP; } uint16_t protection_epoch = 0; uint8_t inner_content_type; DataBuffer plaintext; TlsRecordHeader out_header; if (!Unprotect(header, record, &protection_epoch, &inner_content_type, &plaintext, &out_header) || !plaintext.len()) { return KEEP; } if (inner_content_type != ssl_ct_handshake) { return KEEP; } size_t off = 0; uint32_t msg_len = 0; uint32_t msg_type = 255; // Not a real message do { if (!plaintext.Read(off, 1, &msg_type) || msg_type == old_ct_) { break; } // Increment and check next messages if (!plaintext.Read(++off, 3, &msg_len)) { break; } off += 3 + msg_len; } while (msg_type != old_ct_); if (msg_type == old_ct_) { plaintext.Write(off, new_ct_, 1); } DataBuffer ciphertext; bool ok = Protect(spec(protection_epoch), out_header, inner_content_type, plaintext, &ciphertext, &out_header); if (!ok) { return KEEP; } *offset = out_header.Write(output, *offset, ciphertext); return CHANGE; } private: uint8_t old_ct_; uint8_t new_ct_; }; class TlsExtensionInjector : public TlsHandshakeFilter { public: TlsExtensionInjector(const std::shared_ptr& a, uint16_t ext, const DataBuffer& data) : TlsHandshakeFilter(a), extension_(ext), data_(data) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override; private: const uint16_t extension_; const DataBuffer data_; }; class TlsExtensionDamager : public TlsExtensionFilter { public: TlsExtensionDamager(const std::shared_ptr& a, uint16_t extension, size_t index) : TlsExtensionFilter(a), extension_(extension), index_(index) {} virtual PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output); private: uint16_t extension_; size_t index_; }; typedef std::function VoidFunction; class AfterRecordN : public TlsRecordFilter { public: AfterRecordN(const std::shared_ptr& src, const std::shared_ptr& dest, unsigned int record, VoidFunction func) : TlsRecordFilter(src), dest_(dest), record_(record), func_(func), counter_(0) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& body, DataBuffer* out) override; private: std::weak_ptr dest_; unsigned int record_; VoidFunction func_; unsigned int counter_; }; // When we see the ClientKeyExchange from |client|, increment the // ClientHelloVersion on |server|. class TlsClientHelloVersionChanger : public TlsHandshakeFilter { public: TlsClientHelloVersionChanger(const std::shared_ptr& client, const std::shared_ptr& server) : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}), server_(server) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); private: std::weak_ptr server_; }; // Damage a record. class TlsRecordLastByteDamager : public TlsRecordFilter { public: TlsRecordLastByteDamager(const std::shared_ptr& a) : TlsRecordFilter(a) {} protected: PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& data, DataBuffer* changed) override { *changed = data; changed->data()[changed->len() - 1]++; return CHANGE; } }; // Saves the first received message into a buffer and then drops it. // After receiving, the filter is disabled. class TLSRecordSaveAndDropNext : public TlsRecordFilter { public: TLSRecordSaveAndDropNext(const std::shared_ptr& a) : TlsRecordFilter(a), replaced_(false), data_(0) {} DataBuffer ReturnRecorded() { return data_; } protected: PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output) { if (!replaced_) { data_ = input; replaced_ = true; return DROP; } return KEEP; } private: bool replaced_; DataBuffer data_; }; // This class selectively drops complete writes. This relies on the fact that // writes in libssl are on record boundaries. class SelectiveDropFilter : public PacketFilter { public: SelectiveDropFilter(uint32_t pattern) : pattern_(pattern), counter_(0) {} protected: virtual PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output) override; private: const uint32_t pattern_; uint8_t counter_; }; // This class selectively drops complete records. The difference from // SelectiveDropFilter is that if multiple DTLS records are in the same // datagram, we just drop one. class SelectiveRecordDropFilter : public TlsRecordFilter { public: SelectiveRecordDropFilter(const std::shared_ptr& a, uint32_t pattern, bool on = true) : TlsRecordFilter(a), pattern_(pattern), counter_(0) { if (!on) { Disable(); } } SelectiveRecordDropFilter(const std::shared_ptr& a, std::initializer_list records) : SelectiveRecordDropFilter(a, ToPattern(records), true) {} void Reset(uint32_t pattern) { counter_ = 0; PacketFilter::Enable(); pattern_ = pattern; } void Reset(std::initializer_list records) { Reset(ToPattern(records)); } protected: PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& data, DataBuffer* changed) override; private: static uint32_t ToPattern(std::initializer_list records); uint32_t pattern_; uint8_t counter_; }; // Set the version value in the ClientHello, ServerHello or HelloRetryRequest class TlsMessageVersionSetter : public TlsHandshakeFilter { public: TlsMessageVersionSetter(const std::shared_ptr& a, uint8_t message, uint16_t version) : TlsHandshakeFilter(a, {message}), version_(version) { PR_ASSERT(message == kTlsHandshakeClientHello || message == kTlsHandshakeServerHello || message == kTlsHandshakeHelloRetryRequest); } virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); private: uint16_t version_; }; // Damages the last byte of a handshake message. class TlsLastByteDamager : public TlsHandshakeFilter { public: TlsLastByteDamager(const std::shared_ptr& a, uint8_t type) : TlsHandshakeFilter(a), type_(type) {} PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { if (header.handshake_type() != type_) { return KEEP; } *output = input; output->data()[output->len() - 1]++; return CHANGE; } private: uint8_t type_; }; class SelectedCipherSuiteReplacer : public TlsHandshakeFilter { public: SelectedCipherSuiteReplacer(const std::shared_ptr& a, uint16_t suite) : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}), cipher_suite_(suite) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override; private: uint16_t cipher_suite_; }; class ClientHelloPreambleCapture : public TlsHandshakeFilter { public: ClientHelloPreambleCapture(const std::shared_ptr& a) : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}), captured_(false), data_() {} const DataBuffer& contents() const { return data_; } bool captured() const { return captured_; } protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override; private: bool captured_; DataBuffer data_; }; class ClientHelloCiphersuiteCapture : public TlsHandshakeFilter { public: ClientHelloCiphersuiteCapture(const std::shared_ptr& a) : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}), captured_(false), data_() {} const DataBuffer& contents() const { return data_; } bool captured() const { return captured_; } protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override; private: bool captured_; DataBuffer data_; }; class ServerHelloRandomChanger : public TlsHandshakeFilter { public: ServerHelloRandomChanger(const std::shared_ptr& a) : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override; }; // Replace SignatureAndHashAlgorithm of a SKE. class DHEServerKEXSigAlgReplacer : public TlsHandshakeFilter { public: DHEServerKEXSigAlgReplacer(const std::shared_ptr& server, uint16_t sig_scheme) : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}), sig_scheme_(sig_scheme) {} protected: virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { *output = input; uint32_t len; uint32_t idx = 0; EXPECT_TRUE(output->Read(idx, 2, &len)); idx += 2 + len; EXPECT_TRUE(output->Read(idx, 2, &len)); idx += 2 + len; EXPECT_TRUE(output->Read(idx, 2, &len)); idx += 2 + len; output->Write(idx, sig_scheme_, 2); return CHANGE; } private: uint16_t sig_scheme_; }; // Replace SignatureAndHashAlgorithm of a SKE. class ECCServerKEXSigAlgReplacer : public TlsHandshakeFilter { public: ECCServerKEXSigAlgReplacer(const std::shared_ptr& server, uint16_t sig_scheme) : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}), sig_scheme_(sig_scheme) {} protected: virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { *output = input; uint32_t point_len; EXPECT_TRUE(output->Read(3, 1, &point_len)); output->Write(4 + point_len, sig_scheme_, 2); return CHANGE; } private: uint16_t sig_scheme_; }; // Replace NamedCurve of a ECDHE SKE. class ECCServerKEXNamedCurveReplacer : public TlsHandshakeFilter { public: ECCServerKEXNamedCurveReplacer(const std::shared_ptr& server, uint16_t curve_name) : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}), curve_name_(curve_name) {} protected: virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { *output = input; uint32_t curve_type; EXPECT_TRUE(output->Read(0, 1, &curve_type)); EXPECT_EQ(curve_type, ec_type_named); output->Write(1, curve_name_, 2); return CHANGE; } private: uint16_t curve_name_; }; // Replaces the signature scheme in a CertificateVerify message. class TlsReplaceSignatureSchemeFilter : public TlsHandshakeFilter { public: TlsReplaceSignatureSchemeFilter(const std::shared_ptr& a, uint16_t scheme) : TlsHandshakeFilter(a, {kTlsHandshakeCertificateVerify}), scheme_(scheme) {} protected: virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { *output = input; output->Write(0, scheme_, 2); return CHANGE; } private: uint16_t scheme_; }; } // namespace nss_test #endif