/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "tls_agent.h" #include "databuffer.h" #include "keyhi.h" #include "pk11func.h" #include "ssl.h" #include "sslerr.h" #include "sslexp.h" #include "sslproto.h" #include "tls_filter.h" #include "tls_parser.h" extern "C" { // This is not something that should make you happy. #include "libssl_internals.h" } #define GTEST_HAS_RTTI 0 #include "gtest/gtest.h" #include "gtest_utils.h" #include "nss_scoped_ptrs.h" extern std::string g_working_dir_path; namespace nss_test { const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"}; const std::string TlsAgent::kClient = "client"; // both sign and encrypt const std::string TlsAgent::kRsa2048 = "rsa2048"; // bigger const std::string TlsAgent::kRsa8192 = "rsa8192"; // biggest allowed const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt const std::string TlsAgent::kServerRsaSign = "rsa_sign"; const std::string TlsAgent::kServerRsaPss = "rsa_pss"; const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt"; const std::string TlsAgent::kServerEcdsa256 = "ecdsa256"; const std::string TlsAgent::kServerEcdsa384 = "ecdsa384"; const std::string TlsAgent::kServerEcdsa521 = "ecdsa521"; const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa"; const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa"; const std::string TlsAgent::kServerDsa = "dsa"; const std::string TlsAgent::kDelegatorEcdsa256 = "delegator_ecdsa256"; const std::string TlsAgent::kDelegatorRsae2048 = "delegator_rsae2048"; const std::string TlsAgent::kDelegatorRsaPss2048 = "delegator_rsa_pss2048"; static const uint8_t kCannedTls13ServerHello[] = { 0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3, 0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b, 0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76, 0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03, 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9, 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08, 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04}; TlsAgent::TlsAgent(const std::string& nm, Role rl, SSLProtocolVariant var) : name_(nm), variant_(var), role_(rl), server_key_bits_(0), adapter_(new DummyPrSocket(role_str(), var)), ssl_fd_(nullptr), state_(STATE_INIT), timer_handle_(nullptr), falsestart_enabled_(false), expected_version_(0), expected_cipher_suite_(0), expect_client_auth_(false), expect_ech_(false), expect_psk_(ssl_psk_none), can_falsestart_hook_called_(false), sni_hook_called_(false), auth_certificate_hook_called_(false), expected_received_alert_(kTlsAlertCloseNotify), expected_received_alert_level_(kTlsAlertWarning), expected_sent_alert_(kTlsAlertCloseNotify), expected_sent_alert_level_(kTlsAlertWarning), handshake_callback_called_(false), resumption_callback_called_(false), error_code_(0), send_ctr_(0), recv_ctr_(0), expect_readwrite_error_(false), handshake_callback_(), auth_certificate_callback_(), sni_callback_(), skip_version_checks_(false), resumption_token_(), policy_() { memset(&info_, 0, sizeof(info_)); memset(&csinfo_, 0, sizeof(csinfo_)); SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_); EXPECT_EQ(SECSuccess, rv); } TlsAgent::~TlsAgent() { if (timer_handle_) { timer_handle_->Cancel(); } if (adapter_) { Poller::Instance()->Cancel(READABLE_EVENT, adapter_); } // Add failures manually, if any, so we don't throw in a destructor. if (expected_received_alert_ != kTlsAlertCloseNotify || expected_received_alert_level_ != kTlsAlertWarning) { ADD_FAILURE() << "Wrong expected_received_alert status: " << role_str(); } if (expected_sent_alert_ != kTlsAlertCloseNotify || expected_sent_alert_level_ != kTlsAlertWarning) { ADD_FAILURE() << "Wrong expected_sent_alert status: " << role_str(); } } void TlsAgent::SetState(State s) { if (state_ == s) return; LOG("Changing state from " << state_ << " to " << s); state_ = s; } /*static*/ bool TlsAgent::LoadCertificate(const std::string& name, ScopedCERTCertificate* cert, ScopedSECKEYPrivateKey* priv) { cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr)); EXPECT_NE(nullptr, cert); if (!cert) return false; EXPECT_NE(nullptr, cert->get()); if (!cert->get()) return false; priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr)); EXPECT_NE(nullptr, priv); if (!priv) return false; EXPECT_NE(nullptr, priv->get()); if (!priv->get()) return false; return true; } // Loads a key pair from the certificate identified by |id|. /*static*/ bool TlsAgent::LoadKeyPairFromCert(const std::string& name, ScopedSECKEYPublicKey* pub, ScopedSECKEYPrivateKey* priv) { ScopedCERTCertificate cert; if (!TlsAgent::LoadCertificate(name, &cert, priv)) { return false; } pub->reset(SECKEY_ExtractPublicKey(&cert->subjectPublicKeyInfo)); if (!pub->get()) { return false; } return true; } void TlsAgent::DelegateCredential(const std::string& name, const ScopedSECKEYPublicKey& dc_pub, SSLSignatureScheme dc_cert_verify_alg, PRUint32 dc_valid_for, PRTime now, SECItem* dc) { ScopedCERTCertificate cert; ScopedSECKEYPrivateKey cert_priv; EXPECT_TRUE(TlsAgent::LoadCertificate(name, &cert, &cert_priv)) << "Could not load delegate certificate: " << name << "; test db corrupt?"; EXPECT_EQ(SECSuccess, SSL_DelegateCredential(cert.get(), cert_priv.get(), dc_pub.get(), dc_cert_verify_alg, dc_valid_for, now, dc)); } void TlsAgent::EnableDelegatedCredentials() { ASSERT_TRUE(EnsureTlsSetup()); SetOption(SSL_ENABLE_DELEGATED_CREDENTIALS, PR_TRUE); } void TlsAgent::AddDelegatedCredential(const std::string& dc_name, SSLSignatureScheme dc_cert_verify_alg, PRUint32 dc_valid_for, PRTime now) { ASSERT_TRUE(EnsureTlsSetup()); ScopedSECKEYPublicKey pub; ScopedSECKEYPrivateKey priv; EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(dc_name, &pub, &priv)); StackSECItem dc; TlsAgent::DelegateCredential(name_, pub, dc_cert_verify_alg, dc_valid_for, now, &dc); SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr, nullptr, &dc, priv.get()}; EXPECT_TRUE(ConfigServerCert(name_, true, &extra_data)); } bool TlsAgent::ConfigServerCert(const std::string& id, bool updateKeyBits, const SSLExtraServerCertData* serverCertData) { ScopedCERTCertificate cert; ScopedSECKEYPrivateKey priv; if (!TlsAgent::LoadCertificate(id, &cert, &priv)) { return false; } if (updateKeyBits) { ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get())); EXPECT_NE(nullptr, pub.get()); if (!pub.get()) return false; server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get()); } SECStatus rv = SSL_ConfigSecureServer(ssl_fd(), nullptr, nullptr, ssl_kea_null); EXPECT_EQ(SECFailure, rv); rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData, serverCertData ? sizeof(*serverCertData) : 0); return rv == SECSuccess; } bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { // Don't set up twice if (ssl_fd_) return true; NssManagePolicy policyManage(policy_, option_); ScopedPRFileDesc dummy_fd(adapter_->CreateFD()); EXPECT_NE(nullptr, dummy_fd); if (!dummy_fd) { return false; } if (adapter_->variant() == ssl_variant_stream) { ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get())); } else { ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get())); } EXPECT_NE(nullptr, ssl_fd_); if (!ssl_fd_) { return false; } dummy_fd.release(); // Now subsumed by ssl_fd_. SECStatus rv; if (!skip_version_checks_) { rv = SSL_VersionRangeSet(ssl_fd(), &vrange_); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; } ScopedCERTCertList anchors(CERT_NewCertList()); rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get()); if (rv != SECSuccess) return false; if (role_ == SERVER) { EXPECT_TRUE(ConfigServerCert(name_, true)); rv = SSL_SNISocketConfigHook(ssl_fd(), SniHook, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; rv = SSL_SetMaxEarlyDataSize(ssl_fd(), 1024); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; } else { rv = SSL_SetURL(ssl_fd(), "server"); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; } rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; // All these tests depend on having this disabled to start with. SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_FALSE); return true; } bool TlsAgent::MaybeSetResumptionToken() { if (!resumption_token_.empty()) { LOG("setting external resumption token"); SECStatus rv = SSL_SetResumptionToken(ssl_fd(), resumption_token_.data(), resumption_token_.size()); // rv is SECFailure with error set to SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR // if the resumption token was bad (expired/malformed/etc.). if (expect_psk_ == ssl_psk_resume) { // Only in case we expect resumption this has to be successful. We might // not expect resumption due to some reason but the token is totally fine. EXPECT_EQ(SECSuccess, rv); } if (rv != SECSuccess) { EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError()); resumption_token_.clear(); EXPECT_FALSE(expect_psk_ == ssl_psk_resume); if (expect_psk_ == ssl_psk_resume) return false; } } return true; } void TlsAgent::SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx) { EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(ssl_fd(), ctx.get())); } // Defaults to a Sync callback returning success void TlsAgent::SetupClientAuth(ClientAuthCallbackType callbackType, bool callbackSuccess) { EXPECT_TRUE(EnsureTlsSetup()); ASSERT_EQ(CLIENT, role_); client_auth_callback_type_ = callbackType; client_auth_callback_success_ = callbackSuccess; if (callbackType == ClientAuthCallbackType::kNone && !callbackSuccess) { // Don't set a callback for this case. return; } EXPECT_EQ(SECSuccess, SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook, reinterpret_cast(this))); } void CheckCertReqAgainstDefaultCAs(const CERTDistNames* caNames) { ScopedCERTDistNames expected(CERT_GetSSLCACerts(nullptr)); ASSERT_EQ(expected->nnames, caNames->nnames); for (size_t i = 0; i < static_cast(expected->nnames); ++i) { EXPECT_EQ(SECEqual, SECITEM_CompareItem(&(expected->names[i]), &(caNames->names[i]))); } } // Complete processing of Client Certificate Selection // A No-op if the agent is using synchronous client cert selection. // Otherwise, calls SSL_ClientCertCallbackComplete. // kAsyncDelay triggers a call to SSL_ForceHandshake prior to completion to // ensure that the socket is correctly blocked. void TlsAgent::ClientAuthCallbackComplete() { ASSERT_EQ(CLIENT, role_); if (client_auth_callback_type_ != ClientAuthCallbackType::kAsyncDelay && client_auth_callback_type_ != ClientAuthCallbackType::kAsyncImmediate) { return; } client_auth_callback_fired_++; EXPECT_TRUE(client_auth_callback_awaiting_); std::cerr << "client: calling SSL_ClientCertCallbackComplete with status " << (client_auth_callback_success_ ? "success" : "failed") << std::endl; client_auth_callback_awaiting_ = false; if (client_auth_callback_type_ == ClientAuthCallbackType::kAsyncDelay) { std::cerr << "Running Handshake prior to running SSL_ClientCertCallbackComplete" << std::endl; SECStatus rv = SSL_ForceHandshake(ssl_fd()); EXPECT_EQ(rv, SECFailure); EXPECT_EQ(PORT_GetError(), PR_WOULD_BLOCK_ERROR); } ScopedCERTCertificate cert; ScopedSECKEYPrivateKey priv; if (client_auth_callback_success_) { ASSERT_TRUE(TlsAgent::LoadCertificate(name(), &cert, &priv)); EXPECT_EQ(SECSuccess, SSL_ClientCertCallbackComplete(ssl_fd(), SECSuccess, priv.release(), cert.release())); } else { EXPECT_EQ(SECSuccess, SSL_ClientCertCallbackComplete(ssl_fd(), SECFailure, nullptr, nullptr)); } } SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd, CERTDistNames* caNames, CERTCertificate** clientCert, SECKEYPrivateKey** clientKey) { TlsAgent* agent = reinterpret_cast(self); EXPECT_EQ(CLIENT, agent->role_); agent->client_auth_callback_fired_++; switch (agent->client_auth_callback_type_) { case ClientAuthCallbackType::kAsyncDelay: case ClientAuthCallbackType::kAsyncImmediate: std::cerr << "Waiting for complete call" << std::endl; agent->client_auth_callback_awaiting_ = true; return SECWouldBlock; case ClientAuthCallbackType::kSync: case ClientAuthCallbackType::kNone: // Handle the sync case. None && Success is treated as Sync and Success. if (!agent->client_auth_callback_success_) { return SECFailure; } ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd())); EXPECT_TRUE(peerCert) << "Client should be able to see the server cert"; // See bug 1573945 // CheckCertReqAgainstDefaultCAs(caNames); ScopedCERTCertificate cert; ScopedSECKEYPrivateKey priv; if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) { return SECFailure; } *clientCert = cert.release(); *clientKey = priv.release(); return SECSuccess; } /* This is unreachable, but some old compilers can't tell that. */ PORT_Assert(0); PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); return SECFailure; } // Increments by 1 for each callback bool TlsAgent::CheckClientAuthCallbacksCompleted(uint8_t expected) { EXPECT_EQ(CLIENT, role_); return expected == client_auth_callback_fired_; } bool TlsAgent::GetPeerChainLength(size_t* count) { CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd()); if (!chain) return false; *count = 0; for (PRCList* cursor = PR_NEXT_LINK(&chain->list); cursor != &chain->list; cursor = PR_NEXT_LINK(cursor)) { CERTCertListNode* node = (CERTCertListNode*)cursor; std::cerr << node->cert->subjectName << std::endl; ++(*count); } CERT_DestroyCertList(chain); return true; } void TlsAgent::CheckCipherSuite(uint16_t suite) { EXPECT_EQ(csinfo_.cipherSuite, suite); } void TlsAgent::RequestClientAuth(bool requireAuth) { ASSERT_EQ(SERVER, role_); SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE); SetOption(SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook( ssl_fd(), &TlsAgent::ClientAuthenticated, this)); expect_client_auth_ = true; } void TlsAgent::StartConnect(PRFileDesc* model) { EXPECT_TRUE(EnsureTlsSetup(model)); SECStatus rv; rv = SSL_ResetHandshake(ssl_fd(), role_ == SERVER ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); SetState(STATE_CONNECTING); } void TlsAgent::DisableAllCiphers() { for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { SECStatus rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_FALSE); EXPECT_EQ(SECSuccess, rv); } } // Not actually all groups, just the ones that we are actually willing // to use. const std::vector kAllDHEGroups = { ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1, ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096, ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192, ssl_grp_kem_xyber768d00, }; const std::vector kECDHEGroups = { ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1, ssl_grp_kem_xyber768d00, }; const std::vector kFFDHEGroups = { ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096, ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192}; // Defined because the big DHE groups are ridiculously slow. const std::vector kFasterDHEGroups = { ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_kem_xyber768d00, }; const std::vector kEcdhHybridGroups = { ssl_grp_kem_xyber768d00, }; void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) { EXPECT_TRUE(EnsureTlsSetup()); for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { SSLCipherSuiteInfo csinfo; SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, sizeof(csinfo)); ASSERT_EQ(SECSuccess, rv); EXPECT_EQ(sizeof(csinfo), csinfo.length); if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) { rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE); EXPECT_EQ(SECSuccess, rv); } } } void TlsAgent::EnableGroupsByKeyExchange(SSLKEAType kea) { switch (kea) { case ssl_kea_dh: ConfigNamedGroups(kFFDHEGroups); break; case ssl_kea_ecdh: ConfigNamedGroups(kECDHEGroups); break; case ssl_kea_ecdh_hybrid: ConfigNamedGroups(kEcdhHybridGroups); break; default: break; } } void TlsAgent::EnableGroupsByAuthType(SSLAuthType authType) { if (authType == ssl_auth_ecdh_rsa || authType == ssl_auth_ecdh_ecdsa || authType == ssl_auth_ecdsa || authType == ssl_auth_tls13_any) { ConfigNamedGroups(kECDHEGroups); } } void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) { EXPECT_TRUE(EnsureTlsSetup()); for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { SSLCipherSuiteInfo csinfo; SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, sizeof(csinfo)); ASSERT_EQ(SECSuccess, rv); if ((csinfo.authType == authType) || (csinfo.keaType == ssl_kea_tls13_any)) { rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE); EXPECT_EQ(SECSuccess, rv); } } } void TlsAgent::EnableSingleCipher(uint16_t cipher) { DisableAllCiphers(); SECStatus rv = SSL_CipherPrefSet(ssl_fd(), cipher, PR_TRUE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::ConfigNamedGroups(const std::vector& groups) { EXPECT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_NamedGroupConfig(ssl_fd(), &groups[0], groups.size()); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::Set0RttEnabled(bool en) { SetOption(SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE); } void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { vrange_.min = minver; vrange_.max = maxver; if (ssl_fd()) { SECStatus rv = SSL_VersionRangeSet(ssl_fd(), &vrange_); EXPECT_EQ(SECSuccess, rv); } } SECStatus ResumptionTokenCallback(PRFileDesc* fd, const PRUint8* resumptionToken, unsigned int len, void* ctx) { EXPECT_NE(nullptr, resumptionToken); if (!resumptionToken) { return SECFailure; } std::vector new_token(resumptionToken, resumptionToken + len); reinterpret_cast(ctx)->SetResumptionToken(new_token); reinterpret_cast(ctx)->SetResumptionCallbackCalled(); return SECSuccess; } void TlsAgent::SetResumptionTokenCallback() { EXPECT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_SetResumptionTokenCallback(ssl_fd(), ResumptionTokenCallback, this); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) { *minver = vrange_.min; *maxver = vrange_.max; } void TlsAgent::SetExpectedVersion(uint16_t ver) { expected_version_ = ver; } void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; } void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; } void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; } void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count) { EXPECT_TRUE(EnsureTlsSetup()); EXPECT_LE(count, SSL_SignatureMaxCount()); EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, static_cast(count))); EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 0)) << "setting no schemes should fail and do nothing"; std::vector configuredSchemes(count); unsigned int configuredCount; EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefGet(ssl_fd(), nullptr, &configuredCount, 1)) << "get schemes, schemes is nullptr"; EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], &configuredCount, 0)) << "get schemes, too little space"; EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], nullptr, configuredSchemes.size())) << "get schemes, countOut is nullptr"; EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet( ssl_fd(), &configuredSchemes[0], &configuredCount, configuredSchemes.size())); // SignatureSchemePrefSet drops unsupported algorithms silently, so the // number that are configured might be fewer. EXPECT_LE(configuredCount, count); unsigned int i = 0; for (unsigned int j = 0; j < count && i < configuredCount; ++j) { if (i < configuredCount && schemes[j] == configuredSchemes[i]) { ++i; } } EXPECT_EQ(i, configuredCount) << "schemes in use were all set"; } void TlsAgent::CheckKEA(SSLKEAType kea, SSLNamedGroup kea_group, size_t kea_size) const { EXPECT_EQ(STATE_CONNECTED, state_); EXPECT_EQ(kea, info_.keaType); if (kea_size == 0) { switch (kea_group) { case ssl_grp_ec_curve25519: case ssl_grp_kem_xyber768d00: kea_size = 255; break; case ssl_grp_ec_secp256r1: kea_size = 256; break; case ssl_grp_ec_secp384r1: kea_size = 384; break; case ssl_grp_ffdhe_2048: kea_size = 2048; break; case ssl_grp_ffdhe_3072: kea_size = 3072; break; case ssl_grp_ffdhe_custom: break; default: if (kea == ssl_kea_rsa) { kea_size = server_key_bits_; } else { EXPECT_TRUE(false) << "need to update group sizes"; } } } if (kea_group != ssl_grp_ffdhe_custom) { EXPECT_EQ(kea_size, info_.keaKeyBits); EXPECT_EQ(kea_group, info_.keaGroup); } } void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const { if (kea_group != ssl_grp_ffdhe_custom) { EXPECT_EQ(kea_group, info_.originalKeaGroup); } } void TlsAgent::CheckAuthType(SSLAuthType auth, SSLSignatureScheme sig_scheme) const { EXPECT_EQ(STATE_CONNECTED, state_); EXPECT_EQ(auth, info_.authType); if (auth != ssl_auth_psk) { EXPECT_EQ(server_key_bits_, info_.authKeyBits); } if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) { switch (auth) { case ssl_auth_rsa_sign: sig_scheme = ssl_sig_rsa_pkcs1_sha1md5; break; case ssl_auth_ecdsa: sig_scheme = ssl_sig_ecdsa_sha1; break; default: break; } } EXPECT_EQ(sig_scheme, info_.signatureScheme); if (info_.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) { return; } // Check authAlgorithm, which is the old value for authType. This is a second // switch statement because default label is different. switch (auth) { case ssl_auth_rsa_sign: case ssl_auth_rsa_pss: EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm) << "authAlgorithm for RSA is always decrypt"; break; case ssl_auth_ecdh_rsa: EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm) << "authAlgorithm for ECDH_RSA is RSA decrypt (i.e., wrong)"; break; case ssl_auth_ecdh_ecdsa: EXPECT_EQ(ssl_auth_ecdsa, csinfo_.authAlgorithm) << "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)"; break; default: EXPECT_EQ(auth, csinfo_.authAlgorithm) << "authAlgorithm is (usually) the same as authType"; break; } } void TlsAgent::EnableFalseStart() { EXPECT_TRUE(EnsureTlsSetup()); falsestart_enabled_ = true; EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback( ssl_fd(), CanFalseStartCallback, this)); SetOption(SSL_ENABLE_FALSE_START, PR_TRUE); } void TlsAgent::ExpectEch(bool expected) { expect_ech_ = expected; } void TlsAgent::ExpectPsk(SSLPskType psk) { expect_psk_ = psk; } void TlsAgent::ExpectResumption() { expect_psk_ = ssl_psk_resume; } void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) { EXPECT_TRUE(EnsureTlsSetup()); EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len)); } void TlsAgent::AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash, uint16_t zeroRttSuite) { EXPECT_TRUE(EnsureTlsSetup()); EXPECT_EQ(SECSuccess, SSL_AddExternalPsk0Rtt( ssl_fd(), psk.get(), reinterpret_cast(label.data()), label.length(), hash, zeroRttSuite, 1000)); } void TlsAgent::RemovePsk(std::string label) { EXPECT_EQ(SECSuccess, SSL_RemoveExternalPsk( ssl_fd(), reinterpret_cast(label.data()), label.length())); } void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, const std::string& expected) const { SSLNextProtoState alpn_state; char chosen[10]; unsigned int chosen_len; SECStatus rv = SSL_GetNextProto(ssl_fd(), &alpn_state, reinterpret_cast(chosen), &chosen_len, sizeof(chosen)); EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(expected_state, alpn_state); if (alpn_state == SSL_NEXT_PROTO_NO_SUPPORT) { EXPECT_EQ("", expected); } else { EXPECT_NE("", expected); EXPECT_EQ(expected, std::string(chosen, chosen_len)); } } void TlsAgent::CheckEpochs(uint16_t expected_read, uint16_t expected_write) const { uint16_t read_epoch = 0; uint16_t write_epoch = 0; EXPECT_EQ(SECSuccess, SSL_GetCurrentEpoch(ssl_fd(), &read_epoch, &write_epoch)); EXPECT_EQ(expected_read, read_epoch) << role_str() << " read epoch"; EXPECT_EQ(expected_write, write_epoch) << role_str() << " write epoch"; } void TlsAgent::EnableSrtp() { EXPECT_TRUE(EnsureTlsSetup()); const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}; EXPECT_EQ(SECSuccess, SSL_SetSRTPCiphers(ssl_fd(), ciphers, PR_ARRAY_SIZE(ciphers))); } void TlsAgent::CheckSrtp() const { uint16_t actual; EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd(), &actual)); EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual); } void TlsAgent::CheckErrorCode(int32_t expected) const { EXPECT_EQ(STATE_ERROR, state_); EXPECT_EQ(expected, error_code_) << "Got error code " << PORT_ErrorToName(error_code_) << " expecting " << PORT_ErrorToName(expected) << std::endl; } static uint8_t GetExpectedAlertLevel(uint8_t alert) { if (alert == kTlsAlertCloseNotify) { return kTlsAlertWarning; } return kTlsAlertFatal; } void TlsAgent::ExpectReceiveAlert(uint8_t alert, uint8_t level) { expected_received_alert_ = alert; if (level == 0) { expected_received_alert_level_ = GetExpectedAlertLevel(alert); } else { expected_received_alert_level_ = level; } } void TlsAgent::ExpectSendAlert(uint8_t alert, uint8_t level) { expected_sent_alert_ = alert; if (level == 0) { expected_sent_alert_level_ = GetExpectedAlertLevel(alert); } else { expected_sent_alert_level_ = level; } } void TlsAgent::CheckAlert(bool sent, const SSLAlert* alert) { LOG(((alert->level == kTlsAlertWarning) ? "Warning" : "Fatal") << " alert " << (sent ? "sent" : "received") << ": " << static_cast(alert->description)); auto& expected = sent ? expected_sent_alert_ : expected_received_alert_; auto& expected_level = sent ? expected_sent_alert_level_ : expected_received_alert_level_; /* Silently pass close_notify in case the test has already ended. */ if (expected == kTlsAlertCloseNotify && expected_level == kTlsAlertWarning && alert->description == expected && alert->level == expected_level) { return; } EXPECT_EQ(expected, alert->description); EXPECT_EQ(expected_level, alert->level); expected = kTlsAlertCloseNotify; expected_level = kTlsAlertWarning; } void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const { ASSERT_EQ(0, error_code_); WAIT_(error_code_ != 0, delay); EXPECT_EQ(expected, error_code_) << "Got error code " << PORT_ErrorToName(error_code_) << " expecting " << PORT_ErrorToName(expected) << std::endl; } void TlsAgent::CheckPreliminaryInfo() { SSLPreliminaryChannelInfo preinfo; EXPECT_EQ(SECSuccess, SSL_GetPreliminaryChannelInfo(ssl_fd(), &preinfo, sizeof(preinfo))); EXPECT_EQ(sizeof(preinfo), preinfo.length); EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version); // A version of 0 is invalid and indicates no expectation. This value is // initialized to 0 so that tests that don't explicitly set an expected // version can negotiate a version. if (!expected_version_) { expected_version_ = preinfo.protocolVersion; } EXPECT_EQ(expected_version_, preinfo.protocolVersion); // As with the version; 0 is the null cipher suite (and also invalid). if (!expected_cipher_suite_) { expected_cipher_suite_ = preinfo.cipherSuite; } EXPECT_EQ(expected_cipher_suite_, preinfo.cipherSuite); } // Check that all the expected callbacks have been called. void TlsAgent::CheckCallbacks() const { // If false start happens, the handshake is reported as being complete at the // point that false start happens. if (expect_psk_ == ssl_psk_resume || !falsestart_enabled_) { EXPECT_TRUE(handshake_callback_called_); } // These callbacks shouldn't fire if we are resuming, except on TLS 1.3. if (role_ == SERVER) { PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn); EXPECT_EQ(((expect_psk_ != ssl_psk_resume && have_sni) || expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3), sni_hook_called_); } else { EXPECT_EQ(expect_psk_ == ssl_psk_none, auth_certificate_hook_called_); // Note that this isn't unconditionally called, even with false start on. // But the callback is only skipped if a cipher that is ridiculously weak // (80 bits) is chosen. Don't test that: plan to remove bad ciphers. EXPECT_EQ(falsestart_enabled_ && expect_psk_ != ssl_psk_resume, can_falsestart_hook_called_); } } void TlsAgent::ResetPreliminaryInfo() { expected_version_ = 0; expected_cipher_suite_ = 0; } void TlsAgent::UpdatePreliminaryChannelInfo() { SECStatus rv = SSL_GetPreliminaryChannelInfo(ssl_fd(), &pre_info_, sizeof(pre_info_)); EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(sizeof(pre_info_), pre_info_.length); } void TlsAgent::ValidateCipherSpecs() { PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd()); // We use one ciphersuite in each direction. PRInt32 expected = 2; if (variant_ == ssl_variant_datagram) { // For DTLS 1.3, the client retains the cipher spec for early data and the // handshake so that it can retransmit EndOfEarlyData and its final flight. // It also retains the handshake read cipher spec so that it can read ACKs // from the server. The server retains the handshake read cipher spec so it // can read the client's retransmitted Finished. if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { if (role_ == CLIENT) { expected = info_.earlyDataAccepted ? 5 : 4; } else { expected = 3; } } else { // For DTLS 1.1 and 1.2, the last endpoint to send maintains a cipher spec // until the holddown timer runs down. if (expect_psk_ == ssl_psk_resume) { if (role_ == CLIENT) { expected = 3; } } else { if (role_ == SERVER) { expected = 3; } } } } // This function will be run before the handshake completes if false start is // enabled. In that case, the client will still be reading cleartext, but // will have a spec prepared for reading ciphertext. With DTLS, the client // will also have a spec retained for retransmission of handshake messages. if (role_ == CLIENT && falsestart_enabled_ && !handshake_callback_called_) { EXPECT_GT(SSL_LIBRARY_VERSION_TLS_1_3, expected_version_); expected = (variant_ == ssl_variant_datagram) ? 4 : 3; } EXPECT_EQ(expected, cipherSpecs); if (expected != cipherSpecs) { SSLInt_PrintCipherSpecs(role_str().c_str(), ssl_fd()); } } void TlsAgent::Connected() { if (state_ == STATE_CONNECTED) { return; } LOG("Handshake success"); CheckPreliminaryInfo(); CheckCallbacks(); SECStatus rv = SSL_GetChannelInfo(ssl_fd(), &info_, sizeof(info_)); EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(sizeof(info_), info_.length); EXPECT_EQ(expect_psk_ == ssl_psk_resume, info_.resumed == PR_TRUE); EXPECT_EQ(expect_psk_, info_.pskType); EXPECT_EQ(expect_ech_, info_.echAccepted); // Preliminary values are exposed through callbacks during the handshake. // If either expected values were set or the callbacks were called, check // that the final values are correct. UpdatePreliminaryChannelInfo(); EXPECT_EQ(expected_version_, info_.protocolVersion); EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite); rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_)); EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(sizeof(csinfo_), csinfo_.length); ValidateCipherSpecs(); SetState(STATE_CONNECTED); } void TlsAgent::CheckClientAuthCompleted(uint8_t handshakes) { EXPECT_FALSE(client_auth_callback_awaiting_); switch (client_auth_callback_type_) { case ClientAuthCallbackType::kNone: if (!client_auth_callback_success_) { EXPECT_TRUE(CheckClientAuthCallbacksCompleted(0)); break; } case ClientAuthCallbackType::kSync: EXPECT_TRUE(CheckClientAuthCallbacksCompleted(handshakes)); break; case ClientAuthCallbackType::kAsyncDelay: case ClientAuthCallbackType::kAsyncImmediate: EXPECT_TRUE(CheckClientAuthCallbacksCompleted(2 * handshakes)); break; } } void TlsAgent::EnableExtendedMasterSecret() { SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); } void TlsAgent::CheckExtendedMasterSecret(bool expected) { if (version() >= SSL_LIBRARY_VERSION_TLS_1_3) { expected = PR_TRUE; } ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE) << "unexpected extended master secret state for " << name_; } void TlsAgent::CheckEarlyDataAccepted(bool expected) { if (version() < SSL_LIBRARY_VERSION_TLS_1_3) { expected = false; } ASSERT_EQ(expected, info_.earlyDataAccepted != PR_FALSE) << "unexpected early data state for " << name_; } void TlsAgent::CheckSecretsDestroyed() { ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd())); } void TlsAgent::SetDowngradeCheckVersion(uint16_t ver) { ASSERT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), ver); ASSERT_EQ(SECSuccess, rv); } void TlsAgent::Handshake() { LOGV("Handshake"); SECStatus rv = SSL_ForceHandshake(ssl_fd()); if (client_auth_callback_awaiting_) { ClientAuthCallbackComplete(); rv = SSL_ForceHandshake(ssl_fd()); } if (rv == SECSuccess) { Connected(); Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, &TlsAgent::ReadableCallback); return; } int32_t err = PR_GetError(); if (err == PR_WOULD_BLOCK_ERROR) { LOGV("Would have blocked"); if (variant_ == ssl_variant_datagram) { if (timer_handle_) { timer_handle_->Cancel(); timer_handle_ = nullptr; } PRIntervalTime timeout; rv = DTLS_GetHandshakeTimeout(ssl_fd(), &timeout); if (rv == SECSuccess) { Poller::Instance()->SetTimer( timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_); } } Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, &TlsAgent::ReadableCallback); return; } if (err != 0) { LOG("Handshake failed with error " << PORT_ErrorToName(err) << ": " << PORT_ErrorToString(err)); } error_code_ = err; SetState(STATE_ERROR); } void TlsAgent::PrepareForRenegotiate() { EXPECT_EQ(STATE_CONNECTED, state_); SetState(STATE_CONNECTING); } void TlsAgent::StartRenegotiate() { PrepareForRenegotiate(); SECStatus rv = SSL_ReHandshake(ssl_fd(), PR_TRUE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SendDirect(const DataBuffer& buf) { LOG("Send Direct " << buf); auto peer = adapter_->peer().lock(); if (peer) { peer->PacketReceived(buf); } else { LOG("Send Direct peer absent"); } } void TlsAgent::SendRecordDirect(const TlsRecord& record) { DataBuffer buf; auto rv = record.header.Write(&buf, 0, record.buffer); EXPECT_EQ(record.header.header_length() + record.buffer.len(), rv); SendDirect(buf); } static bool ErrorIsFatal(PRErrorCode code) { return code != PR_WOULD_BLOCK_ERROR && code != SSL_ERROR_RX_SHORT_DTLS_READ; } void TlsAgent::SendData(size_t bytes, size_t blocksize) { uint8_t block[16385]; // One larger than the maximum record size. ASSERT_LE(blocksize, sizeof(block)); while (bytes) { size_t tosend = std::min(blocksize, bytes); for (size_t i = 0; i < tosend; ++i) { block[i] = 0xff & send_ctr_; ++send_ctr_; } SendBuffer(DataBuffer(block, tosend)); bytes -= tosend; } } void TlsAgent::SendBuffer(const DataBuffer& buf) { LOGV("Writing " << buf.len() << " bytes"); int32_t rv = PR_Write(ssl_fd(), buf.data(), buf.len()); if (expect_readwrite_error_) { EXPECT_GT(0, rv); EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_); error_code_ = PR_GetError(); expect_readwrite_error_ = false; } else { ASSERT_EQ(buf.len(), static_cast(rv)); } } bool TlsAgent::SendEncryptedRecord(const std::shared_ptr& spec, uint64_t seq, uint8_t ct, const DataBuffer& buf) { // Ensure that we are doing TLS 1.3. EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3); if (variant_ != ssl_variant_datagram) { ADD_FAILURE(); return false; } LOGV("Encrypting " << buf.len() << " bytes"); uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | kCtDtlsCiphertextLengthPresent; TlsRecordHeader header(variant_, expected_version_, dtls13_ct, seq); TlsRecordHeader out_header(header); DataBuffer padded = buf; padded.Write(padded.len(), ct, 1); DataBuffer ciphertext; if (!spec->Protect(header, padded, &ciphertext, &out_header)) { return false; } DataBuffer record; auto rv = out_header.Write(&record, 0, ciphertext); EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv); SendDirect(record); return true; } void TlsAgent::ReadBytes(size_t amount) { uint8_t block[16384]; size_t remaining = amount; while (remaining > 0) { int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block))); LOGV("ReadBytes " << rv); if (rv > 0) { size_t count = static_cast(rv); for (size_t i = 0; i < count; ++i) { ASSERT_EQ(recv_ctr_ & 0xff, block[i]); recv_ctr_++; } remaining -= rv; } else { PRErrorCode err = 0; if (rv < 0) { err = PR_GetError(); if (err != 0) { LOG("Read error " << PORT_ErrorToName(err) << ": " << PORT_ErrorToString(err)); } if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) { if (ErrorIsFatal(err)) { SetState(STATE_ERROR); } error_code_ = err; expect_readwrite_error_ = false; } } if (err != 0 && ErrorIsFatal(err)) { // If we hit a fatal error, we're done. remaining = 0; } break; } } // If closed, then don't bother waiting around. if (remaining) { LOGV("Re-arming"); Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, &TlsAgent::ReadableCallback); } } void TlsAgent::ResetSentBytes(size_t bytes) { send_ctr_ = bytes; } void TlsAgent::SetOption(int32_t option, int value) { ASSERT_TRUE(EnsureTlsSetup()); EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), option, value)); } void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { SetOption(SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE); SetOption(SSL_ENABLE_SESSION_TICKETS, mode & RESUME_TICKET ? PR_TRUE : PR_FALSE); } void TlsAgent::EnableECDHEServerKeyReuse() { ASSERT_EQ(TlsAgent::SERVER, role_); SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_TRUE); } static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"}; ::testing::internal::ParamGenerator TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr); void TlsAgentTestBase::SetUp() { SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); } void TlsAgentTestBase::TearDown() { agent_ = nullptr; SSL_ClearSessionCache(); SSL_ShutdownServerSessionIDCache(); } void TlsAgentTestBase::Reset(const std::string& server_name) { agent_.reset( new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name, role_, variant_)); if (version_) { agent_->SetVersionRange(version_, version_); } agent_->adapter()->SetPeer(sink_adapter_); agent_->StartConnect(); } void TlsAgentTestBase::EnsureInit() { if (!agent_) { Reset(); } const std::vector groups = { ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ffdhe_2048}; agent_->ConfigNamedGroups(groups); } void TlsAgentTestBase::ExpectAlert(uint8_t alert) { EnsureInit(); agent_->ExpectSendAlert(alert); } void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, int32_t error_code) { std::cerr << "Process message: " << buffer << std::endl; EnsureInit(); agent_->adapter()->PacketReceived(buffer); agent_->Handshake(); ASSERT_EQ(expected_state, agent_->state()); if (expected_state == TlsAgent::STATE_ERROR) { ASSERT_EQ(error_code, agent_->error_code()); } } void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type, uint16_t version, const uint8_t* buf, size_t len, DataBuffer* out, uint64_t sequence_number) { // Fixup the content type for DTLSCiphertext if (variant == ssl_variant_datagram && version >= SSL_LIBRARY_VERSION_TLS_1_3 && type == ssl_ct_application_data) { type = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | kCtDtlsCiphertextLengthPresent; } size_t index = 0; if (variant == ssl_variant_stream) { index = out->Write(index, type, 1); index = out->Write(index, version, 2); } else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 && (type & kCtDtlsCiphertextMask) == kCtDtlsCiphertext) { uint32_t epoch = (sequence_number >> 48) & 0x3; index = out->Write(index, type | epoch, 1); uint32_t seqno = sequence_number & ((1ULL << 16) - 1); index = out->Write(index, seqno, 2); } else { index = out->Write(index, type, 1); index = out->Write(index, TlsVersionToDtlsVersion(version), 2); index = out->Write(index, sequence_number >> 32, 4); index = out->Write(index, sequence_number & PR_UINT32_MAX, 4); } index = out->Write(index, len, 2); out->Write(index, buf, len); } void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf, size_t len, DataBuffer* out, uint64_t seq_num) const { MakeRecord(variant_, type, version, buf, len, out, seq_num); } void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out, uint64_t seq_num) const { return MakeHandshakeMessageFragment(hs_type, data, hs_len, out, seq_num, 0, 0); } void TlsAgentTestBase::MakeHandshakeMessageFragment( uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out, uint64_t seq_num, uint32_t fragment_offset, uint32_t fragment_length) const { size_t index = 0; if (!fragment_length) fragment_length = hs_len; index = out->Write(index, hs_type, 1); // Handshake record type. index = out->Write(index, hs_len, 3); // Handshake length if (variant_ == ssl_variant_datagram) { index = out->Write(index, seq_num, 2); index = out->Write(index, fragment_offset, 3); index = out->Write(index, fragment_length, 3); } if (data) { index = out->Write(index, data, fragment_length); } else { for (size_t i = 0; i < fragment_length; ++i) { index = out->Write(index, 1, 1); } } } void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type, size_t hs_len, DataBuffer* out) { size_t index = 0; index = out->Write(index, ssl_ct_handshake, 1); // Content Type index = out->Write(index, 3, 1); // Version high index = out->Write(index, 1, 1); // Version low index = out->Write(index, 4 + hs_len, 2); // Length index = out->Write(index, hs_type, 1); // Handshake record type. index = out->Write(index, hs_len, 3); // Handshake length for (size_t i = 0; i < hs_len; ++i) { index = out->Write(index, 1, 1); } } DataBuffer TlsAgentTestBase::MakeCannedTls13ServerHello() { DataBuffer sh(kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello)); if (variant_ == ssl_variant_datagram) { sh.Write(0, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 2); // The version should be at the end. uint32_t v; EXPECT_TRUE(sh.Read(sh.len() - 2, 2, &v)); EXPECT_EQ(static_cast(SSL_LIBRARY_VERSION_TLS_1_3), v); sh.Write(sh.len() - 2, SSL_LIBRARY_VERSION_DTLS_1_3_WIRE, 2); } return sh; } } // namespace nss_test