summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_filter.h
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /security/nss/gtests/ssl_gtest/tls_filter.h
parentInitial commit. (diff)
downloadfirefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz
firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.h')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.h1013
1 files changed, 1013 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h
new file mode 100644
index 0000000000..7c45aab12f
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_filter.h
@@ -0,0 +1,1013 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_filter_h_
+#define tls_filter_h_
+
+#include <functional>
+#include <memory>
+#include <set>
+#include <vector>
+#include "pk11pub.h"
+#include "sslt.h"
+#include "sslproto.h"
+#include "test_io.h"
+#include "tls_agent.h"
+#include "tls_parser.h"
+#include "tls_protect.h"
+
+extern "C" {
+#include "libssl_internals.h"
+}
+
+namespace nss_test {
+
+class TlsCipherSpec;
+
+class TlsSendCipherSpecCapturer {
+ public:
+ TlsSendCipherSpecCapturer(const std::shared_ptr<TlsAgent>& agent)
+ : agent_(agent), send_cipher_specs_() {
+ EXPECT_EQ(SECSuccess,
+ SSL_SecretCallback(agent_->ssl_fd(), SecretCallback, this));
+ }
+
+ std::shared_ptr<TlsCipherSpec> spec(size_t i) {
+ if (i >= send_cipher_specs_.size()) {
+ return nullptr;
+ }
+ return send_cipher_specs_[i];
+ }
+
+ private:
+ static void SecretCallback(PRFileDesc* fd, PRUint16 epoch,
+ SSLSecretDirection dir, PK11SymKey* secret,
+ void* arg) {
+ auto self = static_cast<TlsSendCipherSpecCapturer*>(arg);
+ std::cerr << self->agent_->role_str() << ": capture " << dir
+ << " secret for epoch " << epoch << std::endl;
+
+ if (dir == ssl_secret_read) {
+ return;
+ }
+
+ SSLPreliminaryChannelInfo preinfo;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetPreliminaryChannelInfo(self->agent_->ssl_fd(), &preinfo,
+ sizeof(preinfo)));
+ EXPECT_EQ(sizeof(preinfo), preinfo.length);
+ EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite);
+
+ // Check the version:
+ EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version);
+ ASSERT_GE(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion);
+
+ SSLCipherSuiteInfo cipherinfo;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetCipherSuiteInfo(preinfo.cipherSuite, &cipherinfo,
+ sizeof(cipherinfo)));
+ EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length);
+
+ auto spec = std::make_shared<TlsCipherSpec>(true, epoch);
+ EXPECT_TRUE(spec->SetKeys(&cipherinfo, secret));
+ self->send_cipher_specs_.push_back(spec);
+ }
+
+ std::shared_ptr<TlsAgent> agent_;
+ std::vector<std::shared_ptr<TlsCipherSpec>> send_cipher_specs_;
+};
+
+class TlsVersioned {
+ public:
+ TlsVersioned() : variant_(ssl_variant_stream), version_(0) {}
+ TlsVersioned(SSLProtocolVariant var, uint16_t ver)
+ : variant_(var), version_(ver) {}
+
+ bool is_dtls() const { return variant_ == ssl_variant_datagram; }
+ SSLProtocolVariant variant() const { return variant_; }
+ uint16_t version() const { return version_; }
+
+ void WriteStream(std::ostream& stream) const;
+
+ protected:
+ SSLProtocolVariant variant_;
+ uint16_t version_;
+};
+
+class TlsRecordHeader : public TlsVersioned {
+ public:
+ TlsRecordHeader()
+ : TlsVersioned(),
+ content_type_(0),
+ guess_seqno_(0),
+ seqno_is_masked_(false),
+ sequence_number_(0),
+ header_() {}
+ TlsRecordHeader(SSLProtocolVariant var, uint16_t ver, uint8_t ct,
+ uint64_t seqno)
+ : TlsVersioned(var, ver),
+ content_type_(ct),
+ guess_seqno_(0),
+ seqno_is_masked_(false),
+ sequence_number_(seqno),
+ header_(),
+ sn_mask_() {}
+
+ bool is_protected() const {
+ // *TLS < 1.3
+ if (version() < SSL_LIBRARY_VERSION_TLS_1_3 &&
+ content_type() == ssl_ct_application_data) {
+ return true;
+ }
+
+ // TLS 1.3
+ if (!is_dtls() && version() >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ content_type() == ssl_ct_application_data) {
+ return true;
+ }
+
+ // DTLS 1.3
+ return is_dtls13_ciphertext();
+ }
+
+ uint8_t content_type() const { return content_type_; }
+ uint16_t epoch() const {
+ return static_cast<uint16_t>(sequence_number_ >> 48);
+ }
+ uint64_t sequence_number() const { return sequence_number_; }
+ void sequence_number(uint64_t seqno) { sequence_number_ = seqno; }
+ const DataBuffer& sn_mask() const { return sn_mask_; }
+ bool is_dtls13_ciphertext() const {
+ return is_dtls() && (version() >= SSL_LIBRARY_VERSION_TLS_1_3) &&
+ (content_type() & kCtDtlsCiphertextMask) == kCtDtlsCiphertext;
+ }
+
+ size_t header_length() const;
+ const DataBuffer& header() const { return header_; }
+
+ bool MaskSequenceNumber();
+ bool MaskSequenceNumber(const DataBuffer& mask_buf);
+
+ // Parse the header; return true if successful; body in an outparam if OK.
+ bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser,
+ DataBuffer* body);
+ // Write the header and body to a buffer at the given offset.
+ // Return the offset of the end of the write.
+ size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
+ size_t WriteHeader(DataBuffer* buffer, size_t offset, size_t body_len) const;
+
+ private:
+ static uint64_t RecoverSequenceNumber(uint64_t guess_seqno, uint32_t partial,
+ size_t partial_bits);
+ uint64_t ParseSequenceNumber(uint64_t expected, uint64_t raw,
+ size_t seq_no_bits, size_t epoch_bits);
+
+ uint8_t content_type_;
+ uint64_t guess_seqno_;
+ bool seqno_is_masked_;
+ uint64_t sequence_number_;
+ DataBuffer header_;
+ DataBuffer sn_mask_;
+};
+
+struct TlsRecord {
+ const TlsRecordHeader header;
+ const DataBuffer buffer;
+};
+
+// Make a filter and install it on a TlsAgent.
+template <class T, typename... Args>
+inline std::shared_ptr<T> MakeTlsFilter(const std::shared_ptr<TlsAgent>& agent,
+ Args&&... args) {
+ auto filter = std::make_shared<T>(agent, std::forward<Args>(args)...);
+ agent->SetFilter(filter);
+ return filter;
+}
+
+// Abstract filter that operates on entire (D)TLS records.
+class TlsRecordFilter : public PacketFilter {
+ public:
+ TlsRecordFilter(const std::shared_ptr<TlsAgent>& a);
+
+ std::shared_ptr<TlsAgent> agent() const { return agent_.lock(); }
+
+ // External interface. Overrides PacketFilter.
+ PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output);
+
+ // Report how many packets were altered by the filter.
+ size_t filtered_packets() const { return count_; }
+
+ // Enable decryption. This only works properly for TLS 1.3 and above.
+ // Enabling it for lower version tests will cause undefined
+ // behavior.
+ void EnableDecryption();
+ bool decrypting() const { return decrypting_; };
+ bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText,
+ uint16_t* protection_epoch, uint8_t* inner_content_type,
+ DataBuffer* plaintext, TlsRecordHeader* out_header);
+ bool Protect(TlsCipherSpec& protection_spec, const TlsRecordHeader& header,
+ uint8_t inner_content_type, const DataBuffer& plaintext,
+ DataBuffer* ciphertext, TlsRecordHeader* out_header,
+ size_t padding = 0);
+
+ protected:
+ // There are two filter functions which can be overriden. Both are
+ // called with the header and the record but the outer one is called
+ // with a raw pointer to let you write into the buffer and lets you
+ // do anything with this section of the stream. The inner one
+ // just lets you change the record contents. By default, the
+ // outer one calls the inner one, so if you override the outer
+ // one, the inner one is never called unless you call it yourself.
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& record,
+ size_t* offset, DataBuffer* output);
+
+ // The record filter receives the record contentType, version and DTLS
+ // sequence number (which is zero for TLS), plus the existing record payload.
+ // It returns an action (KEEP, CHANGE, DROP). It writes to the `changed`
+ // outparam with the new record contents if it chooses to CHANGE the record.
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) {
+ return KEEP;
+ }
+
+ bool is_dtls_agent() const;
+ bool is_dtls13() const;
+ bool is_dtls13_ciphertext(uint8_t ct) const;
+ TlsCipherSpec& spec(uint16_t epoch);
+
+ private:
+ static void SecretCallback(PRFileDesc* fd, PRUint16 epoch,
+ SSLSecretDirection dir, PK11SymKey* secret,
+ void* arg);
+
+ std::weak_ptr<TlsAgent> agent_;
+ size_t count_ = 0;
+ std::vector<TlsCipherSpec> cipher_specs_;
+ bool decrypting_ = false;
+};
+
+inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) {
+ v.WriteStream(stream);
+ return stream;
+}
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const TlsRecordHeader& hdr) {
+ hdr.WriteStream(stream);
+ stream << ' ';
+ switch (hdr.content_type()) {
+ case ssl_ct_change_cipher_spec:
+ stream << "CCS";
+ break;
+ case ssl_ct_alert:
+ stream << "Alert";
+ break;
+ case ssl_ct_handshake:
+ stream << "Handshake";
+ break;
+ case ssl_ct_application_data:
+ stream << "Data";
+ break;
+ case ssl_ct_ack:
+ stream << "ACK";
+ break;
+ default:
+ stream << '<' << static_cast<int>(hdr.content_type()) << '>';
+ break;
+ }
+ return stream << ' ' << std::hex << hdr.sequence_number() << std::dec;
+}
+
+// Abstract filter that operates on handshake messages rather than records.
+// This assumes that the handshake messages are written in a block as entire
+// records and that they don't span records or anything crazy like that.
+class TlsHandshakeFilter : public TlsRecordFilter {
+ public:
+ TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a), handshake_types_(), preceding_fragment_() {}
+ TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a,
+ const std::set<uint8_t>& types)
+ : TlsRecordFilter(a), handshake_types_(types), preceding_fragment_() {}
+
+ // This filter can be set to be selective based on handshake message type. If
+ // this function isn't used (or the set is empty), then all handshake messages
+ // will be filtered.
+ void SetHandshakeTypes(const std::set<uint8_t>& types) {
+ handshake_types_ = types;
+ }
+
+ class HandshakeHeader : public TlsVersioned {
+ public:
+ HandshakeHeader() : TlsVersioned(), handshake_type_(0), message_seq_(0) {}
+
+ uint8_t handshake_type() const { return handshake_type_; }
+ bool Parse(TlsParser* parser, const TlsRecordHeader& record_header,
+ const DataBuffer& preceding_fragment, DataBuffer* body,
+ bool* complete);
+ size_t Write(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body) const;
+ size_t WriteFragment(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body, size_t fragment_offset,
+ size_t fragment_length) const;
+
+ private:
+ // Reads the length from the record header.
+ // This also reads the DTLS fragment information and checks it.
+ bool ReadLength(TlsParser* parser, const TlsRecordHeader& header,
+ uint32_t expected_offset, uint32_t* length,
+ bool* last_fragment);
+
+ uint8_t handshake_type_;
+ uint16_t message_seq_;
+ // fragment_offset is always zero in these tests.
+ };
+
+ protected:
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) = 0;
+
+ private:
+ bool IsFilteredType(const HandshakeHeader& header,
+ const DataBuffer& handshake);
+
+ std::set<uint8_t> handshake_types_;
+ DataBuffer preceding_fragment_;
+};
+
+// Make a copy of the first instance of a handshake message.
+class TlsHandshakeRecorder : public TlsHandshakeFilter {
+ public:
+ TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a,
+ uint8_t handshake_type)
+ : TlsHandshakeFilter(a, {handshake_type}), buffer_() {}
+ TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a,
+ const std::set<uint8_t>& handshake_types)
+ : TlsHandshakeFilter(a, handshake_types), buffer_() {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ void Reset() { buffer_.Truncate(0); }
+
+ const DataBuffer& buffer() const { return buffer_; }
+
+ private:
+ DataBuffer buffer_;
+};
+
+// Replace all instances of a handshake message.
+class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
+ public:
+ TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& a,
+ uint8_t handshake_type,
+ const DataBuffer& replacement)
+ : TlsHandshakeFilter(a, {handshake_type}), buffer_(replacement) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ DataBuffer buffer_;
+};
+
+// Make a copy of each record of a given type.
+class TlsRecordRecorder : public TlsRecordFilter {
+ public:
+ TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a, uint8_t ct)
+ : TlsRecordFilter(a), filter_(true), ct_(ct), records_() {}
+ TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a),
+ filter_(false),
+ ct_(ssl_ct_handshake), // dummy (<optional> is C++14)
+ records_() {}
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ size_t count() const { return records_.size(); }
+ void Clear() { records_.clear(); }
+
+ const TlsRecord& record(size_t i) const { return records_[i]; }
+
+ private:
+ bool filter_;
+ uint8_t ct_;
+ std::vector<TlsRecord> records_;
+};
+
+// Make a copy of the complete conversation.
+class TlsConversationRecorder : public TlsRecordFilter {
+ public:
+ TlsConversationRecorder(const std::shared_ptr<TlsAgent>& a,
+ DataBuffer& buffer)
+ : TlsRecordFilter(a), buffer_(buffer) {}
+
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ DataBuffer buffer_;
+};
+
+// Make a copy of the records
+class TlsHeaderRecorder : public TlsRecordFilter {
+ public:
+ TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& a) : TlsRecordFilter(a) {}
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+ const TlsRecordHeader* header(size_t index);
+
+ private:
+ std::vector<TlsRecordHeader> headers_;
+};
+
+typedef std::initializer_list<std::shared_ptr<PacketFilter>>
+ ChainedPacketFilterInit;
+
+// Runs multiple packet filters in series.
+class ChainedPacketFilter : public PacketFilter {
+ public:
+ ChainedPacketFilter() {}
+ ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters)
+ : filters_(filters.begin(), filters.end()) {}
+ ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {}
+ virtual ~ChainedPacketFilter() {}
+
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output);
+
+ // Takes ownership of the filter.
+ void Add(std::shared_ptr<PacketFilter> filter) { filters_.push_back(filter); }
+
+ private:
+ std::vector<std::shared_ptr<PacketFilter>> filters_;
+};
+
+typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
+ TlsExtensionFinder;
+
+class TlsExtensionFilter : public TlsHandshakeFilter {
+ public:
+ TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a,
+ {kTlsHandshakeClientHello, kTlsHandshakeServerHello,
+ kTlsHandshakeHelloRetryRequest,
+ kTlsHandshakeEncryptedExtensions}) {}
+
+ TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a,
+ const std::set<uint8_t>& types)
+ : TlsHandshakeFilter(a, types) {}
+
+ static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) = 0;
+
+ private:
+ PacketFilter::Action FilterExtensions(TlsParser* parser,
+ const DataBuffer& input,
+ DataBuffer* output);
+};
+
+class TlsExtensionOrderCapture : public TlsExtensionFilter {
+ public:
+ TlsExtensionOrderCapture(const std::shared_ptr<TlsAgent>& a, uint8_t message)
+ : TlsExtensionFilter(a, {message}){};
+
+ std::vector<uint16_t> order;
+
+ protected:
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+};
+
+class TlsExtensionCapture : public TlsExtensionFilter {
+ public:
+ TlsExtensionCapture(const std::shared_ptr<TlsAgent>& a, uint16_t ext,
+ bool last = false)
+ : TlsExtensionFilter(a),
+ extension_(ext),
+ captured_(false),
+ last_(last),
+ data_() {}
+
+ const DataBuffer& extension() const { return data_; }
+ bool captured() const { return captured_; }
+
+ protected:
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ const uint16_t extension_;
+ bool captured_;
+ bool last_;
+ DataBuffer data_;
+};
+
+class TlsExtensionReplacer : public TlsExtensionFilter {
+ public:
+ TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& a, uint16_t extension,
+ const DataBuffer& data)
+ : TlsExtensionFilter(a), extension_(extension), data_(data) {}
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ const uint16_t extension_;
+ const DataBuffer data_;
+};
+
+class TlsExtensionResizer : public TlsExtensionFilter {
+ public:
+ TlsExtensionResizer(const std::shared_ptr<TlsAgent>& a, uint16_t extension,
+ size_t length)
+ : TlsExtensionFilter(a), extension_(extension), length_(length) {}
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ uint16_t extension_;
+ size_t length_;
+};
+
+class TlsExtensionAppender : public TlsHandshakeFilter {
+ public:
+ TlsExtensionAppender(const std::shared_ptr<TlsAgent>& a,
+ uint8_t handshake_type, uint16_t ext, DataBuffer& data)
+ : TlsHandshakeFilter(a, {handshake_type}), extension_(ext), data_(data) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ bool UpdateLength(DataBuffer* output, size_t offset, size_t size);
+
+ const uint16_t extension_;
+ const DataBuffer data_;
+};
+
+class TlsExtensionDropper : public TlsExtensionFilter {
+ public:
+ TlsExtensionDropper(const std::shared_ptr<TlsAgent>& a, uint16_t extension)
+ : TlsExtensionFilter(a), extension_(extension) {}
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer&, DataBuffer*) override;
+
+ private:
+ uint16_t extension_;
+};
+
+class TlsHandshakeDropper : public TlsHandshakeFilter {
+ public:
+ TlsHandshakeDropper(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ return DROP;
+ }
+};
+
+class TlsEncryptedHandshakeMessageReplacer : public TlsRecordFilter {
+ public:
+ TlsEncryptedHandshakeMessageReplacer(const std::shared_ptr<TlsAgent>& a,
+ uint8_t old_ct, uint8_t new_ct)
+ : TlsRecordFilter(a), old_ct_(old_ct), new_ct_(new_ct) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& record, size_t* offset,
+ DataBuffer* output) override {
+ if (header.content_type() != ssl_ct_application_data) {
+ return KEEP;
+ }
+
+ uint16_t protection_epoch = 0;
+ uint8_t inner_content_type;
+ DataBuffer plaintext;
+ TlsRecordHeader out_header;
+ if (!Unprotect(header, record, &protection_epoch, &inner_content_type,
+ &plaintext, &out_header) ||
+ !plaintext.len()) {
+ return KEEP;
+ }
+
+ if (inner_content_type != ssl_ct_handshake) {
+ return KEEP;
+ }
+
+ size_t off = 0;
+ uint32_t msg_len = 0;
+ uint32_t msg_type = 255; // Not a real message
+ do {
+ if (!plaintext.Read(off, 1, &msg_type) || msg_type == old_ct_) {
+ break;
+ }
+
+ // Increment and check next messages
+ if (!plaintext.Read(++off, 3, &msg_len)) {
+ break;
+ }
+ off += 3 + msg_len;
+ } while (msg_type != old_ct_);
+
+ if (msg_type == old_ct_) {
+ plaintext.Write(off, new_ct_, 1);
+ }
+
+ DataBuffer ciphertext;
+ bool ok = Protect(spec(protection_epoch), out_header, inner_content_type,
+ plaintext, &ciphertext, &out_header);
+ if (!ok) {
+ return KEEP;
+ }
+ *offset = out_header.Write(output, *offset, ciphertext);
+ return CHANGE;
+ }
+
+ private:
+ uint8_t old_ct_;
+ uint8_t new_ct_;
+};
+
+class TlsExtensionInjector : public TlsHandshakeFilter {
+ public:
+ TlsExtensionInjector(const std::shared_ptr<TlsAgent>& a, uint16_t ext,
+ const DataBuffer& data)
+ : TlsHandshakeFilter(a), extension_(ext), data_(data) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ const uint16_t extension_;
+ const DataBuffer data_;
+};
+
+class TlsExtensionDamager : public TlsExtensionFilter {
+ public:
+ TlsExtensionDamager(const std::shared_ptr<TlsAgent>& a, uint16_t extension,
+ size_t index)
+ : TlsExtensionFilter(a), extension_(extension), index_(index) {}
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ uint16_t extension_;
+ size_t index_;
+};
+
+typedef std::function<void(void)> VoidFunction;
+
+class AfterRecordN : public TlsRecordFilter {
+ public:
+ AfterRecordN(const std::shared_ptr<TlsAgent>& src,
+ const std::shared_ptr<TlsAgent>& dest, unsigned int record,
+ VoidFunction func)
+ : TlsRecordFilter(src),
+ dest_(dest),
+ record_(record),
+ func_(func),
+ counter_(0) {}
+
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& body,
+ DataBuffer* out) override;
+
+ private:
+ std::weak_ptr<TlsAgent> dest_;
+ unsigned int record_;
+ VoidFunction func_;
+ unsigned int counter_;
+};
+
+// When we see the ClientKeyExchange from |client|, increment the
+// ClientHelloVersion on |server|.
+class TlsClientHelloVersionChanger : public TlsHandshakeFilter {
+ public:
+ TlsClientHelloVersionChanger(const std::shared_ptr<TlsAgent>& client,
+ const std::shared_ptr<TlsAgent>& server)
+ : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}),
+ server_(server) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ std::weak_ptr<TlsAgent> server_;
+};
+
+// Damage a record.
+class TlsRecordLastByteDamager : public TlsRecordFilter {
+ public:
+ TlsRecordLastByteDamager(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ *changed = data;
+ changed->data()[changed->len() - 1]++;
+ return CHANGE;
+ }
+};
+
+// This class selectively drops complete writes. This relies on the fact that
+// writes in libssl are on record boundaries.
+class SelectiveDropFilter : public PacketFilter {
+ public:
+ SelectiveDropFilter(uint32_t pattern) : pattern_(pattern), counter_(0) {}
+
+ protected:
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ const uint32_t pattern_;
+ uint8_t counter_;
+};
+
+// This class selectively drops complete records. The difference from
+// SelectiveDropFilter is that if multiple DTLS records are in the same
+// datagram, we just drop one.
+class SelectiveRecordDropFilter : public TlsRecordFilter {
+ public:
+ SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a,
+ uint32_t pattern, bool on = true)
+ : TlsRecordFilter(a), pattern_(pattern), counter_(0) {
+ if (!on) {
+ Disable();
+ }
+ }
+ SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a,
+ std::initializer_list<size_t> records)
+ : SelectiveRecordDropFilter(a, ToPattern(records), true) {}
+
+ void Reset(uint32_t pattern) {
+ counter_ = 0;
+ PacketFilter::Enable();
+ pattern_ = pattern;
+ }
+
+ void Reset(std::initializer_list<size_t> records) {
+ Reset(ToPattern(records));
+ }
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override;
+
+ private:
+ static uint32_t ToPattern(std::initializer_list<size_t> records);
+
+ uint32_t pattern_;
+ uint8_t counter_;
+};
+
+// Set the version value in the ClientHello, ServerHello or HelloRetryRequest
+class TlsMessageVersionSetter : public TlsHandshakeFilter {
+ public:
+ TlsMessageVersionSetter(const std::shared_ptr<TlsAgent>& a, uint8_t message,
+ uint16_t version)
+ : TlsHandshakeFilter(a, {message}), version_(version) {
+ PR_ASSERT(message == kTlsHandshakeClientHello ||
+ message == kTlsHandshakeServerHello ||
+ message == kTlsHandshakeHelloRetryRequest);
+ }
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ uint16_t version_;
+};
+
+// Damages the last byte of a handshake message.
+class TlsLastByteDamager : public TlsHandshakeFilter {
+ public:
+ TlsLastByteDamager(const std::shared_ptr<TlsAgent>& a, uint8_t type)
+ : TlsHandshakeFilter(a), type_(type) {}
+ PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) override {
+ if (header.handshake_type() != type_) {
+ return KEEP;
+ }
+
+ *output = input;
+
+ output->data()[output->len() - 1]++;
+ return CHANGE;
+ }
+
+ private:
+ uint8_t type_;
+};
+
+class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
+ public:
+ SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& a,
+ uint16_t suite)
+ : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}),
+ cipher_suite_(suite) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ uint16_t cipher_suite_;
+};
+
+class ClientHelloPreambleCapture : public TlsHandshakeFilter {
+ public:
+ ClientHelloPreambleCapture(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}),
+ captured_(false),
+ data_() {}
+
+ const DataBuffer& contents() const { return data_; }
+ bool captured() const { return captured_; }
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ bool captured_;
+ DataBuffer data_;
+};
+
+class ClientHelloCiphersuiteCapture : public TlsHandshakeFilter {
+ public:
+ ClientHelloCiphersuiteCapture(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}),
+ captured_(false),
+ data_() {}
+
+ const DataBuffer& contents() const { return data_; }
+ bool captured() const { return captured_; }
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ bool captured_;
+ DataBuffer data_;
+};
+
+class ServerHelloRandomChanger : public TlsHandshakeFilter {
+ public:
+ ServerHelloRandomChanger(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+};
+
+// Replace SignatureAndHashAlgorithm of a SKE.
+class DHEServerKEXSigAlgReplacer : public TlsHandshakeFilter {
+ public:
+ DHEServerKEXSigAlgReplacer(const std::shared_ptr<TlsAgent>& server,
+ uint16_t sig_scheme)
+ : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}),
+ sig_scheme_(sig_scheme) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+
+ uint32_t len;
+ uint32_t idx = 0;
+ EXPECT_TRUE(output->Read(idx, 2, &len));
+ idx += 2 + len;
+ EXPECT_TRUE(output->Read(idx, 2, &len));
+ idx += 2 + len;
+ EXPECT_TRUE(output->Read(idx, 2, &len));
+ idx += 2 + len;
+ output->Write(idx, sig_scheme_, 2);
+
+ return CHANGE;
+ }
+
+ private:
+ uint16_t sig_scheme_;
+};
+
+// Replace SignatureAndHashAlgorithm of a SKE.
+class ECCServerKEXSigAlgReplacer : public TlsHandshakeFilter {
+ public:
+ ECCServerKEXSigAlgReplacer(const std::shared_ptr<TlsAgent>& server,
+ uint16_t sig_scheme)
+ : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}),
+ sig_scheme_(sig_scheme) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+
+ uint32_t point_len;
+ EXPECT_TRUE(output->Read(3, 1, &point_len));
+ output->Write(4 + point_len, sig_scheme_, 2);
+
+ return CHANGE;
+ }
+
+ private:
+ uint16_t sig_scheme_;
+};
+
+// Replace NamedCurve of a ECDHE SKE.
+class ECCServerKEXNamedCurveReplacer : public TlsHandshakeFilter {
+ public:
+ ECCServerKEXNamedCurveReplacer(const std::shared_ptr<TlsAgent>& server,
+ uint16_t curve_name)
+ : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}),
+ curve_name_(curve_name) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+
+ uint32_t curve_type;
+ EXPECT_TRUE(output->Read(0, 1, &curve_type));
+ EXPECT_EQ(curve_type, ec_type_named);
+ output->Write(1, curve_name_, 2);
+
+ return CHANGE;
+ }
+
+ private:
+ uint16_t curve_name_;
+};
+
+// Replaces the signature scheme in a CertificateVerify message.
+class TlsReplaceSignatureSchemeFilter : public TlsHandshakeFilter {
+ public:
+ TlsReplaceSignatureSchemeFilter(const std::shared_ptr<TlsAgent>& a,
+ uint16_t scheme)
+ : TlsHandshakeFilter(a, {kTlsHandshakeCertificateVerify}),
+ scheme_(scheme) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+ output->Write(0, scheme_, 2);
+ return CHANGE;
+ }
+
+ private:
+ uint16_t scheme_;
+};
+
+} // namespace nss_test
+
+#endif