From 26a029d407be480d791972afb5975cf62c9360a6 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 19 Apr 2024 02:47:55 +0200 Subject: Adding upstream version 124.0.1. Signed-off-by: Daniel Baumann --- security/nss/gtests/ssl_gtest/tls_agent.h | 594 ++++++++++++++++++++++++++++++ 1 file changed, 594 insertions(+) create mode 100644 security/nss/gtests/ssl_gtest/tls_agent.h (limited to 'security/nss/gtests/ssl_gtest/tls_agent.h') diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h new file mode 100644 index 0000000000..00045b4365 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_agent.h @@ -0,0 +1,594 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef tls_agent_h_ +#define tls_agent_h_ + +#include "prio.h" +#include "ssl.h" +#include "sslproto.h" + +#include +#include + +#include "nss_policy.h" +#include "test_io.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "nss_scoped_ptrs.h" +#include "scoped_ptrs_ssl.h" + +extern bool g_ssl_gtest_verbose; + +namespace nss_test { + +#define LOG(msg) std::cerr << role_str() << ": " << msg << std::endl +#define LOGV(msg) \ + do { \ + if (g_ssl_gtest_verbose) LOG(msg); \ + } while (false) + +enum SessionResumptionMode { + RESUME_NONE = 0, + RESUME_SESSIONID = 1, + RESUME_TICKET = 2, + RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET +}; + +enum class ClientAuthCallbackType { + kAsyncImmediate, + kAsyncDelay, + kSync, + kNone, +}; + +class PacketFilter; +class TlsAgent; +class TlsCipherSpec; +struct TlsRecord; + +const extern std::vector kAllDHEGroups; +const extern std::vector kECDHEGroups; +const extern std::vector kFFDHEGroups; +const extern std::vector kFasterDHEGroups; +const extern std::vector kEcdhHybridGroups; + +// These functions are called from callbacks. They use bare pointers because +// TlsAgent sets up the callback and it doesn't know who owns it. +typedef std::function + AuthCertificateCallbackFunction; + +typedef std::function HandshakeCallbackFunction; + +typedef std::function + SniCallbackFunction; + +class TlsAgent : public PollTarget { + public: + enum Role { CLIENT, SERVER }; + enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED, STATE_ERROR }; + + static const std::string kClient; // the client key is sign only + static const std::string kRsa2048; // bigger sign and encrypt for either + static const std::string kRsa8192; // biggest sign and encrypt for either + static const std::string kServerRsa; // both sign and encrypt + static const std::string kServerRsaSign; + static const std::string kServerRsaPss; + static const std::string kServerRsaDecrypt; + static const std::string kServerEcdsa256; + static const std::string kServerEcdsa384; + static const std::string kServerEcdsa521; + static const std::string kServerEcdhEcdsa; + static const std::string kServerEcdhRsa; + static const std::string kServerDsa; + static const std::string kDelegatorEcdsa256; // draft-ietf-tls-subcerts + static const std::string kDelegatorRsae2048; // draft-ietf-tls-subcerts + static const std::string kDelegatorRsaPss2048; // draft-ietf-tls-subcerts + + TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant); + virtual ~TlsAgent(); + + void SetPeer(std::shared_ptr& peer) { + adapter_->SetPeer(peer->adapter_); + } + + void SetFilter(std::shared_ptr filter) { + adapter_->SetPacketFilter(filter); + } + void ClearFilter() { adapter_->SetPacketFilter(nullptr); } + + void StartConnect(PRFileDesc* model = nullptr); + void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group, + size_t kea_size = 0) const; + void CheckOriginalKEA(SSLNamedGroup kea_group) const; + void CheckAuthType(SSLAuthType auth_type, + SSLSignatureScheme sig_scheme) const; + + void DisableAllCiphers(); + void EnableCiphersByAuthType(SSLAuthType authType); + void EnableCiphersByKeyExchange(SSLKEAType kea); + void EnableGroupsByKeyExchange(SSLKEAType kea); + void EnableGroupsByAuthType(SSLAuthType authType); + void EnableSingleCipher(uint16_t cipher); + + void Handshake(); + // Marks the internal state as CONNECTING in anticipation of renegotiation. + void PrepareForRenegotiate(); + // Prepares for renegotiation, then actually triggers it. + void StartRenegotiate(); + void SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx); + + static bool LoadCertificate(const std::string& name, + ScopedCERTCertificate* cert, + ScopedSECKEYPrivateKey* priv); + static bool LoadKeyPairFromCert(const std::string& name, + ScopedSECKEYPublicKey* pub, + ScopedSECKEYPrivateKey* priv); + + // Delegated credentials. + // + // Generate a delegated credential and sign it using the certificate + // associated with |name|. + static void DelegateCredential(const std::string& name, + const ScopedSECKEYPublicKey& dcPub, + SSLSignatureScheme dcCertVerifyAlg, + PRUint32 dcValidFor, PRTime now, SECItem* dc); + // Indicate support for the delegated credentials extension. + void EnableDelegatedCredentials(); + // Generate and configure a delegated credential to use in the handshake with + // clients that support this extension.. + void AddDelegatedCredential(const std::string& dc_name, + SSLSignatureScheme dcCertVerifyAlg, + PRUint32 dcValidFor, PRTime now); + void UpdatePreliminaryChannelInfo(); + + bool ConfigServerCert(const std::string& name, bool updateKeyBits = false, + const SSLExtraServerCertData* serverCertData = nullptr); + bool ConfigServerCertWithChain(const std::string& name); + bool EnsureTlsSetup(PRFileDesc* modelSocket = nullptr); + + void SetupClientAuth( + ClientAuthCallbackType callbackType = ClientAuthCallbackType::kSync, + bool callbackSuccess = true); + void RequestClientAuth(bool requireAuth); + void ClientAuthCallbackComplete(); + bool CheckClientAuthCallbacksCompleted(uint8_t expected); + void CheckClientAuthCompleted(uint8_t handshakes = 1); + void SetOption(int32_t option, int value); + void ConfigureSessionCache(SessionResumptionMode mode); + void Set0RttEnabled(bool en); + void SetFallbackSCSVEnabled(bool en); + void SetVersionRange(uint16_t minver, uint16_t maxver); + void GetVersionRange(uint16_t* minver, uint16_t* maxver); + void CheckPreliminaryInfo(); + void ResetPreliminaryInfo(); + void SetExpectedVersion(uint16_t version); + void SetServerKeyBits(uint16_t bits); + void ExpectReadWriteError(); + void EnableFalseStart(); + void ExpectEch(bool expected = true); + bool GetEchExpected() const { return expect_ech_; } + void ExpectPsk(SSLPskType psk = ssl_psk_external); + void ExpectResumption(); + void SkipVersionChecks(); + void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count); + void EnableAlpn(const uint8_t* val, size_t len); + void CheckAlpn(SSLNextProtoState expected_state, + const std::string& expected = "") const; + void EnableSrtp(); + void CheckSrtp() const; + void CheckEpochs(uint16_t expected_read, uint16_t expected_write) const; + void CheckErrorCode(int32_t expected) const; + void WaitForErrorCode(int32_t expected, uint32_t delay) const; + // Send data on the socket, encrypting it. + void SendData(size_t bytes, size_t blocksize = 1024); + void SendBuffer(const DataBuffer& buf); + bool SendEncryptedRecord(const std::shared_ptr& spec, + uint64_t seq, uint8_t ct, const DataBuffer& buf); + // Send data directly to the underlying socket, skipping the TLS layer. + void SendDirect(const DataBuffer& buf); + void SendRecordDirect(const TlsRecord& record); + void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash, + uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL); + void RemovePsk(std::string label); + void ReadBytes(size_t max = 16384U); + void ResetSentBytes(size_t bytes = 0); // Hack to test drops. + void EnableExtendedMasterSecret(); + void CheckExtendedMasterSecret(bool expected); + void CheckEarlyDataAccepted(bool expected); + void CheckEchAccepted(bool expected); + void SetDowngradeCheckVersion(uint16_t version); + void CheckSecretsDestroyed(); + void ConfigNamedGroups(const std::vector& groups); + void EnableECDHEServerKeyReuse(); + bool GetPeerChainLength(size_t* count); + void CheckCipherSuite(uint16_t cipher_suite); + void SetResumptionTokenCallback(); + bool MaybeSetResumptionToken(); + void SetResumptionToken(const std::vector& resumption_token) { + resumption_token_ = resumption_token; + } + const std::vector& GetResumptionToken() const { + return resumption_token_; + } + void GetTokenInfo(ScopedSSLResumptionTokenInfo& token) { + SECStatus rv = SSL_GetResumptionTokenInfo( + resumption_token_.data(), resumption_token_.size(), token.get(), + sizeof(SSLResumptionTokenInfo)); + ASSERT_EQ(SECSuccess, rv); + } + void SetResumptionCallbackCalled() { resumption_callback_called_ = true; } + bool resumption_callback_called() const { + return resumption_callback_called_; + } + + const std::string& name() const { return name_; } + + Role role() const { return role_; } + std::string role_str() const { return role_ == SERVER ? "server" : "client"; } + + SSLProtocolVariant variant() const { return variant_; } + + State state() const { return state_; } + + const CERTCertificate* peer_cert() const { + return SSL_PeerCertificate(ssl_fd_.get()); + } + + const char* state_str() const { return state_str(state()); } + + static const char* state_str(State state) { return states[state]; } + + NssManagedFileDesc ssl_fd() const { + return NssManagedFileDesc(ssl_fd_.get(), policy_, option_); + } + std::shared_ptr& adapter() { return adapter_; } + + const SSLChannelInfo& info() const { + EXPECT_EQ(STATE_CONNECTED, state_); + return info_; + } + + const SSLPreliminaryChannelInfo& pre_info() const { return pre_info_; } + + bool is_compressed() const { + return info().compressionMethod != ssl_compression_null; + } + uint16_t server_key_bits() const { return server_key_bits_; } + uint16_t min_version() const { return vrange_.min; } + uint16_t max_version() const { return vrange_.max; } + uint16_t version() const { return info().protocolVersion; } + + bool cipher_suite(uint16_t* suite) const { + if (state_ != STATE_CONNECTED) return false; + + *suite = info_.cipherSuite; + return true; + } + + void expected_cipher_suite(uint16_t suite) { expected_cipher_suite_ = suite; } + + std::string cipher_suite_name() const { + if (state_ != STATE_CONNECTED) return "UNKNOWN"; + + return csinfo_.cipherSuiteName; + } + + std::vector session_id() const { + return std::vector(info_.sessionID, + info_.sessionID + info_.sessionIDLength); + } + + bool auth_type(SSLAuthType* a) const { + if (state_ != STATE_CONNECTED) return false; + + *a = info_.authType; + return true; + } + + bool kea_type(SSLKEAType* k) const { + if (state_ != STATE_CONNECTED) return false; + + *k = info_.keaType; + return true; + } + + size_t received_bytes() const { return recv_ctr_; } + PRErrorCode error_code() const { return error_code_; } + + bool can_falsestart_hook_called() const { + return can_falsestart_hook_called_; + } + + void SetHandshakeCallback(HandshakeCallbackFunction handshake_callback) { + handshake_callback_ = handshake_callback; + } + + void SetAuthCertificateCallback( + AuthCertificateCallbackFunction auth_certificate_callback) { + auth_certificate_callback_ = auth_certificate_callback; + } + + void SetSniCallback(SniCallbackFunction sni_callback) { + sni_callback_ = sni_callback; + } + + void ExpectReceiveAlert(uint8_t alert, uint8_t level = 0); + void ExpectSendAlert(uint8_t alert, uint8_t level = 0); + + std::string alpn_value_to_use_ = ""; + // set the given policy before this agent runs + void SetPolicy(SECOidTag oid, PRUint32 set, PRUint32 clear) { + policy_ = NssPolicy(oid, set, clear); + } + void SetNssOption(PRInt32 id, PRInt32 value) { + option_ = NssOption(id, value); + } + + private: + const static char* states[]; + + void SetState(State state); + void ValidateCipherSpecs(); + + // Dummy auth certificate hook. + static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd, + PRBool checksig, PRBool isServer) { + TlsAgent* agent = reinterpret_cast(arg); + agent->CheckPreliminaryInfo(); + agent->auth_certificate_hook_called_ = true; + if (agent->auth_certificate_callback_) { + return agent->auth_certificate_callback_(agent, checksig ? true : false, + isServer ? true : false); + } + return SECSuccess; + } + + // Client auth certificate hook. + static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd, + PRBool checksig, PRBool isServer) { + TlsAgent* agent = reinterpret_cast(arg); + EXPECT_TRUE(agent->expect_client_auth_); + EXPECT_EQ(PR_TRUE, isServer); + if (agent->auth_certificate_callback_) { + return agent->auth_certificate_callback_(agent, checksig ? true : false, + isServer ? true : false); + } + return SECSuccess; + } + + static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd, + CERTDistNames* caNames, + CERTCertificate** cert, + SECKEYPrivateKey** privKey); + + static void ReadableCallback(PollTarget* self, Event event) { + TlsAgent* agent = static_cast(self); + if (event == TIMER_EVENT) { + agent->timer_handle_ = nullptr; + } + agent->ReadableCallback_int(); + } + + void ReadableCallback_int() { + LOGV("Readable"); + switch (state_) { + case STATE_CONNECTING: + Handshake(); + break; + case STATE_CONNECTED: + ReadBytes(); + break; + default: + break; + } + } + + static PRInt32 SniHook(PRFileDesc* fd, const SECItem* srvNameArr, + PRUint32 srvNameArrSize, void* arg) { + TlsAgent* agent = reinterpret_cast(arg); + agent->CheckPreliminaryInfo(); + agent->sni_hook_called_ = true; + EXPECT_EQ(1UL, srvNameArrSize); + if (agent->sni_callback_) { + return agent->sni_callback_(agent, srvNameArr, srvNameArrSize); + } + return 0; // First configuration. + } + + static SECStatus CanFalseStartCallback(PRFileDesc* fd, void* arg, + PRBool* canFalseStart) { + TlsAgent* agent = reinterpret_cast(arg); + agent->CheckPreliminaryInfo(); + EXPECT_TRUE(agent->falsestart_enabled_); + EXPECT_FALSE(agent->can_falsestart_hook_called_); + agent->can_falsestart_hook_called_ = true; + *canFalseStart = true; + return SECSuccess; + } + + void CheckAlert(bool sent, const SSLAlert* alert); + + static void AlertReceivedCallback(const PRFileDesc* fd, void* arg, + const SSLAlert* alert) { + reinterpret_cast(arg)->CheckAlert(false, alert); + } + + static void AlertSentCallback(const PRFileDesc* fd, void* arg, + const SSLAlert* alert) { + reinterpret_cast(arg)->CheckAlert(true, alert); + } + + static void HandshakeCallback(PRFileDesc* fd, void* arg) { + TlsAgent* agent = reinterpret_cast(arg); + agent->handshake_callback_called_ = true; + agent->Connected(); + if (agent->handshake_callback_) { + agent->handshake_callback_(agent); + } + } + + void DisableLameGroups(); + void ConfigStrongECGroups(bool en); + void ConfigAllDHGroups(bool en); + void CheckCallbacks() const; + void Connected(); + + const std::string name_; + SSLProtocolVariant variant_; + Role role_; + uint16_t server_key_bits_; + std::shared_ptr adapter_; + ScopedPRFileDesc ssl_fd_; + State state_; + std::shared_ptr timer_handle_; + bool falsestart_enabled_; + uint16_t expected_version_; + uint16_t expected_cipher_suite_; + bool expect_client_auth_; + bool expect_ech_; + SSLPskType expect_psk_; + bool can_falsestart_hook_called_; + bool sni_hook_called_; + bool auth_certificate_hook_called_; + uint8_t expected_received_alert_; + uint8_t expected_received_alert_level_; + uint8_t expected_sent_alert_; + uint8_t expected_sent_alert_level_; + bool handshake_callback_called_; + bool resumption_callback_called_; + SSLChannelInfo info_; + SSLPreliminaryChannelInfo pre_info_; + SSLCipherSuiteInfo csinfo_; + SSLVersionRange vrange_; + PRErrorCode error_code_; + size_t send_ctr_; + size_t recv_ctr_; + bool expect_readwrite_error_; + HandshakeCallbackFunction handshake_callback_; + AuthCertificateCallbackFunction auth_certificate_callback_; + SniCallbackFunction sni_callback_; + bool skip_version_checks_; + std::vector resumption_token_; + NssPolicy policy_; + NssOption option_; + ClientAuthCallbackType client_auth_callback_type_ = + ClientAuthCallbackType::kNone; + bool client_auth_callback_success_ = false; + uint8_t client_auth_callback_fired_ = 0; + bool client_auth_callback_awaiting_ = false; +}; + +inline std::ostream& operator<<(std::ostream& stream, + const TlsAgent::State& state) { + return stream << TlsAgent::state_str(state); +} + +class TlsAgentTestBase : public ::testing::Test { + public: + static ::testing::internal::ParamGenerator kTlsRolesAll; + + TlsAgentTestBase(TlsAgent::Role role, SSLProtocolVariant variant, + uint16_t version = 0) + : agent_(nullptr), + role_(role), + variant_(variant), + version_(version), + sink_adapter_(new DummyPrSocket("sink", variant)) {} + virtual ~TlsAgentTestBase() {} + + void SetUp(); + void TearDown(); + + void ExpectAlert(uint8_t alert); + + static void MakeRecord(SSLProtocolVariant variant, uint8_t type, + uint16_t version, const uint8_t* buf, size_t len, + DataBuffer* out, uint64_t seq_num = 0); + void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf, + size_t len, DataBuffer* out, uint64_t seq_num = 0) const; + void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len, + DataBuffer* out, uint64_t seq_num = 0) const; + void MakeHandshakeMessageFragment(uint8_t hs_type, const uint8_t* data, + size_t hs_len, DataBuffer* out, + uint64_t seq_num, uint32_t fragment_offset, + uint32_t fragment_length) const; + DataBuffer MakeCannedTls13ServerHello(); + static void MakeTrivialHandshakeRecord(uint8_t hs_type, size_t hs_len, + DataBuffer* out); + static inline TlsAgent::Role ToRole(const std::string& str) { + return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER; + } + + void Init(const std::string& server_name = TlsAgent::kServerRsa); + void Reset(const std::string& server_name = TlsAgent::kServerRsa); + + protected: + void EnsureInit(); + void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, + int32_t error_code = 0); + + std::shared_ptr agent_; + TlsAgent::Role role_; + SSLProtocolVariant variant_; + uint16_t version_; + // This adapter is here just to accept packets from this agent. + std::shared_ptr sink_adapter_; +}; + +class TlsAgentTest + : public TlsAgentTestBase, + public ::testing::WithParamInterface< + std::tuple> { + public: + TlsAgentTest() + : TlsAgentTestBase(ToRole(std::get<0>(GetParam())), + std::get<1>(GetParam()), std::get<2>(GetParam())) {} +}; + +class TlsAgentTestClient : public TlsAgentTestBase, + public ::testing::WithParamInterface< + std::tuple> { + public: + TlsAgentTestClient() + : TlsAgentTestBase(TlsAgent::CLIENT, std::get<0>(GetParam()), + std::get<1>(GetParam())) {} +}; + +class TlsAgentTestClient13 : public TlsAgentTestClient {}; + +class TlsAgentStreamTestClient13 : public TlsAgentTestClient { + public: + TlsAgentStreamTestClient13() { variant_ = ssl_variant_stream; } +}; + +class TlsAgentStreamTestClient : public TlsAgentTestBase { + public: + TlsAgentStreamTestClient() + : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_stream) {} +}; + +class TlsAgentStreamTestServer : public TlsAgentTestBase { + public: + TlsAgentStreamTestServer() + : TlsAgentTestBase(TlsAgent::SERVER, ssl_variant_stream) {} +}; + +class TlsAgentDgramTestClient : public TlsAgentTestBase { + public: + TlsAgentDgramTestClient() + : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_datagram) {} +}; + +inline bool operator==(const SSLVersionRange& vr1, const SSLVersionRange& vr2) { + return vr1.min == vr2.min && vr1.max == vr2.max; +} + +} // namespace nss_test + +#endif -- cgit v1.2.3