diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /security/nss/gtests/ssl_gtest/tls_connect.h | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_connect.h')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_connect.h | 390 |
1 files changed, 390 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.h b/security/nss/gtests/ssl_gtest/tls_connect.h new file mode 100644 index 0000000000..6a4795f83e --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_connect.h @@ -0,0 +1,390 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef tls_connect_h_ +#define tls_connect_h_ + +#include <tuple> + +#include "sslproto.h" +#include "sslt.h" +#include "nss.h" + +#include "tls_agent.h" +#include "tls_filter.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" + +namespace nss_test { + +extern std::string VersionString(uint16_t version); + +// A generic TLS connection test base. +class TlsConnectTestBase : public ::testing::Test { + public: + static ::testing::internal::ParamGenerator<SSLProtocolVariant> + kTlsVariantsStream; + static ::testing::internal::ParamGenerator<SSLProtocolVariant> + kTlsVariantsDatagram; + static ::testing::internal::ParamGenerator<SSLProtocolVariant> + kTlsVariantsAll; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV10; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV11; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV12; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV10V11; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV11V12; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV10ToV12; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV13; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV11Plus; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus; + static ::testing::internal::ParamGenerator<uint16_t> kTlsVAll; + + TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version); + virtual ~TlsConnectTestBase(); + + virtual void SetUp(); + virtual void TearDown(); + + PRTime now() const { return now_; } + + // Initialize client and server. + void Init(); + // Clear the statistics. + void ClearStats(); + // Clear the server session cache. + void ClearServerCache(); + // Make sure TLS is configured for a connection. + virtual void EnsureTlsSetup(); + // Reset and keep the same certificate names + void Reset(); + // Reset, and update the certificate names on both peers + void Reset(const std::string& server_name, + const std::string& client_name = "client"); + // Replace the server. + void MakeNewServer(); + + // Set up + void StartConnect(); + // Run the handshake. + void Handshake(); + // Connect and check that it works. + void Connect(); + // Check that the connection was successfully established. + void CheckConnected(); + // Connect and expect it to fail. + void ConnectExpectFail(); + void ExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert); + void ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert); + void ConnectExpectFailOneSide(TlsAgent::Role failingSide); + void ConnectWithCipherSuite(uint16_t cipher_suite); + void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent, + size_t expected_size); + // Check that the keys used in the handshake match expectations. + void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, + SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const; + // This version guesses some of the values. + void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const; + // This version assumes defaults. + void CheckKeys() const; + // Check that keys on resumed sessions. + void CheckKeysResumption(SSLKEAType kea_type, SSLNamedGroup kea_group, + SSLNamedGroup original_kea_group, + SSLAuthType auth_type, + SSLSignatureScheme sig_scheme); + void CheckGroups(const DataBuffer& groups, + std::function<void(SSLNamedGroup)> check_group); + void CheckShares(const DataBuffer& shares, + std::function<void(SSLNamedGroup)> check_group); + void CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const; + + void ConfigureVersion(uint16_t version); + void SetExpectedVersion(uint16_t version); + // Expect resumption of a particular type. + void ExpectResumption(SessionResumptionMode expected, + uint8_t num_resumed = 1); + void DisableAllCiphers(); + void EnableOnlyStaticRsaCiphers(); + void EnableOnlyDheCiphers(); + void EnableSomeEcdhCiphers(); + void EnableExtendedMasterSecret(); + void ConfigureSelfEncrypt(); + void ConfigureSessionCache(SessionResumptionMode client, + SessionResumptionMode server); + void EnableAlpn(); + void EnableAlpnWithCallback(const std::vector<uint8_t>& client, + std::string server_choice); + void EnableAlpn(const std::vector<uint8_t>& vals); + void EnsureModelSockets(); + void CheckAlpn(const std::string& val); + void EnableSrtp(); + void CheckSrtp() const; + void SendReceive(size_t total = 50); + void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash, + uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL); + void RemovePsk(std::string label); + void SetupForZeroRtt(); + void SetupForResume(); + void ZeroRttSendReceive( + bool expect_writable, bool expect_readable, + std::function<bool()> post_clienthello_check = nullptr); + void Receive(size_t amount); + void ExpectExtendedMasterSecret(bool expected); + void ExpectEarlyDataAccepted(bool expected); + void EnableECDHEServerKeyReuse(); + void SkipVersionChecks(); + + // Move the DTLS timers for both endpoints to pop the next timer. + void ShiftDtlsTimers(); + void AdvanceTime(PRTime time_shift); + + void ResetAntiReplay(PRTime window); + void RolloverAntiReplay(); + + void SaveAlgorithmPolicy(); + void RestoreAlgorithmPolicy(); + + static ScopedSECItem MakeEcKeyParams(SSLNamedGroup group); + static void GenerateEchConfig( + HpkeKemId kem_id, const std::vector<HpkeSymmetricSuite>& cipher_suites, + const std::string& public_name, uint16_t max_name_len, DataBuffer& record, + ScopedSECKEYPublicKey& pubKey, ScopedSECKEYPrivateKey& privKey); + void SetupEch(std::shared_ptr<TlsAgent>& client, + std::shared_ptr<TlsAgent>& server, + HpkeKemId kem_id = HpkeDhKemX25519Sha256, + bool expect_ech = true, bool set_client_config = true, + bool set_server_config = true, int maxConfigSize = 100); + + protected: + SSLProtocolVariant variant_; + std::shared_ptr<TlsAgent> client_; + std::shared_ptr<TlsAgent> server_; + std::unique_ptr<TlsAgent> client_model_; + std::unique_ptr<TlsAgent> server_model_; + uint16_t version_; + SessionResumptionMode expected_resumption_mode_; + uint8_t expected_resumptions_; + std::vector<std::vector<uint8_t>> session_ids_; + ScopedSSLAntiReplayContext anti_replay_; + + // A simple value of "a", "b". Note that the preferred value of "a" is placed + // at the end, because the NSS API follows the now defunct NPN specification, + // which places the preferred (and default) entry at the end of the list. + // NSS will move this final entry to the front when used with ALPN. + const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61}; + + // A list of algorithm IDs whose policies need to be preserved + // around test cases. In particular, DSA is checked in + // ssl_extension_unittest.cc. + const std::vector<SECOidTag> algorithms_ = {SEC_OID_APPLY_SSL_POLICY, + SEC_OID_ANSIX9_DSA_SIGNATURE, + SEC_OID_CURVE25519, SEC_OID_SHA1}; + std::vector<std::tuple<SECOidTag, uint32_t>> saved_policies_; + const std::vector<PRInt32> options_ = { + NSS_RSA_MIN_KEY_SIZE, NSS_DH_MIN_KEY_SIZE, NSS_DSA_MIN_KEY_SIZE, + NSS_TLS_VERSION_MIN_POLICY, NSS_TLS_VERSION_MAX_POLICY}; + std::vector<std::tuple<PRInt32, uint32_t>> saved_options_; + + private: + void CheckResumption(SessionResumptionMode expected); + void CheckExtendedMasterSecret(); + void CheckEarlyDataAccepted(); + static PRTime TimeFunc(void* arg); + + bool expect_extended_master_secret_; + bool expect_early_data_accepted_; + bool skip_version_checks_; + PRTime now_; + + // Track groups and make sure that there are no duplicates. + class DuplicateGroupChecker { + public: + void AddAndCheckGroup(SSLNamedGroup group) { + EXPECT_EQ(groups_.end(), groups_.find(group)) + << "Group " << group << " should not be duplicated"; + groups_.insert(group); + } + + private: + std::set<SSLNamedGroup> groups_; + }; +}; + +// A non-parametrized TLS test base. +class TlsConnectTest : public TlsConnectTestBase { + public: + TlsConnectTest() : TlsConnectTestBase(ssl_variant_stream, 0) {} +}; + +// A non-parametrized DTLS-only test base. +class DtlsConnectTest : public TlsConnectTestBase { + public: + DtlsConnectTest() : TlsConnectTestBase(ssl_variant_datagram, 0) {} +}; + +// A TLS-only test base. +class TlsConnectStream : public TlsConnectTestBase, + public ::testing::WithParamInterface<uint16_t> { + public: + TlsConnectStream() : TlsConnectTestBase(ssl_variant_stream, GetParam()) {} +}; + +// A TLS-only test base for tests before 1.3 +class TlsConnectStreamPre13 : public TlsConnectStream {}; + +// A DTLS-only test base. +class TlsConnectDatagram : public TlsConnectTestBase, + public ::testing::WithParamInterface<uint16_t> { + public: + TlsConnectDatagram() : TlsConnectTestBase(ssl_variant_datagram, GetParam()) {} +}; + +// A generic test class that can be either stream or datagram and a single +// version of TLS. This is configured in ssl_loopback_unittest.cc. +class TlsConnectGeneric : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsConnectGeneric(); +}; + +class TlsConnectGenericResumption + : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t, bool>> { + private: + bool external_cache_; + + public: + TlsConnectGenericResumption(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + // Enable external resumption token cache. + if (external_cache_) { + client_->SetResumptionTokenCallback(); + } + } + + bool use_external_cache() const { return external_cache_; } +}; + +class TlsConnectTls13ResumptionToken + : public TlsConnectTestBase, + public ::testing::WithParamInterface<SSLProtocolVariant> { + public: + TlsConnectTls13ResumptionToken(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + client_->SetResumptionTokenCallback(); + } +}; + +class TlsConnectGenericResumptionToken + : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsConnectGenericResumptionToken(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + client_->SetResumptionTokenCallback(); + } +}; + +// A Pre TLS 1.2 generic test. +class TlsConnectPre12 : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsConnectPre12(); +}; + +// A TLS 1.2 only generic test. +class TlsConnectTls12 + : public TlsConnectTestBase, + public ::testing::WithParamInterface<SSLProtocolVariant> { + public: + TlsConnectTls12(); +}; + +// A TLS 1.2 only stream test. +class TlsConnectStreamTls12 : public TlsConnectTestBase { + public: + TlsConnectStreamTls12() + : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_2) {} +}; + +// A TLS 1.2+ generic test. +class TlsConnectTls12Plus : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsConnectTls12Plus(); +}; + +// A TLS 1.3 only generic test. +class TlsConnectTls13 + : public TlsConnectTestBase, + public ::testing::WithParamInterface<SSLProtocolVariant> { + public: + TlsConnectTls13(); +}; + +// A TLS 1.3 only stream test. +class TlsConnectStreamTls13 : public TlsConnectTestBase { + public: + TlsConnectStreamTls13() + : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {} +}; + +class TlsConnectDatagram13 : public TlsConnectTestBase { + public: + TlsConnectDatagram13() + : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {} +}; + +class TlsConnectDatagramPre13 : public TlsConnectDatagram { + public: + TlsConnectDatagramPre13() {} +}; + +// A variant that is used only with Pre13. +class TlsConnectGenericPre13 : public TlsConnectGeneric {}; + +class TlsKeyExchangeTest : public TlsConnectGeneric { + protected: + std::shared_ptr<TlsExtensionCapture> groups_capture_; + std::shared_ptr<TlsExtensionCapture> shares_capture_; + std::shared_ptr<TlsExtensionCapture> shares_capture2_; + std::shared_ptr<TlsHandshakeRecorder> capture_hrr_; + + void EnsureKeyShareSetup(); + void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); + std::vector<SSLNamedGroup> GetGroupDetails( + const std::shared_ptr<TlsExtensionCapture>& capture); + std::vector<SSLNamedGroup> GetShareDetails( + const std::shared_ptr<TlsExtensionCapture>& capture); + void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, + const std::vector<SSLNamedGroup>& expectedShares); + void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, + const std::vector<SSLNamedGroup>& expectedShares, + SSLNamedGroup expectedShare2); + + private: + void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, + const std::vector<SSLNamedGroup>& expectedShares, + bool expect_hrr); +}; + +class TlsKeyExchangeTest13 : public TlsKeyExchangeTest {}; +class TlsKeyExchangeTestPre13 : public TlsKeyExchangeTest {}; + +} // namespace nss_test + +#endif |