summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_connect.h
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /security/nss/gtests/ssl_gtest/tls_connect.h
parentInitial commit. (diff)
downloadfirefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz
firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_connect.h')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_connect.h390
1 files changed, 390 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.h b/security/nss/gtests/ssl_gtest/tls_connect.h
new file mode 100644
index 0000000000..6a4795f83e
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_connect.h
@@ -0,0 +1,390 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_connect_h_
+#define tls_connect_h_
+
+#include <tuple>
+
+#include "sslproto.h"
+#include "sslt.h"
+#include "nss.h"
+
+#include "tls_agent.h"
+#include "tls_filter.h"
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+
+namespace nss_test {
+
+extern std::string VersionString(uint16_t version);
+
+// A generic TLS connection test base.
+class TlsConnectTestBase : public ::testing::Test {
+ public:
+ static ::testing::internal::ParamGenerator<SSLProtocolVariant>
+ kTlsVariantsStream;
+ static ::testing::internal::ParamGenerator<SSLProtocolVariant>
+ kTlsVariantsDatagram;
+ static ::testing::internal::ParamGenerator<SSLProtocolVariant>
+ kTlsVariantsAll;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV10;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV11;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV12;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV10V11;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV11V12;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV10ToV12;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV13;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV11Plus;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsVAll;
+
+ TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version);
+ virtual ~TlsConnectTestBase();
+
+ virtual void SetUp();
+ virtual void TearDown();
+
+ PRTime now() const { return now_; }
+
+ // Initialize client and server.
+ void Init();
+ // Clear the statistics.
+ void ClearStats();
+ // Clear the server session cache.
+ void ClearServerCache();
+ // Make sure TLS is configured for a connection.
+ virtual void EnsureTlsSetup();
+ // Reset and keep the same certificate names
+ void Reset();
+ // Reset, and update the certificate names on both peers
+ void Reset(const std::string& server_name,
+ const std::string& client_name = "client");
+ // Replace the server.
+ void MakeNewServer();
+
+ // Set up
+ void StartConnect();
+ // Run the handshake.
+ void Handshake();
+ // Connect and check that it works.
+ void Connect();
+ // Check that the connection was successfully established.
+ void CheckConnected();
+ // Connect and expect it to fail.
+ void ConnectExpectFail();
+ void ExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
+ void ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
+ void ConnectExpectFailOneSide(TlsAgent::Role failingSide);
+ void ConnectWithCipherSuite(uint16_t cipher_suite);
+ void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent,
+ size_t expected_size);
+ // Check that the keys used in the handshake match expectations.
+ void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
+ SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const;
+ // This version guesses some of the values.
+ void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const;
+ // This version assumes defaults.
+ void CheckKeys() const;
+ // Check that keys on resumed sessions.
+ void CheckKeysResumption(SSLKEAType kea_type, SSLNamedGroup kea_group,
+ SSLNamedGroup original_kea_group,
+ SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme);
+ void CheckGroups(const DataBuffer& groups,
+ std::function<void(SSLNamedGroup)> check_group);
+ void CheckShares(const DataBuffer& shares,
+ std::function<void(SSLNamedGroup)> check_group);
+ void CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const;
+
+ void ConfigureVersion(uint16_t version);
+ void SetExpectedVersion(uint16_t version);
+ // Expect resumption of a particular type.
+ void ExpectResumption(SessionResumptionMode expected,
+ uint8_t num_resumed = 1);
+ void DisableAllCiphers();
+ void EnableOnlyStaticRsaCiphers();
+ void EnableOnlyDheCiphers();
+ void EnableSomeEcdhCiphers();
+ void EnableExtendedMasterSecret();
+ void ConfigureSelfEncrypt();
+ void ConfigureSessionCache(SessionResumptionMode client,
+ SessionResumptionMode server);
+ void EnableAlpn();
+ void EnableAlpnWithCallback(const std::vector<uint8_t>& client,
+ std::string server_choice);
+ void EnableAlpn(const std::vector<uint8_t>& vals);
+ void EnsureModelSockets();
+ void CheckAlpn(const std::string& val);
+ void EnableSrtp();
+ void CheckSrtp() const;
+ void SendReceive(size_t total = 50);
+ void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash,
+ uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL);
+ void RemovePsk(std::string label);
+ void SetupForZeroRtt();
+ void SetupForResume();
+ void ZeroRttSendReceive(
+ bool expect_writable, bool expect_readable,
+ std::function<bool()> post_clienthello_check = nullptr);
+ void Receive(size_t amount);
+ void ExpectExtendedMasterSecret(bool expected);
+ void ExpectEarlyDataAccepted(bool expected);
+ void EnableECDHEServerKeyReuse();
+ void SkipVersionChecks();
+
+ // Move the DTLS timers for both endpoints to pop the next timer.
+ void ShiftDtlsTimers();
+ void AdvanceTime(PRTime time_shift);
+
+ void ResetAntiReplay(PRTime window);
+ void RolloverAntiReplay();
+
+ void SaveAlgorithmPolicy();
+ void RestoreAlgorithmPolicy();
+
+ static ScopedSECItem MakeEcKeyParams(SSLNamedGroup group);
+ static void GenerateEchConfig(
+ HpkeKemId kem_id, const std::vector<HpkeSymmetricSuite>& cipher_suites,
+ const std::string& public_name, uint16_t max_name_len, DataBuffer& record,
+ ScopedSECKEYPublicKey& pubKey, ScopedSECKEYPrivateKey& privKey);
+ void SetupEch(std::shared_ptr<TlsAgent>& client,
+ std::shared_ptr<TlsAgent>& server,
+ HpkeKemId kem_id = HpkeDhKemX25519Sha256,
+ bool expect_ech = true, bool set_client_config = true,
+ bool set_server_config = true, int maxConfigSize = 100);
+
+ protected:
+ SSLProtocolVariant variant_;
+ std::shared_ptr<TlsAgent> client_;
+ std::shared_ptr<TlsAgent> server_;
+ std::unique_ptr<TlsAgent> client_model_;
+ std::unique_ptr<TlsAgent> server_model_;
+ uint16_t version_;
+ SessionResumptionMode expected_resumption_mode_;
+ uint8_t expected_resumptions_;
+ std::vector<std::vector<uint8_t>> session_ids_;
+ ScopedSSLAntiReplayContext anti_replay_;
+
+ // A simple value of "a", "b". Note that the preferred value of "a" is placed
+ // at the end, because the NSS API follows the now defunct NPN specification,
+ // which places the preferred (and default) entry at the end of the list.
+ // NSS will move this final entry to the front when used with ALPN.
+ const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61};
+
+ // A list of algorithm IDs whose policies need to be preserved
+ // around test cases. In particular, DSA is checked in
+ // ssl_extension_unittest.cc.
+ const std::vector<SECOidTag> algorithms_ = {SEC_OID_APPLY_SSL_POLICY,
+ SEC_OID_ANSIX9_DSA_SIGNATURE,
+ SEC_OID_CURVE25519, SEC_OID_SHA1};
+ std::vector<std::tuple<SECOidTag, uint32_t>> saved_policies_;
+ const std::vector<PRInt32> options_ = {
+ NSS_RSA_MIN_KEY_SIZE, NSS_DH_MIN_KEY_SIZE, NSS_DSA_MIN_KEY_SIZE,
+ NSS_TLS_VERSION_MIN_POLICY, NSS_TLS_VERSION_MAX_POLICY};
+ std::vector<std::tuple<PRInt32, uint32_t>> saved_options_;
+
+ private:
+ void CheckResumption(SessionResumptionMode expected);
+ void CheckExtendedMasterSecret();
+ void CheckEarlyDataAccepted();
+ static PRTime TimeFunc(void* arg);
+
+ bool expect_extended_master_secret_;
+ bool expect_early_data_accepted_;
+ bool skip_version_checks_;
+ PRTime now_;
+
+ // Track groups and make sure that there are no duplicates.
+ class DuplicateGroupChecker {
+ public:
+ void AddAndCheckGroup(SSLNamedGroup group) {
+ EXPECT_EQ(groups_.end(), groups_.find(group))
+ << "Group " << group << " should not be duplicated";
+ groups_.insert(group);
+ }
+
+ private:
+ std::set<SSLNamedGroup> groups_;
+ };
+};
+
+// A non-parametrized TLS test base.
+class TlsConnectTest : public TlsConnectTestBase {
+ public:
+ TlsConnectTest() : TlsConnectTestBase(ssl_variant_stream, 0) {}
+};
+
+// A non-parametrized DTLS-only test base.
+class DtlsConnectTest : public TlsConnectTestBase {
+ public:
+ DtlsConnectTest() : TlsConnectTestBase(ssl_variant_datagram, 0) {}
+};
+
+// A TLS-only test base.
+class TlsConnectStream : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<uint16_t> {
+ public:
+ TlsConnectStream() : TlsConnectTestBase(ssl_variant_stream, GetParam()) {}
+};
+
+// A TLS-only test base for tests before 1.3
+class TlsConnectStreamPre13 : public TlsConnectStream {};
+
+// A DTLS-only test base.
+class TlsConnectDatagram : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<uint16_t> {
+ public:
+ TlsConnectDatagram() : TlsConnectTestBase(ssl_variant_datagram, GetParam()) {}
+};
+
+// A generic test class that can be either stream or datagram and a single
+// version of TLS. This is configured in ssl_loopback_unittest.cc.
+class TlsConnectGeneric : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsConnectGeneric();
+};
+
+class TlsConnectGenericResumption
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t, bool>> {
+ private:
+ bool external_cache_;
+
+ public:
+ TlsConnectGenericResumption();
+
+ virtual void EnsureTlsSetup() {
+ TlsConnectTestBase::EnsureTlsSetup();
+ // Enable external resumption token cache.
+ if (external_cache_) {
+ client_->SetResumptionTokenCallback();
+ }
+ }
+
+ bool use_external_cache() const { return external_cache_; }
+};
+
+class TlsConnectTls13ResumptionToken
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
+ public:
+ TlsConnectTls13ResumptionToken();
+
+ virtual void EnsureTlsSetup() {
+ TlsConnectTestBase::EnsureTlsSetup();
+ client_->SetResumptionTokenCallback();
+ }
+};
+
+class TlsConnectGenericResumptionToken
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsConnectGenericResumptionToken();
+
+ virtual void EnsureTlsSetup() {
+ TlsConnectTestBase::EnsureTlsSetup();
+ client_->SetResumptionTokenCallback();
+ }
+};
+
+// A Pre TLS 1.2 generic test.
+class TlsConnectPre12 : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsConnectPre12();
+};
+
+// A TLS 1.2 only generic test.
+class TlsConnectTls12
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
+ public:
+ TlsConnectTls12();
+};
+
+// A TLS 1.2 only stream test.
+class TlsConnectStreamTls12 : public TlsConnectTestBase {
+ public:
+ TlsConnectStreamTls12()
+ : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_2) {}
+};
+
+// A TLS 1.2+ generic test.
+class TlsConnectTls12Plus : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsConnectTls12Plus();
+};
+
+// A TLS 1.3 only generic test.
+class TlsConnectTls13
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
+ public:
+ TlsConnectTls13();
+};
+
+// A TLS 1.3 only stream test.
+class TlsConnectStreamTls13 : public TlsConnectTestBase {
+ public:
+ TlsConnectStreamTls13()
+ : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {}
+};
+
+class TlsConnectDatagram13 : public TlsConnectTestBase {
+ public:
+ TlsConnectDatagram13()
+ : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {}
+};
+
+class TlsConnectDatagramPre13 : public TlsConnectDatagram {
+ public:
+ TlsConnectDatagramPre13() {}
+};
+
+// A variant that is used only with Pre13.
+class TlsConnectGenericPre13 : public TlsConnectGeneric {};
+
+class TlsKeyExchangeTest : public TlsConnectGeneric {
+ protected:
+ std::shared_ptr<TlsExtensionCapture> groups_capture_;
+ std::shared_ptr<TlsExtensionCapture> shares_capture_;
+ std::shared_ptr<TlsExtensionCapture> shares_capture2_;
+ std::shared_ptr<TlsHandshakeRecorder> capture_hrr_;
+
+ void EnsureKeyShareSetup();
+ void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
+ std::vector<SSLNamedGroup> GetGroupDetails(
+ const std::shared_ptr<TlsExtensionCapture>& capture);
+ std::vector<SSLNamedGroup> GetShareDetails(
+ const std::shared_ptr<TlsExtensionCapture>& capture);
+ void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
+ const std::vector<SSLNamedGroup>& expectedShares);
+ void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
+ const std::vector<SSLNamedGroup>& expectedShares,
+ SSLNamedGroup expectedShare2);
+
+ private:
+ void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
+ const std::vector<SSLNamedGroup>& expectedShares,
+ bool expect_hrr);
+};
+
+class TlsKeyExchangeTest13 : public TlsKeyExchangeTest {};
+class TlsKeyExchangeTestPre13 : public TlsKeyExchangeTest {};
+
+} // namespace nss_test
+
+#endif