diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_connect.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_connect.cc | 1055 |
1 files changed, 1055 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc new file mode 100644 index 0000000000..6456bff563 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_connect.cc @@ -0,0 +1,1055 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "tls_connect.h" +#include "sslexp.h" +extern "C" { +#include "libssl_internals.h" +} + +#include <iostream> + +#include "databuffer.h" +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "sslproto.h" + +extern std::string g_working_dir_path; + +namespace nss_test { + +static const SSLProtocolVariant kTlsVariantsStreamArr[] = {ssl_variant_stream}; +::testing::internal::ParamGenerator<SSLProtocolVariant> + TlsConnectTestBase::kTlsVariantsStream = + ::testing::ValuesIn(kTlsVariantsStreamArr); +static const SSLProtocolVariant kTlsVariantsDatagramArr[] = { + ssl_variant_datagram}; +::testing::internal::ParamGenerator<SSLProtocolVariant> + TlsConnectTestBase::kTlsVariantsDatagram = + ::testing::ValuesIn(kTlsVariantsDatagramArr); +static const SSLProtocolVariant kTlsVariantsAllArr[] = {ssl_variant_stream, + ssl_variant_datagram}; +::testing::internal::ParamGenerator<SSLProtocolVariant> + TlsConnectTestBase::kTlsVariantsAll = + ::testing::ValuesIn(kTlsVariantsAllArr); + +static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10 = + ::testing::ValuesIn(kTlsV10Arr); +static const uint16_t kTlsV11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11 = + ::testing::ValuesIn(kTlsV11Arr); +static const uint16_t kTlsV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_2}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12 = + ::testing::ValuesIn(kTlsV12Arr); +static const uint16_t kTlsV10V11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_1}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10V11 = + ::testing::ValuesIn(kTlsV10V11Arr); +static const uint16_t kTlsV10ToV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10ToV12 = + ::testing::ValuesIn(kTlsV10ToV12Arr); +static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11V12 = + ::testing::ValuesIn(kTlsV11V12Arr); + +static const uint16_t kTlsV11PlusArr[] = { +#ifndef NSS_DISABLE_TLS_1_3 + SSL_LIBRARY_VERSION_TLS_1_3, +#endif + SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11Plus = + ::testing::ValuesIn(kTlsV11PlusArr); +static const uint16_t kTlsV12PlusArr[] = { +#ifndef NSS_DISABLE_TLS_1_3 + SSL_LIBRARY_VERSION_TLS_1_3, +#endif + SSL_LIBRARY_VERSION_TLS_1_2}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12Plus = + ::testing::ValuesIn(kTlsV12PlusArr); +static const uint16_t kTlsV13Arr[] = {SSL_LIBRARY_VERSION_TLS_1_3}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV13 = + ::testing::ValuesIn(kTlsV13Arr); +static const uint16_t kTlsVAllArr[] = { +#ifndef NSS_DISABLE_TLS_1_3 + SSL_LIBRARY_VERSION_TLS_1_3, +#endif + SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_0}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsVAll = + ::testing::ValuesIn(kTlsVAllArr); + +std::string VersionString(uint16_t version) { + switch (version) { + case 0: + return "(no version)"; + case SSL_LIBRARY_VERSION_3_0: + return "1.0"; + case SSL_LIBRARY_VERSION_TLS_1_0: + return "1.0"; + case SSL_LIBRARY_VERSION_TLS_1_1: + return "1.1"; + case SSL_LIBRARY_VERSION_TLS_1_2: + return "1.2"; + case SSL_LIBRARY_VERSION_TLS_1_3: + return "1.3"; + default: + std::cerr << "Invalid version: " << version << std::endl; + EXPECT_TRUE(false); + return ""; + } +} + +// The default anti-replay window for tests. Tests that rely on a different +// value call ResetAntiReplay directly. +static PRTime kAntiReplayWindow = 100 * PR_USEC_PER_SEC; + +TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant, + uint16_t version) + : variant_(variant), + client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)), + server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)), + client_model_(nullptr), + server_model_(nullptr), + version_(version), + expected_resumption_mode_(RESUME_NONE), + expected_resumptions_(0), + session_ids_(), + expect_extended_master_secret_(false), + expect_early_data_accepted_(false), + skip_version_checks_(false) { + std::string v; + if (variant_ == ssl_variant_datagram && + version_ == SSL_LIBRARY_VERSION_TLS_1_1) { + v = "1.0"; + } else { + v = VersionString(version_); + } + std::cerr << "Version: " << variant_ << " " << v << std::endl; +} + +TlsConnectTestBase::~TlsConnectTestBase() {} + +// Check the group of each of the supported groups +void TlsConnectTestBase::CheckGroups( + const DataBuffer& groups, std::function<void(SSLNamedGroup)> check_group) { + DuplicateGroupChecker group_set; + uint32_t tmp = 0; + EXPECT_TRUE(groups.Read(0, 2, &tmp)); + EXPECT_EQ(groups.len() - 2, static_cast<size_t>(tmp)); + for (size_t i = 2; i < groups.len(); i += 2) { + EXPECT_TRUE(groups.Read(i, 2, &tmp)); + SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp); + group_set.AddAndCheckGroup(group); + check_group(group); + } +} + +// Check the group of each of the shares +void TlsConnectTestBase::CheckShares( + const DataBuffer& shares, std::function<void(SSLNamedGroup)> check_group) { + DuplicateGroupChecker group_set; + uint32_t tmp = 0; + EXPECT_TRUE(shares.Read(0, 2, &tmp)); + EXPECT_EQ(shares.len() - 2, static_cast<size_t>(tmp)); + size_t i; + for (i = 2; i < shares.len(); i += 4 + tmp) { + ASSERT_TRUE(shares.Read(i, 2, &tmp)); + SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp); + group_set.AddAndCheckGroup(group); + check_group(group); + ASSERT_TRUE(shares.Read(i + 2, 2, &tmp)); + } + EXPECT_EQ(shares.len(), i); +} + +void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch, + uint16_t server_epoch) const { + client_->CheckEpochs(server_epoch, client_epoch); + server_->CheckEpochs(client_epoch, server_epoch); +} + +void TlsConnectTestBase::ClearStats() { + // Clear statistics. + SSL3Statistics* stats = SSL_GetStatistics(); + memset(stats, 0, sizeof(*stats)); +} + +void TlsConnectTestBase::ClearServerCache() { + SSL_ShutdownServerSessionIDCache(); + SSLInt_ClearSelfEncryptKey(); + SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); +} + +void TlsConnectTestBase::SaveAlgorithmPolicy() { + saved_policies_.clear(); + for (auto it = algorithms_.begin(); it != algorithms_.end(); ++it) { + uint32_t policy; + SECStatus rv = NSS_GetAlgorithmPolicy(*it, &policy); + ASSERT_EQ(SECSuccess, rv); + saved_policies_.push_back(std::make_tuple(*it, policy)); + } +} + +void TlsConnectTestBase::RestoreAlgorithmPolicy() { + for (auto it = saved_policies_.begin(); it != saved_policies_.end(); ++it) { + auto algorithm = std::get<0>(*it); + auto policy = std::get<1>(*it); + SECStatus rv = NSS_SetAlgorithmPolicy( + algorithm, policy, NSS_USE_POLICY_IN_SSL | NSS_USE_ALG_IN_SSL_KX); + ASSERT_EQ(SECSuccess, rv); + } +} + +PRTime TlsConnectTestBase::TimeFunc(void* arg) { + return *reinterpret_cast<PRTime*>(arg); +} + +void TlsConnectTestBase::SetUp() { + SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); + SSLInt_ClearSelfEncryptKey(); + now_ = PR_Now(); + ResetAntiReplay(kAntiReplayWindow); + ClearStats(); + SaveAlgorithmPolicy(); + Init(); +} + +void TlsConnectTestBase::TearDown() { + client_ = nullptr; + server_ = nullptr; + + SSL_ClearSessionCache(); + SSLInt_ClearSelfEncryptKey(); + SSL_ShutdownServerSessionIDCache(); + RestoreAlgorithmPolicy(); +} + +void TlsConnectTestBase::Init() { + client_->SetPeer(server_); + server_->SetPeer(client_); + + if (version_) { + ConfigureVersion(version_); + } +} + +void TlsConnectTestBase::ResetAntiReplay(PRTime window) { + SSLAntiReplayContext* p_anti_replay = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateAntiReplayContext(now_, window, 1, 3, &p_anti_replay)); + EXPECT_NE(nullptr, p_anti_replay); + anti_replay_.reset(p_anti_replay); +} + +void TlsConnectTestBase::MakeEcKeyParams(SECItem* params, SSLNamedGroup group) { + auto groupDef = ssl_LookupNamedGroup(group); + ASSERT_NE(nullptr, groupDef); + + auto oidData = SECOID_FindOIDByTag(groupDef->oidTag); + ASSERT_NE(nullptr, oidData); + ASSERT_NE(nullptr, + SECITEM_AllocItem(nullptr, params, (2 + oidData->oid.len))); + params->data[0] = SEC_ASN1_OBJECT_ID; + params->data[1] = oidData->oid.len; + memcpy(params->data + 2, oidData->oid.data, oidData->oid.len); +} + +void TlsConnectTestBase::GenerateEchConfig( + HpkeKemId kem_id, const std::vector<uint32_t>& cipher_suites, + const std::string& public_name, uint16_t max_name_len, DataBuffer& record, + ScopedSECKEYPublicKey& pubKey, ScopedSECKEYPrivateKey& privKey) { + bool gen_keys = !pubKey && !privKey; + SECKEYECParams ecParams = {siBuffer, NULL, 0}; + MakeEcKeyParams(&ecParams, ssl_grp_ec_curve25519); + + SECKEYPublicKey* pub = nullptr; + SECKEYPrivateKey* priv = nullptr; + + if (gen_keys) { + priv = SECKEY_CreateECPrivateKey(&ecParams, &pub, nullptr); + } else { + priv = privKey.get(); + pub = pubKey.get(); + } + ASSERT_NE(nullptr, priv); + SECITEM_FreeItem(&ecParams, PR_FALSE); + PRUint8 encoded[1024]; + unsigned int encoded_len = 0; + SECStatus rv = SSL_EncodeEchConfig( + public_name.c_str(), cipher_suites.data(), cipher_suites.size(), kem_id, + pub, max_name_len, encoded, &encoded_len, sizeof(encoded)); + EXPECT_EQ(SECSuccess, rv); + EXPECT_GT(encoded_len, 0U); + + if (gen_keys) { + pubKey.reset(pub); + privKey.reset(priv); + } + record.Truncate(0); + record.Write(0, encoded, encoded_len); +} + +void TlsConnectTestBase::SetupEch(std::shared_ptr<TlsAgent>& client, + std::shared_ptr<TlsAgent>& server, + HpkeKemId kem_id, bool expect_ech, + bool set_client_config, + bool set_server_config) { + EXPECT_TRUE(set_server_config || set_client_config); + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer record; + static const std::vector<uint32_t> kDefaultSuites = { + (static_cast<uint16_t>(HpkeKdfHkdfSha256) << 16) | + HpkeAeadChaCha20Poly1305, + (static_cast<uint16_t>(HpkeKdfHkdfSha256) << 16) | HpkeAeadAes128Gcm}; + + GenerateEchConfig(kem_id, kDefaultSuites, "public.name", 100, record, pub, + priv); + ASSERT_NE(0U, record.len()); + SECStatus rv; + if (set_server_config) { + rv = SSL_SetServerEchConfigs(server->ssl_fd(), pub.get(), priv.get(), + record.data(), record.len()); + ASSERT_EQ(SECSuccess, rv); + } + if (set_client_config) { + rv = SSL_SetClientEchConfigs(client->ssl_fd(), record.data(), record.len()); + ASSERT_EQ(SECSuccess, rv); + } + + /* Filter expect_ech, which typically defaults to true. Parameterized tests + * running DTLS or TLS < 1.3 should expect only a non-ECH result. */ + bool expect = expect_ech && variant_ != ssl_variant_datagram && + version_ >= SSL_LIBRARY_VERSION_TLS_1_3 && set_client_config && + set_server_config; + client->ExpectEch(expect); + server->ExpectEch(expect); +} + +void TlsConnectTestBase::Reset() { + // Take a copy of the names because they are about to disappear. + std::string server_name = server_->name(); + std::string client_name = client_->name(); + Reset(server_name, client_name); +} + +void TlsConnectTestBase::Reset(const std::string& server_name, + const std::string& client_name) { + auto token = client_->GetResumptionToken(); + client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_)); + client_->SetResumptionToken(token); + server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_)); + if (skip_version_checks_) { + client_->SkipVersionChecks(); + server_->SkipVersionChecks(); + } + + std::cerr << "Reset server:" << server_name << ", client:" << client_name + << std::endl; + Init(); +} + +void TlsConnectTestBase::MakeNewServer() { + auto replacement = std::make_shared<TlsAgent>( + server_->name(), TlsAgent::SERVER, server_->variant()); + server_ = replacement; + if (version_) { + server_->SetVersionRange(version_, version_); + } + client_->SetPeer(server_); + server_->SetPeer(client_); + server_->StartConnect(); +} + +void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected, + uint8_t num_resumptions) { + expected_resumption_mode_ = expected; + if (expected != RESUME_NONE) { + client_->ExpectResumption(); + server_->ExpectResumption(); + expected_resumptions_ = num_resumptions; + } + EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE); +} + +void TlsConnectTestBase::EnsureTlsSetup() { + EXPECT_TRUE(server_->EnsureTlsSetup(server_model_ ? server_model_->ssl_fd() + : nullptr)); + EXPECT_TRUE(client_->EnsureTlsSetup(client_model_ ? client_model_->ssl_fd() + : nullptr)); + server_->SetAntiReplayContext(anti_replay_); + EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(client_->ssl_fd(), + TlsConnectTestBase::TimeFunc, &now_)); + EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(server_->ssl_fd(), + TlsConnectTestBase::TimeFunc, &now_)); +} + +void TlsConnectTestBase::Handshake() { + client_->SetServerKeyBits(server_->server_key_bits()); + client_->Handshake(); + server_->Handshake(); + + ASSERT_TRUE_WAIT((client_->state() != TlsAgent::STATE_CONNECTING) && + (server_->state() != TlsAgent::STATE_CONNECTING), + 5000); +} + +void TlsConnectTestBase::EnableExtendedMasterSecret() { + client_->EnableExtendedMasterSecret(); + server_->EnableExtendedMasterSecret(); + ExpectExtendedMasterSecret(true); +} + +void TlsConnectTestBase::Connect() { + StartConnect(); + client_->MaybeSetResumptionToken(); + Handshake(); + CheckConnected(); +} + +void TlsConnectTestBase::StartConnect() { + EnsureTlsSetup(); + server_->StartConnect(); + client_->StartConnect(); +} + +void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { + EnsureTlsSetup(); + client_->EnableSingleCipher(cipher_suite); + + Connect(); + SendReceive(); + + // Check that we used the right cipher suite. + uint16_t actual; + EXPECT_TRUE(client_->cipher_suite(&actual)); + EXPECT_EQ(cipher_suite, actual); + EXPECT_TRUE(server_->cipher_suite(&actual)); + EXPECT_EQ(cipher_suite, actual); +} + +void TlsConnectTestBase::CheckConnected() { + // Have the client read handshake twice to make sure we get the + // NST and the ACK. + if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 && + variant_ == ssl_variant_datagram) { + client_->Handshake(); + client_->Handshake(); + auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd()); + // Verify that we dropped the client's retransmission cipher suites. + EXPECT_EQ(2, suites) << "Client has the wrong number of suites"; + if (suites != 2) { + SSLInt_PrintCipherSpecs("client", client_->ssl_fd()); + } + } + EXPECT_EQ(client_->version(), server_->version()); + if (!skip_version_checks_) { + // Check the version is as expected + EXPECT_EQ(std::min(client_->max_version(), server_->max_version()), + client_->version()); + } + + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + uint16_t cipher_suite1, cipher_suite2; + bool ret = client_->cipher_suite(&cipher_suite1); + EXPECT_TRUE(ret); + ret = server_->cipher_suite(&cipher_suite2); + EXPECT_TRUE(ret); + EXPECT_EQ(cipher_suite1, cipher_suite2); + + std::cerr << "Connected with version " << client_->version() + << " cipher suite " << client_->cipher_suite_name() << std::endl; + + if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) { + // Check and store session ids. + std::vector<uint8_t> sid_c1 = client_->session_id(); + EXPECT_EQ(32U, sid_c1.size()); + std::vector<uint8_t> sid_s1 = server_->session_id(); + EXPECT_EQ(32U, sid_s1.size()); + EXPECT_EQ(sid_c1, sid_s1); + session_ids_.push_back(sid_c1); + } + + CheckExtendedMasterSecret(); + CheckEarlyDataAccepted(); + CheckResumption(expected_resumption_mode_); + client_->CheckSecretsDestroyed(); + server_->CheckSecretsDestroyed(); +} + +void TlsConnectTestBase::CheckEarlyDataLimit( + const std::shared_ptr<TlsAgent>& agent, size_t expected_size) { + SSLPreliminaryChannelInfo preinfo; + SECStatus rv = + SSL_GetPreliminaryChannelInfo(agent->ssl_fd(), &preinfo, sizeof(preinfo)); + EXPECT_EQ(SECSuccess, rv); + EXPECT_EQ(expected_size, static_cast<size_t>(preinfo.maxEarlyDataSize)); +} + +void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, + SSLAuthType auth_type, + SSLSignatureScheme sig_scheme) const { + if (kea_group != ssl_grp_none) { + client_->CheckKEA(kea_type, kea_group); + server_->CheckKEA(kea_type, kea_group); + } + server_->CheckAuthType(auth_type, sig_scheme); + client_->CheckAuthType(auth_type, sig_scheme); +} + +void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, + SSLAuthType auth_type) const { + SSLNamedGroup group; + switch (kea_type) { + case ssl_kea_ecdh: + group = ssl_grp_ec_curve25519; + break; + case ssl_kea_dh: + group = ssl_grp_ffdhe_2048; + break; + case ssl_kea_rsa: + group = ssl_grp_none; + break; + default: + EXPECT_TRUE(false) << "unexpected KEA"; + group = ssl_grp_none; + break; + } + + SSLSignatureScheme scheme; + switch (auth_type) { + case ssl_auth_rsa_decrypt: + scheme = ssl_sig_none; + break; + case ssl_auth_rsa_sign: + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) { + scheme = ssl_sig_rsa_pss_rsae_sha256; + } else { + scheme = ssl_sig_rsa_pkcs1_sha256; + } + break; + case ssl_auth_rsa_pss: + scheme = ssl_sig_rsa_pss_rsae_sha256; + break; + case ssl_auth_ecdsa: + scheme = ssl_sig_ecdsa_secp256r1_sha256; + break; + case ssl_auth_dsa: + scheme = ssl_sig_dsa_sha1; + break; + default: + EXPECT_TRUE(false) << "unexpected auth type"; + scheme = static_cast<SSLSignatureScheme>(0x0100); + break; + } + CheckKeys(kea_type, group, auth_type, scheme); +} + +void TlsConnectTestBase::CheckKeys() const { + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); +} + +void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type, + SSLNamedGroup kea_group, + SSLNamedGroup original_kea_group, + SSLAuthType auth_type, + SSLSignatureScheme sig_scheme) { + CheckKeys(kea_type, kea_group, auth_type, sig_scheme); + EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE); + client_->CheckOriginalKEA(original_kea_group); + server_->CheckOriginalKEA(original_kea_group); +} + +void TlsConnectTestBase::ConnectExpectFail() { + StartConnect(); + Handshake(); + ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state()); + ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state()); +} + +void TlsConnectTestBase::ExpectAlert(std::shared_ptr<TlsAgent>& sender, + uint8_t alert) { + EnsureTlsSetup(); + auto receiver = (sender == client_) ? server_ : client_; + sender->ExpectSendAlert(alert); + receiver->ExpectReceiveAlert(alert); +} + +void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, + uint8_t alert) { + ExpectAlert(sender, alert); + ConnectExpectFail(); +} + +void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) { + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + client_->Handshake(); + server_->Handshake(); + + auto failing_agent = server_; + if (failing_side == TlsAgent::CLIENT) { + failing_agent = client_; + } + ASSERT_TRUE_WAIT(failing_agent->state() == TlsAgent::STATE_ERROR, 5000); +} + +void TlsConnectTestBase::ConfigureVersion(uint16_t version) { + version_ = version; + client_->SetVersionRange(version, version); + server_->SetVersionRange(version, version); +} + +void TlsConnectTestBase::SetExpectedVersion(uint16_t version) { + client_->SetExpectedVersion(version); + server_->SetExpectedVersion(version); +} + +void TlsConnectTestBase::AddPsk(const ScopedPK11SymKey& psk, std::string label, + SSLHashType hash, uint16_t zeroRttSuite) { + client_->AddPsk(psk, label, hash, zeroRttSuite); + server_->AddPsk(psk, label, hash, zeroRttSuite); + client_->ExpectPsk(); + server_->ExpectPsk(); +} + +void TlsConnectTestBase::DisableAllCiphers() { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + server_->DisableAllCiphers(); +} + +void TlsConnectTestBase::EnableOnlyStaticRsaCiphers() { + DisableAllCiphers(); + + client_->EnableCiphersByKeyExchange(ssl_kea_rsa); + server_->EnableCiphersByKeyExchange(ssl_kea_rsa); +} + +void TlsConnectTestBase::EnableOnlyDheCiphers() { + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_dh); + server_->EnableCiphersByKeyExchange(ssl_kea_dh); + } else { + client_->ConfigNamedGroups(kFFDHEGroups); + server_->ConfigNamedGroups(kFFDHEGroups); + } +} + +void TlsConnectTestBase::EnableSomeEcdhCiphers() { + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + client_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa); + client_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa); + server_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa); + server_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa); + } else { + client_->ConfigNamedGroups(kECDHEGroups); + server_->ConfigNamedGroups(kECDHEGroups); + } +} + +void TlsConnectTestBase::ConfigureSelfEncrypt() { + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey privKey; + ASSERT_TRUE( + TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey)); + + ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); + ASSERT_TRUE(pubKey); + + EXPECT_EQ(SECSuccess, + SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); +} + +void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, + SessionResumptionMode server) { + client_->ConfigureSessionCache(client); + server_->ConfigureSessionCache(server); + if ((server & RESUME_TICKET) != 0) { + ConfigureSelfEncrypt(); + } +} + +void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { + EXPECT_NE(RESUME_BOTH, expected); + + int resume_count = expected ? expected_resumptions_ : 0; + int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0; + + // Note: hch == server counter; hsh == client counter. + SSL3Statistics* stats = SSL_GetStatistics(); + EXPECT_EQ(resume_count, stats->hch_sid_cache_hits); + EXPECT_EQ(resume_count, stats->hsh_sid_cache_hits); + + EXPECT_EQ(stateless_count, stats->hch_sid_stateless_resumes); + EXPECT_EQ(stateless_count, stats->hsh_sid_stateless_resumes); + + if (expected != RESUME_NONE) { + if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 && + client_->GetResumptionToken().size() == 0) { + // Check that the last two session ids match. + ASSERT_EQ(1U + expected_resumptions_, session_ids_.size()); + EXPECT_EQ(session_ids_[session_ids_.size() - 1], + session_ids_[session_ids_.size() - 2]); + } else { + // We've either chosen TLS 1.3 or are using an external resumption token, + // both of which only use tickets. + EXPECT_TRUE(expected & RESUME_TICKET); + } + } +} + +static SECStatus NextProtoCallbackServer(void* arg, PRFileDesc* fd, + const unsigned char* protos, + unsigned int protos_len, + unsigned char* protoOut, + unsigned int* protoOutLen, + unsigned int protoMaxLen) { + EXPECT_EQ(protoMaxLen, 255U); + TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); + // Check that agent->alpn_value_to_use_ is in protos. + if (protos_len < 1) { + return SECFailure; + } + for (size_t i = 0; i < protos_len;) { + size_t l = protos[i]; + EXPECT_LT(i + l, protos_len); + if (i + l >= protos_len) { + return SECFailure; + } + std::string protos_s(reinterpret_cast<const char*>(protos + i + 1), l); + if (protos_s == agent->alpn_value_to_use_) { + size_t s_len = agent->alpn_value_to_use_.size(); + EXPECT_LE(s_len, 255U); + memcpy(protoOut, &agent->alpn_value_to_use_[0], s_len); + *protoOutLen = s_len; + return SECSuccess; + } + i += l + 1; + } + return SECFailure; +} + +void TlsConnectTestBase::EnableAlpn() { + client_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); + server_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); +} + +void TlsConnectTestBase::EnableAlpnWithCallback( + const std::vector<uint8_t>& client_vals, std::string server_choice) { + EnsureTlsSetup(); + server_->alpn_value_to_use_ = server_choice; + EXPECT_EQ(SECSuccess, + SSL_SetNextProtoNego(client_->ssl_fd(), client_vals.data(), + client_vals.size())); + SECStatus rv = SSL_SetNextProtoCallback( + server_->ssl_fd(), NextProtoCallbackServer, server_.get()); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsConnectTestBase::EnableAlpn(const std::vector<uint8_t>& vals) { + client_->EnableAlpn(vals.data(), vals.size()); + server_->EnableAlpn(vals.data(), vals.size()); +} + +void TlsConnectTestBase::EnsureModelSockets() { + // Make sure models agents are available. + if (!client_model_) { + ASSERT_EQ(server_model_, nullptr); + client_model_.reset( + new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)); + server_model_.reset( + new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)); + if (skip_version_checks_) { + client_model_->SkipVersionChecks(); + server_model_->SkipVersionChecks(); + } + } +} + +void TlsConnectTestBase::CheckAlpn(const std::string& val) { + client_->CheckAlpn(SSL_NEXT_PROTO_SELECTED, val); + server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, val); +} + +void TlsConnectTestBase::EnableSrtp() { + client_->EnableSrtp(); + server_->EnableSrtp(); +} + +void TlsConnectTestBase::CheckSrtp() const { + client_->CheckSrtp(); + server_->CheckSrtp(); +} + +void TlsConnectTestBase::SendReceive(size_t total) { + ASSERT_GT(total, client_->received_bytes()); + ASSERT_GT(total, server_->received_bytes()); + client_->SendData(total - server_->received_bytes()); + server_->SendData(total - client_->received_bytes()); + Receive(total); // Receive() is cumulative +} + +// Do a first connection so we can do 0-RTT on the second one. +void TlsConnectTestBase::SetupForZeroRtt() { + // Force rollover of the anti-replay window. + // If we don't do this, then all 0-RTT attempts will be rejected. + RolloverAntiReplay(); + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT. + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); + StartConnect(); +} + +// Do a first connection so we can do resumption +void TlsConnectTestBase::SetupForResume() { + EnsureTlsSetup(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); +} + +void TlsConnectTestBase::ZeroRttSendReceive( + bool expect_writable, bool expect_readable, + std::function<bool()> post_clienthello_check) { + const char* k0RttData = "ABCDEF"; + const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + + client_->Handshake(); // Send ClientHello. + if (post_clienthello_check) { + if (!post_clienthello_check()) return; + } + PRInt32 rv = + PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write. + if (expect_writable) { + EXPECT_EQ(k0RttDataLen, rv); + } else { + EXPECT_EQ(SECFailure, rv); + } + server_->Handshake(); // Consume ClientHello + + std::vector<uint8_t> buf(k0RttDataLen); + rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read + if (expect_readable) { + std::cerr << "0-RTT read " << rv << " bytes\n"; + EXPECT_EQ(k0RttDataLen, rv); + } else { + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + } + + // Do a second read. This should fail. + rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +void TlsConnectTestBase::Receive(size_t amount) { + WAIT_(client_->received_bytes() == amount && + server_->received_bytes() == amount, + 2000); + ASSERT_EQ(amount, client_->received_bytes()); + ASSERT_EQ(amount, server_->received_bytes()); +} + +void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) { + expect_extended_master_secret_ = expected; +} + +void TlsConnectTestBase::CheckExtendedMasterSecret() { + client_->CheckExtendedMasterSecret(expect_extended_master_secret_); + server_->CheckExtendedMasterSecret(expect_extended_master_secret_); +} + +void TlsConnectTestBase::ExpectEarlyDataAccepted(bool expected) { + expect_early_data_accepted_ = expected; +} + +void TlsConnectTestBase::CheckEarlyDataAccepted() { + client_->CheckEarlyDataAccepted(expect_early_data_accepted_); + server_->CheckEarlyDataAccepted(expect_early_data_accepted_); +} + +void TlsConnectTestBase::DisableECDHEServerKeyReuse() { + server_->DisableECDHEServerKeyReuse(); +} + +void TlsConnectTestBase::SkipVersionChecks() { + skip_version_checks_ = true; + client_->SkipVersionChecks(); + server_->SkipVersionChecks(); +} + +// Shift the DTLS timers, to the minimum time necessary to let the next timer +// run on either client or server. This allows tests to skip waiting without +// having timers run out of order. +void TlsConnectTestBase::ShiftDtlsTimers() { + PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT; + PRIntervalTime time; + SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); + if (rv == SECSuccess) { + time_shift = time; + } + rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); + if (rv == SECSuccess && + (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) { + time_shift = time; + } + + if (time_shift != PR_INTERVAL_NO_TIMEOUT) { + AdvanceTime(PR_IntervalToMicroseconds(time_shift)); + EXPECT_EQ(SECSuccess, + SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift)); + EXPECT_EQ(SECSuccess, + SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift)); + } +} + +void TlsConnectTestBase::AdvanceTime(PRTime time_shift) { now_ += time_shift; } + +// Advance time by a full anti-replay window. +void TlsConnectTestBase::RolloverAntiReplay() { + AdvanceTime(kAntiReplayWindow); +} + +TlsConnectGeneric::TlsConnectGeneric() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + +TlsConnectPre12::TlsConnectPre12() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + +TlsConnectTls12::TlsConnectTls12() + : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_2) {} + +TlsConnectTls12Plus::TlsConnectTls12Plus() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + +TlsConnectTls13::TlsConnectTls13() + : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} + +TlsConnectGenericResumption::TlsConnectGenericResumption() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())), + external_cache_(std::get<2>(GetParam())) {} + +TlsConnectTls13ResumptionToken::TlsConnectTls13ResumptionToken() + : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} + +TlsConnectGenericResumptionToken::TlsConnectGenericResumptionToken() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + +void TlsKeyExchangeTest::EnsureKeyShareSetup() { + EnsureTlsSetup(); + groups_capture_ = + std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn); + shares_capture_ = + std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn); + shares_capture2_ = std::make_shared<TlsExtensionCapture>( + client_, ssl_tls13_key_share_xtn, true); + std::vector<std::shared_ptr<PacketFilter>> captures = { + groups_capture_, shares_capture_, shares_capture2_}; + client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures)); + capture_hrr_ = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeHelloRetryRequest); +} + +void TlsKeyExchangeTest::ConfigNamedGroups( + const std::vector<SSLNamedGroup>& groups) { + client_->ConfigNamedGroups(groups); + server_->ConfigNamedGroups(groups); +} + +std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails( + const std::shared_ptr<TlsExtensionCapture>& capture) { + EXPECT_TRUE(capture->captured()); + const DataBuffer& ext = capture->extension(); + + uint32_t tmp = 0; + EXPECT_TRUE(ext.Read(0, 2, &tmp)); + EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); + EXPECT_TRUE(ext.len() % 2 == 0); + + std::vector<SSLNamedGroup> groups; + for (size_t i = 1; i < ext.len() / 2; i += 1) { + EXPECT_TRUE(ext.Read(2 * i, 2, &tmp)); + groups.push_back(static_cast<SSLNamedGroup>(tmp)); + } + return groups; +} + +std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails( + const std::shared_ptr<TlsExtensionCapture>& capture) { + EXPECT_TRUE(capture->captured()); + const DataBuffer& ext = capture->extension(); + + uint32_t tmp = 0; + EXPECT_TRUE(ext.Read(0, 2, &tmp)); + EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); + + std::vector<SSLNamedGroup> shares; + size_t i = 2; + while (i < ext.len()) { + EXPECT_TRUE(ext.Read(i, 2, &tmp)); + shares.push_back(static_cast<SSLNamedGroup>(tmp)); + EXPECT_TRUE(ext.Read(i + 2, 2, &tmp)); + i += 4 + tmp; + } + EXPECT_EQ(ext.len(), i); + return shares; +} + +void TlsKeyExchangeTest::CheckKEXDetails( + const std::vector<SSLNamedGroup>& expected_groups, + const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) { + std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_); + EXPECT_EQ(expected_groups, groups); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + ASSERT_LT(0U, expected_shares.size()); + std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_); + EXPECT_EQ(expected_shares, shares); + } else { + EXPECT_FALSE(shares_capture_->captured()); + } + + EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0); +} + +void TlsKeyExchangeTest::CheckKEXDetails( + const std::vector<SSLNamedGroup>& expected_groups, + const std::vector<SSLNamedGroup>& expected_shares) { + CheckKEXDetails(expected_groups, expected_shares, false); +} + +void TlsKeyExchangeTest::CheckKEXDetails( + const std::vector<SSLNamedGroup>& expected_groups, + const std::vector<SSLNamedGroup>& expected_shares, + SSLNamedGroup expected_share2) { + CheckKEXDetails(expected_groups, expected_shares, true); + + for (auto it : expected_shares) { + EXPECT_NE(expected_share2, it); + } + std::vector<SSLNamedGroup> expected_shares2 = {expected_share2}; + EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_)); +} +} // namespace nss_test |