summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_agent.cc
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /security/nss/gtests/ssl_gtest/tls_agent.cc
parentInitial commit. (diff)
downloadfirefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz
firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_agent.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_agent.cc1432
1 files changed, 1432 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc
new file mode 100644
index 0000000000..8ec2f40f75
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_agent.cc
@@ -0,0 +1,1432 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_agent.h"
+#include "databuffer.h"
+#include "keyhi.h"
+#include "pk11func.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslexp.h"
+#include "sslproto.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+// This is an internal header, used to get DTLS_1_3_DRAFT_VERSION.
+#include "ssl3prot.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+
+extern std::string g_working_dir_path;
+
+namespace nss_test {
+
+const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"};
+
+const std::string TlsAgent::kClient = "client"; // both sign and encrypt
+const std::string TlsAgent::kRsa2048 = "rsa2048"; // bigger
+const std::string TlsAgent::kRsa8192 = "rsa8192"; // biggest allowed
+const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt
+const std::string TlsAgent::kServerRsaSign = "rsa_sign";
+const std::string TlsAgent::kServerRsaPss = "rsa_pss";
+const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt";
+const std::string TlsAgent::kServerEcdsa256 = "ecdsa256";
+const std::string TlsAgent::kServerEcdsa384 = "ecdsa384";
+const std::string TlsAgent::kServerEcdsa521 = "ecdsa521";
+const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa";
+const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa";
+const std::string TlsAgent::kServerDsa = "dsa";
+const std::string TlsAgent::kDelegatorEcdsa256 = "delegator_ecdsa256";
+const std::string TlsAgent::kDelegatorRsae2048 = "delegator_rsae2048";
+const std::string TlsAgent::kDelegatorRsaPss2048 = "delegator_rsa_pss2048";
+
+static const uint8_t kCannedTls13ServerHello[] = {
+ 0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3,
+ 0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b,
+ 0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76,
+ 0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24,
+ 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03,
+ 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9,
+ 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08,
+ 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04};
+
+TlsAgent::TlsAgent(const std::string& nm, Role rl, SSLProtocolVariant var)
+ : name_(nm),
+ variant_(var),
+ role_(rl),
+ server_key_bits_(0),
+ adapter_(new DummyPrSocket(role_str(), var)),
+ ssl_fd_(nullptr),
+ state_(STATE_INIT),
+ timer_handle_(nullptr),
+ falsestart_enabled_(false),
+ expected_version_(0),
+ expected_cipher_suite_(0),
+ expect_client_auth_(false),
+ expect_ech_(false),
+ expect_psk_(ssl_psk_none),
+ can_falsestart_hook_called_(false),
+ sni_hook_called_(false),
+ auth_certificate_hook_called_(false),
+ expected_received_alert_(kTlsAlertCloseNotify),
+ expected_received_alert_level_(kTlsAlertWarning),
+ expected_sent_alert_(kTlsAlertCloseNotify),
+ expected_sent_alert_level_(kTlsAlertWarning),
+ handshake_callback_called_(false),
+ resumption_callback_called_(false),
+ error_code_(0),
+ send_ctr_(0),
+ recv_ctr_(0),
+ expect_readwrite_error_(false),
+ handshake_callback_(),
+ auth_certificate_callback_(),
+ sni_callback_(),
+ skip_version_checks_(false),
+ resumption_token_(),
+ policy_() {
+ memset(&info_, 0, sizeof(info_));
+ memset(&csinfo_, 0, sizeof(csinfo_));
+ SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+TlsAgent::~TlsAgent() {
+ if (timer_handle_) {
+ timer_handle_->Cancel();
+ }
+
+ if (adapter_) {
+ Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
+ }
+
+ // Add failures manually, if any, so we don't throw in a destructor.
+ if (expected_received_alert_ != kTlsAlertCloseNotify ||
+ expected_received_alert_level_ != kTlsAlertWarning) {
+ ADD_FAILURE() << "Wrong expected_received_alert status: " << role_str();
+ }
+ if (expected_sent_alert_ != kTlsAlertCloseNotify ||
+ expected_sent_alert_level_ != kTlsAlertWarning) {
+ ADD_FAILURE() << "Wrong expected_sent_alert status: " << role_str();
+ }
+}
+
+void TlsAgent::SetState(State s) {
+ if (state_ == s) return;
+
+ LOG("Changing state from " << state_ << " to " << s);
+ state_ = s;
+}
+
+/*static*/ bool TlsAgent::LoadCertificate(const std::string& name,
+ ScopedCERTCertificate* cert,
+ ScopedSECKEYPrivateKey* priv) {
+ cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr));
+ EXPECT_NE(nullptr, cert);
+ if (!cert) return false;
+ EXPECT_NE(nullptr, cert->get());
+ if (!cert->get()) return false;
+
+ priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr));
+ EXPECT_NE(nullptr, priv);
+ if (!priv) return false;
+ EXPECT_NE(nullptr, priv->get());
+ if (!priv->get()) return false;
+
+ return true;
+}
+
+// Loads a key pair from the certificate identified by |id|.
+/*static*/ bool TlsAgent::LoadKeyPairFromCert(const std::string& name,
+ ScopedSECKEYPublicKey* pub,
+ ScopedSECKEYPrivateKey* priv) {
+ ScopedCERTCertificate cert;
+ if (!TlsAgent::LoadCertificate(name, &cert, priv)) {
+ return false;
+ }
+
+ pub->reset(SECKEY_ExtractPublicKey(&cert->subjectPublicKeyInfo));
+ if (!pub->get()) {
+ return false;
+ }
+
+ return true;
+}
+
+void TlsAgent::DelegateCredential(const std::string& name,
+ const ScopedSECKEYPublicKey& dc_pub,
+ SSLSignatureScheme dc_cert_verify_alg,
+ PRUint32 dc_valid_for, PRTime now,
+ SECItem* dc) {
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey cert_priv;
+ EXPECT_TRUE(TlsAgent::LoadCertificate(name, &cert, &cert_priv))
+ << "Could not load delegate certificate: " << name
+ << "; test db corrupt?";
+
+ EXPECT_EQ(SECSuccess,
+ SSL_DelegateCredential(cert.get(), cert_priv.get(), dc_pub.get(),
+ dc_cert_verify_alg, dc_valid_for, now, dc));
+}
+
+void TlsAgent::EnableDelegatedCredentials() {
+ ASSERT_TRUE(EnsureTlsSetup());
+ SetOption(SSL_ENABLE_DELEGATED_CREDENTIALS, PR_TRUE);
+}
+
+void TlsAgent::AddDelegatedCredential(const std::string& dc_name,
+ SSLSignatureScheme dc_cert_verify_alg,
+ PRUint32 dc_valid_for, PRTime now) {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(dc_name, &pub, &priv));
+
+ StackSECItem dc;
+ TlsAgent::DelegateCredential(name_, pub, dc_cert_verify_alg, dc_valid_for,
+ now, &dc);
+
+ SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr,
+ nullptr, &dc, priv.get()};
+ EXPECT_TRUE(ConfigServerCert(name_, true, &extra_data));
+}
+
+bool TlsAgent::ConfigServerCert(const std::string& id, bool updateKeyBits,
+ const SSLExtraServerCertData* serverCertData) {
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ if (!TlsAgent::LoadCertificate(id, &cert, &priv)) {
+ return false;
+ }
+
+ if (updateKeyBits) {
+ ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get()));
+ EXPECT_NE(nullptr, pub.get());
+ if (!pub.get()) return false;
+ server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get());
+ }
+
+ SECStatus rv =
+ SSL_ConfigSecureServer(ssl_fd(), nullptr, nullptr, ssl_kea_null);
+ EXPECT_EQ(SECFailure, rv);
+ rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData,
+ serverCertData ? sizeof(*serverCertData) : 0);
+ return rv == SECSuccess;
+}
+
+bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
+ // Don't set up twice
+ if (ssl_fd_) return true;
+ NssManagePolicy policyManage(policy_, option_);
+
+ ScopedPRFileDesc dummy_fd(adapter_->CreateFD());
+ EXPECT_NE(nullptr, dummy_fd);
+ if (!dummy_fd) {
+ return false;
+ }
+ if (adapter_->variant() == ssl_variant_stream) {
+ ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get()));
+ } else {
+ ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get()));
+ }
+
+ EXPECT_NE(nullptr, ssl_fd_);
+ if (!ssl_fd_) {
+ return false;
+ }
+ dummy_fd.release(); // Now subsumed by ssl_fd_.
+
+ SECStatus rv;
+ if (!skip_version_checks_) {
+ rv = SSL_VersionRangeSet(ssl_fd(), &vrange_);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+ }
+
+ ScopedCERTCertList anchors(CERT_NewCertList());
+ rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get());
+ if (rv != SECSuccess) return false;
+
+ if (role_ == SERVER) {
+ EXPECT_TRUE(ConfigServerCert(name_, true));
+
+ rv = SSL_SNISocketConfigHook(ssl_fd(), SniHook, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ rv = SSL_SetMaxEarlyDataSize(ssl_fd(), 1024);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+ } else {
+ rv = SSL_SetURL(ssl_fd(), "server");
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+ }
+
+ rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ // All these tests depend on having this disabled to start with.
+ SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_FALSE);
+
+ return true;
+}
+
+bool TlsAgent::MaybeSetResumptionToken() {
+ if (!resumption_token_.empty()) {
+ LOG("setting external resumption token");
+ SECStatus rv = SSL_SetResumptionToken(ssl_fd(), resumption_token_.data(),
+ resumption_token_.size());
+
+ // rv is SECFailure with error set to SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR
+ // if the resumption token was bad (expired/malformed/etc.).
+ if (expect_psk_ == ssl_psk_resume) {
+ // Only in case we expect resumption this has to be successful. We might
+ // not expect resumption due to some reason but the token is totally fine.
+ EXPECT_EQ(SECSuccess, rv);
+ }
+ if (rv != SECSuccess) {
+ EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError());
+ resumption_token_.clear();
+ EXPECT_FALSE(expect_psk_ == ssl_psk_resume);
+ if (expect_psk_ == ssl_psk_resume) return false;
+ }
+ }
+
+ return true;
+}
+
+void TlsAgent::SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx) {
+ EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(ssl_fd(), ctx.get()));
+}
+
+// Defaults to a Sync callback returning success
+void TlsAgent::SetupClientAuth(ClientAuthCallbackType callbackType,
+ bool callbackSuccess) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ ASSERT_EQ(CLIENT, role_);
+
+ client_auth_callback_type_ = callbackType;
+ client_auth_callback_success_ = callbackSuccess;
+
+ if (callbackType == ClientAuthCallbackType::kNone && !callbackSuccess) {
+ // Don't set a callback for this case.
+ return;
+ }
+ EXPECT_EQ(SECSuccess,
+ SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook,
+ reinterpret_cast<void*>(this)));
+}
+
+void CheckCertReqAgainstDefaultCAs(const CERTDistNames* caNames) {
+ ScopedCERTDistNames expected(CERT_GetSSLCACerts(nullptr));
+
+ ASSERT_EQ(expected->nnames, caNames->nnames);
+
+ for (size_t i = 0; i < static_cast<size_t>(expected->nnames); ++i) {
+ EXPECT_EQ(SECEqual,
+ SECITEM_CompareItem(&(expected->names[i]), &(caNames->names[i])));
+ }
+}
+
+// Complete processing of Client Certificate Selection
+// A No-op if the agent is using synchronous client cert selection.
+// Otherwise, calls SSL_ClientCertCallbackComplete.
+// kAsyncDelay triggers a call to SSL_ForceHandshake prior to completion to
+// ensure that the socket is correctly blocked.
+void TlsAgent::ClientAuthCallbackComplete() {
+ ASSERT_EQ(CLIENT, role_);
+
+ if (client_auth_callback_type_ != ClientAuthCallbackType::kAsyncDelay &&
+ client_auth_callback_type_ != ClientAuthCallbackType::kAsyncImmediate) {
+ return;
+ }
+ client_auth_callback_fired_++;
+ EXPECT_TRUE(client_auth_callback_awaiting_);
+
+ std::cerr << "client: calling SSL_ClientCertCallbackComplete with status "
+ << (client_auth_callback_success_ ? "success" : "failed")
+ << std::endl;
+
+ client_auth_callback_awaiting_ = false;
+
+ if (client_auth_callback_type_ == ClientAuthCallbackType::kAsyncDelay) {
+ std::cerr
+ << "Running Handshake prior to running SSL_ClientCertCallbackComplete"
+ << std::endl;
+ SECStatus rv = SSL_ForceHandshake(ssl_fd());
+ EXPECT_EQ(rv, SECFailure);
+ EXPECT_EQ(PORT_GetError(), PR_WOULD_BLOCK_ERROR);
+ }
+
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ if (client_auth_callback_success_) {
+ ASSERT_TRUE(TlsAgent::LoadCertificate(name(), &cert, &priv));
+ EXPECT_EQ(SECSuccess,
+ SSL_ClientCertCallbackComplete(ssl_fd(), SECSuccess,
+ priv.release(), cert.release()));
+ } else {
+ EXPECT_EQ(SECSuccess, SSL_ClientCertCallbackComplete(ssl_fd(), SECFailure,
+ nullptr, nullptr));
+ }
+}
+
+SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd,
+ CERTDistNames* caNames,
+ CERTCertificate** clientCert,
+ SECKEYPrivateKey** clientKey) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(self);
+ EXPECT_EQ(CLIENT, agent->role_);
+ agent->client_auth_callback_fired_++;
+
+ switch (agent->client_auth_callback_type_) {
+ case ClientAuthCallbackType::kAsyncDelay:
+ case ClientAuthCallbackType::kAsyncImmediate:
+ std::cerr << "Waiting for complete call" << std::endl;
+ agent->client_auth_callback_awaiting_ = true;
+ return SECWouldBlock;
+ case ClientAuthCallbackType::kSync:
+ case ClientAuthCallbackType::kNone:
+ // Handle the sync case. None && Success is treated as Sync and Success.
+ if (!agent->client_auth_callback_success_) {
+ return SECFailure;
+ }
+ ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd()));
+ EXPECT_TRUE(peerCert) << "Client should be able to see the server cert";
+
+ // See bug 1573945
+ // CheckCertReqAgainstDefaultCAs(caNames);
+
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) {
+ return SECFailure;
+ }
+
+ *clientCert = cert.release();
+ *clientKey = priv.release();
+ return SECSuccess;
+ }
+ /* This is unreachable, but some old compilers can't tell that. */
+ PORT_Assert(0);
+ PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
+ return SECFailure;
+}
+
+// Increments by 1 for each callback
+bool TlsAgent::CheckClientAuthCallbacksCompleted(uint8_t expected) {
+ EXPECT_EQ(CLIENT, role_);
+ return expected == client_auth_callback_fired_;
+}
+
+bool TlsAgent::GetPeerChainLength(size_t* count) {
+ CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd());
+ if (!chain) return false;
+ *count = 0;
+
+ for (PRCList* cursor = PR_NEXT_LINK(&chain->list); cursor != &chain->list;
+ cursor = PR_NEXT_LINK(cursor)) {
+ CERTCertListNode* node = (CERTCertListNode*)cursor;
+ std::cerr << node->cert->subjectName << std::endl;
+ ++(*count);
+ }
+
+ CERT_DestroyCertList(chain);
+
+ return true;
+}
+
+void TlsAgent::CheckCipherSuite(uint16_t suite) {
+ EXPECT_EQ(csinfo_.cipherSuite, suite);
+}
+
+void TlsAgent::RequestClientAuth(bool requireAuth) {
+ ASSERT_EQ(SERVER, role_);
+
+ SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE);
+ SetOption(SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE);
+
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook(
+ ssl_fd(), &TlsAgent::ClientAuthenticated, this));
+ expect_client_auth_ = true;
+}
+
+void TlsAgent::StartConnect(PRFileDesc* model) {
+ EXPECT_TRUE(EnsureTlsSetup(model));
+
+ SECStatus rv;
+ rv = SSL_ResetHandshake(ssl_fd(), role_ == SERVER ? PR_TRUE : PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+ SetState(STATE_CONNECTING);
+}
+
+void TlsAgent::DisableAllCiphers() {
+ for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
+ SECStatus rv =
+ SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+}
+
+// Not actually all groups, just the onece that we are actually willing
+// to use.
+const std::vector<SSLNamedGroup> kAllDHEGroups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1, ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072,
+ ssl_grp_ffdhe_4096, ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192};
+
+const std::vector<SSLNamedGroup> kECDHEGroups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+
+const std::vector<SSLNamedGroup> kFFDHEGroups = {
+ ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096,
+ ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192};
+
+// Defined because the big DHE groups are ridiculously slow.
+const std::vector<SSLNamedGroup> kFasterDHEGroups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072};
+
+void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
+ SSLCipherSuiteInfo csinfo;
+
+ SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
+ sizeof(csinfo));
+ ASSERT_EQ(SECSuccess, rv);
+ EXPECT_EQ(sizeof(csinfo), csinfo.length);
+
+ if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) {
+ rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+ }
+}
+
+void TlsAgent::EnableGroupsByKeyExchange(SSLKEAType kea) {
+ switch (kea) {
+ case ssl_kea_dh:
+ ConfigNamedGroups(kFFDHEGroups);
+ break;
+ case ssl_kea_ecdh:
+ ConfigNamedGroups(kECDHEGroups);
+ break;
+ default:
+ break;
+ }
+}
+
+void TlsAgent::EnableGroupsByAuthType(SSLAuthType authType) {
+ if (authType == ssl_auth_ecdh_rsa || authType == ssl_auth_ecdh_ecdsa ||
+ authType == ssl_auth_ecdsa || authType == ssl_auth_tls13_any) {
+ ConfigNamedGroups(kECDHEGroups);
+ }
+}
+
+void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
+ SSLCipherSuiteInfo csinfo;
+
+ SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
+ sizeof(csinfo));
+ ASSERT_EQ(SECSuccess, rv);
+
+ if ((csinfo.authType == authType) ||
+ (csinfo.keaType == ssl_kea_tls13_any)) {
+ rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+ }
+}
+
+void TlsAgent::EnableSingleCipher(uint16_t cipher) {
+ DisableAllCiphers();
+ SECStatus rv = SSL_CipherPrefSet(ssl_fd(), cipher, PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ SECStatus rv = SSL_NamedGroupConfig(ssl_fd(), &groups[0], groups.size());
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::Set0RttEnabled(bool en) {
+ SetOption(SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
+}
+
+void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
+ vrange_.min = minver;
+ vrange_.max = maxver;
+
+ if (ssl_fd()) {
+ SECStatus rv = SSL_VersionRangeSet(ssl_fd(), &vrange_);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+}
+
+SECStatus ResumptionTokenCallback(PRFileDesc* fd,
+ const PRUint8* resumptionToken,
+ unsigned int len, void* ctx) {
+ EXPECT_NE(nullptr, resumptionToken);
+ if (!resumptionToken) {
+ return SECFailure;
+ }
+
+ std::vector<uint8_t> new_token(resumptionToken, resumptionToken + len);
+ reinterpret_cast<TlsAgent*>(ctx)->SetResumptionToken(new_token);
+ reinterpret_cast<TlsAgent*>(ctx)->SetResumptionCallbackCalled();
+ return SECSuccess;
+}
+
+void TlsAgent::SetResumptionTokenCallback() {
+ EXPECT_TRUE(EnsureTlsSetup());
+ SECStatus rv =
+ SSL_SetResumptionTokenCallback(ssl_fd(), ResumptionTokenCallback, this);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) {
+ *minver = vrange_.min;
+ *maxver = vrange_.max;
+}
+
+void TlsAgent::SetExpectedVersion(uint16_t ver) { expected_version_ = ver; }
+
+void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; }
+
+void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; }
+
+void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; }
+
+void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes,
+ size_t count) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ EXPECT_LE(count, SSL_SignatureMaxCount());
+ EXPECT_EQ(SECSuccess,
+ SSL_SignatureSchemePrefSet(ssl_fd(), schemes,
+ static_cast<unsigned int>(count)));
+ EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 0))
+ << "setting no schemes should fail and do nothing";
+
+ std::vector<SSLSignatureScheme> configuredSchemes(count);
+ unsigned int configuredCount;
+ EXPECT_EQ(SECFailure,
+ SSL_SignatureSchemePrefGet(ssl_fd(), nullptr, &configuredCount, 1))
+ << "get schemes, schemes is nullptr";
+ EXPECT_EQ(SECFailure,
+ SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0],
+ &configuredCount, 0))
+ << "get schemes, too little space";
+ EXPECT_EQ(SECFailure,
+ SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], nullptr,
+ configuredSchemes.size()))
+ << "get schemes, countOut is nullptr";
+
+ EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet(
+ ssl_fd(), &configuredSchemes[0], &configuredCount,
+ configuredSchemes.size()));
+ // SignatureSchemePrefSet drops unsupported algorithms silently, so the
+ // number that are configured might be fewer.
+ EXPECT_LE(configuredCount, count);
+ unsigned int i = 0;
+ for (unsigned int j = 0; j < count && i < configuredCount; ++j) {
+ if (i < configuredCount && schemes[j] == configuredSchemes[i]) {
+ ++i;
+ }
+ }
+ EXPECT_EQ(i, configuredCount) << "schemes in use were all set";
+}
+
+void TlsAgent::CheckKEA(SSLKEAType kea, SSLNamedGroup kea_group,
+ size_t kea_size) const {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+ EXPECT_EQ(kea, info_.keaType);
+ if (kea_size == 0) {
+ switch (kea_group) {
+ case ssl_grp_ec_curve25519:
+ kea_size = 255;
+ break;
+ case ssl_grp_ec_secp256r1:
+ kea_size = 256;
+ break;
+ case ssl_grp_ec_secp384r1:
+ kea_size = 384;
+ break;
+ case ssl_grp_ffdhe_2048:
+ kea_size = 2048;
+ break;
+ case ssl_grp_ffdhe_3072:
+ kea_size = 3072;
+ break;
+ case ssl_grp_ffdhe_custom:
+ break;
+ default:
+ if (kea == ssl_kea_rsa) {
+ kea_size = server_key_bits_;
+ } else {
+ EXPECT_TRUE(false) << "need to update group sizes";
+ }
+ }
+ }
+ if (kea_group != ssl_grp_ffdhe_custom) {
+ EXPECT_EQ(kea_size, info_.keaKeyBits);
+ EXPECT_EQ(kea_group, info_.keaGroup);
+ }
+}
+
+void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const {
+ if (kea_group != ssl_grp_ffdhe_custom) {
+ EXPECT_EQ(kea_group, info_.originalKeaGroup);
+ }
+}
+
+void TlsAgent::CheckAuthType(SSLAuthType auth,
+ SSLSignatureScheme sig_scheme) const {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+ EXPECT_EQ(auth, info_.authType);
+ if (auth != ssl_auth_psk) {
+ EXPECT_EQ(server_key_bits_, info_.authKeyBits);
+ }
+ if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) {
+ switch (auth) {
+ case ssl_auth_rsa_sign:
+ sig_scheme = ssl_sig_rsa_pkcs1_sha1md5;
+ break;
+ case ssl_auth_ecdsa:
+ sig_scheme = ssl_sig_ecdsa_sha1;
+ break;
+ default:
+ break;
+ }
+ }
+ EXPECT_EQ(sig_scheme, info_.signatureScheme);
+
+ if (info_.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ return;
+ }
+
+ // Check authAlgorithm, which is the old value for authType. This is a second
+ // switch statement because default label is different.
+ switch (auth) {
+ case ssl_auth_rsa_sign:
+ case ssl_auth_rsa_pss:
+ EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
+ << "authAlgorithm for RSA is always decrypt";
+ break;
+ case ssl_auth_ecdh_rsa:
+ EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
+ << "authAlgorithm for ECDH_RSA is RSA decrypt (i.e., wrong)";
+ break;
+ case ssl_auth_ecdh_ecdsa:
+ EXPECT_EQ(ssl_auth_ecdsa, csinfo_.authAlgorithm)
+ << "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)";
+ break;
+ default:
+ EXPECT_EQ(auth, csinfo_.authAlgorithm)
+ << "authAlgorithm is (usually) the same as authType";
+ break;
+ }
+}
+
+void TlsAgent::EnableFalseStart() {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ falsestart_enabled_ = true;
+ EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback(
+ ssl_fd(), CanFalseStartCallback, this));
+ SetOption(SSL_ENABLE_FALSE_START, PR_TRUE);
+}
+
+void TlsAgent::ExpectEch(bool expected) { expect_ech_ = expected; }
+
+void TlsAgent::ExpectPsk(SSLPskType psk) { expect_psk_ = psk; }
+
+void TlsAgent::ExpectResumption() { expect_psk_ = ssl_psk_resume; }
+
+void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len));
+}
+
+void TlsAgent::AddPsk(const ScopedPK11SymKey& psk, std::string label,
+ SSLHashType hash, uint16_t zeroRttSuite) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ EXPECT_EQ(SECSuccess, SSL_AddExternalPsk0Rtt(
+ ssl_fd(), psk.get(),
+ reinterpret_cast<const uint8_t*>(label.data()),
+ label.length(), hash, zeroRttSuite, 1000));
+}
+
+void TlsAgent::RemovePsk(std::string label) {
+ EXPECT_EQ(SECSuccess,
+ SSL_RemoveExternalPsk(
+ ssl_fd(), reinterpret_cast<const uint8_t*>(label.data()),
+ label.length()));
+}
+
+void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
+ const std::string& expected) const {
+ SSLNextProtoState alpn_state;
+ char chosen[10];
+ unsigned int chosen_len;
+ SECStatus rv = SSL_GetNextProto(ssl_fd(), &alpn_state,
+ reinterpret_cast<unsigned char*>(chosen),
+ &chosen_len, sizeof(chosen));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(expected_state, alpn_state);
+ if (alpn_state == SSL_NEXT_PROTO_NO_SUPPORT) {
+ EXPECT_EQ("", expected);
+ } else {
+ EXPECT_NE("", expected);
+ EXPECT_EQ(expected, std::string(chosen, chosen_len));
+ }
+}
+
+void TlsAgent::CheckEpochs(uint16_t expected_read,
+ uint16_t expected_write) const {
+ uint16_t read_epoch = 0;
+ uint16_t write_epoch = 0;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetCurrentEpoch(ssl_fd(), &read_epoch, &write_epoch));
+ EXPECT_EQ(expected_read, read_epoch) << role_str() << " read epoch";
+ EXPECT_EQ(expected_write, write_epoch) << role_str() << " write epoch";
+}
+
+void TlsAgent::EnableSrtp() {
+ EXPECT_TRUE(EnsureTlsSetup());
+ const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80,
+ SRTP_AES128_CM_HMAC_SHA1_32};
+ EXPECT_EQ(SECSuccess,
+ SSL_SetSRTPCiphers(ssl_fd(), ciphers, PR_ARRAY_SIZE(ciphers)));
+}
+
+void TlsAgent::CheckSrtp() const {
+ uint16_t actual;
+ EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd(), &actual));
+ EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
+}
+
+void TlsAgent::CheckErrorCode(int32_t expected) const {
+ EXPECT_EQ(STATE_ERROR, state_);
+ EXPECT_EQ(expected, error_code_)
+ << "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
+ << PORT_ErrorToName(expected) << std::endl;
+}
+
+static uint8_t GetExpectedAlertLevel(uint8_t alert) {
+ if (alert == kTlsAlertCloseNotify) {
+ return kTlsAlertWarning;
+ }
+ return kTlsAlertFatal;
+}
+
+void TlsAgent::ExpectReceiveAlert(uint8_t alert, uint8_t level) {
+ expected_received_alert_ = alert;
+ if (level == 0) {
+ expected_received_alert_level_ = GetExpectedAlertLevel(alert);
+ } else {
+ expected_received_alert_level_ = level;
+ }
+}
+
+void TlsAgent::ExpectSendAlert(uint8_t alert, uint8_t level) {
+ expected_sent_alert_ = alert;
+ if (level == 0) {
+ expected_sent_alert_level_ = GetExpectedAlertLevel(alert);
+ } else {
+ expected_sent_alert_level_ = level;
+ }
+}
+
+void TlsAgent::CheckAlert(bool sent, const SSLAlert* alert) {
+ LOG(((alert->level == kTlsAlertWarning) ? "Warning" : "Fatal")
+ << " alert " << (sent ? "sent" : "received") << ": "
+ << static_cast<int>(alert->description));
+
+ auto& expected = sent ? expected_sent_alert_ : expected_received_alert_;
+ auto& expected_level =
+ sent ? expected_sent_alert_level_ : expected_received_alert_level_;
+ /* Silently pass close_notify in case the test has already ended. */
+ if (expected == kTlsAlertCloseNotify && expected_level == kTlsAlertWarning &&
+ alert->description == expected && alert->level == expected_level) {
+ return;
+ }
+
+ EXPECT_EQ(expected, alert->description);
+ EXPECT_EQ(expected_level, alert->level);
+ expected = kTlsAlertCloseNotify;
+ expected_level = kTlsAlertWarning;
+}
+
+void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const {
+ ASSERT_EQ(0, error_code_);
+ WAIT_(error_code_ != 0, delay);
+ EXPECT_EQ(expected, error_code_)
+ << "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
+ << PORT_ErrorToName(expected) << std::endl;
+}
+
+void TlsAgent::CheckPreliminaryInfo() {
+ SSLPreliminaryChannelInfo preinfo;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetPreliminaryChannelInfo(ssl_fd(), &preinfo, sizeof(preinfo)));
+ EXPECT_EQ(sizeof(preinfo), preinfo.length);
+ EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version);
+
+ // A version of 0 is invalid and indicates no expectation. This value is
+ // initialized to 0 so that tests that don't explicitly set an expected
+ // version can negotiate a version.
+ if (!expected_version_) {
+ expected_version_ = preinfo.protocolVersion;
+ }
+ EXPECT_EQ(expected_version_, preinfo.protocolVersion);
+
+ // As with the version; 0 is the null cipher suite (and also invalid).
+ if (!expected_cipher_suite_) {
+ expected_cipher_suite_ = preinfo.cipherSuite;
+ }
+ EXPECT_EQ(expected_cipher_suite_, preinfo.cipherSuite);
+}
+
+// Check that all the expected callbacks have been called.
+void TlsAgent::CheckCallbacks() const {
+ // If false start happens, the handshake is reported as being complete at the
+ // point that false start happens.
+ if (expect_psk_ == ssl_psk_resume || !falsestart_enabled_) {
+ EXPECT_TRUE(handshake_callback_called_);
+ }
+
+ // These callbacks shouldn't fire if we are resuming, except on TLS 1.3.
+ if (role_ == SERVER) {
+ PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn);
+ EXPECT_EQ(((expect_psk_ != ssl_psk_resume && have_sni) ||
+ expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3),
+ sni_hook_called_);
+ } else {
+ EXPECT_EQ(expect_psk_ == ssl_psk_none, auth_certificate_hook_called_);
+ // Note that this isn't unconditionally called, even with false start on.
+ // But the callback is only skipped if a cipher that is ridiculously weak
+ // (80 bits) is chosen. Don't test that: plan to remove bad ciphers.
+ EXPECT_EQ(falsestart_enabled_ && expect_psk_ != ssl_psk_resume,
+ can_falsestart_hook_called_);
+ }
+}
+
+void TlsAgent::ResetPreliminaryInfo() {
+ expected_version_ = 0;
+ expected_cipher_suite_ = 0;
+}
+
+void TlsAgent::UpdatePreliminaryChannelInfo() {
+ SECStatus rv =
+ SSL_GetPreliminaryChannelInfo(ssl_fd(), &pre_info_, sizeof(pre_info_));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(sizeof(pre_info_), pre_info_.length);
+}
+
+void TlsAgent::ValidateCipherSpecs() {
+ PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd());
+ // We use one ciphersuite in each direction.
+ PRInt32 expected = 2;
+ if (variant_ == ssl_variant_datagram) {
+ // For DTLS 1.3, the client retains the cipher spec for early data and the
+ // handshake so that it can retransmit EndOfEarlyData and its final flight.
+ // It also retains the handshake read cipher spec so that it can read ACKs
+ // from the server. The server retains the handshake read cipher spec so it
+ // can read the client's retransmitted Finished.
+ if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ if (role_ == CLIENT) {
+ expected = info_.earlyDataAccepted ? 5 : 4;
+ } else {
+ expected = 3;
+ }
+ } else {
+ // For DTLS 1.1 and 1.2, the last endpoint to send maintains a cipher spec
+ // until the holddown timer runs down.
+ if (expect_psk_ == ssl_psk_resume) {
+ if (role_ == CLIENT) {
+ expected = 3;
+ }
+ } else {
+ if (role_ == SERVER) {
+ expected = 3;
+ }
+ }
+ }
+ }
+ // This function will be run before the handshake completes if false start is
+ // enabled. In that case, the client will still be reading cleartext, but
+ // will have a spec prepared for reading ciphertext. With DTLS, the client
+ // will also have a spec retained for retransmission of handshake messages.
+ if (role_ == CLIENT && falsestart_enabled_ && !handshake_callback_called_) {
+ EXPECT_GT(SSL_LIBRARY_VERSION_TLS_1_3, expected_version_);
+ expected = (variant_ == ssl_variant_datagram) ? 4 : 3;
+ }
+ EXPECT_EQ(expected, cipherSpecs);
+ if (expected != cipherSpecs) {
+ SSLInt_PrintCipherSpecs(role_str().c_str(), ssl_fd());
+ }
+}
+
+void TlsAgent::Connected() {
+ if (state_ == STATE_CONNECTED) {
+ return;
+ }
+
+ LOG("Handshake success");
+ CheckPreliminaryInfo();
+ CheckCallbacks();
+
+ SECStatus rv = SSL_GetChannelInfo(ssl_fd(), &info_, sizeof(info_));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(sizeof(info_), info_.length);
+
+ EXPECT_EQ(expect_psk_ == ssl_psk_resume, info_.resumed == PR_TRUE);
+ EXPECT_EQ(expect_psk_, info_.pskType);
+ EXPECT_EQ(expect_ech_, info_.echAccepted);
+
+ // Preliminary values are exposed through callbacks during the handshake.
+ // If either expected values were set or the callbacks were called, check
+ // that the final values are correct.
+ UpdatePreliminaryChannelInfo();
+ EXPECT_EQ(expected_version_, info_.protocolVersion);
+ EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite);
+
+ rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(sizeof(csinfo_), csinfo_.length);
+
+ ValidateCipherSpecs();
+
+ SetState(STATE_CONNECTED);
+}
+
+void TlsAgent::CheckClientAuthCompleted(uint8_t handshakes) {
+ EXPECT_FALSE(client_auth_callback_awaiting_);
+ switch (client_auth_callback_type_) {
+ case ClientAuthCallbackType::kNone:
+ if (!client_auth_callback_success_) {
+ EXPECT_TRUE(CheckClientAuthCallbacksCompleted(0));
+ break;
+ }
+ case ClientAuthCallbackType::kSync:
+ EXPECT_TRUE(CheckClientAuthCallbacksCompleted(handshakes));
+ break;
+ case ClientAuthCallbackType::kAsyncDelay:
+ case ClientAuthCallbackType::kAsyncImmediate:
+ EXPECT_TRUE(CheckClientAuthCallbacksCompleted(2 * handshakes));
+ break;
+ }
+}
+
+void TlsAgent::EnableExtendedMasterSecret() {
+ SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
+}
+
+void TlsAgent::CheckExtendedMasterSecret(bool expected) {
+ if (version() >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ expected = PR_TRUE;
+ }
+ ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE)
+ << "unexpected extended master secret state for " << name_;
+}
+
+void TlsAgent::CheckEarlyDataAccepted(bool expected) {
+ if (version() < SSL_LIBRARY_VERSION_TLS_1_3) {
+ expected = false;
+ }
+ ASSERT_EQ(expected, info_.earlyDataAccepted != PR_FALSE)
+ << "unexpected early data state for " << name_;
+}
+
+void TlsAgent::CheckSecretsDestroyed() {
+ ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd()));
+}
+
+void TlsAgent::SetDowngradeCheckVersion(uint16_t ver) {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), ver);
+ ASSERT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::Handshake() {
+ LOGV("Handshake");
+ SECStatus rv = SSL_ForceHandshake(ssl_fd());
+ if (client_auth_callback_awaiting_) {
+ ClientAuthCallbackComplete();
+ rv = SSL_ForceHandshake(ssl_fd());
+ }
+ if (rv == SECSuccess) {
+ Connected();
+ Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+ &TlsAgent::ReadableCallback);
+ return;
+ }
+
+ int32_t err = PR_GetError();
+ if (err == PR_WOULD_BLOCK_ERROR) {
+ LOGV("Would have blocked");
+ if (variant_ == ssl_variant_datagram) {
+ if (timer_handle_) {
+ timer_handle_->Cancel();
+ timer_handle_ = nullptr;
+ }
+
+ PRIntervalTime timeout;
+ rv = DTLS_GetHandshakeTimeout(ssl_fd(), &timeout);
+ if (rv == SECSuccess) {
+ Poller::Instance()->SetTimer(
+ timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_);
+ }
+ }
+ Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+ &TlsAgent::ReadableCallback);
+ return;
+ }
+
+ LOG("Handshake failed with error " << PORT_ErrorToName(err) << ": "
+ << PORT_ErrorToString(err));
+ error_code_ = err;
+ SetState(STATE_ERROR);
+}
+
+void TlsAgent::PrepareForRenegotiate() {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+
+ SetState(STATE_CONNECTING);
+}
+
+void TlsAgent::StartRenegotiate() {
+ PrepareForRenegotiate();
+
+ SECStatus rv = SSL_ReHandshake(ssl_fd(), PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SendDirect(const DataBuffer& buf) {
+ LOG("Send Direct " << buf);
+ auto peer = adapter_->peer().lock();
+ if (peer) {
+ peer->PacketReceived(buf);
+ } else {
+ LOG("Send Direct peer absent");
+ }
+}
+
+void TlsAgent::SendRecordDirect(const TlsRecord& record) {
+ DataBuffer buf;
+
+ auto rv = record.header.Write(&buf, 0, record.buffer);
+ EXPECT_EQ(record.header.header_length() + record.buffer.len(), rv);
+ SendDirect(buf);
+}
+
+static bool ErrorIsFatal(PRErrorCode code) {
+ return code != PR_WOULD_BLOCK_ERROR && code != SSL_ERROR_RX_SHORT_DTLS_READ;
+}
+
+void TlsAgent::SendData(size_t bytes, size_t blocksize) {
+ uint8_t block[16385]; // One larger than the maximum record size.
+
+ ASSERT_LE(blocksize, sizeof(block));
+
+ while (bytes) {
+ size_t tosend = std::min(blocksize, bytes);
+
+ for (size_t i = 0; i < tosend; ++i) {
+ block[i] = 0xff & send_ctr_;
+ ++send_ctr_;
+ }
+
+ SendBuffer(DataBuffer(block, tosend));
+ bytes -= tosend;
+ }
+}
+
+void TlsAgent::SendBuffer(const DataBuffer& buf) {
+ LOGV("Writing " << buf.len() << " bytes");
+ int32_t rv = PR_Write(ssl_fd(), buf.data(), buf.len());
+ if (expect_readwrite_error_) {
+ EXPECT_GT(0, rv);
+ EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_);
+ error_code_ = PR_GetError();
+ expect_readwrite_error_ = false;
+ } else {
+ ASSERT_EQ(buf.len(), static_cast<size_t>(rv));
+ }
+}
+
+bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
+ uint64_t seq, uint8_t ct,
+ const DataBuffer& buf) {
+ // Ensure that we are doing TLS 1.3.
+ EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3);
+ if (variant_ != ssl_variant_datagram) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ LOGV("Encrypting " << buf.len() << " bytes");
+ uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
+ kCtDtlsCiphertextLengthPresent;
+ TlsRecordHeader header(variant_, expected_version_, dtls13_ct, seq);
+ TlsRecordHeader out_header(header);
+ DataBuffer padded = buf;
+ padded.Write(padded.len(), ct, 1);
+ DataBuffer ciphertext;
+ if (!spec->Protect(header, padded, &ciphertext, &out_header)) {
+ return false;
+ }
+
+ DataBuffer record;
+ auto rv = out_header.Write(&record, 0, ciphertext);
+ EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv);
+ SendDirect(record);
+ return true;
+}
+
+void TlsAgent::ReadBytes(size_t amount) {
+ uint8_t block[16384];
+
+ size_t remaining = amount;
+ while (remaining > 0) {
+ int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block)));
+ LOGV("ReadBytes " << rv);
+
+ if (rv > 0) {
+ size_t count = static_cast<size_t>(rv);
+ for (size_t i = 0; i < count; ++i) {
+ ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
+ recv_ctr_++;
+ }
+ remaining -= rv;
+ } else {
+ PRErrorCode err = 0;
+ if (rv < 0) {
+ err = PR_GetError();
+ LOG("Read error " << PORT_ErrorToName(err) << ": "
+ << PORT_ErrorToString(err));
+ if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) {
+ error_code_ = err;
+ expect_readwrite_error_ = false;
+ }
+ }
+ if (err != 0 && ErrorIsFatal(err)) {
+ // If we hit a fatal error, we're done.
+ remaining = 0;
+ }
+ break;
+ }
+ }
+
+ // If closed, then don't bother waiting around.
+ if (remaining) {
+ LOGV("Re-arming");
+ Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+ &TlsAgent::ReadableCallback);
+ }
+}
+
+void TlsAgent::ResetSentBytes(size_t bytes) { send_ctr_ = bytes; }
+
+void TlsAgent::SetOption(int32_t option, int value) {
+ ASSERT_TRUE(EnsureTlsSetup());
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), option, value));
+}
+
+void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
+ SetOption(SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE);
+ SetOption(SSL_ENABLE_SESSION_TICKETS,
+ mode & RESUME_TICKET ? PR_TRUE : PR_FALSE);
+}
+
+void TlsAgent::EnableECDHEServerKeyReuse() {
+ ASSERT_EQ(TlsAgent::SERVER, role_);
+ SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_TRUE);
+}
+
+static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"};
+::testing::internal::ParamGenerator<std::string>
+ TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr);
+
+void TlsAgentTestBase::SetUp() {
+ SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
+}
+
+void TlsAgentTestBase::TearDown() {
+ agent_ = nullptr;
+ SSL_ClearSessionCache();
+ SSL_ShutdownServerSessionIDCache();
+}
+
+void TlsAgentTestBase::Reset(const std::string& server_name) {
+ agent_.reset(
+ new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name,
+ role_, variant_));
+ if (version_) {
+ agent_->SetVersionRange(version_, version_);
+ }
+ agent_->adapter()->SetPeer(sink_adapter_);
+ agent_->StartConnect();
+}
+
+void TlsAgentTestBase::EnsureInit() {
+ if (!agent_) {
+ Reset();
+ }
+ const std::vector<SSLNamedGroup> groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ffdhe_2048};
+ agent_->ConfigNamedGroups(groups);
+}
+
+void TlsAgentTestBase::ExpectAlert(uint8_t alert) {
+ EnsureInit();
+ agent_->ExpectSendAlert(alert);
+}
+
+void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
+ TlsAgent::State expected_state,
+ int32_t error_code) {
+ std::cerr << "Process message: " << buffer << std::endl;
+ EnsureInit();
+ agent_->adapter()->PacketReceived(buffer);
+ agent_->Handshake();
+
+ ASSERT_EQ(expected_state, agent_->state());
+
+ if (expected_state == TlsAgent::STATE_ERROR) {
+ ASSERT_EQ(error_code, agent_->error_code());
+ }
+}
+
+void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type,
+ uint16_t version, const uint8_t* buf,
+ size_t len, DataBuffer* out,
+ uint64_t sequence_number) {
+ // Fixup the content type for DTLSCiphertext
+ if (variant == ssl_variant_datagram &&
+ version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ type == ssl_ct_application_data) {
+ type = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
+ kCtDtlsCiphertextLengthPresent;
+ }
+
+ size_t index = 0;
+ if (variant == ssl_variant_stream) {
+ index = out->Write(index, type, 1);
+ index = out->Write(index, version, 2);
+ } else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ (type & kCtDtlsCiphertextMask) == kCtDtlsCiphertext) {
+ uint32_t epoch = (sequence_number >> 48) & 0x3;
+ index = out->Write(index, type | epoch, 1);
+ uint32_t seqno = sequence_number & ((1ULL << 16) - 1);
+ index = out->Write(index, seqno, 2);
+ } else {
+ index = out->Write(index, type, 1);
+ index = out->Write(index, TlsVersionToDtlsVersion(version), 2);
+ index = out->Write(index, sequence_number >> 32, 4);
+ index = out->Write(index, sequence_number & PR_UINT32_MAX, 4);
+ }
+ index = out->Write(index, len, 2);
+ out->Write(index, buf, len);
+}
+
+void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version,
+ const uint8_t* buf, size_t len,
+ DataBuffer* out, uint64_t seq_num) const {
+ MakeRecord(variant_, type, version, buf, len, out, seq_num);
+}
+
+void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type,
+ const uint8_t* data, size_t hs_len,
+ DataBuffer* out,
+ uint64_t seq_num) const {
+ return MakeHandshakeMessageFragment(hs_type, data, hs_len, out, seq_num, 0,
+ 0);
+}
+
+void TlsAgentTestBase::MakeHandshakeMessageFragment(
+ uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out,
+ uint64_t seq_num, uint32_t fragment_offset,
+ uint32_t fragment_length) const {
+ size_t index = 0;
+ if (!fragment_length) fragment_length = hs_len;
+ index = out->Write(index, hs_type, 1); // Handshake record type.
+ index = out->Write(index, hs_len, 3); // Handshake length
+ if (variant_ == ssl_variant_datagram) {
+ index = out->Write(index, seq_num, 2);
+ index = out->Write(index, fragment_offset, 3);
+ index = out->Write(index, fragment_length, 3);
+ }
+ if (data) {
+ index = out->Write(index, data, fragment_length);
+ } else {
+ for (size_t i = 0; i < fragment_length; ++i) {
+ index = out->Write(index, 1, 1);
+ }
+ }
+}
+
+void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type,
+ size_t hs_len,
+ DataBuffer* out) {
+ size_t index = 0;
+ index = out->Write(index, ssl_ct_handshake, 1); // Content Type
+ index = out->Write(index, 3, 1); // Version high
+ index = out->Write(index, 1, 1); // Version low
+ index = out->Write(index, 4 + hs_len, 2); // Length
+
+ index = out->Write(index, hs_type, 1); // Handshake record type.
+ index = out->Write(index, hs_len, 3); // Handshake length
+ for (size_t i = 0; i < hs_len; ++i) {
+ index = out->Write(index, 1, 1);
+ }
+}
+
+DataBuffer TlsAgentTestBase::MakeCannedTls13ServerHello() {
+ DataBuffer sh(kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello));
+ if (variant_ == ssl_variant_datagram) {
+ sh.Write(0, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 2);
+ // The version should be at the end.
+ uint32_t v;
+ EXPECT_TRUE(sh.Read(sh.len() - 2, 2, &v));
+ EXPECT_EQ(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_3), v);
+ sh.Write(sh.len() - 2, 0x7f00 | DTLS_1_3_DRAFT_VERSION, 2);
+ }
+ return sh;
+}
+
+} // namespace nss_test