diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /security/nss/gtests/ssl_gtest | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'security/nss/gtests/ssl_gtest')
62 files changed, 33053 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/Makefile b/security/nss/gtests/ssl_gtest/Makefile new file mode 100644 index 0000000000..46f0303576 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/Makefile @@ -0,0 +1,58 @@ +#! gmake +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +####################################################################### +# (1) Include initial platform-independent assignments (MANDATORY). # +####################################################################### + +include manifest.mn + +####################################################################### +# (2) Include "global" configuration information. (OPTIONAL) # +####################################################################### + +include $(CORE_DEPTH)/coreconf/config.mk + +####################################################################### +# (3) Include "component" configuration information. (OPTIONAL) # +####################################################################### + + +####################################################################### +# (4) Include "local" platform-dependent assignments (OPTIONAL). # +####################################################################### + +include ../common/gtest.mk + +CFLAGS += -I$(CORE_DEPTH)/lib/ssl + +ifdef NSS_DISABLE_TLS_1_3 +NSS_DISABLE_TLS_1_3=1 +# Run parameterized tests only, for which we can easily exclude TLS 1.3 +CPPSRCS := $(filter-out $(shell grep -l '^TEST_F' $(CPPSRCS)), $(CPPSRCS)) +CFLAGS += -DNSS_DISABLE_TLS_1_3 +endif + +ifdef NSS_ALLOW_SSLKEYLOGFILE +SSLKEYLOGFILE_FILES = ssl_keylog_unittest.cc +else +SSLKEYLOGFILE_FILES = $(NULL) +endif + +####################################################################### +# (5) Execute "global" rules. (OPTIONAL) # +####################################################################### + +include $(CORE_DEPTH)/coreconf/rules.mk + +####################################################################### +# (6) Execute "component" rules. (OPTIONAL) # +####################################################################### + + +####################################################################### +# (7) Execute "local" rules. (OPTIONAL). # +####################################################################### diff --git a/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc b/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc new file mode 100644 index 0000000000..ccb2cd88ef --- /dev/null +++ b/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc @@ -0,0 +1,108 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +extern "C" { +#include "sslbloom.h" +} + +#include "gtest_utils.h" + +namespace nss_test { + +// Some random-ish inputs to test with. These don't result in collisions in any +// of the configurations that are tested below. +static const uint8_t kHashes1[] = { + 0x79, 0x53, 0xb8, 0xdd, 0x6b, 0x98, 0xce, 0x00, 0xb7, 0xdc, 0xe8, + 0x03, 0x70, 0x8c, 0xe3, 0xac, 0x06, 0x8b, 0x22, 0xfd, 0x0e, 0x34, + 0x48, 0xe6, 0xe5, 0xe0, 0x8a, 0xd6, 0x16, 0x18, 0xe5, 0x48}; +static const uint8_t kHashes2[] = { + 0xc6, 0xdd, 0x6e, 0xc4, 0x76, 0xb8, 0x55, 0xf2, 0xa4, 0xfc, 0x59, + 0x04, 0xa4, 0x90, 0xdc, 0xa7, 0xa7, 0x0d, 0x94, 0x8f, 0xc2, 0xdc, + 0x15, 0x6d, 0x48, 0x93, 0x9d, 0x05, 0xbb, 0x9a, 0xbc, 0xc1}; + +typedef struct { + unsigned int k; + unsigned int bits; +} BloomFilterConfig; + +class BloomFilterTest + : public ::testing::Test, + public ::testing::WithParamInterface<BloomFilterConfig> { + public: + BloomFilterTest() : filter_() {} + + void SetUp() { Init(); } + + void TearDown() { sslBloom_Destroy(&filter_); } + + protected: + void Init() { + if (filter_.filter) { + sslBloom_Destroy(&filter_); + } + ASSERT_EQ(SECSuccess, + sslBloom_Init(&filter_, GetParam().k, GetParam().bits)); + } + + bool Check(const uint8_t* hashes) { + return sslBloom_Check(&filter_, hashes) ? true : false; + } + + void Add(const uint8_t* hashes, bool expect_collision = false) { + EXPECT_EQ(expect_collision, sslBloom_Add(&filter_, hashes) ? true : false); + EXPECT_TRUE(Check(hashes)); + } + + sslBloomFilter filter_; +}; + +TEST_P(BloomFilterTest, InitOnly) {} + +TEST_P(BloomFilterTest, AddToEmpty) { + EXPECT_FALSE(Check(kHashes1)); + Add(kHashes1); +} + +TEST_P(BloomFilterTest, AddTwo) { + Add(kHashes1); + Add(kHashes2); +} + +TEST_P(BloomFilterTest, AddOneTwice) { + Add(kHashes1); + Add(kHashes1, true); +} + +TEST_P(BloomFilterTest, Zero) { + Add(kHashes1); + sslBloom_Zero(&filter_); + EXPECT_FALSE(Check(kHashes1)); + EXPECT_FALSE(Check(kHashes2)); +} + +TEST_P(BloomFilterTest, Fill) { + sslBloom_Fill(&filter_); + EXPECT_TRUE(Check(kHashes1)); + EXPECT_TRUE(Check(kHashes2)); +} + +static const BloomFilterConfig kBloomFilterConfigurations[] = { + {1, 1}, // 1 hash, 1 bit input - high chance of collision. + {1, 2}, // 1 hash, 2 bits - smaller than the basic unit size. + {1, 3}, // 1 hash, 3 bits - same as basic unit size. + {1, 4}, // 1 hash, 4 bits - 2 octets each. + {3, 10}, // 3 hashes over a reasonable number of bits. + {3, 3}, // Test that we can read multiple bits. + {4, 15}, // A credible filter. + {2, 18}, // A moderately large allocation. + {16, 16}, // Insane, use all of the bits from the hashes. + {16, 9}, // This also uses all of the bits from the hashes. +}; + +INSTANTIATE_TEST_SUITE_P(BloomFilterConfigurations, BloomFilterTest, + ::testing::ValuesIn(kBloomFilterConfigurations)); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/gtest_utils.h b/security/nss/gtests/ssl_gtest/gtest_utils.h new file mode 100644 index 0000000000..2344c3cea9 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/gtest_utils.h @@ -0,0 +1,57 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef gtest_utils_h__ +#define gtest_utils_h__ + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "test_io.h" + +namespace nss_test { + +// Gtest utilities +class Timeout : public PollTarget { + public: + Timeout(int32_t timer_ms) : handle_(nullptr) { + Poller::Instance()->SetTimer(timer_ms, this, &Timeout::ExpiredCallback, + &handle_); + } + ~Timeout() { + if (handle_) { + handle_->Cancel(); + } + } + + static void ExpiredCallback(PollTarget* target, Event event) { + Timeout* timeout = static_cast<Timeout*>(target); + timeout->handle_ = nullptr; + } + + bool timed_out() const { return !handle_; } + + private: + std::shared_ptr<Poller::Timer> handle_; +}; + +} // namespace nss_test + +#define WAIT_(expression, timeout) \ + do { \ + Timeout tm(timeout); \ + while (!(expression)) { \ + Poller::Instance()->Poll(); \ + if (tm.timed_out()) break; \ + } \ + } while (0) + +#define ASSERT_TRUE_WAIT(expression, timeout) \ + do { \ + WAIT_(expression, timeout); \ + ASSERT_TRUE(expression); \ + } while (0) + +#endif diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.c b/security/nss/gtests/ssl_gtest/libssl_internals.c new file mode 100644 index 0000000000..c6b03c530f --- /dev/null +++ b/security/nss/gtests/ssl_gtest/libssl_internals.c @@ -0,0 +1,501 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +/* This file contains functions for frobbing the internals of libssl */ +#include "libssl_internals.h" + +#include "nss.h" +#include "pk11hpke.h" +#include "pk11pub.h" +#include "pk11priv.h" +#include "tls13ech.h" +#include "seccomon.h" +#include "selfencrypt.h" +#include "secmodti.h" +#include "sslproto.h" + +SECStatus SSLInt_RemoveServerCertificates(PRFileDesc *fd) { + if (!fd) { + return SECFailure; + } + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + PRCList *cursor; + while (!PR_CLIST_IS_EMPTY(&ss->serverCerts)) { + cursor = PR_LIST_TAIL(&ss->serverCerts); + PR_REMOVE_LINK(cursor); + ssl_FreeServerCert((sslServerCert *)cursor); + } + return SECSuccess; +} + +SECStatus SSLInt_SetDCAdvertisedSigSchemes(PRFileDesc *fd, + const SSLSignatureScheme *schemes, + uint32_t num_sig_schemes) { + if (!fd) { + return SECFailure; + } + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + // Alloc and copy, libssl will free. + SSLSignatureScheme *dc_schemes = + PORT_ZNewArray(SSLSignatureScheme, num_sig_schemes); + if (!dc_schemes) { + return SECFailure; + } + memcpy(dc_schemes, schemes, sizeof(SSLSignatureScheme) * num_sig_schemes); + + if (ss->xtnData.delegCredSigSchemesAdvertised) { + PORT_Free(ss->xtnData.delegCredSigSchemesAdvertised); + } + ss->xtnData.delegCredSigSchemesAdvertised = dc_schemes; + ss->xtnData.numDelegCredSigSchemesAdvertised = num_sig_schemes; + return SECSuccess; +} + +SECStatus SSLInt_TweakChannelInfoForDC(PRFileDesc *fd, PRBool changeAuthKeyBits, + PRBool changeScheme) { + if (!fd) { + return SECFailure; + } + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + // Just toggle so we'll always have a valid value. + if (changeScheme) { + ss->sec.signatureScheme = (ss->sec.signatureScheme == ssl_sig_ed25519) + ? ssl_sig_ecdsa_secp256r1_sha256 + : ssl_sig_ed25519; + } + if (changeAuthKeyBits) { + ss->sec.authKeyBits = ss->sec.authKeyBits ? ss->sec.authKeyBits * 2 : 384; + } + + return SECSuccess; +} + +SECStatus SSLInt_GetHandshakeRandoms(PRFileDesc *fd, SSL3Random client_random, + SSL3Random server_random) { + if (!fd) { + return SECFailure; + } + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + if (client_random) { + memcpy(client_random, ss->ssl3.hs.client_random, sizeof(SSL3Random)); + } + if (server_random) { + memcpy(server_random, ss->ssl3.hs.server_random, sizeof(SSL3Random)); + } + return SECSuccess; +} + +SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + ++ss->clientHelloVersion; + + return SECSuccess; +} + +/* Use this function to update the ClientRandom of a client's handshake state + * after replacing its ClientHello message. We for example need to do this + * when replacing an SSLv3 ClientHello with its SSLv2 equivalent. */ +SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd, + size_t rnd_len, uint8_t *msg, + size_t msg_len) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + ssl3_RestartHandshakeHashes(ss); + + // Ensure we don't overrun hs.client_random. + rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len); + + // Zero the client_random. + PORT_Memset(ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH); + + // Copy over the challenge bytes. + size_t offset = SSL3_RANDOM_LENGTH - rnd_len; + PORT_Memcpy(ss->ssl3.hs.client_random + offset, rnd, rnd_len); + + // Rehash the SSLv2 client hello message. + return ssl3_UpdateHandshakeHashes(ss, msg, msg_len); +} + +PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext) { + sslSocket *ss = ssl_FindSocket(fd); + return (PRBool)(ss && ssl3_ExtensionNegotiated(ss, ext)); +} + +// Tests should not use this function directly, because the keys may +// still be in cache. Instead, use TlsConnectTestBase::ClearServerCache. +void SSLInt_ClearSelfEncryptKey() { ssl_ResetSelfEncryptKeys(); } + +sslSelfEncryptKeys *ssl_GetSelfEncryptKeysInt(); + +void SSLInt_SetSelfEncryptMacKey(PK11SymKey *key) { + sslSelfEncryptKeys *keys = ssl_GetSelfEncryptKeysInt(); + + PK11_FreeSymKey(keys->macKey); + keys->macKey = key; +} + +SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + ss->ssl3.mtu = mtu; + ss->ssl3.hs.rtRetries = 0; /* Avoid DTLS shrinking the MTU any more. */ + return SECSuccess; +} + +PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd) { + PRCList *cur_p; + PRInt32 ct = 0; + + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return -1; + } + + for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs); + cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) { + ++ct; + } + return ct; +} + +void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd) { + PRCList *cur_p; + + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return; + } + + fprintf(stderr, "Cipher specs for %s\n", label); + for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs); + cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) { + ssl3CipherSpec *spec = (ssl3CipherSpec *)cur_p; + fprintf(stderr, " %s spec epoch=%d (%s) refct=%d\n", SPEC_DIR(spec), + spec->epoch, spec->phase, spec->refCt); + } +} + +/* DTLS timers are separate from the time that the rest of the stack uses. + * Force a timer expiry by backdating when all active timers were started. + * We could set the remaining time to 0 but then backoff would not work properly + * if we decide to test it. */ +SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift) { + size_t i; + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) { + if (ss->ssl3.hs.timers[i].cb) { + ss->ssl3.hs.timers[i].started -= shift; + } + } + return SECSuccess; +} + +#define CHECK_SECRET(secret) \ + if (ss->ssl3.hs.secret) { \ + fprintf(stderr, "%s != NULL\n", #secret); \ + return PR_FALSE; \ + } + +PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return PR_FALSE; + } + + CHECK_SECRET(currentSecret); + CHECK_SECRET(dheSecret); + CHECK_SECRET(clientEarlyTrafficSecret); + CHECK_SECRET(clientHsTrafficSecret); + CHECK_SECRET(serverHsTrafficSecret); + + return PR_TRUE; +} + +PRBool sslint_DamageTrafficSecret(PRFileDesc *fd, size_t offset) { + unsigned char data[32] = {0}; + PK11SymKey **keyPtr; + PK11SlotInfo *slot = PK11_GetInternalSlot(); + SECItem key_item = {siBuffer, data, sizeof(data)}; + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return PR_FALSE; + } + if (!slot) { + return PR_FALSE; + } + keyPtr = (PK11SymKey **)((char *)&ss->ssl3.hs + offset); + if (!*keyPtr) { + return PR_FALSE; + } + PK11_FreeSymKey(*keyPtr); + *keyPtr = PK11_ImportSymKey(slot, CKM_NSS_HKDF_SHA256, PK11_OriginUnwrap, + CKA_DERIVE, &key_item, NULL); + PK11_FreeSlot(slot); + if (!*keyPtr) { + return PR_FALSE; + } + + return PR_TRUE; +} + +PRBool SSLInt_DamageClientHsTrafficSecret(PRFileDesc *fd) { + return sslint_DamageTrafficSecret( + fd, offsetof(SSL3HandshakeState, clientHsTrafficSecret)); +} + +PRBool SSLInt_DamageServerHsTrafficSecret(PRFileDesc *fd) { + return sslint_DamageTrafficSecret( + fd, offsetof(SSL3HandshakeState, serverHsTrafficSecret)); +} + +PRBool SSLInt_DamageEarlyTrafficSecret(PRFileDesc *fd) { + return sslint_DamageTrafficSecret( + fd, offsetof(SSL3HandshakeState, clientEarlyTrafficSecret)); +} + +SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + ss->xtnData.nextProtoState = SSL_NEXT_PROTO_EARLY_VALUE; + if (ss->xtnData.nextProto.data) { + SECITEM_FreeItem(&ss->xtnData.nextProto, PR_FALSE); + } + if (!SECITEM_AllocItem(NULL, &ss->xtnData.nextProto, len)) { + return SECFailure; + } + PORT_Memcpy(ss->xtnData.nextProto.data, data, len); + + return SECSuccess; +} + +PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return PR_FALSE; + } + + return (PRBool)(!!ssl_FindServerCert(ss, authType, NULL)); +} + +PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return PR_FALSE; + } + + SECStatus rv = SSL3_SendAlert(ss, level, type); + if (rv != SECSuccess) return PR_FALSE; + + return PR_TRUE; +} + +SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) { + sslSocket *ss; + ssl3CipherSpec *spec; + + ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + if (to > RECORD_SEQ_MAX) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + ssl_GetSpecWriteLock(ss); + spec = ss->ssl3.crSpec; + spec->nextSeqNum = to; + + /* For DTLS, we need to fix the record sequence number. For this, we can just + * scrub the entire structure on the assumption that the new sequence number + * is far enough past the last received sequence number. */ + if (spec->nextSeqNum <= + spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + dtls_RecordSetRecvd(&spec->recvdRecords, spec->nextSeqNum - 1); + + ssl_ReleaseSpecWriteLock(ss); + return SECSuccess; +} + +SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) { + sslSocket *ss; + ssl3CipherSpec *spec; + PK11Context *pk11ctxt; + const ssl3BulkCipherDef *cipher_def; + + ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + if (to >= RECORD_SEQ_MAX) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + ssl_GetSpecWriteLock(ss); + spec = ss->ssl3.cwSpec; + cipher_def = spec->cipherDef; + spec->nextSeqNum = to; + if (cipher_def->type != type_aead) { + ssl_ReleaseSpecWriteLock(ss); + return SECSuccess; + } + /* If we are using aead, we need to advance the counter in the + * internal IV generator as well. + * This could be in the token or software. */ + pk11ctxt = spec->cipherContext; + /* If counter is in the token, we need to switch it to software, + * since we don't have access to the internal state of the token. We do + * that by turning on the simulated message interface, then setting up the + * software IV generator */ + if (pk11ctxt->ivCounter == 0) { + _PK11_ContextSetAEADSimulation(pk11ctxt); + pk11ctxt->ivLen = cipher_def->iv_size + cipher_def->explicit_nonce_size; + pk11ctxt->ivMaxCount = PR_UINT64(0xffffffffffffffff); + if ((cipher_def->explicit_nonce_size == 0) || + (spec->version >= SSL_LIBRARY_VERSION_TLS_1_3)) { + pk11ctxt->ivFixedBits = + (pk11ctxt->ivLen - sizeof(sslSequenceNumber)) * BPB; + pk11ctxt->ivGen = CKG_GENERATE_COUNTER_XOR; + } else { + pk11ctxt->ivFixedBits = cipher_def->iv_size * BPB; + pk11ctxt->ivGen = CKG_GENERATE_COUNTER; + } + /* DTLS included the epoch in the fixed portion of the IV */ + if (IS_DTLS(ss)) { + pk11ctxt->ivFixedBits += 2 * BPB; + } + } + /* now we can update the internal counter (either we are already using + * the software IV generator, or we just switched to it above */ + pk11ctxt->ivCounter = to; + + ssl_ReleaseSpecWriteLock(ss); + return SECSuccess; +} + +SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra) { + sslSocket *ss; + sslSequenceNumber to; + + ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + ssl_GetSpecReadLock(ss); + to = ss->ssl3.cwSpec->nextSeqNum + DTLS_RECVD_RECORDS_WINDOW + extra; + ssl_ReleaseSpecReadLock(ss); + return SSLInt_AdvanceWriteSeqNum(fd, to); +} + +SECStatus SSLInt_AdvanceDtls13DecryptFailures(PRFileDesc *fd, PRUint64 to) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + ssl_GetSpecWriteLock(ss); + ssl3CipherSpec *spec = ss->ssl3.crSpec; + if (spec->cipherDef->type != type_aead) { + ssl_ReleaseSpecWriteLock(ss); + return SECFailure; + } + + spec->deprotectionFailures = to; + ssl_ReleaseSpecWriteLock(ss); + return SECSuccess; +} + +SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) { + const sslNamedGroupDef *groupDef = ssl_LookupNamedGroup(group); + if (!groupDef) return ssl_kea_null; + + return groupDef->keaType; +} + +SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) { + sslSocket *ss; + + ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + /* This only works when resuming. */ + if (!ss->statelessResume) { + PORT_SetError(SEC_INTERNAL_ONLY); + return SECFailure; + } + + /* Modifying both specs allows this to be used on either peer. */ + ssl_GetSpecWriteLock(ss); + ss->ssl3.crSpec->earlyDataRemaining = size; + ss->ssl3.cwSpec->earlyDataRemaining = size; + ssl_ReleaseSpecWriteLock(ss); + + return SECSuccess; +} + +SECStatus SSLInt_HasPendingHandshakeData(PRFileDesc *fd, PRBool *pending) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + ssl_GetSSL3HandshakeLock(ss); + *pending = ss->ssl3.hs.msg_body.len > 0; + ssl_ReleaseSSL3HandshakeLock(ss); + return SECSuccess; +} + +SECStatus SSLInt_SetRawEchConfigForRetry(PRFileDesc *fd, const uint8_t *buf, + size_t len) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + sslEchConfig *cfg = (sslEchConfig *)PR_LIST_HEAD(&ss->echConfigs); + SECITEM_FreeItem(&cfg->raw, PR_FALSE); + SECITEM_AllocItem(NULL, &cfg->raw, len); + PORT_Memcpy(cfg->raw.data, buf, len); + return SECSuccess; +} + +PRBool SSLInt_IsIp(PRUint8 *s, unsigned int len) { return tls13_IsIp(s, len); } diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.h b/security/nss/gtests/ssl_gtest/libssl_internals.h new file mode 100644 index 0000000000..70c6520c54 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/libssl_internals.h @@ -0,0 +1,56 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef libssl_internals_h_ +#define libssl_internals_h_ + +#include <stdint.h> + +#include "prio.h" +#include "seccomon.h" +#include "ssl.h" +#include "sslimpl.h" +#include "sslt.h" + +SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd); + +SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd, + size_t rnd_len, uint8_t *msg, + size_t msg_len); +SECStatus SSLInt_GetHandshakeRandoms(PRFileDesc *fd, SSL3Random client_random, + SSL3Random server_random); +PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext); +void SSLInt_ClearSelfEncryptKey(); +void SSLInt_SetSelfEncryptMacKey(PK11SymKey *key); +PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd); +void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd); +SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift); +SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu); +PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd); +PRBool SSLInt_DamageClientHsTrafficSecret(PRFileDesc *fd); +PRBool SSLInt_DamageServerHsTrafficSecret(PRFileDesc *fd); +PRBool SSLInt_DamageEarlyTrafficSecret(PRFileDesc *fd); +SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len); +PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType); +PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type); +SECStatus SSLInt_AdvanceDtls13DecryptFailures(PRFileDesc *fd, PRUint64 to); +SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to); +SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to); +SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra); +SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group); +SECStatus SSLInt_HasPendingHandshakeData(PRFileDesc *fd, PRBool *pending); +SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size); +SECStatus SSLInt_TweakChannelInfoForDC(PRFileDesc *fd, PRBool changeAuthKeyBits, + PRBool changeScheme); +SECStatus SSLInt_SetDCAdvertisedSigSchemes(PRFileDesc *fd, + const SSLSignatureScheme *schemes, + uint32_t num_sig_schemes); +SECStatus SSLInt_RemoveServerCertificates(PRFileDesc *fd); +SECStatus SSLInt_SetRawEchConfigForRetry(PRFileDesc *fd, const uint8_t *buf, + size_t len); +PRBool SSLInt_IsIp(PRUint8 *s, unsigned int len); + +#endif // ifndef libssl_internals_h_ diff --git a/security/nss/gtests/ssl_gtest/manifest.mn b/security/nss/gtests/ssl_gtest/manifest.mn new file mode 100644 index 0000000000..af3081e8e4 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/manifest.mn @@ -0,0 +1,77 @@ +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +CORE_DEPTH = ../.. +DEPTH = ../.. +MODULE = nss + +# These sources have access to libssl internals +CSRCS = \ + libssl_internals.c \ + $(NULL) + +CPPSRCS = \ + bloomfilter_unittest.cc \ + ssl_0rtt_unittest.cc \ + ssl_aead_unittest.cc \ + ssl_agent_unittest.cc \ + ssl_auth_unittest.cc \ + ssl_cert_ext_unittest.cc \ + ssl_cipherorder_unittest.cc \ + ssl_ciphersuite_unittest.cc \ + ssl_custext_unittest.cc \ + ssl_damage_unittest.cc \ + ssl_debug_env_unittest.cc \ + ssl_dhe_unittest.cc \ + ssl_drop_unittest.cc \ + ssl_ecdh_unittest.cc \ + ssl_ems_unittest.cc \ + ssl_exporter_unittest.cc \ + ssl_extension_unittest.cc \ + ssl_fragment_unittest.cc \ + ssl_fuzz_unittest.cc \ + ssl_gather_unittest.cc \ + ssl_gtest.cc \ + ssl_hrr_unittest.cc \ + ssl_keyupdate_unittest.cc \ + ssl_loopback_unittest.cc \ + ssl_masking_unittest.cc \ + ssl_misc_unittest.cc \ + ssl_record_unittest.cc \ + ssl_recordsep_unittest.cc \ + ssl_recordsize_unittest.cc \ + ssl_resumption_unittest.cc \ + ssl_renegotiation_unittest.cc \ + ssl_skip_unittest.cc \ + ssl_staticrsa_unittest.cc \ + ssl_tls13compat_unittest.cc \ + ssl_v2_client_hello_unittest.cc \ + ssl_version_unittest.cc \ + ssl_versionpolicy_unittest.cc \ + selfencrypt_unittest.cc \ + test_io.cc \ + tls_agent.cc \ + tls_connect.cc \ + tls_hkdf_unittest.cc \ + tls_filter.cc \ + tls_protect.cc \ + tls_psk_unittest.cc \ + tls_subcerts_unittest.cc \ + tls_ech_unittest.cc \ + $(SSLKEYLOGFILE_FILES) \ + $(NULL) + +INCLUDES += -I$(CORE_DEPTH)/gtests/google_test/gtest/include \ + -I$(CORE_DEPTH)/gtests/common \ + -I$(CORE_DEPTH)/cpputil + +REQUIRES = nspr nss libdbm gtest cpputil + +PROGRAM = ssl_gtest +EXTRA_LIBS += \ + $(DIST)/lib/$(LIB_PREFIX)gtest.$(LIB_SUFFIX) \ + $(DIST)/lib/$(LIB_PREFIX)cpputil.$(LIB_SUFFIX) \ + $(NULL) + +USE_STATIC_LIBS = 1 diff --git a/security/nss/gtests/ssl_gtest/nss_policy.h b/security/nss/gtests/ssl_gtest/nss_policy.h new file mode 100644 index 0000000000..ceab03becc --- /dev/null +++ b/security/nss/gtests/ssl_gtest/nss_policy.h @@ -0,0 +1,107 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef nss_policy_h_ +#define nss_policy_h_ + +#include "prtypes.h" +#include "secoid.h" +#include "nss.h" + +namespace nss_test { + +// container class to hold all a temp policy +class NssPolicy { + public: + NssPolicy() : oid_(SEC_OID_UNKNOWN), set_(0), clear_(0) {} + NssPolicy(SECOidTag _oid, PRUint32 _set, PRUint32 _clear) + : oid_(_oid), set_(_set), clear_(_clear) {} + NssPolicy(const NssPolicy &p) + : oid_(p.oid_), set_(p.set_), clear_(p.clear_) {} + // clone the current policy for this oid + NssPolicy(SECOidTag _oid) : oid_(_oid), set_(0), clear_(0) { + NSS_GetAlgorithmPolicy(_oid, &set_); + clear_ = ~set_; + } + SECOidTag oid(void) const { return oid_; } + PRUint32 set(void) const { return set_; } + PRUint32 clear(void) const { return clear_; } + operator bool() const { return oid_ != SEC_OID_UNKNOWN; } + + private: + SECOidTag oid_; + PRUint32 set_; + PRUint32 clear_; +}; + +// container class to hold a temp option +class NssOption { + public: + NssOption() : id_(-1), value_(0) {} + NssOption(PRInt32 _id, PRInt32 _value) : id_(_id), value_(_value) {} + NssOption(const NssOption &o) : id_(o.id_), value_(o.value_) {} + // clone the current option for this id + NssOption(PRInt32 _id) : id_(_id), value_(0) { NSS_OptionGet(id_, &value_); } + PRInt32 id(void) const { return id_; } + PRInt32 value(void) const { return value_; } + operator bool() const { return id_ != -1; } + + private: + PRInt32 id_; + PRInt32 value_; +}; + +// set the policy indicated in NssPolicy and restor the old policy +// when we go out of scope +class NssManagePolicy { + public: + NssManagePolicy(const NssPolicy &p, const NssOption &o) + : policy_(p), save_policy_(~(PRUint32)0), option_(o), save_option_(0) { + if (p) { + (void)NSS_GetAlgorithmPolicy(p.oid(), &save_policy_); + (void)NSS_SetAlgorithmPolicy(p.oid(), p.set(), p.clear()); + } + if (o) { + (void)NSS_OptionGet(o.id(), &save_option_); + (void)NSS_OptionSet(o.id(), o.value()); + } + } + ~NssManagePolicy() { + if (policy_) { + (void)NSS_SetAlgorithmPolicy(policy_.oid(), save_policy_, ~save_policy_); + } + if (option_) { + (void)NSS_OptionSet(option_.id(), save_option_); + } + } + + private: + NssPolicy policy_; + PRUint32 save_policy_; + NssOption option_; + PRInt32 save_option_; +}; + +// wrapping PRFileDesc this way ensures that tests that attempt to access +// PRFileDesc always correctly apply +// the policy that was bound to that socket with TlsAgent::SetPolicy(). +class NssManagedFileDesc { + public: + NssManagedFileDesc(PRFileDesc *fd, const NssPolicy &policy, + const NssOption &option) + : fd_(fd), managed_policy_(policy, option) {} + PRFileDesc *get(void) const { return fd_; } + operator PRFileDesc *() const { return fd_; } + bool operator==(PRFileDesc *fd) const { return fd_ == fd; } + + private: + PRFileDesc *fd_; + NssManagePolicy managed_policy_; +}; + +} // namespace nss_test + +#endif diff --git a/security/nss/gtests/ssl_gtest/rsa8193.h b/security/nss/gtests/ssl_gtest/rsa8193.h new file mode 100644 index 0000000000..1ac8503bc0 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/rsa8193.h @@ -0,0 +1,209 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// openssl req -nodes -x509 -newkey rsa:8193 -out cert.pem -days 365 +static const uint8_t rsa8193[] = { + 0x30, 0x82, 0x09, 0x61, 0x30, 0x82, 0x05, 0x48, 0xa0, 0x03, 0x02, 0x01, + 0x02, 0x02, 0x09, 0x00, 0xaf, 0xff, 0x37, 0x91, 0x3e, 0x44, 0xae, 0x57, + 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, + 0x0b, 0x05, 0x00, 0x30, 0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, + 0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, + 0x55, 0x04, 0x08, 0x0c, 0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04, 0x0a, + 0x0c, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, + 0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, + 0x74, 0x64, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x38, 0x30, 0x35, 0x31, 0x37, + 0x30, 0x39, 0x34, 0x32, 0x32, 0x39, 0x5a, 0x17, 0x0d, 0x31, 0x39, 0x30, + 0x35, 0x31, 0x37, 0x30, 0x39, 0x34, 0x32, 0x32, 0x39, 0x5a, 0x30, 0x45, + 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x41, + 0x55, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x0c, 0x0a, + 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65, 0x31, 0x21, + 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x0c, 0x18, 0x49, 0x6e, 0x74, + 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67, 0x69, 0x74, + 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x30, 0x82, 0x04, + 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, + 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x04, 0x0f, 0x00, 0x30, 0x82, 0x04, + 0x0a, 0x02, 0x82, 0x04, 0x01, 0x01, 0x77, 0xd6, 0xa9, 0x93, 0x4e, 0x15, + 0xb5, 0x67, 0x70, 0x8e, 0xc3, 0x77, 0x4f, 0xc9, 0x8a, 0x06, 0xd9, 0xb9, + 0xa6, 0x41, 0xb8, 0xfa, 0x4a, 0x13, 0x26, 0xdc, 0x2b, 0xc5, 0x82, 0xa0, + 0x74, 0x8c, 0x1e, 0xe9, 0xc0, 0x70, 0x15, 0x56, 0xec, 0x1f, 0x7e, 0x91, + 0x6e, 0x31, 0x42, 0x8b, 0xd5, 0xe2, 0x0e, 0x9c, 0xeb, 0xff, 0xbc, 0xf9, + 0x42, 0xd3, 0xb9, 0x1c, 0x5e, 0x46, 0x80, 0x90, 0x5f, 0xe1, 0x59, 0x22, + 0x13, 0x71, 0xd3, 0xd6, 0x66, 0x7a, 0xe0, 0x56, 0x04, 0x10, 0x59, 0x01, + 0xb3, 0xb6, 0xd2, 0xc7, 0xa7, 0x3b, 0xbc, 0xe6, 0x38, 0x44, 0xd5, 0x71, + 0x66, 0x1d, 0xb2, 0x63, 0x2f, 0xa9, 0x5e, 0x80, 0x92, 0x3c, 0x21, 0x0e, + 0xe1, 0xda, 0xd6, 0x1d, 0xcb, 0xce, 0xac, 0xe1, 0x5f, 0x97, 0x45, 0x8f, + 0xc1, 0x64, 0x16, 0xa6, 0x88, 0x2a, 0x36, 0x4a, 0x76, 0x64, 0x8f, 0x83, + 0x7a, 0x1d, 0xd8, 0x91, 0x90, 0x7b, 0x58, 0xb8, 0x1c, 0x7f, 0x56, 0x57, + 0x35, 0xfb, 0xf3, 0x1a, 0xcb, 0x7c, 0x66, 0x66, 0x04, 0x95, 0xee, 0x3a, + 0x80, 0xf0, 0xd4, 0x12, 0x3a, 0x7e, 0x7e, 0x5e, 0xb8, 0x55, 0x29, 0x23, + 0x06, 0xd3, 0x85, 0x0c, 0x99, 0x91, 0x42, 0xee, 0x5a, 0x30, 0x7f, 0x52, + 0x20, 0xb3, 0xe2, 0xe7, 0x39, 0x69, 0xb6, 0xfc, 0x42, 0x1e, 0x98, 0xd3, + 0x31, 0xa2, 0xfa, 0x81, 0x52, 0x69, 0x6d, 0x23, 0xf8, 0xc4, 0xc3, 0x3c, + 0x9b, 0x48, 0x75, 0xa8, 0xc7, 0xe7, 0x61, 0x81, 0x1f, 0xf7, 0xce, 0x10, + 0xaa, 0x13, 0xcb, 0x6e, 0x19, 0xc0, 0x4f, 0x6f, 0x90, 0xa8, 0x41, 0xea, + 0x49, 0xdf, 0xe4, 0xef, 0x84, 0x54, 0xb5, 0x37, 0xaf, 0x12, 0x75, 0x1a, + 0x11, 0x4b, 0x58, 0x7f, 0x63, 0x22, 0x33, 0xb1, 0xc8, 0x4d, 0xf2, 0x41, + 0x10, 0xbc, 0x37, 0xb5, 0xd5, 0xb2, 0x21, 0x32, 0x35, 0x9d, 0xf3, 0x8d, + 0xab, 0x66, 0x9d, 0x19, 0x12, 0x71, 0x45, 0xb3, 0x82, 0x5a, 0x5c, 0xff, + 0x2d, 0xcf, 0xf4, 0x5b, 0x56, 0xb8, 0x08, 0xb3, 0xd2, 0x43, 0x8c, 0xac, + 0xd2, 0xf8, 0xcc, 0x6d, 0x90, 0x97, 0xff, 0x12, 0x74, 0x97, 0xf8, 0xa4, + 0xe3, 0x95, 0xae, 0x92, 0xdc, 0x7e, 0x9d, 0x2b, 0xb4, 0x94, 0xc3, 0x8d, + 0x80, 0xe7, 0x77, 0x5c, 0x5b, 0xbb, 0x43, 0xdc, 0xa6, 0xe9, 0xbe, 0x20, + 0xcc, 0x9d, 0x8e, 0xa4, 0x2b, 0xf2, 0x72, 0xdc, 0x44, 0x61, 0x0f, 0xad, + 0x1a, 0x5e, 0xa5, 0x48, 0xe4, 0x42, 0xc5, 0xe4, 0xf1, 0x6d, 0x33, 0xdb, + 0xb2, 0x1b, 0x9f, 0xb2, 0xff, 0x18, 0x0e, 0x62, 0x35, 0x99, 0xed, 0x22, + 0x19, 0x4a, 0x5e, 0xb3, 0x3c, 0x07, 0x8f, 0x6e, 0x22, 0x5b, 0x16, 0x4a, + 0x9f, 0xef, 0xf3, 0xe7, 0xd6, 0x48, 0xe1, 0xb4, 0x3b, 0xab, 0x1b, 0x9e, + 0x53, 0xd7, 0x1b, 0xd9, 0x2d, 0x51, 0x8f, 0xe4, 0x1c, 0xab, 0xdd, 0xb9, + 0xe2, 0xee, 0xe4, 0xdd, 0x60, 0x04, 0x86, 0x6b, 0x4e, 0x7a, 0xc8, 0x09, + 0x51, 0xd1, 0x9b, 0x36, 0x9a, 0x36, 0x7f, 0xe8, 0x6b, 0x09, 0x6c, 0xee, + 0xad, 0x3a, 0x2f, 0xa8, 0x63, 0x92, 0x23, 0x2f, 0x7e, 0x00, 0xe2, 0xd1, + 0xbb, 0xd9, 0x5b, 0x5b, 0xfa, 0x4b, 0x83, 0x00, 0x19, 0x28, 0xfb, 0x7e, + 0xfe, 0x58, 0xab, 0xb7, 0x33, 0x45, 0x8f, 0x75, 0x9a, 0x54, 0x3d, 0x77, + 0x06, 0x75, 0x61, 0x4f, 0x5c, 0x93, 0xa0, 0xf9, 0xe8, 0xcf, 0xf6, 0x04, + 0x14, 0xda, 0x1b, 0x2e, 0x79, 0x35, 0xb8, 0xb4, 0xfa, 0x08, 0x27, 0x9a, + 0x03, 0x70, 0x78, 0x97, 0x8f, 0xae, 0x2e, 0xd5, 0x1c, 0xe0, 0x4d, 0x91, + 0x3a, 0xfe, 0x1a, 0x64, 0xd8, 0x49, 0xdf, 0x6c, 0x66, 0xac, 0xc9, 0x57, + 0x06, 0x72, 0xc0, 0xc0, 0x09, 0x71, 0x6a, 0xd0, 0xb0, 0x7d, 0x35, 0x3f, + 0x53, 0x17, 0x49, 0x38, 0x92, 0x22, 0x55, 0xf6, 0x58, 0x56, 0xa2, 0x42, + 0x77, 0x94, 0xb7, 0x28, 0x0a, 0xa0, 0xd2, 0xda, 0x25, 0xc1, 0xcc, 0x52, + 0x51, 0xd6, 0xba, 0x18, 0x0f, 0x0d, 0xe3, 0x7d, 0xd1, 0xda, 0xd9, 0x0c, + 0x5e, 0x3a, 0xca, 0xe9, 0xf1, 0xf5, 0x65, 0xfc, 0xc3, 0x99, 0x72, 0x25, + 0xf2, 0xc0, 0xa1, 0x8c, 0x43, 0x9d, 0xb2, 0xc9, 0xb1, 0x1a, 0x24, 0x34, + 0x57, 0xd8, 0xa7, 0x52, 0xa3, 0x39, 0x6e, 0x0b, 0xec, 0xbd, 0x5e, 0xc9, + 0x1f, 0x74, 0xed, 0xae, 0xe6, 0x4e, 0x49, 0xe8, 0x87, 0x3e, 0x46, 0x0d, + 0x40, 0x30, 0xda, 0x9d, 0xcf, 0xf5, 0x03, 0x1f, 0x38, 0x29, 0x3b, 0x66, + 0xe5, 0xc0, 0x89, 0x4c, 0xfc, 0x09, 0x62, 0x37, 0x01, 0xf9, 0x01, 0xab, + 0x8d, 0x53, 0x9c, 0x36, 0x5d, 0x36, 0x66, 0x8d, 0x87, 0xf4, 0xab, 0x37, + 0xb7, 0xf7, 0xe3, 0xdf, 0xc1, 0x52, 0xc0, 0x1d, 0x09, 0x92, 0x21, 0x47, + 0x49, 0x9a, 0x19, 0x38, 0x05, 0x62, 0xf3, 0x47, 0x80, 0x89, 0x1e, 0x70, + 0xa1, 0x57, 0xb7, 0x72, 0xd0, 0x41, 0x7a, 0x5c, 0x6a, 0x13, 0x8b, 0x6c, + 0xda, 0xdf, 0x6b, 0x01, 0x15, 0x20, 0xfa, 0xc8, 0x67, 0xee, 0xb2, 0x13, + 0xd8, 0x5f, 0x84, 0x30, 0x44, 0x8e, 0xf9, 0x2a, 0xae, 0x17, 0x53, 0x49, + 0xaa, 0x34, 0x31, 0x12, 0x31, 0xec, 0xf3, 0x25, 0x27, 0x53, 0x6b, 0xb5, + 0x63, 0xa6, 0xbc, 0xf1, 0x77, 0xd4, 0xb4, 0x77, 0xd1, 0xee, 0xad, 0x62, + 0x9d, 0x2c, 0x2e, 0x11, 0x0a, 0xd1, 0x87, 0xfe, 0xef, 0x77, 0x0e, 0xd1, + 0x38, 0xfe, 0xcc, 0x88, 0xaa, 0x1c, 0x06, 0x93, 0x25, 0x56, 0xfe, 0x0c, + 0x52, 0xe9, 0x7f, 0x4c, 0x3b, 0x2a, 0xfb, 0x40, 0x62, 0x29, 0x0a, 0x1d, + 0x58, 0x78, 0x8b, 0x09, 0x25, 0xaa, 0xc6, 0x8f, 0x66, 0x8f, 0xd1, 0x93, + 0x5a, 0xd6, 0x68, 0x35, 0x69, 0x13, 0x5d, 0x42, 0x35, 0x95, 0xcb, 0xc4, + 0xec, 0x17, 0x92, 0x96, 0xcb, 0x4a, 0xb9, 0x8f, 0xe5, 0xc4, 0x4a, 0xe7, + 0x54, 0x52, 0x4c, 0x64, 0x06, 0xac, 0x2f, 0x13, 0x32, 0x02, 0x47, 0x13, + 0x5c, 0xa2, 0x66, 0xdc, 0x36, 0x0c, 0x4f, 0xbb, 0x89, 0x58, 0x85, 0x16, + 0xf1, 0xf1, 0xff, 0xd2, 0x86, 0x54, 0x29, 0xb3, 0x7e, 0x2a, 0xbd, 0xf9, + 0x53, 0x8c, 0xa0, 0x60, 0x60, 0xb2, 0x90, 0x7f, 0x3a, 0x11, 0x5f, 0x2a, + 0x50, 0x74, 0x2a, 0xd1, 0x68, 0x78, 0xdb, 0x31, 0x1b, 0x8b, 0xee, 0xee, + 0x18, 0x97, 0xf3, 0x50, 0x84, 0xc1, 0x8f, 0xe1, 0xc6, 0x01, 0xb4, 0x16, + 0x65, 0x25, 0x0c, 0x03, 0xab, 0xed, 0x4f, 0xd6, 0xe6, 0x16, 0x23, 0xcc, + 0x42, 0x93, 0xff, 0xfa, 0x92, 0x63, 0x33, 0x9e, 0x36, 0xb0, 0xdc, 0x9a, + 0xb6, 0xaa, 0xd7, 0x48, 0xfe, 0x27, 0x01, 0xcf, 0x67, 0xc0, 0x75, 0xa0, + 0x86, 0x9a, 0xec, 0xa7, 0x2e, 0xb8, 0x7b, 0x00, 0x7f, 0xd4, 0xe3, 0xb3, + 0xfc, 0x48, 0xab, 0x50, 0x20, 0xd4, 0x0d, 0x58, 0x26, 0xc0, 0x3c, 0x09, + 0x0b, 0x80, 0x9e, 0xaf, 0x14, 0x3c, 0x0c, 0x6e, 0x69, 0xbc, 0x6c, 0x4e, + 0x50, 0x33, 0xb0, 0x07, 0x64, 0x6e, 0x77, 0x96, 0xc2, 0xe6, 0x3b, 0xd7, + 0xfe, 0xdc, 0xa4, 0x2f, 0x18, 0x5b, 0x53, 0xe5, 0xdd, 0xb6, 0xce, 0xeb, + 0x16, 0xb4, 0x25, 0xc6, 0xcb, 0xf2, 0x65, 0x3c, 0x4f, 0x94, 0xa5, 0x11, + 0x18, 0xeb, 0x7b, 0x62, 0x1d, 0xd5, 0x02, 0x35, 0x76, 0xf6, 0xb5, 0x20, + 0x27, 0x21, 0x9b, 0xab, 0xf4, 0xb6, 0x8f, 0x1a, 0x70, 0x1d, 0x12, 0xe3, + 0xb9, 0x8e, 0x29, 0x52, 0x25, 0xf4, 0xba, 0xb4, 0x25, 0x2c, 0x91, 0x11, + 0xf2, 0xae, 0x7b, 0xbe, 0xb6, 0x67, 0xd6, 0x08, 0xf8, 0x6f, 0xe7, 0xb0, + 0x16, 0xc5, 0xf6, 0xd5, 0xfb, 0x07, 0x71, 0x5b, 0x0e, 0xe1, 0x02, 0x03, + 0x01, 0x00, 0x01, 0xa3, 0x53, 0x30, 0x51, 0x30, 0x1d, 0x06, 0x03, 0x55, + 0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, 0xaa, 0xe7, 0x7f, 0xcf, 0xf8, 0xb4, + 0xe0, 0x8d, 0x39, 0x9a, 0x1d, 0x4f, 0x86, 0xa2, 0xac, 0x56, 0x32, 0xd9, + 0x58, 0xe3, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, + 0x16, 0x80, 0x14, 0xaa, 0xe7, 0x7f, 0xcf, 0xf8, 0xb4, 0xe0, 0x8d, 0x39, + 0x9a, 0x1d, 0x4f, 0x86, 0xa2, 0xac, 0x56, 0x32, 0xd9, 0x58, 0xe3, 0x30, + 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x05, 0x30, + 0x03, 0x01, 0x01, 0xff, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, + 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x03, 0x82, 0x04, 0x02, 0x00, + 0x00, 0x0a, 0x0a, 0x81, 0xb5, 0x2e, 0xac, 0x52, 0xab, 0x0f, 0xeb, 0xad, + 0x96, 0xd6, 0xd6, 0x59, 0x8f, 0x55, 0x15, 0x56, 0x70, 0xda, 0xd5, 0x75, + 0x47, 0x12, 0x9a, 0x0e, 0xd1, 0x65, 0x68, 0xe0, 0x51, 0x89, 0x59, 0xcc, + 0xe3, 0x5a, 0x1b, 0x85, 0x14, 0xa3, 0x1d, 0x9b, 0x3f, 0xd1, 0xa4, 0x42, + 0xb0, 0x89, 0x12, 0x93, 0xd3, 0x54, 0x19, 0x04, 0xa2, 0xaf, 0xaa, 0x60, + 0xca, 0x03, 0xc2, 0xae, 0x62, 0x8c, 0xb6, 0x31, 0x03, 0xd6, 0xa5, 0xf3, + 0x5e, 0x8d, 0x5c, 0x69, 0x4c, 0x7d, 0x81, 0x49, 0x20, 0x25, 0x41, 0xa4, + 0x2a, 0x95, 0x87, 0x36, 0xa3, 0x9b, 0x9e, 0x9f, 0xed, 0x85, 0xf3, 0xb1, + 0xf1, 0xe9, 0x1b, 0xbb, 0xe3, 0xbc, 0x3b, 0x11, 0x36, 0xca, 0xb9, 0x5f, + 0xee, 0x64, 0xde, 0x2a, 0x99, 0x27, 0x91, 0xc0, 0x54, 0x9e, 0x7a, 0xd4, + 0x89, 0x8c, 0xa0, 0xe3, 0xfd, 0x44, 0x6f, 0x02, 0x38, 0x3c, 0xee, 0x52, + 0x48, 0x1b, 0xd4, 0x25, 0x2b, 0xcb, 0x8e, 0xa8, 0x1b, 0x09, 0xd6, 0x30, + 0x51, 0x15, 0x6c, 0x5c, 0x03, 0x76, 0xad, 0x64, 0x45, 0x50, 0xa2, 0xe1, + 0x3c, 0x5a, 0x67, 0x87, 0xff, 0x8c, 0xed, 0x9a, 0x8d, 0x04, 0xc1, 0xac, + 0xf9, 0xca, 0xf5, 0x2a, 0x05, 0x9c, 0xdd, 0x78, 0xce, 0x99, 0x78, 0x7b, + 0xcd, 0x43, 0x10, 0x40, 0xf7, 0xb5, 0x27, 0x12, 0xec, 0xe9, 0xb2, 0x3f, + 0xf4, 0x5d, 0xd9, 0xbb, 0xf8, 0xc4, 0xc9, 0xa4, 0x46, 0x20, 0x41, 0x7f, + 0xeb, 0x79, 0xb0, 0x51, 0x8c, 0xf7, 0xc3, 0x2c, 0x16, 0xfe, 0x42, 0x59, + 0x77, 0xfe, 0x53, 0xfe, 0x19, 0x57, 0x58, 0x44, 0x6d, 0x12, 0xe2, 0x95, + 0xd0, 0xd3, 0x5a, 0xb5, 0x2d, 0xe5, 0x7e, 0xb4, 0xb3, 0xa9, 0xcc, 0x7d, + 0x53, 0x77, 0x81, 0x01, 0x0f, 0x0a, 0xf6, 0x86, 0x3c, 0x7d, 0xb5, 0x2c, + 0xbf, 0x62, 0xc3, 0xf5, 0x38, 0x89, 0x13, 0x84, 0x1f, 0x44, 0x2d, 0x87, + 0x5c, 0x23, 0x9e, 0x05, 0x62, 0x56, 0x3d, 0x71, 0x4d, 0xd0, 0xe3, 0x15, + 0xe9, 0x09, 0x9c, 0x1a, 0xc0, 0x9a, 0x19, 0x8b, 0x9c, 0xe9, 0xae, 0xde, + 0x62, 0x05, 0x23, 0xe2, 0xd0, 0x3f, 0xf5, 0xef, 0x04, 0x96, 0x4c, 0x87, + 0x34, 0x2f, 0xd5, 0x90, 0xde, 0xbf, 0x4b, 0x56, 0x12, 0x5f, 0xc6, 0xdc, + 0xa4, 0x1c, 0xc4, 0x53, 0x0c, 0xf9, 0xb4, 0xe4, 0x2c, 0xe7, 0x48, 0xbd, + 0xb1, 0xac, 0xf1, 0xc1, 0x8d, 0x53, 0x47, 0x84, 0xc0, 0x78, 0x0a, 0x5e, + 0xc2, 0x16, 0xff, 0xef, 0x97, 0x5b, 0x33, 0x85, 0x92, 0xcd, 0xd4, 0xbb, + 0x64, 0xee, 0xed, 0x17, 0x18, 0x43, 0x32, 0x99, 0x32, 0x36, 0x25, 0xf4, + 0x21, 0x3c, 0x2f, 0x55, 0xdc, 0x16, 0x06, 0x4d, 0x86, 0xa3, 0xa9, 0x34, + 0x22, 0xd5, 0xc3, 0xc8, 0x64, 0x3c, 0x4e, 0x3a, 0x69, 0xbd, 0xcf, 0xd7, + 0xee, 0x3f, 0x0d, 0x15, 0xeb, 0xfb, 0xbd, 0x91, 0x7f, 0xef, 0x48, 0xec, + 0x86, 0xb2, 0x78, 0xf7, 0x53, 0x90, 0x38, 0xb5, 0x04, 0x9c, 0xb7, 0xd7, + 0x9e, 0xaa, 0x15, 0xf7, 0xcd, 0xc2, 0x17, 0xd5, 0x8f, 0x82, 0x98, 0xa3, + 0xaf, 0x59, 0xf1, 0x71, 0xda, 0x6e, 0xaf, 0x97, 0x6d, 0x77, 0x72, 0xfd, + 0xa8, 0x80, 0x25, 0xce, 0x46, 0x04, 0x6e, 0x40, 0x15, 0x24, 0xc0, 0xf9, + 0xbf, 0x13, 0x16, 0x72, 0xcb, 0xb7, 0x10, 0xc7, 0x0a, 0xd6, 0x66, 0x96, + 0x5b, 0x27, 0x4d, 0x66, 0xc4, 0x2f, 0x21, 0x90, 0x9f, 0x8c, 0x24, 0xa0, + 0x0e, 0xa2, 0x89, 0x92, 0xd2, 0x44, 0x63, 0x06, 0xb2, 0xab, 0x07, 0x26, + 0xde, 0x03, 0x1d, 0xdb, 0x2a, 0x42, 0x5b, 0x4c, 0xf6, 0xfe, 0x53, 0xfa, + 0x80, 0x45, 0x8d, 0x75, 0xf6, 0x0e, 0x1d, 0xcc, 0x4c, 0x3b, 0xb0, 0x80, + 0x6d, 0x4c, 0xed, 0x7c, 0xe0, 0xd2, 0xe7, 0x62, 0x59, 0xb1, 0x5a, 0x5d, + 0x3a, 0xec, 0x86, 0x04, 0xfe, 0x26, 0xd1, 0x18, 0xed, 0x56, 0x7d, 0x67, + 0x56, 0x24, 0x6d, 0x7c, 0x6e, 0x8f, 0xc8, 0xa0, 0xba, 0x42, 0x0a, 0x33, + 0x38, 0x7a, 0x09, 0x03, 0xc2, 0xbf, 0x9b, 0x01, 0xdd, 0x03, 0x5a, 0xba, + 0x76, 0x04, 0xb1, 0xc3, 0x40, 0x23, 0x53, 0xbd, 0x64, 0x4e, 0x0f, 0xe7, + 0xc3, 0x4e, 0x48, 0xea, 0x19, 0x2b, 0x1c, 0xe4, 0x3d, 0x93, 0xd8, 0xf6, + 0xfb, 0xda, 0x3d, 0xeb, 0xed, 0xc2, 0xbd, 0x14, 0x57, 0x40, 0xde, 0xd1, + 0x74, 0x54, 0x1b, 0xa8, 0x39, 0xda, 0x73, 0x56, 0xd4, 0xbe, 0xab, 0xec, + 0xc7, 0x17, 0x4f, 0x91, 0xb6, 0xf6, 0xcb, 0x24, 0xc6, 0x1c, 0x07, 0xc4, + 0xf3, 0xd0, 0x5e, 0x8d, 0xfa, 0x44, 0x98, 0x5c, 0x87, 0x36, 0x75, 0xb6, + 0xa5, 0x31, 0xaa, 0xab, 0x7d, 0x38, 0x66, 0xb3, 0x18, 0x58, 0x65, 0x97, + 0x06, 0xfd, 0x61, 0x81, 0x71, 0xc5, 0x17, 0x8b, 0x19, 0x03, 0xc8, 0x58, + 0xec, 0x05, 0xca, 0x7b, 0x0f, 0xec, 0x9d, 0xb4, 0xbc, 0xa3, 0x20, 0x2e, + 0xf8, 0xe4, 0xb1, 0x82, 0xdc, 0x5a, 0xd2, 0x92, 0x9c, 0x43, 0x5d, 0x16, + 0x5b, 0x90, 0x80, 0xe4, 0xfb, 0x6e, 0x24, 0x6b, 0x8c, 0x1a, 0x35, 0xab, + 0xbd, 0x77, 0x7f, 0xf9, 0x61, 0x80, 0xa5, 0xab, 0xa3, 0x39, 0xc2, 0xc9, + 0x69, 0x3c, 0xfc, 0xb3, 0x9a, 0x05, 0x45, 0x03, 0x88, 0x8f, 0x8e, 0x23, + 0xf2, 0x0c, 0x4c, 0x54, 0xb9, 0x40, 0x3a, 0x31, 0x1a, 0x22, 0x67, 0x43, + 0x4a, 0x3e, 0xa0, 0x8c, 0x2d, 0x4d, 0x4f, 0xfc, 0xb5, 0x9b, 0x1f, 0xe1, + 0xef, 0x02, 0x54, 0xab, 0x8d, 0x75, 0x4d, 0x93, 0xba, 0x76, 0xe1, 0xbc, + 0x42, 0x7f, 0x6c, 0xcb, 0xf5, 0x47, 0xd6, 0x8a, 0xac, 0x5d, 0xe9, 0xbb, + 0x3a, 0x65, 0x2c, 0x81, 0xe5, 0xff, 0x27, 0x7e, 0x60, 0x64, 0x80, 0x42, + 0x8d, 0x36, 0x6b, 0x07, 0x76, 0x6a, 0xf1, 0xdf, 0x96, 0x17, 0x93, 0x21, + 0x5d, 0xe4, 0x6c, 0xce, 0x1c, 0xb9, 0x82, 0x45, 0x05, 0x61, 0xe2, 0x41, + 0x96, 0x03, 0x7d, 0x10, 0x8b, 0x3e, 0xc7, 0xe5, 0xcf, 0x08, 0xeb, 0x81, + 0xd3, 0x82, 0x1b, 0x04, 0x96, 0x93, 0x5a, 0xe2, 0x8c, 0x8e, 0x50, 0x33, + 0xf6, 0xf9, 0xf0, 0xfb, 0xb1, 0xd7, 0xc6, 0x97, 0xaa, 0xef, 0x0b, 0x87, + 0xe1, 0x34, 0x97, 0x78, 0x2e, 0x7c, 0x46, 0x11, 0xd5, 0x3c, 0xec, 0x38, + 0x70, 0x59, 0x14, 0x65, 0x4d, 0x0e, 0xd1, 0xeb, 0x49, 0xb3, 0x99, 0x6f, + 0x87, 0xf1, 0x79, 0x21, 0xd9, 0x5c, 0x37, 0xb2, 0xfe, 0xc4, 0x7a, 0xc1, + 0x67, 0xbd, 0x02, 0xfc, 0x02, 0xab, 0x2f, 0xf5, 0x0f, 0xa7, 0xae, 0x90, + 0xc2, 0xaf, 0xdb, 0xd1, 0x96, 0xb2, 0x92, 0x5a, 0xfb, 0xca, 0x28, 0x74, + 0x17, 0xed, 0xda, 0x2c, 0x9f, 0xb4, 0x2d, 0xf5, 0x71, 0x20, 0x64, 0x2d, + 0x44, 0xe5, 0xa3, 0xa0, 0x94, 0x6f, 0x20, 0xb3, 0x73, 0x96, 0x40, 0x06, + 0x9b, 0x25, 0x47, 0x4b, 0xe0, 0x63, 0x91, 0xd9, 0xda, 0xf3, 0xc3, 0xe5, + 0x3a, 0x3c, 0xb7, 0x5f, 0xab, 0x1e, 0x51, 0x17, 0x4f, 0xec, 0xc1, 0x6d, + 0x82, 0x79, 0x8e, 0xba, 0x7c, 0x47, 0x8e, 0x99, 0x00, 0x17, 0x9e, 0xda, + 0x10, 0x42, 0x70, 0x25, 0x42, 0x84, 0xc8, 0xb1, 0x95, 0x56, 0xb2, 0x08, + 0xa0, 0x4f, 0xdc, 0xcd, 0x9e, 0x31, 0x4b, 0x0c, 0x0b, 0x03, 0x5d, 0x2c, + 0x26, 0xbc, 0xa9, 0x4b, 0x19, 0xdf, 0x90, 0x01, 0x9a, 0xe0, 0x06, 0x05, + 0x13, 0x34, 0x9d, 0x34, 0xb8, 0xef, 0x13, 0x3a, 0x20, 0xf5, 0x74, 0x02, + 0x70, 0x3b, 0x41, 0x60, 0x1f, 0x5e, 0x76, 0x0a, 0xb1, 0x17, 0xd5, 0xcf, + 0x79, 0xef, 0xf7, 0xab, 0xe7, 0xd6, 0x0f, 0xad, 0x85, 0x2c, 0x52, 0x67, + 0xb5, 0xa0, 0x4a, 0xfd, 0xaf}; diff --git a/security/nss/gtests/ssl_gtest/selfencrypt_unittest.cc b/security/nss/gtests/ssl_gtest/selfencrypt_unittest.cc new file mode 100644 index 0000000000..24f000454b --- /dev/null +++ b/security/nss/gtests/ssl_gtest/selfencrypt_unittest.cc @@ -0,0 +1,281 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "nss.h" +#include "pk11pub.h" +#include "prerror.h" +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +extern "C" { +#include "sslimpl.h" +#include "selfencrypt.h" +} + +#include "databuffer.h" +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" + +namespace nss_test { + +static const uint8_t kAesKey1Buf[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, + 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f}; +static const DataBuffer kAesKey1(kAesKey1Buf, sizeof(kAesKey1Buf)); + +static const uint8_t kAesKey2Buf[] = {0x10, 0x11, 0x12, 0x13, 0x14, 0x15, + 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f}; +static const DataBuffer kAesKey2(kAesKey2Buf, sizeof(kAesKey2Buf)); + +static const uint8_t kHmacKey1Buf[] = { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, + 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, + 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f}; +static const DataBuffer kHmacKey1(kHmacKey1Buf, sizeof(kHmacKey1Buf)); + +static const uint8_t kHmacKey2Buf[] = { + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, + 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, + 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f}; +static const DataBuffer kHmacKey2(kHmacKey2Buf, sizeof(kHmacKey2Buf)); + +static const uint8_t* kKeyName1 = + reinterpret_cast<const unsigned char*>("KEY1KEY1KEY1KEY1"); +static const uint8_t* kKeyName2 = + reinterpret_cast<const uint8_t*>("KEY2KEY2KEY2KEY2"); + +static void ImportKey(const DataBuffer& key, PK11SlotInfo* slot, + CK_MECHANISM_TYPE mech, CK_ATTRIBUTE_TYPE cka, + ScopedPK11SymKey* to) { + SECItem key_item = {siBuffer, const_cast<uint8_t*>(key.data()), + static_cast<unsigned int>(key.len())}; + + PK11SymKey* inner = + PK11_ImportSymKey(slot, mech, PK11_OriginUnwrap, cka, &key_item, nullptr); + ASSERT_NE(nullptr, inner); + to->reset(inner); +} + +extern "C" { +extern char ssl_trace; +extern FILE* ssl_trace_iob; +} + +class SelfEncryptTestBase : public ::testing::Test { + public: + SelfEncryptTestBase(size_t message_size) + : aes1_(), + aes2_(), + hmac1_(), + hmac2_(), + message_(), + slot_(PK11_GetInternalSlot()) { + EXPECT_NE(nullptr, slot_); + char* ev = getenv("SSLTRACE"); + if (ev && ev[0]) { + ssl_trace = atoi(ev); + ssl_trace_iob = stderr; + } + message_.Allocate(message_size); + for (size_t i = 0; i < message_.len(); ++i) { + message_.data()[i] = i; + } + } + + void SetUp() { + message_.Allocate(100); + for (size_t i = 0; i < 100; ++i) { + message_.data()[i] = i; + } + ImportKey(kAesKey1, slot_.get(), CKM_AES_CBC, CKA_ENCRYPT, &aes1_); + ImportKey(kAesKey2, slot_.get(), CKM_AES_CBC, CKA_ENCRYPT, &aes2_); + ImportKey(kHmacKey1, slot_.get(), CKM_SHA256_HMAC, CKA_SIGN, &hmac1_); + ImportKey(kHmacKey2, slot_.get(), CKM_SHA256_HMAC, CKA_SIGN, &hmac2_); + } + + void SelfTest( + const uint8_t* writeKeyName, const ScopedPK11SymKey& writeAes, + const ScopedPK11SymKey& writeHmac, const uint8_t* readKeyName, + const ScopedPK11SymKey& readAes, const ScopedPK11SymKey& readHmac, + PRErrorCode protect_error_code = 0, PRErrorCode unprotect_error_code = 0, + std::function<void(uint8_t* ciphertext, unsigned int* ciphertext_len)> + mutate = nullptr) { + uint8_t ciphertext[1000]; + unsigned int ciphertext_len; + uint8_t plaintext[1000]; + unsigned int plaintext_len; + + SECStatus rv = ssl_SelfEncryptProtectInt( + writeAes.get(), writeHmac.get(), writeKeyName, message_.data(), + message_.len(), ciphertext, &ciphertext_len, sizeof(ciphertext)); + if (rv != SECSuccess) { + std::cerr << "Error: " << PORT_ErrorToName(PORT_GetError()) << std::endl; + } + if (protect_error_code) { + ASSERT_EQ(protect_error_code, PORT_GetError()); + return; + } + ASSERT_EQ(SECSuccess, rv); + + if (mutate) { + mutate(ciphertext, &ciphertext_len); + } + rv = ssl_SelfEncryptUnprotectInt(readAes.get(), readHmac.get(), readKeyName, + ciphertext, ciphertext_len, plaintext, + &plaintext_len, sizeof(plaintext)); + if (rv != SECSuccess) { + std::cerr << "Error: " << PORT_ErrorToName(PORT_GetError()) << std::endl; + } + if (!unprotect_error_code) { + ASSERT_EQ(SECSuccess, rv); + EXPECT_EQ(message_.len(), plaintext_len); + EXPECT_EQ(0, memcmp(message_.data(), plaintext, message_.len())); + } else { + ASSERT_EQ(SECFailure, rv); + EXPECT_EQ(unprotect_error_code, PORT_GetError()); + } + } + + protected: + ScopedPK11SymKey aes1_; + ScopedPK11SymKey aes2_; + ScopedPK11SymKey hmac1_; + ScopedPK11SymKey hmac2_; + DataBuffer message_; + + private: + ScopedPK11SlotInfo slot_; +}; + +class SelfEncryptTestVariable : public SelfEncryptTestBase, + public ::testing::WithParamInterface<size_t> { + public: + SelfEncryptTestVariable() : SelfEncryptTestBase(GetParam()) {} +}; + +class SelfEncryptTest128 : public SelfEncryptTestBase { + public: + SelfEncryptTest128() : SelfEncryptTestBase(128) {} +}; + +TEST_P(SelfEncryptTestVariable, SuccessCase) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_); +} + +TEST_P(SelfEncryptTestVariable, WrongMacKey) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac2_, 0, + SEC_ERROR_BAD_DATA); +} + +TEST_P(SelfEncryptTestVariable, WrongKeyName) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName2, aes1_, hmac1_, 0, + SEC_ERROR_NOT_A_RECIPIENT); +} + +TEST_P(SelfEncryptTestVariable, AddAByte) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0, + SEC_ERROR_BAD_DATA, + [](uint8_t* ciphertext, unsigned int* ciphertext_len) { + (*ciphertext_len)++; + }); +} + +TEST_P(SelfEncryptTestVariable, SubtractAByte) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0, + SEC_ERROR_BAD_DATA, + [](uint8_t* ciphertext, unsigned int* ciphertext_len) { + (*ciphertext_len)--; + }); +} + +TEST_P(SelfEncryptTestVariable, BogusIv) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0, + SEC_ERROR_BAD_DATA, + [](uint8_t* ciphertext, unsigned int* ciphertext_len) { + ciphertext[16]++; + }); +} + +TEST_P(SelfEncryptTestVariable, BogusCiphertext) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0, + SEC_ERROR_BAD_DATA, + [](uint8_t* ciphertext, unsigned int* ciphertext_len) { + ciphertext[32]++; + }); +} + +TEST_P(SelfEncryptTestVariable, BadMac) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0, + SEC_ERROR_BAD_DATA, + [](uint8_t* ciphertext, unsigned int* ciphertext_len) { + ciphertext[*ciphertext_len - 1]++; + }); +} + +TEST_F(SelfEncryptTest128, DISABLED_BadPadding) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes2_, hmac1_, 0, + SEC_ERROR_BAD_DATA); +} + +TEST_F(SelfEncryptTest128, ShortKeyName) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0, + SEC_ERROR_BAD_DATA, + [](uint8_t* ciphertext, unsigned int* ciphertext_len) { + *ciphertext_len = 15; + }); +} + +TEST_F(SelfEncryptTest128, ShortIv) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0, + SEC_ERROR_BAD_DATA, + [](uint8_t* ciphertext, unsigned int* ciphertext_len) { + *ciphertext_len = 31; + }); +} + +TEST_F(SelfEncryptTest128, ShortCiphertextLen) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0, + SEC_ERROR_BAD_DATA, + [](uint8_t* ciphertext, unsigned int* ciphertext_len) { + *ciphertext_len = 32; + }); +} + +TEST_F(SelfEncryptTest128, ShortCiphertext) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0, + SEC_ERROR_BAD_DATA, + [](uint8_t* ciphertext, unsigned int* ciphertext_len) { + *ciphertext_len -= 17; + }); +} + +TEST_F(SelfEncryptTest128, MacWithAESKeyEncrypt) { + SelfTest(kKeyName1, aes1_, aes1_, kKeyName1, aes1_, hmac1_, + SEC_ERROR_LIBRARY_FAILURE); +} + +TEST_F(SelfEncryptTest128, AESWithMacKeyEncrypt) { + SelfTest(kKeyName1, hmac1_, hmac1_, kKeyName1, aes1_, hmac1_, + SEC_ERROR_INVALID_KEY); +} + +TEST_F(SelfEncryptTest128, MacWithAESKeyDecrypt) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, aes1_, 0, + SEC_ERROR_LIBRARY_FAILURE); +} + +TEST_F(SelfEncryptTest128, AESWithMacKeyDecrypt) { + SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, hmac1_, hmac1_, 0, + SEC_ERROR_INVALID_KEY); +} + +INSTANTIATE_TEST_SUITE_P(VariousSizes, SelfEncryptTestVariable, + ::testing::Values(0, 15, 16, 31, 255, 256, 257)); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc new file mode 100644 index 0000000000..51ec9d3ee5 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc @@ -0,0 +1,1183 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslexp.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "cpputil.h" +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectTls13, ZeroRtt) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, ZeroRttServerRejectByOption) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false); + Handshake(); + CheckConnected(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, ZeroRttApplicationReject) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + + auto reject_0rtt = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast<bool*>(arg); + *called = true; + + EXPECT_TRUE(firstHello); + EXPECT_EQ(0U, clientTokenLen); + return ssl_hello_retry_reject_0rtt; + }; + + bool cb_run = false; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + reject_0rtt, &cb_run)); + ZeroRttSendReceive(true, false); + Handshake(); + EXPECT_TRUE(cb_run); + CheckConnected(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, ZeroRttApparentReplayAfterRestart) { + // The test fixtures enable anti-replay in SetUp(). This results in 0-RTT + // being rejected until at least one window passes. SetupFor0Rtt() forces a + // rollover of the anti-replay filters, which clears that state and allows + // 0-RTT to work. Make the first connection manually to avoid that rollover + // and cause 0-RTT to be rejected. + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT. + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); + StartConnect(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false); + Handshake(); + CheckConnected(); + SendReceive(); +} + +class TlsZeroRttReplayTest : public TlsConnectTls13 { + private: + class SaveFirstPacket : public PacketFilter { + public: + PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) override { + if (!packet_.len() && input.len()) { + packet_ = input; + } + return KEEP; + } + + const DataBuffer& packet() const { return packet_; } + + private: + DataBuffer packet_; + }; + + protected: + void RunTest(bool rollover, const ScopedPK11SymKey& epsk) { + // Now run a true 0-RTT handshake, but capture the first packet. + auto first_packet = std::make_shared<SaveFirstPacket>(); + client_->SetFilter(first_packet); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ZeroRttSendReceive(true, true); + Handshake(); + EXPECT_LT(0U, first_packet->packet().len()); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + + if (rollover) { + RolloverAntiReplay(); + } + + // Now replay that packet against the server. + Reset(); + server_->StartConnect(); + server_->Set0RttEnabled(true); + server_->SetAntiReplayContext(anti_replay_); + if (epsk) { + AddPsk(epsk, std::string("foo"), ssl_hash_sha256, + TLS_CHACHA20_POLY1305_SHA256); + } + + // Capture the early_data extension, which should not appear. + auto early_data_ext = + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_early_data_xtn); + + // Finally, replay the ClientHello and force the server to consume it. Stop + // after the server sends its first flight; the client will not be able to + // complete this handshake. + server_->adapter()->PacketReceived(first_packet->packet()); + server_->Handshake(); + EXPECT_FALSE(early_data_ext->captured()); + } + + void RunResPskTest(bool rollover) { + // Run the initial handshake + SetupForZeroRtt(); + ExpectResumption(RESUME_TICKET); + RunTest(rollover, ScopedPK11SymKey(nullptr)); + } + + void RunExtPskTest(bool rollover) { + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + ASSERT_NE(nullptr, slot); + + const std::vector<uint8_t> kPskDummyVal(16, 0xFF); + SECItem psk_item = {siBuffer, toUcharPtr(kPskDummyVal.data()), + static_cast<unsigned int>(kPskDummyVal.size())}; + PK11SymKey* key = + PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN, PK11_OriginUnwrap, + CKA_DERIVE, &psk_item, NULL); + ASSERT_NE(nullptr, key); + ScopedPK11SymKey scoped_psk(key); + RolloverAntiReplay(); + AddPsk(scoped_psk, std::string("foo"), ssl_hash_sha256, + TLS_CHACHA20_POLY1305_SHA256); + StartConnect(); + RunTest(rollover, scoped_psk); + } +}; + +TEST_P(TlsZeroRttReplayTest, ResPskZeroRttReplay) { RunResPskTest(false); } + +TEST_P(TlsZeroRttReplayTest, ExtPskZeroRttReplay) { RunExtPskTest(false); } + +TEST_P(TlsZeroRttReplayTest, ZeroRttReplayAfterRollover) { + RunResPskTest(true); +} + +// Test that we don't try to send 0-RTT data when the server sent +// us a ticket without the 0-RTT flags. +TEST_P(TlsConnectTls13, ZeroRttOptionsSetLate) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); + Reset(); + StartConnect(); + // Now turn on 0-RTT but too late for the ticket. + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(false, false); + Handshake(); + CheckConnected(); + SendReceive(); +} + +// Make sure that a session ticket sent well after the original handshake +// can be used for 0-RTT. +// Stream because DTLS doesn't support SSL_SendSessionTicket. +TEST_F(TlsConnectStreamTls13, ZeroRttUsingLateTicket) { + // Use a small-ish anti-replay window. + ResetAntiReplay(100 * PR_USEC_PER_MSEC); + RolloverAntiReplay(); + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + server_->Set0RttEnabled(true); + Connect(); + CheckKeys(); + + // Now move time forward 30s and send a ticket. + AdvanceTime(30 * PR_USEC_PER_SEC); + EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)); + SendReceive(); + Reset(); + StartConnect(); + + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); +} + +// Check that post-handshake authentication with a long RTT doesn't +// make things worse. +TEST_F(TlsConnectStreamTls13, ZeroRttUsingLateTicketPha) { + // Use a small-ish anti-replay window. + ResetAntiReplay(100 * PR_USEC_PER_MSEC); + RolloverAntiReplay(); + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + server_->Set0RttEnabled(true); + client_->SetupClientAuth(); + client_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE); + Connect(); + CheckKeys(); + + // Add post-handshake authentication, with some added delays. + AdvanceTime(10 * PR_USEC_PER_SEC); + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())); + AdvanceTime(10 * PR_USEC_PER_SEC); + server_->SendData(50); + client_->ReadBytes(50); + client_->SendData(50); + server_->ReadBytes(50); + + AdvanceTime(10 * PR_USEC_PER_SEC); + EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)); + server_->SendData(100); + client_->ReadBytes(100); + Reset(); + StartConnect(); + + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); +} + +// Same, but with client authentication on the first connection. +TEST_F(TlsConnectStreamTls13, ZeroRttUsingLateTicketClientAuth) { + // Use a small-ish anti-replay window. + ResetAntiReplay(100 * PR_USEC_PER_MSEC); + RolloverAntiReplay(); + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + server_->Set0RttEnabled(true); + Connect(); + CheckKeys(); + + // Now move time forward 30s and send a ticket. + AdvanceTime(30 * PR_USEC_PER_SEC); + EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)); + SendReceive(); + Reset(); + StartConnect(); + + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, ZeroRttServerForgetTicket) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ClearServerCache(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_NONE); + ZeroRttSendReceive(true, false); + Handshake(); + CheckConnected(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, ZeroRttServerOnly) { + ExpectResumption(RESUME_NONE); + server_->Set0RttEnabled(true); + StartConnect(); + + // Client sends ordinary ClientHello. + client_->Handshake(); + + // Verify that the server doesn't get data. + uint8_t buf[100]; + PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Now make sure that things complete. + Handshake(); + CheckConnected(); + SendReceive(); + CheckKeys(); +} + +// Advancing time after sending the ClientHello means that the ticket age that +// arrives at the server is too low. The server then rejects early data if this +// delay exceeds half the anti-replay window. +TEST_P(TlsConnectTls13, ZeroRttRejectOldTicket) { + static const PRTime kWindow = 10 * PR_USEC_PER_SEC; + ResetAntiReplay(kWindow); + SetupForZeroRtt(); + + Reset(); + StartConnect(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false, [this]() { + AdvanceTime(1 + kWindow / 2); + return true; + }); + Handshake(); + ExpectEarlyDataAccepted(false); + CheckConnected(); + SendReceive(); +} + +// In this test, we falsely inflate the estimate of the RTT by delaying the +// ServerHello on the first handshake. This results in the server estimating a +// higher value of the ticket age than the client ultimately provides. Add a +// small tolerance for variation in ticket age and the ticket will appear to +// arrive prematurely, causing the server to reject early data. +TEST_P(TlsConnectTls13, ZeroRttRejectPrematureTicket) { + static const PRTime kWindow = 10 * PR_USEC_PER_SEC; + ResetAntiReplay(kWindow); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + server_->Set0RttEnabled(true); + StartConnect(); + client_->Handshake(); // ClientHello + server_->Handshake(); // ServerHello + AdvanceTime(1 + kWindow / 2); + Handshake(); // Remainder of handshake + CheckConnected(); + SendReceive(); + CheckKeys(); + + Reset(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ExpectEarlyDataAccepted(false); + StartConnect(); + ZeroRttSendReceive(true, false); + Handshake(); + CheckConnected(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpn) { + EnableAlpn(); + SetupForZeroRtt(); + EnableAlpn(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ExpectEarlyDataAccepted(true); + ZeroRttSendReceive(true, true, [this]() { + client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a"); + return true; + }); + Handshake(); + CheckConnected(); + SendReceive(); + CheckAlpn("a"); +} + +// NOTE: In this test and those below, the client always sends +// post-ServerHello alerts with the handshake keys, even if the server +// has accepted 0-RTT. In some cases, as with errors in +// EncryptedExtensions, the client can't know the server's behavior, +// and in others it's just simpler. What the server is expecting +// depends on whether it accepted 0-RTT or not. Eventually, we may +// make the server trial decrypt. +// +// Have the server negotiate a different ALPN value, and therefore +// reject 0-RTT. +TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeServer) { + EnableAlpn(); + SetupForZeroRtt(); + static const uint8_t client_alpn[] = {0x01, 0x61, 0x01, 0x62}; // "a", "b" + static const uint8_t server_alpn[] = {0x01, 0x62}; // "b" + client_->EnableAlpn(client_alpn, sizeof(client_alpn)); + server_->EnableAlpn(server_alpn, sizeof(server_alpn)); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false, [this]() { + client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a"); + return true; + }); + Handshake(); + CheckConnected(); + SendReceive(); + CheckAlpn("b"); +} + +// Check that the client validates the ALPN selection of the server. +// Stomp the ALPN on the client after sending the ClientHello so +// that the server selection appears to be incorrect. The client +// should then fail the connection. +TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnServer) { + EnableAlpn(); + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + EnableAlpn(); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true, [this]() { + PRUint8 b[] = {'b'}; + client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a"); + EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, sizeof(b))); + client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b"); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + return true; + }); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + } else { + client_->Handshake(); + } + client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID); +} + +// Set up with no ALPN and then set the client so it thinks it has ALPN. +// The server responds without the extension and the client returns an +// error. +TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnClient) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true, [this]() { + PRUint8 b[] = {'b'}; + EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, 1)); + client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b"); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + return true; + }); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + } else { + client_->Handshake(); + } + client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID); +} + +// Remove the old ALPN value and so the client will not offer early data. +TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeBoth) { + EnableAlpn(); + SetupForZeroRtt(); + static const std::vector<uint8_t> alpn({0x01, 0x62}); // "b" + EnableAlpn(alpn); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false, [this]() { + client_->CheckAlpn(SSL_NEXT_PROTO_NO_SUPPORT); + return false; + }); + Handshake(); + CheckConnected(); + SendReceive(); + CheckAlpn("b"); +} + +// The client should abort the connection when sending a 0-rtt handshake but +// the servers responds with a TLS 1.2 ServerHello. (no app data sent) +TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngrade) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + server_->Set0RttEnabled(true); // set ticket_allow_early_data + Connect(); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + StartConnect(); + // We will send the early data xtn without sending actual early data. Thus + // a 1.2 server shouldn't fail until the client sends an alert because the + // client sends end_of_early_data only after reading the server's flight. + client_->Set0RttEnabled(true); + + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + } + client_->Handshake(); + server_->Handshake(); + ASSERT_TRUE_WAIT( + (client_->error_code() == SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA), 2000); + + // DTLS will timeout as we bump the epoch when installing the early app data + // cipher suite. Thus the encrypted alert will be ignored. + if (variant_ == ssl_variant_stream) { + // The client sends an encrypted alert message. + ASSERT_TRUE_WAIT( + (server_->error_code() == SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA), + 2000); + } +} + +// The client should abort the connection when sending a 0-rtt handshake but +// the servers responds with a TLS 1.2 ServerHello. (with app data) +TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) { + const char* k0RttData = "ABCDEF"; + const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + server_->Set0RttEnabled(true); // set ticket_allow_early_data + Connect(); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + StartConnect(); + // Send the early data xtn in the CH, followed by early app data. The server + // will fail right after sending its flight, when receiving the early data. + client_->Set0RttEnabled(true); + client_->Handshake(); // Send ClientHello. + PRInt32 rv = + PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write. + EXPECT_EQ(k0RttDataLen, rv); + + if (variant_ == ssl_variant_stream) { + // When the server receives the early data, it will fail. + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + server_->Handshake(); // Consume ClientHello + EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state()); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA); + } else { + // If it's datagram, we just discard the early data. + server_->Handshake(); // Consume ClientHello + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + } + + // The client now reads the ServerHello and fails. + ASSERT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA); +} + +TEST_P(TlsConnectTls13, SendTooMuchEarlyData) { + EnsureTlsSetup(); + const char* big_message = "0123456789abcdef"; + const size_t short_size = strlen(big_message) - 1; + const PRInt32 short_length = static_cast<PRInt32>(short_size); + EXPECT_EQ(SECSuccess, + SSL_SetMaxEarlyDataSize(server_->ssl_fd(), + static_cast<PRUint32>(short_size))); + SetupForZeroRtt(); + + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + + client_->Handshake(); + CheckEarlyDataLimit(client_, short_size); + + PRInt32 sent; + // Writing more than the limit will succeed in TLS, but fail in DTLS. + if (variant_ == ssl_variant_stream) { + sent = PR_Write(client_->ssl_fd(), big_message, + static_cast<PRInt32>(strlen(big_message))); + } else { + sent = PR_Write(client_->ssl_fd(), big_message, + static_cast<PRInt32>(strlen(big_message))); + EXPECT_GE(0, sent); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Try an exact-sized write now. + sent = PR_Write(client_->ssl_fd(), big_message, short_length); + } + EXPECT_EQ(short_length, sent); + + // Even a single octet write should now fail. + sent = PR_Write(client_->ssl_fd(), big_message, 1); + EXPECT_GE(0, sent); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Process the ClientHello and read 0-RTT. + server_->Handshake(); + CheckEarlyDataLimit(server_, short_size); + + std::vector<uint8_t> buf(short_size + 1); + PRInt32 read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity()); + EXPECT_EQ(short_length, read); + EXPECT_EQ(0, memcmp(big_message, buf.data(), short_size)); + + // Second read fails. + read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity()); + EXPECT_EQ(SECFailure, read); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) { + EnsureTlsSetup(); + + const size_t limit = 5; + EXPECT_EQ(SECSuccess, SSL_SetMaxEarlyDataSize(server_->ssl_fd(), limit)); + SetupForZeroRtt(); + + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + + client_->Handshake(); // Send ClientHello + CheckEarlyDataLimit(client_, limit); + + server_->Handshake(); // Process ClientHello, send server flight. + + // Lift the limit on the client. + EXPECT_EQ(SECSuccess, + SSLInt_SetSocketMaxEarlyDataSize(client_->ssl_fd(), 1000)); + + // Send message + const char* message = "0123456789abcdef"; + const PRInt32 message_len = static_cast<PRInt32>(strlen(message)); + EXPECT_EQ(message_len, PR_Write(client_->ssl_fd(), message, message_len)); + + if (variant_ == ssl_variant_stream) { + // This error isn't fatal for DTLS. + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + } + + server_->Handshake(); // This reads the early data and maybe throws an error. + if (variant_ == ssl_variant_stream) { + server_->CheckErrorCode(SSL_ERROR_TOO_MUCH_EARLY_DATA); + } else { + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + } + CheckEarlyDataLimit(server_, limit); + + // Attempt to read early data. This will get an error. + std::vector<uint8_t> buf(strlen(message) + 1); + EXPECT_GT(0, PR_Read(server_->ssl_fd(), buf.data(), buf.capacity())); + if (variant_ == ssl_variant_stream) { + EXPECT_EQ(SSL_ERROR_HANDSHAKE_FAILED, PORT_GetError()); + } else { + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + } + + client_->Handshake(); // Process the server's first flight. + if (variant_ == ssl_variant_stream) { + client_->Handshake(); // Process the alert. + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + } else { + server_->Handshake(); // Finish connecting. + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + } +} + +class PacketCoalesceFilter : public PacketFilter { + public: + PacketCoalesceFilter() : packet_data_() {} + + void SendCoalesced(std::shared_ptr<TlsAgent> agent) { + agent->SendDirect(packet_data_); + } + + protected: + PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) override { + packet_data_.Write(packet_data_.len(), input); + return DROP; + } + + private: + DataBuffer packet_data_; +}; + +TEST_P(TlsConnectTls13, ZeroRttOrdering) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + + // Send out the ClientHello. + client_->Handshake(); + + // Now, coalesce the next three things from the client: early data, second + // flight and 1-RTT data. + auto coalesce = std::make_shared<PacketCoalesceFilter>(); + client_->SetFilter(coalesce); + + // Send (and hold) early data. + static const std::vector<uint8_t> early_data = {3, 2, 1}; + EXPECT_EQ(static_cast<PRInt32>(early_data.size()), + PR_Write(client_->ssl_fd(), early_data.data(), early_data.size())); + + // Send (and hold) the second client handshake flight. + // The client sends EndOfEarlyData after seeing the server Finished. + server_->Handshake(); + client_->Handshake(); + + // Send (and hold) 1-RTT data. + static const std::vector<uint8_t> late_data = {7, 8, 9, 10}; + EXPECT_EQ(static_cast<PRInt32>(late_data.size()), + PR_Write(client_->ssl_fd(), late_data.data(), late_data.size())); + + // Now release them all at once. + coalesce->SendCoalesced(client_); + + // Now ensure that the three steps are exposed in the right order on the + // server: delivery of early data, handshake callback, delivery of 1-RTT. + size_t step = 0; + server_->SetHandshakeCallback([&step](TlsAgent*) { + EXPECT_EQ(1U, step); + ++step; + }); + + std::vector<uint8_t> buf(10); + PRInt32 read = PR_Read(server_->ssl_fd(), buf.data(), buf.size()); + ASSERT_EQ(static_cast<PRInt32>(early_data.size()), read); + buf.resize(read); + EXPECT_EQ(early_data, buf); + EXPECT_EQ(0U, step); + ++step; + + // The third read should be after the handshake callback and should return the + // data that was sent after the handshake completed. + buf.resize(10); + read = PR_Read(server_->ssl_fd(), buf.data(), buf.size()); + ASSERT_EQ(static_cast<PRInt32>(late_data.size()), read); + buf.resize(read); + EXPECT_EQ(late_data, buf); + EXPECT_EQ(2U, step); +} + +// Early data remains available after the handshake completes for TLS. +TEST_F(TlsConnectStreamTls13, ZeroRttLateReadTls) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + client_->Handshake(); // ClientHello + + // Write some early data. + const uint8_t data[] = {1, 2, 3, 4, 5, 6, 7, 8}; + PRInt32 rv = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), rv); + + // Consume the ClientHello and generate ServerHello..Finished. + server_->Handshake(); + + // Read some of the data. + std::vector<uint8_t> small_buffer(1 + sizeof(data) / 2); + rv = PR_Read(server_->ssl_fd(), small_buffer.data(), small_buffer.size()); + EXPECT_EQ(static_cast<PRInt32>(small_buffer.size()), rv); + EXPECT_EQ(0, memcmp(data, small_buffer.data(), small_buffer.size())); + + Handshake(); // Complete the handshake. + ExpectEarlyDataAccepted(true); + CheckConnected(); + + // After the handshake, it should be possible to read the remainder. + uint8_t big_buf[100]; + rv = PR_Read(server_->ssl_fd(), big_buf, sizeof(big_buf)); + EXPECT_EQ(static_cast<PRInt32>(sizeof(data) - small_buffer.size()), rv); + EXPECT_EQ(0, memcmp(&data[small_buffer.size()], big_buf, + sizeof(data) - small_buffer.size())); + + // And that's all there is to read. + rv = PR_Read(server_->ssl_fd(), big_buf, sizeof(big_buf)); + EXPECT_GT(0, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +// Early data that arrives before the handshake can be read after the handshake +// is complete. +TEST_F(TlsConnectDatagram13, ZeroRttLateReadDtls) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + client_->Handshake(); // ClientHello + + // Write some early data. + const uint8_t data[] = {1, 2, 3}; + PRInt32 written = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), written); + + Handshake(); // Complete the handshake. + ExpectEarlyDataAccepted(true); + CheckConnected(); + + // Reading at the server should return the early data, which was buffered. + uint8_t buf[sizeof(data) + 1] = {0}; + PRInt32 read = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), read); + EXPECT_EQ(0, memcmp(data, buf, sizeof(data))); +} + +class PacketHolder : public PacketFilter { + public: + PacketHolder() = default; + + virtual Action Filter(const DataBuffer& input, DataBuffer* output) { + packet_ = input; + Disable(); + return DROP; + } + + const DataBuffer& packet() const { return packet_; } + + private: + DataBuffer packet_; +}; + +// Early data that arrives late is discarded for DTLS. +TEST_F(TlsConnectDatagram13, ZeroRttLateArrivalDtls) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + client_->Handshake(); // ClientHello + + // Write some early data. Twice, so that we can read bits of it. + const uint8_t data[] = {1, 2, 3}; + PRInt32 written = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), written); + + // Block and capture the next packet. + auto holder = std::make_shared<PacketHolder>(); + client_->SetFilter(holder); + written = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), written); + EXPECT_FALSE(holder->enabled()) << "the filter should disable itself"; + + // Consume the ClientHello and generate ServerHello..Finished. + server_->Handshake(); + + // Read some of the data. + std::vector<uint8_t> small_buffer(sizeof(data)); + PRInt32 read = + PR_Read(server_->ssl_fd(), small_buffer.data(), small_buffer.size()); + + EXPECT_EQ(static_cast<PRInt32>(small_buffer.size()), read); + EXPECT_EQ(0, memcmp(data, small_buffer.data(), small_buffer.size())); + + Handshake(); // Complete the handshake. + ExpectEarlyDataAccepted(true); + CheckConnected(); + + server_->SendDirect(holder->packet()); + + // Reading now should return nothing, even though a valid packet was + // delivered. + read = PR_Read(server_->ssl_fd(), small_buffer.data(), small_buffer.size()); + EXPECT_GT(0, read); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +// Early data reads in TLS should be coalesced. +TEST_F(TlsConnectStreamTls13, ZeroRttCoalesceReadTls) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + client_->Handshake(); // ClientHello + + // Write some early data. In two writes. + const uint8_t data[] = {1, 2, 3, 4, 5, 6}; + PRInt32 written = PR_Write(client_->ssl_fd(), data, 1); + EXPECT_EQ(1, written); + + written = PR_Write(client_->ssl_fd(), data + 1, sizeof(data) - 1); + EXPECT_EQ(static_cast<PRInt32>(sizeof(data) - 1), written); + + // Consume the ClientHello and generate ServerHello..Finished. + server_->Handshake(); + + // Read all of the data. + std::vector<uint8_t> buffer(sizeof(data)); + PRInt32 read = PR_Read(server_->ssl_fd(), buffer.data(), buffer.size()); + EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), read); + EXPECT_EQ(0, memcmp(data, buffer.data(), sizeof(data))); + + Handshake(); // Complete the handshake. + ExpectEarlyDataAccepted(true); + CheckConnected(); +} + +// Early data reads in DTLS should not be coalesced. +TEST_F(TlsConnectDatagram13, ZeroRttNoCoalesceReadDtls) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + client_->Handshake(); // ClientHello + + // Write some early data. In two writes. + const uint8_t data[] = {1, 2, 3, 4, 5, 6}; + PRInt32 written = PR_Write(client_->ssl_fd(), data, 1); + EXPECT_EQ(1, written); + + written = PR_Write(client_->ssl_fd(), data + 1, sizeof(data) - 1); + EXPECT_EQ(static_cast<PRInt32>(sizeof(data) - 1), written); + + // Consume the ClientHello and generate ServerHello..Finished. + server_->Handshake(); + + // Try to read all of the data. + std::vector<uint8_t> buffer(sizeof(data)); + PRInt32 read = PR_Read(server_->ssl_fd(), buffer.data(), buffer.size()); + EXPECT_EQ(1, read); + EXPECT_EQ(0, memcmp(data, buffer.data(), 1)); + + // Read the remainder. + read = PR_Read(server_->ssl_fd(), buffer.data(), buffer.size()); + EXPECT_EQ(static_cast<PRInt32>(sizeof(data) - 1), read); + EXPECT_EQ(0, memcmp(data + 1, buffer.data(), sizeof(data) - 1)); + + Handshake(); // Complete the handshake. + ExpectEarlyDataAccepted(true); + CheckConnected(); +} + +// Early data reads in DTLS should fail if the buffer is too small. +TEST_F(TlsConnectDatagram13, ZeroRttShortReadDtls) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + client_->Handshake(); // ClientHello + + // Write some early data. In two writes. + const uint8_t data[] = {1, 2, 3, 4, 5, 6}; + PRInt32 written = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), written); + + // Consume the ClientHello and generate ServerHello..Finished. + server_->Handshake(); + + // Try to read all of the data into a small buffer. + std::vector<uint8_t> buffer(sizeof(data)); + PRInt32 read = PR_Read(server_->ssl_fd(), buffer.data(), 1); + EXPECT_GT(0, read); + EXPECT_EQ(SSL_ERROR_RX_SHORT_DTLS_READ, PORT_GetError()); + + // Read again with more space. + read = PR_Read(server_->ssl_fd(), buffer.data(), buffer.size()); + EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), read); + EXPECT_EQ(0, memcmp(data, buffer.data(), sizeof(data))); + + Handshake(); // Complete the handshake. + ExpectEarlyDataAccepted(true); + CheckConnected(); +} + +// There are few ways in which TLS uses the clock and most of those operate on +// timescales that would be ridiculous to wait for in a test. This is the one +// test we have that uses the real clock. It tests that time passes by checking +// that a small sleep results in rejection of early data. 0-RTT has a +// configurable timer, which makes it ideal for this. +TEST_F(TlsConnectStreamTls13, TimePassesByDefault) { + // Calling EnsureTlsSetup() replaces the time function on client and server, + // and sets up anti-replay, which we don't want, so initialize each directly. + client_->EnsureTlsSetup(); + server_->EnsureTlsSetup(); + // StartConnect() calls EnsureTlsSetup(), so avoid that too. + client_->StartConnect(); + server_->StartConnect(); + + // Set a tiny anti-replay window. This has to be at least 2 milliseconds to + // have any chance of being relevant as that is the smallest window that we + // can detect. Anything smaller rounds to zero. + static const unsigned int kTinyWindowMs = 5; + ResetAntiReplay(static_cast<PRTime>(kTinyWindowMs * PR_USEC_PER_MSEC)); + server_->SetAntiReplayContext(anti_replay_); + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + server_->Set0RttEnabled(true); + Handshake(); + CheckConnected(); + SendReceive(); // Absorb a session ticket. + CheckKeys(); + + // Clear the first window. + PR_Sleep(PR_MillisecondsToInterval(kTinyWindowMs)); + + Reset(); + client_->EnsureTlsSetup(); + server_->EnsureTlsSetup(); + client_->StartConnect(); + server_->StartConnect(); + + // Early data is rejected by the server only if time passes for it as well. + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false, []() { + // Sleep long enough that we minimize the risk of our RTT estimation being + // duped by stutters in test execution. This is very long to allow for + // flaky and low-end hardware, especially what our CI runs on. + PR_Sleep(PR_MillisecondsToInterval(1000)); + return true; + }); + Handshake(); + ExpectEarlyDataAccepted(false); + CheckConnected(); +} + +// Test that SSL_CreateAntiReplayContext doesn't pass bad inputs. +TEST_F(TlsConnectStreamTls13, BadAntiReplayArgs) { + SSLAntiReplayContext* p; + // Zero or negative window. + EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, -1, 1, 1, &p)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, 0, 1, 1, &p)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + // Zero k. + EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, 1, 0, 1, &p)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + // Zero bits. + EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, 1, 1, 0, &p)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, 1, 1, 1, nullptr)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Prove that these parameters do work, even if they are useless.. + EXPECT_EQ(SECSuccess, SSL_CreateAntiReplayContext(0, 1, 1, 1, &p)); + ASSERT_NE(nullptr, p); + ScopedSSLAntiReplayContext ctx(p); + + // The socket isn't a client or server until later, so configuring a client + // should work OK. + client_->EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(client_->ssl_fd(), ctx.get())); + EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(client_->ssl_fd(), nullptr)); +} + +// See also TlsConnectGenericResumption.ResumeServerIncompatibleCipher +TEST_P(TlsConnectTls13, ZeroRttDifferentCompatibleCipher) { + EnsureTlsSetup(); + server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + // Change the ciphersuite. Resumption is OK because the hash is the same, but + // early data will be rejected. + server_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256); + ExpectResumption(RESUME_TICKET); + + StartConnect(); + ZeroRttSendReceive(true, false); + + Handshake(); + ExpectEarlyDataAccepted(false); + CheckConnected(); + SendReceive(); +} + +// See also TlsConnectGenericResumption.ResumeServerIncompatibleCipher +TEST_P(TlsConnectTls13, ZeroRttDifferentIncompatibleCipher) { + EnsureTlsSetup(); + server_->EnableSingleCipher(TLS_AES_256_GCM_SHA384); + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + // Resumption is rejected because the hash is different. + server_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256); + ExpectResumption(RESUME_NONE); + + StartConnect(); + ZeroRttSendReceive(true, false); + + Handshake(); + ExpectEarlyDataAccepted(false); + CheckConnected(); + SendReceive(); +} + +// The client failing to provide EndOfEarlyData results in failure. +// After 0-RTT working perfectly, things fall apart later. +// The server is unable to detect the change in keys, so it fails decryption. +// The client thinks everything has worked until it gets the alert. +TEST_F(TlsConnectStreamTls13, SuppressEndOfEarlyDataClientOnly) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + client_->SetOption(SSL_SUPPRESS_END_OF_EARLY_DATA, true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + ExpectAlert(server_, kTlsAlertBadRecordMac); + Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state()); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_ALERT); +} + +TEST_P(TlsConnectGeneric, SuppressEndOfEarlyDataNoZeroRtt) { + EnsureTlsSetup(); + client_->SetOption(SSL_SUPPRESS_END_OF_EARLY_DATA, true); + server_->SetOption(SSL_SUPPRESS_END_OF_EARLY_DATA, true); + Connect(); + SendReceive(); +} + +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_SUITE_P(Tls13ZeroRttReplayTest, TlsZeroRttReplayTest, + TlsConnectTestBase::kTlsVariantsAll); +#endif + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_aead_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_aead_unittest.cc new file mode 100644 index 0000000000..d94683be30 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_aead_unittest.cc @@ -0,0 +1,218 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <memory> + +#include "keyhi.h" +#include "pk11pub.h" +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslexp.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "scoped_ptrs_ssl.h" +#include "tls_connect.h" + +namespace nss_test { + +// From tls_hkdf_unittest.cc: +extern size_t GetHashLength(SSLHashType ht); + +class AeadTest : public ::testing::Test { + public: + AeadTest() : slot_(PK11_GetInternalSlot()) {} + + void InitSecret(SSLHashType hash_type) { + static const uint8_t kData[64] = {'s', 'e', 'c', 'r', 'e', 't'}; + SECItem key_item = {siBuffer, const_cast<uint8_t *>(kData), + static_cast<unsigned int>(GetHashLength(hash_type))}; + PK11SymKey *s = + PK11_ImportSymKey(slot_.get(), CKM_SSL3_MASTER_KEY_DERIVE, + PK11_OriginUnwrap, CKA_DERIVE, &key_item, NULL); + ASSERT_NE(nullptr, s); + secret_.reset(s); + } + + void SetUp() override { + InitSecret(ssl_hash_sha256); + PORT_SetError(0); + } + + protected: + static void EncryptDecrypt(const ScopedSSLAeadContext &ctx, + const uint8_t *ciphertext, size_t ciphertext_len) { + static const uint8_t kAad[] = {'a', 'a', 'd'}; + static const uint8_t kPlaintext[] = {'t', 'e', 'x', 't'}; + static const size_t kMaxSize = 32; + + ASSERT_GE(kMaxSize, ciphertext_len); + ASSERT_LT(0U, ciphertext_len); + + uint8_t output[kMaxSize] = {0}; + unsigned int output_len = 0; + EXPECT_EQ(SECSuccess, SSL_AeadEncrypt(ctx.get(), 0, kAad, sizeof(kAad), + kPlaintext, sizeof(kPlaintext), + output, &output_len, sizeof(output))); + ASSERT_EQ(ciphertext_len, static_cast<size_t>(output_len)); + EXPECT_EQ(0, memcmp(ciphertext, output, ciphertext_len)); + + memset(output, 0, sizeof(output)); + EXPECT_EQ(SECSuccess, SSL_AeadDecrypt(ctx.get(), 0, kAad, sizeof(kAad), + ciphertext, ciphertext_len, output, + &output_len, sizeof(output))); + ASSERT_EQ(sizeof(kPlaintext), static_cast<size_t>(output_len)); + EXPECT_EQ(0, memcmp(kPlaintext, output, sizeof(kPlaintext))); + + // Now for some tests of decryption failure. + // Truncate the input. + EXPECT_EQ(SECFailure, SSL_AeadDecrypt(ctx.get(), 0, kAad, sizeof(kAad), + ciphertext, ciphertext_len - 1, + output, &output_len, sizeof(output))); + EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError()); + + // Skip the first byte of the AAD. + EXPECT_EQ( + SECFailure, + SSL_AeadDecrypt(ctx.get(), 0, kAad + 1, sizeof(kAad) - 1, ciphertext, + ciphertext_len, output, &output_len, sizeof(output))); + EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError()); + + uint8_t input[kMaxSize] = {0}; + // Toggle a byte of the input. + memcpy(input, ciphertext, ciphertext_len); + input[0] ^= 9; + EXPECT_EQ(SECFailure, SSL_AeadDecrypt(ctx.get(), 0, kAad, sizeof(kAad), + input, ciphertext_len, output, + &output_len, sizeof(output))); + EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError()); + + // Toggle the last byte (the auth tag). + memcpy(input, ciphertext, ciphertext_len); + input[ciphertext_len - 1] ^= 77; + EXPECT_EQ(SECFailure, SSL_AeadDecrypt(ctx.get(), 0, kAad, sizeof(kAad), + input, ciphertext_len, output, + &output_len, sizeof(output))); + EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError()); + + // Toggle some of the AAD. + memcpy(input, kAad, sizeof(kAad)); + input[1] ^= 23; + EXPECT_EQ(SECFailure, SSL_AeadDecrypt(ctx.get(), 0, input, sizeof(kAad), + ciphertext, ciphertext_len, output, + &output_len, sizeof(output))); + EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError()); + } + + protected: + ScopedPK11SymKey secret_; + + private: + ScopedPK11SlotInfo slot_; +}; + +// These tests all use fixed inputs: a fixed secret, a fixed label, and fixed +// inputs. So they have fixed outputs. +static const char *kLabel = "test "; +static const uint8_t kCiphertextAes128Gcm[] = { + 0x11, 0x14, 0xfc, 0x58, 0x4f, 0x44, 0xff, 0x8c, 0xb6, 0xd8, + 0x20, 0xb3, 0xfb, 0x50, 0xd9, 0x3b, 0xd4, 0xc6, 0xe1, 0x14}; +static const uint8_t kCiphertextAes256Gcm[] = { + 0xf7, 0x27, 0x35, 0x80, 0x88, 0xaf, 0x99, 0x85, 0xf2, 0x83, + 0xca, 0xbb, 0x95, 0x42, 0x09, 0x3f, 0x9c, 0xf3, 0x29, 0xf0}; +static const uint8_t kCiphertextChaCha20Poly1305[] = { + 0x4e, 0x89, 0x2c, 0xfa, 0xfc, 0x8c, 0x40, 0x55, 0x6d, 0x7e, + 0x99, 0xac, 0x8e, 0x54, 0x58, 0xb1, 0x18, 0xd2, 0x66, 0x22}; + +TEST_F(AeadTest, AeadBadVersion) { + SSLAeadContext *ctx = nullptr; + ASSERT_EQ(SECFailure, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256, + secret_.get(), kLabel, strlen(kLabel), &ctx)); + EXPECT_EQ(nullptr, ctx); +} + +TEST_F(AeadTest, AeadUnsupportedCipher) { + SSLAeadContext *ctx = nullptr; + ASSERT_EQ(SECFailure, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_NULL_MD5, + secret_.get(), kLabel, strlen(kLabel), &ctx)); + EXPECT_EQ(nullptr, ctx); +} + +TEST_F(AeadTest, AeadOlderCipher) { + SSLAeadContext *ctx = nullptr; + ASSERT_EQ( + SECFailure, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_AES_128_CBC_SHA, + secret_.get(), kLabel, strlen(kLabel), &ctx)); + EXPECT_EQ(nullptr, ctx); +} + +TEST_F(AeadTest, AeadNoLabel) { + SSLAeadContext *ctx = nullptr; + ASSERT_EQ(SECFailure, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256, + secret_.get(), nullptr, 12, &ctx)); + EXPECT_EQ(nullptr, ctx); +} + +TEST_F(AeadTest, AeadLongLabel) { + SSLAeadContext *ctx = nullptr; + ASSERT_EQ(SECFailure, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256, + secret_.get(), "", 254, &ctx)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(nullptr, ctx); +} + +TEST_F(AeadTest, AeadNoPointer) { + SSLAeadContext *ctx = nullptr; + ASSERT_EQ(SECFailure, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256, + secret_.get(), kLabel, strlen(kLabel), nullptr)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(nullptr, ctx); +} + +TEST_F(AeadTest, AeadAes128Gcm) { + SSLAeadContext *ctxInit = nullptr; + ASSERT_EQ(SECSuccess, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256, + secret_.get(), kLabel, strlen(kLabel), &ctxInit)); + ScopedSSLAeadContext ctx(ctxInit); + EXPECT_NE(nullptr, ctx); + + EncryptDecrypt(ctx, kCiphertextAes128Gcm, sizeof(kCiphertextAes128Gcm)); +} + +TEST_F(AeadTest, AeadAes256Gcm) { + SSLAeadContext *ctxInit = nullptr; + ASSERT_EQ(SECSuccess, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_256_GCM_SHA384, + secret_.get(), kLabel, strlen(kLabel), &ctxInit)); + ScopedSSLAeadContext ctx(ctxInit); + EXPECT_NE(nullptr, ctx); + + EncryptDecrypt(ctx, kCiphertextAes256Gcm, sizeof(kCiphertextAes256Gcm)); +} + +TEST_F(AeadTest, AeadChaCha20Poly1305) { + SSLAeadContext *ctxInit = nullptr; + ASSERT_EQ( + SECSuccess, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_CHACHA20_POLY1305_SHA256, + secret_.get(), kLabel, strlen(kLabel), &ctxInit)); + ScopedSSLAeadContext ctx(ctxInit); + EXPECT_NE(nullptr, ctx); + + EncryptDecrypt(ctx, kCiphertextChaCha20Poly1305, + sizeof(kCiphertextChaCha20Poly1305)); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc new file mode 100644 index 0000000000..283bfec169 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc @@ -0,0 +1,235 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include <memory> + +#include "databuffer.h" +#include "tls_agent.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +// This is a 1-RTT ClientHello with ECDHE. +const static uint8_t kCannedTls13ClientHello[] = { + 0x01, 0x00, 0x00, 0xcf, 0x03, 0x03, 0x6c, 0xb3, 0x46, 0x81, 0xc8, 0x1a, + 0xf9, 0xd2, 0x05, 0x97, 0x48, 0x7c, 0xa8, 0x31, 0x03, 0x1c, 0x06, 0xa8, + 0x62, 0xb1, 0x90, 0xd6, 0x21, 0x44, 0x7f, 0xc1, 0x9b, 0x87, 0x3e, 0xad, + 0x91, 0x85, 0x00, 0x00, 0x06, 0x13, 0x01, 0x13, 0x03, 0x13, 0x02, 0x01, + 0x00, 0x00, 0xa0, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x09, 0x00, 0x00, 0x06, + 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, + 0x0a, 0x00, 0x12, 0x00, 0x10, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x01, + 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x33, 0x00, + 0x47, 0x00, 0x45, 0x00, 0x17, 0x00, 0x41, 0x04, 0x86, 0x4a, 0xb9, 0xdc, + 0x6a, 0x38, 0xa7, 0xce, 0xe7, 0xc2, 0x4f, 0xa6, 0x28, 0xb9, 0xdc, 0x65, + 0xbf, 0x73, 0x47, 0x3c, 0x9c, 0x65, 0x8c, 0x47, 0x6d, 0x57, 0x22, 0x8a, + 0xc2, 0xb3, 0xc6, 0x80, 0x72, 0x86, 0x08, 0x86, 0x8f, 0x52, 0xc5, 0xcb, + 0xbf, 0x2a, 0xb5, 0x59, 0x64, 0xcc, 0x0c, 0x49, 0x95, 0x36, 0xe4, 0xd9, + 0x2f, 0xd4, 0x24, 0x66, 0x71, 0x6f, 0x5d, 0x70, 0xe2, 0xa0, 0xea, 0x26, + 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x0d, 0x00, 0x20, 0x00, + 0x1e, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x02, 0x03, 0x08, 0x04, 0x08, + 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x02, 0x01, 0x04, + 0x02, 0x05, 0x02, 0x06, 0x02, 0x02, 0x02}; +static const size_t kFirstFragmentSize = 20; +static const char *k0RttData = "ABCDEF"; + +TEST_P(TlsAgentTest, EarlyFinished) { + DataBuffer buffer; + MakeTrivialHandshakeRecord(kTlsHandshakeFinished, 0, &buffer); + ExpectAlert(kTlsAlertUnexpectedMessage); + ProcessMessage(buffer, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_UNEXPECTED_FINISHED); +} + +TEST_P(TlsAgentTest, EarlyCertificateVerify) { + DataBuffer buffer; + MakeTrivialHandshakeRecord(kTlsHandshakeCertificateVerify, 0, &buffer); + ExpectAlert(kTlsAlertUnexpectedMessage); + ProcessMessage(buffer, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); +} + +TEST_P(TlsAgentTestClient13, CannedHello) { + DataBuffer buffer; + EnsureInit(); + DataBuffer server_hello; + auto sh = MakeCannedTls13ServerHello(); + MakeHandshakeMessage(kTlsHandshakeServerHello, sh.data(), sh.len(), + &server_hello); + MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3, server_hello.data(), + server_hello.len(), &buffer); + ProcessMessage(buffer, TlsAgent::STATE_CONNECTING); +} + +TEST_P(TlsAgentTestClient13, EncryptedExtensionsInClear) { + DataBuffer server_hello; + auto sh = MakeCannedTls13ServerHello(); + MakeHandshakeMessage(kTlsHandshakeServerHello, sh.data(), sh.len(), + &server_hello); + DataBuffer encrypted_extensions; + MakeHandshakeMessage(kTlsHandshakeEncryptedExtensions, nullptr, 0, + &encrypted_extensions, 1); + server_hello.Append(encrypted_extensions); + DataBuffer buffer; + MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3, server_hello.data(), + server_hello.len(), &buffer); + EnsureInit(); + ExpectAlert(kTlsAlertUnexpectedMessage); + ProcessMessage(buffer, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_UNEXPECTED_HANDSHAKE); +} + +TEST_F(TlsAgentStreamTestClient, EncryptedExtensionsInClearTwoPieces) { + DataBuffer server_hello; + auto sh = MakeCannedTls13ServerHello(); + MakeHandshakeMessage(kTlsHandshakeServerHello, sh.data(), sh.len(), + &server_hello); + DataBuffer encrypted_extensions; + MakeHandshakeMessage(kTlsHandshakeEncryptedExtensions, nullptr, 0, + &encrypted_extensions, 1); + server_hello.Append(encrypted_extensions); + DataBuffer buffer; + MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3, server_hello.data(), + kFirstFragmentSize, &buffer); + + DataBuffer buffer2; + MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3, + server_hello.data() + kFirstFragmentSize, + server_hello.len() - kFirstFragmentSize, &buffer2); + + EnsureInit(); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, + SSL_LIBRARY_VERSION_TLS_1_3); + ProcessMessage(buffer, TlsAgent::STATE_CONNECTING); + ExpectAlert(kTlsAlertUnexpectedMessage); + ProcessMessage(buffer2, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_UNEXPECTED_HANDSHAKE); +} + +TEST_F(TlsAgentDgramTestClient, EncryptedExtensionsInClearTwoPieces) { + auto sh = MakeCannedTls13ServerHello(); + DataBuffer server_hello_frag1; + MakeHandshakeMessageFragment(kTlsHandshakeServerHello, sh.data(), sh.len(), + &server_hello_frag1, 0, 0, kFirstFragmentSize); + DataBuffer server_hello_frag2; + MakeHandshakeMessageFragment(kTlsHandshakeServerHello, + sh.data() + kFirstFragmentSize, sh.len(), + &server_hello_frag2, 0, kFirstFragmentSize, + sh.len() - kFirstFragmentSize); + DataBuffer encrypted_extensions; + MakeHandshakeMessage(kTlsHandshakeEncryptedExtensions, nullptr, 0, + &encrypted_extensions, 1); + server_hello_frag2.Append(encrypted_extensions); + DataBuffer buffer; + MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3, + server_hello_frag1.data(), server_hello_frag1.len(), &buffer); + + DataBuffer buffer2; + MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3, + server_hello_frag2.data(), server_hello_frag2.len(), &buffer2, 1); + + EnsureInit(); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, + SSL_LIBRARY_VERSION_TLS_1_3); + ProcessMessage(buffer, TlsAgent::STATE_CONNECTING); + ExpectAlert(kTlsAlertUnexpectedMessage); + ProcessMessage(buffer2, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_UNEXPECTED_HANDSHAKE); +} + +TEST_F(TlsAgentDgramTestClient, AckWithBogusLengthField) { + EnsureInit(); + // Length doesn't match + const uint8_t ackBuf[] = {0x00, 0x08, 0x00}; + DataBuffer record; + MakeRecord(variant_, ssl_ct_ack, SSL_LIBRARY_VERSION_TLS_1_2, ackBuf, + sizeof(ackBuf), &record, 0); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, + SSL_LIBRARY_VERSION_TLS_1_3); + ProcessMessage(record, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_MALFORMED_DTLS_ACK); +} + +TEST_F(TlsAgentDgramTestClient, AckWithNonEvenLength) { + EnsureInit(); + // Length isn't a multiple of 8 + const uint8_t ackBuf[] = {0x00, 0x01, 0x00}; + DataBuffer record; + MakeRecord(variant_, ssl_ct_ack, SSL_LIBRARY_VERSION_TLS_1_2, ackBuf, + sizeof(ackBuf), &record, 0); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, + SSL_LIBRARY_VERSION_TLS_1_3); + // Because we haven't negotiated the version, + // ssl3_DecodeError() sends an older (pre-TLS error). + ExpectAlert(kTlsAlertIllegalParameter); + ProcessMessage(record, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_MALFORMED_DTLS_ACK); +} + +TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenWrite) { + EnsureInit(); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + agent_->StartConnect(); + agent_->Set0RttEnabled(true); + auto filter = + MakeTlsFilter<TlsHandshakeRecorder>(agent_, kTlsHandshakeClientHello); + PRInt32 rv = PR_Write(agent_->ssl_fd(), k0RttData, strlen(k0RttData)); + EXPECT_EQ(-1, rv); + int32_t err = PORT_GetError(); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, err); + EXPECT_LT(0UL, filter->buffer().len()); +} + +TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenRead) { + EnsureInit(); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + agent_->StartConnect(); + agent_->Set0RttEnabled(true); + DataBuffer buffer; + MakeRecord(ssl_ct_application_data, SSL_LIBRARY_VERSION_TLS_1_3, + reinterpret_cast<const uint8_t *>(k0RttData), strlen(k0RttData), + &buffer); + ExpectAlert(kTlsAlertUnexpectedMessage); + ProcessMessage(buffer, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA); +} + +// The server is allowing 0-RTT but the client doesn't offer it, +// so trial decryption isn't engaged and 0-RTT messages cause +// an error. +TEST_F(TlsAgentStreamTestServer, Set0RttOptionClientHelloThenRead) { + EnsureInit(); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + agent_->StartConnect(); + agent_->Set0RttEnabled(true); + DataBuffer buffer; + MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3, + kCannedTls13ClientHello, sizeof(kCannedTls13ClientHello), &buffer); + ProcessMessage(buffer, TlsAgent::STATE_CONNECTING); + MakeRecord(ssl_ct_application_data, SSL_LIBRARY_VERSION_TLS_1_3, + reinterpret_cast<const uint8_t *>(k0RttData), strlen(k0RttData), + &buffer); + ExpectAlert(kTlsAlertBadRecordMac); + ProcessMessage(buffer, TlsAgent::STATE_ERROR, SSL_ERROR_BAD_MAC_READ); +} + +INSTANTIATE_TEST_SUITE_P( + AgentTests, TlsAgentTest, + ::testing::Combine(TlsAgentTestBase::kTlsRolesAll, + TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_SUITE_P(ClientTests13, TlsAgentTestClient13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13)); +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc new file mode 100644 index 0000000000..edd2479a78 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc @@ -0,0 +1,2261 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectGeneric, ServerAuthBigRsa) { + Reset(TlsAgent::kRsa2048); + Connect(); + CheckKeys(); +} + +TEST_P(TlsConnectGeneric, ServerAuthRsaChain) { + Reset("rsa_chain"); + Connect(); + CheckKeys(); + size_t chain_length; + EXPECT_TRUE(client_->GetPeerChainLength(&chain_length)); + EXPECT_EQ(2UL, chain_length); +} + +TEST_P(TlsConnectTls12Plus, ServerAuthRsaPss) { + static const SSLSignatureScheme kSignatureSchemePss[] = { + ssl_sig_rsa_pss_pss_sha256}; + + Reset(TlsAgent::kServerRsaPss); + client_->SetSignatureSchemes(kSignatureSchemePss, + PR_ARRAY_SIZE(kSignatureSchemePss)); + server_->SetSignatureSchemes(kSignatureSchemePss, + PR_ARRAY_SIZE(kSignatureSchemePss)); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_pss, + ssl_sig_rsa_pss_pss_sha256); +} + +// PSS doesn't work with TLS 1.0 or 1.1 because we can't signal it. +TEST_P(TlsConnectPre12, ServerAuthRsaPssFails) { + static const SSLSignatureScheme kSignatureSchemePss[] = { + ssl_sig_rsa_pss_pss_sha256}; + + Reset(TlsAgent::kServerRsaPss); + client_->SetSignatureSchemes(kSignatureSchemePss, + PR_ARRAY_SIZE(kSignatureSchemePss)); + server_->SetSignatureSchemes(kSignatureSchemePss, + PR_ARRAY_SIZE(kSignatureSchemePss)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// Check that a PSS certificate with no parameters works. +TEST_P(TlsConnectTls12Plus, ServerAuthRsaPssNoParameters) { + static const SSLSignatureScheme kSignatureSchemePss[] = { + ssl_sig_rsa_pss_pss_sha256}; + + Reset("rsa_pss_noparam"); + client_->SetSignatureSchemes(kSignatureSchemePss, + PR_ARRAY_SIZE(kSignatureSchemePss)); + server_->SetSignatureSchemes(kSignatureSchemePss, + PR_ARRAY_SIZE(kSignatureSchemePss)); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_pss, + ssl_sig_rsa_pss_pss_sha256); +} + +TEST_P(TlsConnectGeneric, ServerAuthRsaPssChain) { + Reset("rsa_pss_chain"); + Connect(); + CheckKeys(); + size_t chain_length; + EXPECT_TRUE(client_->GetPeerChainLength(&chain_length)); + EXPECT_EQ(2UL, chain_length); +} + +TEST_P(TlsConnectGeneric, ServerAuthRsaCARsaPssChain) { + Reset("rsa_ca_rsa_pss_chain"); + Connect(); + CheckKeys(); + size_t chain_length; + EXPECT_TRUE(client_->GetPeerChainLength(&chain_length)); + EXPECT_EQ(2UL, chain_length); +} + +TEST_P(TlsConnectGeneric, ServerAuthRejected) { + EnsureTlsSetup(); + client_->SetAuthCertificateCallback( + [](TlsAgent*, PRBool, PRBool) -> SECStatus { return SECFailure; }); + ConnectExpectAlert(client_, kTlsAlertBadCertificate); + client_->CheckErrorCode(SSL_ERROR_BAD_CERTIFICATE); + server_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); +} + +struct AuthCompleteArgs : public PollTarget { + AuthCompleteArgs(const std::shared_ptr<TlsAgent>& a, PRErrorCode c) + : agent(a), code(c) {} + + std::shared_ptr<TlsAgent> agent; + PRErrorCode code; +}; + +static void CallAuthComplete(PollTarget* target, Event event) { + EXPECT_EQ(TIMER_EVENT, event); + auto args = reinterpret_cast<AuthCompleteArgs*>(target); + std::cerr << args->agent->role_str() << ": call SSL_AuthCertificateComplete " + << (args->code ? PR_ErrorToName(args->code) : "no error") + << std::endl; + EXPECT_EQ(SECSuccess, + SSL_AuthCertificateComplete(args->agent->ssl_fd(), args->code)); + args->agent->Handshake(); // Make the TlsAgent aware of the error. + delete args; +} + +// Install an AuthCertificateCallback that blocks when called. Then +// SSL_AuthCertificateComplete is called on a very short timer. This allows any +// processing that might follow the callback to complete. +static void SetDeferredAuthCertificateCallback(std::shared_ptr<TlsAgent> agent, + PRErrorCode code) { + auto args = new AuthCompleteArgs(agent, code); + agent->SetAuthCertificateCallback( + [args](TlsAgent*, PRBool, PRBool) -> SECStatus { + // This can't be 0 or we race the message from the client to the server, + // and tests assume that we lose that race. + std::shared_ptr<Poller::Timer> timer_handle; + Poller::Instance()->SetTimer(1U, args, CallAuthComplete, &timer_handle); + return SECWouldBlock; + }); +} + +TEST_P(TlsConnectTls13, ServerAuthRejectAsync) { + SetDeferredAuthCertificateCallback(client_, SEC_ERROR_REVOKED_CERTIFICATE); + ConnectExpectAlert(client_, kTlsAlertCertificateRevoked); + // We only detect the error here when we attempt to handshake, so all the + // client learns is that the handshake has already failed. + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILED); + server_->CheckErrorCode(SSL_ERROR_REVOKED_CERT_ALERT); +} + +// In TLS 1.2 and earlier, this will result in the client sending its Finished +// before learning that the server certificate is bad. That means that the +// server will believe that the handshake is complete. +TEST_P(TlsConnectGenericPre13, ServerAuthRejectAsync) { + SetDeferredAuthCertificateCallback(client_, SEC_ERROR_EXPIRED_CERTIFICATE); + client_->ExpectSendAlert(kTlsAlertCertificateExpired); + server_->ExpectReceiveAlert(kTlsAlertCertificateExpired); + ConnectExpectFailOneSide(TlsAgent::CLIENT); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILED); + + // The server might not receive the alert that the client sends, which would + // cause the test to fail when it cleans up. Reset expectations. + server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning); +} + +class TlsCertificateRequestContextRecorder : public TlsHandshakeFilter { + public: + TlsCertificateRequestContextRecorder(const std::shared_ptr<TlsAgent>& a, + uint8_t handshake_type) + : TlsHandshakeFilter(a, {handshake_type}), buffer_(), filtered_(false) { + EnableDecryption(); + } + + bool filtered() const { return filtered_; } + const DataBuffer& buffer() const { return buffer_; } + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + assert(1 < input.len()); + size_t len = input.data()[0]; + assert(len + 1 < input.len()); + buffer_.Assign(input.data() + 1, len); + filtered_ = true; + return KEEP; + } + + private: + DataBuffer buffer_; + bool filtered_; +}; + +using ClientAuthParam = + std::tuple<SSLProtocolVariant, uint16_t, ClientAuthCallbackType>; + +class TlsConnectClientAuth + : public TlsConnectTestBase, + public testing::WithParamInterface<ClientAuthParam> { + public: + TlsConnectClientAuth() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} +}; + +// Wrapper classes for tests that target specific versions + +class TlsConnectClientAuth13 : public TlsConnectClientAuth {}; + +class TlsConnectClientAuth12 : public TlsConnectClientAuth {}; + +class TlsConnectClientAuthStream13 : public TlsConnectClientAuth {}; + +class TlsConnectClientAuthPre13 : public TlsConnectClientAuth {}; + +class TlsConnectClientAuth12Plus : public TlsConnectClientAuth {}; + +std::string getClientAuthTestName( + testing::TestParamInfo<ClientAuthParam> info) { + auto param = info.param; + auto variant = std::get<0>(param); + auto version = std::get<1>(param); + auto callback_type = std::get<2>(param); + + std::string output = std::string(); + switch (variant) { + case ssl_variant_stream: + output.append("TLS"); + break; + case ssl_variant_datagram: + output.append("DTLS"); + break; + } + output.append(VersionString(version).replace(1, 1, "")); + switch (callback_type) { + case ClientAuthCallbackType::kAsyncImmediate: + output.append("AsyncImmediate"); + break; + case ClientAuthCallbackType::kAsyncDelay: + output.append("AsyncDelay"); + break; + case ClientAuthCallbackType::kSync: + output.append("Sync"); + break; + case ClientAuthCallbackType::kNone: + output.append("None"); + break; + } + return output; +} + +auto kClientAuthCallbacks = testing::Values( + ClientAuthCallbackType::kAsyncImmediate, + ClientAuthCallbackType::kAsyncDelay, ClientAuthCallbackType::kSync, + ClientAuthCallbackType::kNone); + +INSTANTIATE_TEST_SUITE_P( + ClientAuthGenericStream, TlsConnectClientAuth, + testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll, kClientAuthCallbacks), + getClientAuthTestName); + +INSTANTIATE_TEST_SUITE_P( + ClientAuthGenericDatagram, TlsConnectClientAuth, + testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus, kClientAuthCallbacks), + getClientAuthTestName); + +INSTANTIATE_TEST_SUITE_P(ClientAuth13, TlsConnectClientAuth13, + testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13, + kClientAuthCallbacks), + getClientAuthTestName); + +INSTANTIATE_TEST_SUITE_P( + ClientAuth13, TlsConnectClientAuthStream13, + testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV13, kClientAuthCallbacks), + getClientAuthTestName); + +INSTANTIATE_TEST_SUITE_P(ClientAuth12, TlsConnectClientAuth12, + testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12, + kClientAuthCallbacks), + getClientAuthTestName); + +INSTANTIATE_TEST_SUITE_P( + ClientAuthPre13Stream, TlsConnectClientAuthPre13, + testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12, kClientAuthCallbacks), + getClientAuthTestName); + +INSTANTIATE_TEST_SUITE_P( + ClientAuthPre13Datagram, TlsConnectClientAuthPre13, + testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11V12, kClientAuthCallbacks), + getClientAuthTestName); + +INSTANTIATE_TEST_SUITE_P(ClientAuth12Plus, TlsConnectClientAuth12Plus, + testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12Plus, + kClientAuthCallbacks), + getClientAuthTestName); + +TEST_P(TlsConnectClientAuth, ClientAuth) { + EnsureTlsSetup(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(); + client_->CheckClientAuthCompleted(); +} + +// All stream only tests; PostHandshakeAuth isn't supported for DTLS. + +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuth) { + EnsureTlsSetup(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + auto capture_cert_req = MakeTlsFilter<TlsCertificateRequestContextRecorder>( + server_, kTlsHandshakeCertificateRequest); + auto capture_certificate = + MakeTlsFilter<TlsCertificateRequestContextRecorder>( + client_, kTlsHandshakeCertificate); + client_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE); + size_t called = 0; + server_->SetAuthCertificateCallback( + [&called](TlsAgent*, PRBool, PRBool) -> SECStatus { + called++; + return SECSuccess; + }); + Connect(); + EXPECT_EQ(0U, called); + EXPECT_FALSE(capture_cert_req->filtered()); + EXPECT_FALSE(capture_certificate->filtered()); + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + // Need to do a round-trip so that the post-handshake message is + // handled on both client and server. + server_->SendData(50); + client_->ReadBytes(50); + client_->ClientAuthCallbackComplete(); + client_->SendData(50); + server_->ReadBytes(50); + + EXPECT_EQ(1U, called); + ASSERT_TRUE(capture_cert_req->filtered()); + ASSERT_TRUE(capture_certificate->filtered()); + + client_->CheckClientAuthCompleted(); + // Check if a non-empty request context is generated and it is + // properly sent back. + EXPECT_LT(0U, capture_cert_req->buffer().len()); + EXPECT_EQ(capture_cert_req->buffer().len(), + capture_certificate->buffer().len()); + EXPECT_EQ(0, memcmp(capture_cert_req->buffer().data(), + capture_certificate->buffer().data(), + capture_cert_req->buffer().len())); + ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); + ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); +} + +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthAfterResumption) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + + // Resume the connection. + Reset(); + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_TICKET); + + client_->SetupClientAuth(std::get<2>(GetParam()), true); + client_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE); + Connect(); + SendReceive(); + + size_t called = 0; + server_->SetAuthCertificateCallback( + [&called](TlsAgent*, PRBool, PRBool) -> SECStatus { + called++; + return SECSuccess; + }); + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + + server_->SendData(50); + client_->ReadBytes(50); + client_->ClientAuthCallbackComplete(); + client_->SendData(50); + server_->ReadBytes(50); + + client_->CheckClientAuthCompleted(); + EXPECT_EQ(1U, called); + + ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); + ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); +} + +static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd, + CERTDistNames* caNames, + CERTCertificate** clientCert, + SECKEYPrivateKey** clientKey) { + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + // use a different certificate than TlsAgent::kClient + if (!TlsAgent::LoadCertificate(TlsAgent::kRsa2048, &cert, &priv)) { + return SECFailure; + } + + *clientCert = cert.release(); + *clientKey = priv.release(); + return SECSuccess; +} + +typedef struct AutoClientTestStr { + SECStatus result; + const std::string cert; +} AutoClientTest; + +typedef struct AutoClientResultsStr { + AutoClientTest isRsa2048; + AutoClientTest isClient; + AutoClientTest isNull; + bool hookCalled; +} AutoClientResults; + +void VerifyClientCertMatch(CERTCertificate* clientCert, + const std::string expectedName) { + const char* name = clientCert->nickname; + std::cout << "Match name=\"" << name << "\" expected=\"" << expectedName + << "\"" << std::endl; + EXPECT_TRUE(PORT_Strcmp(name, expectedName.c_str()) == 0) + << " Certmismatch: \"" << name << "\" != \"" << expectedName << "\""; +} + +static SECStatus GetAutoClientAuthDataHook(void* expectResults, PRFileDesc* fd, + CERTDistNames* caNames, + CERTCertificate** clientCert, + SECKEYPrivateKey** clientKey) { + AutoClientResults& results = *(AutoClientResults*)expectResults; + SECStatus rv; + + results.hookCalled = true; + *clientCert = NULL; + *clientKey = NULL; + rv = NSS_GetClientAuthData((void*)TlsAgent::kRsa2048.c_str(), fd, caNames, + clientCert, clientKey); + if (rv == SECSuccess) { + VerifyClientCertMatch(*clientCert, results.isRsa2048.cert); + CERT_DestroyCertificate(*clientCert); + SECKEY_DestroyPrivateKey(*clientKey); + *clientCert = NULL; + *clientKey = NULL; + } + EXPECT_EQ(results.isRsa2048.result, rv); + + rv = NSS_GetClientAuthData((void*)TlsAgent::kClient.c_str(), fd, caNames, + clientCert, clientKey); + if (rv == SECSuccess) { + VerifyClientCertMatch(*clientCert, results.isClient.cert); + CERT_DestroyCertificate(*clientCert); + SECKEY_DestroyPrivateKey(*clientKey); + *clientCert = NULL; + *clientKey = NULL; + } + EXPECT_EQ(results.isClient.result, rv); + EXPECT_EQ(*clientCert, nullptr); + EXPECT_EQ(*clientKey, nullptr); + rv = NSS_GetClientAuthData(NULL, fd, caNames, clientCert, clientKey); + if (rv == SECSuccess) { + VerifyClientCertMatch(*clientCert, results.isNull.cert); + // return this result + } + EXPECT_EQ(results.isNull.result, rv); + return rv; +} + +// while I would have liked to use a new INSTANTIATE macro the +// generates the following three tests, figuring out how to make that +// work on top of the existing TlsConnect* plumbing hurts my head. +TEST_P(TlsConnectTls12, AutoClientSelectRsaPss) { + AutoClientResults rsa = {{SECSuccess, TlsAgent::kRsa2048}, + {SECSuccess, TlsAgent::kClient}, + {SECSuccess, TlsAgent::kDelegatorRsaPss2048}, + false}; + static const SSLSignatureScheme kSchemes[] = {ssl_sig_rsa_pss_pss_sha256, + ssl_sig_rsa_pkcs1_sha256, + ssl_sig_rsa_pkcs1_sha1}; + Reset("rsa_pss_noparam"); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + EXPECT_EQ(SECSuccess, + SSL_GetClientAuthDataHook(client_->ssl_fd(), + GetAutoClientAuthDataHook, (void*)&rsa)); + server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + Connect(); + EXPECT_TRUE(rsa.hookCalled); +} + +TEST_P(TlsConnectTls12, AutoClientSelectEcc) { + AutoClientResults ecc = {{SECFailure, TlsAgent::kClient}, + {SECFailure, TlsAgent::kClient}, + {SECSuccess, TlsAgent::kDelegatorEcdsa256}, + false}; + static const SSLSignatureScheme kSchemes[] = {ssl_sig_ecdsa_secp256r1_sha256}; + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + EXPECT_EQ(SECSuccess, + SSL_GetClientAuthDataHook(client_->ssl_fd(), + GetAutoClientAuthDataHook, (void*)&ecc)); + server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + Connect(); + EXPECT_TRUE(ecc.hookCalled); +} + +TEST_P(TlsConnectTls12, AutoClientSelectDsa) { + AutoClientResults dsa = {{SECFailure, TlsAgent::kClient}, + {SECFailure, TlsAgent::kClient}, + {SECSuccess, TlsAgent::kServerDsa}, + false}; + static const SSLSignatureScheme kSchemes[] = {ssl_sig_dsa_sha256}; + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + EXPECT_EQ(SECSuccess, + SSL_GetClientAuthDataHook(client_->ssl_fd(), + GetAutoClientAuthDataHook, (void*)&dsa)); + server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + Connect(); + EXPECT_TRUE(dsa.hookCalled); +} + +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthMultiple) { + client_->SetupClientAuth(std::get<2>(GetParam()), true); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + size_t called = 0; + server_->SetAuthCertificateCallback( + [&called](TlsAgent*, PRBool, PRBool) -> SECStatus { + called++; + return SECSuccess; + }); + Connect(); + EXPECT_EQ(0U, called); + EXPECT_EQ(nullptr, SSL_PeerCertificate(server_->ssl_fd())); + // Send 1st CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + + server_->SendData(50); + client_->ReadBytes(50); + client_->ClientAuthCallbackComplete(); + client_->ReadBytes(50); + client_->SendData(50); + server_->ReadBytes(50); + EXPECT_EQ(1U, called); + client_->CheckClientAuthCompleted(1); + ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); + ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); + // Send 2nd CertificateRequest. + client_->SetupClientAuth(std::get<2>(GetParam()), true); + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + + server_->SendData(50); + client_->ReadBytes(50); + client_->ClientAuthCallbackComplete(); + client_->ReadBytes(50); + client_->SendData(50); + server_->ReadBytes(50); + client_->CheckClientAuthCompleted(2); + EXPECT_EQ(2U, called); + ScopedCERTCertificate cert3(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert3.get()); + ScopedCERTCertificate cert4(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert4.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert3->derCert, &cert4->derCert)); +} + +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthConcurrent) { + client_->SetupClientAuth(std::get<2>(GetParam()), true); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + Connect(); + // Send 1st CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + // Send 2nd CertificateRequest. + EXPECT_EQ(SECFailure, SSL_SendCertificateRequest(server_->ssl_fd())); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthBeforeKeyUpdate) { + client_->SetupClientAuth(std::get<2>(GetParam()), true); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + Connect(); + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + // Send KeyUpdate. + EXPECT_EQ(SECFailure, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthDuringClientKeyUpdate) { + client_->SetupClientAuth(std::get<2>(GetParam()), true); + ; + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + Connect(); + CheckEpochs(3, 3); + // Send CertificateRequest from server. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + // Send KeyUpdate from client. + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE)); + server_->SendData(50); // server sends CertificateRequest + client_->SendData(50); // client sends KeyUpdate + server_->ReadBytes(50); // server receives KeyUpdate and defers response + CheckEpochs(4, 3); + client_->ReadBytes(60); // client receives CertificateRequest + client_->ClientAuthCallbackComplete(); + client_->ReadBytes(50); // Finish reading the remaining bytes + client_->SendData( + 50); // client sends Certificate, CertificateVerify, Finished + server_->ReadBytes( + 50); // server receives Certificate, CertificateVerify, Finished + client_->CheckClientAuthCompleted(); + client_->CheckEpochs(3, 4); + server_->CheckEpochs(4, 4); + server_->SendData(50); // server sends KeyUpdate + client_->ReadBytes(50); // client receives KeyUpdate + client_->CheckEpochs(4, 4); +} + +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthMissingExtension) { + client_->SetupClientAuth(std::get<2>(GetParam()), true); + Connect(); + // Send CertificateRequest, should fail due to missing + // post_handshake_auth extension. + EXPECT_EQ(SECFailure, SSL_SendCertificateRequest(server_->ssl_fd())); + EXPECT_EQ(SSL_ERROR_MISSING_POST_HANDSHAKE_AUTH_EXTENSION, PORT_GetError()); +} + +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthAfterClientAuth) { + client_->SetupClientAuth(std::get<2>(GetParam()), true); + server_->RequestClientAuth(true); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + size_t called = 0; + server_->SetAuthCertificateCallback( + [&called](TlsAgent*, PRBool, PRBool) -> SECStatus { + called++; + return SECSuccess; + }); + Connect(); + EXPECT_EQ(1U, called); + ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); + ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_GetClientAuthDataHook( + client_->ssl_fd(), GetClientAuthDataHook, nullptr)); + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); + client_->ReadBytes(50); + client_->SendData(50); + server_->ReadBytes(50); + EXPECT_EQ(2U, called); + ScopedCERTCertificate cert3(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert3.get()); + ScopedCERTCertificate cert4(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert4.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert3->derCert, &cert4->derCert)); + EXPECT_FALSE(SECITEM_ItemsAreEqual(&cert3->derCert, &cert1->derCert)); +} + +// Damages the request context in a CertificateRequest message. +// We don't modify a Certificate message instead, so that the client +// can compute CertificateVerify correctly. +class TlsDamageCertificateRequestContextFilter : public TlsHandshakeFilter { + public: + TlsDamageCertificateRequestContextFilter(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, {kTlsHandshakeCertificateRequest}) { + EnableDecryption(); + } + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + *output = input; + assert(1 < output->len()); + // The request context has a 1 octet length. + output->data()[1] ^= 73; + return CHANGE; + } +}; + +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthContextMismatch) { + EnsureTlsSetup(); + MakeTlsFilter<TlsDamageCertificateRequestContextFilter>(server_); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + Connect(); + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); + client_->ReadBytes(50); + client_->ClientAuthCallbackComplete(); + client_->ReadBytes(50); + client_->SendData(50); + server_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ReadBytes(50); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CERTIFICATE, PORT_GetError()); + server_->ExpectReadWriteError(); + server_->SendData(50); + client_->ExpectReceiveAlert(kTlsAlertIllegalParameter); + client_->ReadBytes(50); + EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, PORT_GetError()); +} + +// Replaces signature in a CertificateVerify message. +class TlsDamageSignatureFilter : public TlsHandshakeFilter { + public: + TlsDamageSignatureFilter(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, {kTlsHandshakeCertificateVerify}) { + EnableDecryption(); + } + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + *output = input; + assert(2 < output->len()); + // The signature follows a 2-octet signature scheme. + output->data()[2] ^= 73; + return CHANGE; + } +}; + +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthBadSignature) { + EnsureTlsSetup(); + MakeTlsFilter<TlsDamageSignatureFilter>(client_); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + Connect(); + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); + client_->ReadBytes(50); + client_->ClientAuthCallbackComplete(); + client_->SendData(50); + client_->CheckClientAuthCompleted(); + server_->ExpectSendAlert(kTlsAlertDecodeError); + server_->ReadBytes(50); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CERT_VERIFY, PORT_GetError()); +} + +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthDecline) { + EnsureTlsSetup(); + auto capture_cert_req = MakeTlsFilter<TlsCertificateRequestContextRecorder>( + server_, kTlsHandshakeCertificateRequest); + auto capture_certificate = + MakeTlsFilter<TlsCertificateRequestContextRecorder>( + client_, kTlsHandshakeCertificate); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + EXPECT_EQ(SECSuccess, + SSL_OptionSet(server_->ssl_fd(), SSL_REQUIRE_CERTIFICATE, + SSL_REQUIRE_ALWAYS)); + // Client to decline the certificate request. + EXPECT_EQ(SECSuccess, + SSL_GetClientAuthDataHook( + client_->ssl_fd(), + [](void*, PRFileDesc*, CERTDistNames*, CERTCertificate**, + SECKEYPrivateKey**) -> SECStatus { return SECFailure; }, + nullptr)); + size_t called = 0; + server_->SetAuthCertificateCallback( + [&called](TlsAgent*, PRBool, PRBool) -> SECStatus { + called++; + return SECSuccess; + }); + Connect(); + EXPECT_EQ(0U, called); + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); // send Certificate Request + client_->ReadBytes(50); // read Certificate Request + client_->SendData(50); // send empty Certificate+Finished + server_->ExpectSendAlert(kTlsAlertCertificateRequired); + server_->ReadBytes(50); // read empty Certificate+Finished + server_->ExpectReadWriteError(); + server_->SendData(50); // send alert + // AuthCertificateCallback is not called, because the client sends + // an empty certificate_list. + EXPECT_EQ(0U, called); + EXPECT_TRUE(capture_cert_req->filtered()); + EXPECT_TRUE(capture_certificate->filtered()); + // Check if a non-empty request context is generated and it is + // properly sent back. + EXPECT_LT(0U, capture_cert_req->buffer().len()); + EXPECT_EQ(capture_cert_req->buffer().len(), + capture_certificate->buffer().len()); + EXPECT_EQ(0, memcmp(capture_cert_req->buffer().data(), + capture_certificate->buffer().data(), + capture_cert_req->buffer().len())); +} + +// Check if post-handshake auth still works when session tickets are enabled: +// https://bugzilla.mozilla.org/show_bug.cgi?id=1553443 +TEST_P(TlsConnectClientAuthStream13, + PostHandshakeAuthWithSessionTicketsEnabled) { + EnsureTlsSetup(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_TRUE)); + size_t called = 0; + server_->SetAuthCertificateCallback( + [&called](TlsAgent*, PRBool, PRBool) -> SECStatus { + called++; + return SECSuccess; + }); + Connect(); + EXPECT_EQ(0U, called); + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_GetClientAuthDataHook( + client_->ssl_fd(), GetClientAuthDataHook, nullptr)); + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); + client_->ReadBytes(50); + client_->SendData(50); + server_->ReadBytes(50); + EXPECT_EQ(1U, called); + ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); + ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); +} + +TEST_P(TlsConnectClientAuthPre13, ClientAuthRequiredRejected) { + client_->SetupClientAuth(std::get<2>(GetParam()), false); + server_->RequestClientAuth(true); + ConnectExpectAlert(server_, kTlsAlertBadCertificate); + client_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT); + server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE); +} + +// In TLS 1.3, the client will claim that the connection is done and then +// receive the alert afterwards. So drive the handshake manually. +TEST_P(TlsConnectClientAuth13, ClientAuthRequiredRejected) { + client_->SetupClientAuth(std::get<2>(GetParam()), false); + server_->RequestClientAuth(true); + StartConnect(); + client_->Handshake(); // CH + server_->Handshake(); // SH.. (no resumption) + + client_->Handshake(); // Next message + ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + client_->CheckClientAuthCompleted(); + ExpectAlert(server_, kTlsAlertCertificateRequired); + server_->Handshake(); // Alert + server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE); + client_->Handshake(); // Receive Alert + client_->CheckErrorCode(SSL_ERROR_RX_CERTIFICATE_REQUIRED_ALERT); +} + +TEST_P(TlsConnectClientAuth, ClientAuthRequestedRejected) { + client_->SetupClientAuth(std::get<2>(GetParam()), false); + server_->RequestClientAuth(false); + Connect(); + CheckKeys(); +} + +TEST_P(TlsConnectClientAuth, ClientAuthEcdsa) { + Reset(TlsAgent::kServerEcdsa256); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa); +} + +TEST_P(TlsConnectClientAuth, ClientAuthWithEch) { + Reset(TlsAgent::kServerEcdsa256); + EnsureTlsSetup(); + SetupEch(client_, server_); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa); +} + +TEST_P(TlsConnectClientAuth, ClientAuthBigRsa) { + Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(); +} + +// Offset is the position in the captured buffer where the signature sits. +static void CheckSigScheme(std::shared_ptr<TlsHandshakeRecorder>& capture, + size_t offset, std::shared_ptr<TlsAgent>& peer, + uint16_t expected_scheme, size_t expected_size) { + EXPECT_LT(offset + 2U, capture->buffer().len()); + + uint32_t scheme = 0; + capture->buffer().Read(offset, 2, &scheme); + EXPECT_EQ(expected_scheme, static_cast<uint16_t>(scheme)); + + ScopedCERTCertificate remote_cert(SSL_PeerCertificate(peer->ssl_fd())); + ASSERT_NE(nullptr, remote_cert.get()); + ScopedSECKEYPublicKey remote_key(CERT_ExtractPublicKey(remote_cert.get())); + ASSERT_NE(nullptr, remote_key.get()); + EXPECT_EQ(expected_size, SECKEY_PublicKeyStrengthInBits(remote_key.get())); +} + +// The server should prefer SHA-256 by default, even for the small key size used +// in the default certificate. +TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) { + EnsureTlsSetup(); + auto capture_ske = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeServerKeyExchange); + Connect(); + CheckKeys(); + + const DataBuffer& buffer = capture_ske->buffer(); + EXPECT_LT(3U, buffer.len()); + EXPECT_EQ(3U, buffer.data()[0]) << "curve_type == named_curve"; + uint32_t tmp; + EXPECT_TRUE(buffer.Read(1, 2, &tmp)) << "read NamedCurve"; + EXPECT_EQ(ssl_grp_ec_curve25519, tmp); + EXPECT_TRUE(buffer.Read(3, 1, &tmp)) << " read ECPoint"; + CheckSigScheme(capture_ske, 4 + tmp, client_, ssl_sig_rsa_pss_rsae_sha256, + 1024); +} + +TEST_P(TlsConnectClientAuth12, ClientAuthCheckSigAlg) { + EnsureTlsSetup(); + auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>( + client_, kTlsHandshakeCertificateVerify); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(); + + CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pkcs1_sha1, 1024); +} + +TEST_P(TlsConnectClientAuth12, ClientAuthBigRsaCheckSigAlg) { + Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048); + auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>( + client_, kTlsHandshakeCertificateVerify); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(); + CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_rsae_sha256, + 2048); +} + +// Check if CertificateVerify signed with rsa_pss_rsae_* is properly +// rejected when the certificate is RSA-PSS. +// +// This only works under TLS 1.2, because PSS doesn't work with TLS +// 1.0 or TLS 1.1 and the TLS 1.3 1-RTT handshake is partially +// successful at the client side. +TEST_P(TlsConnectClientAuth12, ClientAuthInconsistentRsaeSignatureScheme) { + static const SSLSignatureScheme kSignatureSchemePss[] = { + ssl_sig_rsa_pss_pss_sha256, ssl_sig_rsa_pss_rsae_sha256}; + + Reset(TlsAgent::kServerRsa, "rsa_pss"); + client_->SetSignatureSchemes(kSignatureSchemePss, + PR_ARRAY_SIZE(kSignatureSchemePss)); + server_->SetSignatureSchemes(kSignatureSchemePss, + PR_ARRAY_SIZE(kSignatureSchemePss)); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + server_->RequestClientAuth(true); + + EnsureTlsSetup(); + + MakeTlsFilter<TlsReplaceSignatureSchemeFilter>(client_, + ssl_sig_rsa_pss_rsae_sha256); + + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); +} + +// Check if CertificateVerify signed with rsa_pss_pss_* is properly +// rejected when the certificate is RSA. +// +// This only works under TLS 1.2, because PSS doesn't work with TLS +// 1.0 or TLS 1.1 and the TLS 1.3 1-RTT handshake is partially +// successful at the client side. +TEST_P(TlsConnectClientAuth12, ClientAuthInconsistentPssSignatureScheme) { + static const SSLSignatureScheme kSignatureSchemePss[] = { + ssl_sig_rsa_pss_rsae_sha256, ssl_sig_rsa_pss_pss_sha256}; + + Reset(TlsAgent::kServerRsa, "rsa"); + client_->SetSignatureSchemes(kSignatureSchemePss, + PR_ARRAY_SIZE(kSignatureSchemePss)); + server_->SetSignatureSchemes(kSignatureSchemePss, + PR_ARRAY_SIZE(kSignatureSchemePss)); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + server_->RequestClientAuth(true); + + EnsureTlsSetup(); + + MakeTlsFilter<TlsReplaceSignatureSchemeFilter>(client_, + ssl_sig_rsa_pss_pss_sha256); + + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); +} + +TEST_P(TlsConnectClientAuth13, ClientAuthPkcs1SignatureScheme) { + static const SSLSignatureScheme kSignatureScheme[] = { + ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pss_rsae_sha256}; + + Reset(TlsAgent::kServerRsa, "rsa"); + client_->SetSignatureSchemes(kSignatureScheme, + PR_ARRAY_SIZE(kSignatureScheme)); + server_->SetSignatureSchemes(kSignatureScheme, + PR_ARRAY_SIZE(kSignatureScheme)); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + server_->RequestClientAuth(true); + + auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>( + client_, kTlsHandshakeCertificateVerify); + capture_cert_verify->EnableDecryption(); + + Connect(); + CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_rsae_sha256, + 1024); +} + +// Client should refuse to connect without a usable signature scheme. +TEST_P(TlsConnectClientAuth13, ClientAuthPkcs1SignatureSchemeOnly) { + static const SSLSignatureScheme kSignatureScheme[] = { + ssl_sig_rsa_pkcs1_sha256}; + + Reset(TlsAgent::kServerRsa, "rsa"); + client_->SetSignatureSchemes(kSignatureScheme, + PR_ARRAY_SIZE(kSignatureScheme)); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + client_->StartConnect(); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + client_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM); +} + +// Though the client has a usable signature scheme, when a certificate is +// requested, it can't produce one. +TEST_P(TlsConnectClientAuth13, ClientAuthPkcs1AndEcdsaScheme) { + static const SSLSignatureScheme kSignatureScheme[] = { + ssl_sig_rsa_pkcs1_sha256, ssl_sig_ecdsa_secp256r1_sha256}; + + Reset(TlsAgent::kServerRsa, "rsa"); + client_->SetSignatureSchemes(kSignatureScheme, + PR_ARRAY_SIZE(kSignatureScheme)); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + server_->RequestClientAuth(true); + + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter { + public: + TlsZeroCertificateRequestSigAlgsFilter(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, {kTlsHandshakeCertificateRequest}) {} + virtual PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) { + TlsParser parser(input); + std::cerr << "Zeroing CertReq.supported_signature_algorithms" << std::endl; + + DataBuffer cert_types; + if (!parser.ReadVariable(&cert_types, 1)) { + ADD_FAILURE(); + return KEEP; + } + + if (!parser.SkipVariable(2)) { + ADD_FAILURE(); + return KEEP; + } + + DataBuffer cas; + if (!parser.ReadVariable(&cas, 2)) { + ADD_FAILURE(); + return KEEP; + } + + size_t idx = 0; + + // Write certificate types. + idx = output->Write(idx, cert_types.len(), 1); + idx = output->Write(idx, cert_types); + + // Write zero signature algorithms. + idx = output->Write(idx, 0U, 2); + + // Write certificate authorities. + idx = output->Write(idx, cas.len(), 2); + idx = output->Write(idx, cas); + + return CHANGE; + } +}; + +// Check that we send an alert when the server doesn't provide any +// supported_signature_algorithms in the CertificateRequest message. +TEST_P(TlsConnectClientAuth12, ClientAuthNoSigAlgs) { + EnsureTlsSetup(); + MakeTlsFilter<TlsZeroCertificateRequestSigAlgsFilter>(server_); + auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>( + client_, kTlsHandshakeCertificateVerify); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + server_->RequestClientAuth(true); + + ConnectExpectAlert(client_, kTlsAlertHandshakeFailure); + + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT); + client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); +} + +static SECStatus GetEcClientAuthDataHook(void* self, PRFileDesc* fd, + CERTDistNames* caNames, + CERTCertificate** clientCert, + SECKEYPrivateKey** clientKey) { + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + // use a different certificate than TlsAgent::kClient + if (!TlsAgent::LoadCertificate(TlsAgent::kServerEcdsa256, &cert, &priv)) { + return SECFailure; + } + + *clientCert = cert.release(); + *clientKey = priv.release(); + return SECSuccess; +} + +TEST_P(TlsConnectClientAuth12Plus, ClientAuthDisjointSchemes) { + EnsureTlsSetup(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + server_->RequestClientAuth(true); + + SSLSignatureScheme server_scheme = ssl_sig_rsa_pss_rsae_sha256; + std::vector<SSLSignatureScheme> client_schemes{ + ssl_sig_rsa_pss_rsae_sha256, ssl_sig_ecdsa_secp256r1_sha256}; + SECStatus rv = + SSL_SignatureSchemePrefSet(server_->ssl_fd(), &server_scheme, 1); + EXPECT_EQ(SECSuccess, rv); + rv = SSL_SignatureSchemePrefSet( + client_->ssl_fd(), client_schemes.data(), + static_cast<unsigned int>(client_schemes.size())); + EXPECT_EQ(SECSuccess, rv); + + // Select an EC cert that's incompatible with server schemes. + EXPECT_EQ(SECSuccess, + SSL_GetClientAuthDataHook(client_->ssl_fd(), + GetEcClientAuthDataHook, nullptr)); + + StartConnect(); + client_->Handshake(); // CH + server_->Handshake(); // SH + client_->Handshake(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + ExpectAlert(server_, kTlsAlertCertificateRequired); + server_->Handshake(); // Alert + server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE); + client_->Handshake(); // Receive Alert + client_->CheckErrorCode(SSL_ERROR_RX_CERTIFICATE_REQUIRED_ALERT); + } else { + ASSERT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + ExpectAlert(server_, kTlsAlertBadCertificate); + server_->Handshake(); // Alert + server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE); + client_->Handshake(); // Receive Alert + client_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT); + } +} + +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthDisjointSchemes) { + EnsureTlsSetup(); + SSLSignatureScheme server_scheme = ssl_sig_rsa_pss_rsae_sha256; + std::vector<SSLSignatureScheme> client_schemes{ + ssl_sig_rsa_pss_rsae_sha256, ssl_sig_ecdsa_secp256r1_sha256}; + SECStatus rv = + SSL_SignatureSchemePrefSet(server_->ssl_fd(), &server_scheme, 1); + EXPECT_EQ(SECSuccess, rv); + rv = SSL_SignatureSchemePrefSet( + client_->ssl_fd(), client_schemes.data(), + static_cast<unsigned int>(client_schemes.size())); + EXPECT_EQ(SECSuccess, rv); + + client_->SetupClientAuth(std::get<2>(GetParam()), true); + client_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE); + + // Select an EC cert that's incompatible with server schemes. + EXPECT_EQ(SECSuccess, + SSL_GetClientAuthDataHook(client_->ssl_fd(), + GetEcClientAuthDataHook, nullptr)); + + Connect(); + + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + + // Need to do a round-trip so that the post-handshake message is + // handled on both client and server. + server_->SendData(50); + client_->ReadBytes(50); + client_->SendData(50); + server_->ReadBytes(50); + + ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_EQ(nullptr, cert1.get()); + ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_EQ(nullptr, cert2.get()); +} + +static const SSLSignatureScheme kSignatureSchemeEcdsaSha384[] = { + ssl_sig_ecdsa_secp384r1_sha384}; +static const SSLSignatureScheme kSignatureSchemeEcdsaSha256[] = { + ssl_sig_ecdsa_secp256r1_sha256}; +static const SSLSignatureScheme kSignatureSchemeRsaSha384[] = { + ssl_sig_rsa_pkcs1_sha384}; +static const SSLSignatureScheme kSignatureSchemeRsaSha256[] = { + ssl_sig_rsa_pkcs1_sha256}; + +static SSLNamedGroup NamedGroupForEcdsa384(uint16_t version) { + // NSS tries to match the group size to the symmetric cipher. In TLS 1.1 and + // 1.0, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA is the highest priority suite, so + // we use P-384. With TLS 1.2 on we pick AES-128 GCM so use x25519. + if (version <= SSL_LIBRARY_VERSION_TLS_1_1) { + return ssl_grp_ec_secp384r1; + } + return ssl_grp_ec_curve25519; +} + +// When signature algorithms match up, this should connect successfully; even +// for TLS 1.1 and 1.0, where they should be ignored. +TEST_P(TlsConnectGeneric, SignatureAlgorithmServerAuth) { + Reset(TlsAgent::kServerEcdsa384); + client_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384)); + server_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384)); + Connect(); + CheckKeys(ssl_kea_ecdh, NamedGroupForEcdsa384(version_), ssl_auth_ecdsa, + ssl_sig_ecdsa_secp384r1_sha384); +} + +// Here the client picks a single option, which should work in all versions. +// Defaults on the server include the first option. +TEST_P(TlsConnectGeneric, SignatureAlgorithmClientOnly) { + const SSLSignatureAndHashAlg clientAlgorithms[] = { + {ssl_hash_sha384, ssl_sign_ecdsa}, + {ssl_hash_sha384, ssl_sign_rsa}, // supported but unusable + {ssl_hash_md5, ssl_sign_ecdsa} // unsupported and ignored + }; + Reset(TlsAgent::kServerEcdsa384); + EnsureTlsSetup(); + // Use the old API for this function. + EXPECT_EQ(SECSuccess, + SSL_SignaturePrefSet(client_->ssl_fd(), clientAlgorithms, + PR_ARRAY_SIZE(clientAlgorithms))); + Connect(); + CheckKeys(ssl_kea_ecdh, NamedGroupForEcdsa384(version_), ssl_auth_ecdsa, + ssl_sig_ecdsa_secp384r1_sha384); +} + +// Here the server picks a single option, which should work in all versions. +// Defaults on the client include the provided option. +TEST_P(TlsConnectGeneric, SignatureAlgorithmServerOnly) { + Reset(TlsAgent::kServerEcdsa384); + server_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384)); + Connect(); + CheckKeys(ssl_kea_ecdh, NamedGroupForEcdsa384(version_), ssl_auth_ecdsa, + ssl_sig_ecdsa_secp384r1_sha384); +} + +// In TLS 1.2, curve and hash aren't bound together. +TEST_P(TlsConnectTls12, SignatureSchemeCurveMismatch) { + Reset(TlsAgent::kServerEcdsa256); + client_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384)); + Connect(); +} + +// In TLS 1.3, curve and hash are coupled. +TEST_P(TlsConnectTls13, SignatureSchemeCurveMismatch) { + Reset(TlsAgent::kServerEcdsa256); + client_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// Configuring a P-256 cert with only SHA-384 signatures is OK in TLS 1.2. +TEST_P(TlsConnectTls12, SignatureSchemeBadConfig) { + Reset(TlsAgent::kServerEcdsa256); // P-256 cert can't be used. + server_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384)); + Connect(); +} + +// A P-256 certificate in TLS 1.3 needs a SHA-256 signature scheme. +TEST_P(TlsConnectTls13, SignatureSchemeBadConfig) { + Reset(TlsAgent::kServerEcdsa256); // P-256 cert can't be used. + server_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// Where there is no overlap on signature schemes, we still connect successfully +// if we aren't going to use a signature. +TEST_P(TlsConnectGenericPre13, SignatureAlgorithmNoOverlapStaticRsa) { + client_->SetSignatureSchemes(kSignatureSchemeRsaSha384, + PR_ARRAY_SIZE(kSignatureSchemeRsaSha384)); + server_->SetSignatureSchemes(kSignatureSchemeRsaSha256, + PR_ARRAY_SIZE(kSignatureSchemeRsaSha256)); + EnableOnlyStaticRsaCiphers(); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_auth_rsa_decrypt); +} + +TEST_P(TlsConnectTls12Plus, SignatureAlgorithmNoOverlapEcdsa) { + Reset(TlsAgent::kServerEcdsa256); + client_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384)); + server_->SetSignatureSchemes(kSignatureSchemeEcdsaSha256, + PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha256)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); +} + +// Pre 1.2, a mismatch on signature algorithms shouldn't affect anything. +TEST_P(TlsConnectPre12, SignatureAlgorithmNoOverlapEcdsa) { + Reset(TlsAgent::kServerEcdsa256); + client_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384)); + server_->SetSignatureSchemes(kSignatureSchemeEcdsaSha256, + PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha256)); + Connect(); +} + +// The signature_algorithms extension is mandatory in TLS 1.3. +TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) { + MakeTlsFilter<TlsExtensionDropper>(client_, ssl_signature_algorithms_xtn); + ConnectExpectAlert(server_, kTlsAlertMissingExtension); + client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION); +} + +// TLS 1.2 has trouble detecting this sort of modification: it uses SHA1 and +// only fails when the Finished is checked. +TEST_P(TlsConnectTls12, SignatureAlgorithmDrop) { + MakeTlsFilter<TlsExtensionDropper>(client_, ssl_signature_algorithms_xtn); + ConnectExpectAlert(server_, kTlsAlertDecryptError); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); +} + +TEST_P(TlsConnectTls13, UnsupportedSignatureSchemeAlert) { + EnsureTlsSetup(); + auto filter = + MakeTlsFilter<TlsReplaceSignatureSchemeFilter>(server_, ssl_sig_none); + filter->EnableDecryption(); + + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CERT_VERIFY); +} + +TEST_P(TlsConnectTls13, InconsistentSignatureSchemeAlert) { + EnsureTlsSetup(); + + // This won't work because we use an RSA cert by default. + auto filter = MakeTlsFilter<TlsReplaceSignatureSchemeFilter>( + server_, ssl_sig_ecdsa_secp256r1_sha256); + filter->EnableDecryption(); + + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + client_->CheckErrorCode(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM); +} + +TEST_P(TlsConnectTls12, RequestClientAuthWithSha384) { + server_->SetSignatureSchemes(kSignatureSchemeRsaSha384, + PR_ARRAY_SIZE(kSignatureSchemeRsaSha384)); + server_->RequestClientAuth(false); + Connect(); +} + +class BeforeFinished : public TlsRecordFilter { + private: + enum HandshakeState { BEFORE_CCS, AFTER_CCS, DONE }; + + public: + BeforeFinished(const std::shared_ptr<TlsAgent>& server, + const std::shared_ptr<TlsAgent>& client, + VoidFunction before_ccs, VoidFunction before_finished) + : TlsRecordFilter(server), + client_(client), + before_ccs_(before_ccs), + before_finished_(before_finished), + state_(BEFORE_CCS) {} + + protected: + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& body, + DataBuffer* out) { + switch (state_) { + case BEFORE_CCS: + // Awaken when we see the CCS. + if (header.content_type() == ssl_ct_change_cipher_spec) { + before_ccs_(); + + // Write the CCS out as a separate write, so that we can make + // progress. Ordinarily, libssl sends the CCS and Finished together, + // but that means that they both get processed together. + DataBuffer ccs; + header.Write(&ccs, 0, body); + agent()->SendDirect(ccs); + client_.lock()->Handshake(); + state_ = AFTER_CCS; + // Request that the original record be dropped by the filter. + return DROP; + } + break; + + case AFTER_CCS: + EXPECT_EQ(ssl_ct_handshake, header.content_type()); + // This could check that data contains a Finished message, but it's + // encrypted, so that's too much extra work. + + before_finished_(); + state_ = DONE; + break; + + case DONE: + break; + } + return KEEP; + } + + private: + std::weak_ptr<TlsAgent> client_; + VoidFunction before_ccs_; + VoidFunction before_finished_; + HandshakeState state_; +}; + +// Running code after the client has started processing the encrypted part of +// the server's first flight, but before the Finished is processed is very hard +// in TLS 1.3. These encrypted messages are sent in a single encrypted blob. +// The following test uses DTLS to make it possible to force the client to +// process the handshake in pieces. +// +// The first encrypted message from the server is dropped, and the MTU is +// reduced to just below the original message size so that the server sends two +// messages. The Finished message is then processed separately. +class BeforeFinished13 : public PacketFilter { + private: + enum HandshakeState { + INIT, + BEFORE_FIRST_FRAGMENT, + BEFORE_SECOND_FRAGMENT, + DONE + }; + + public: + BeforeFinished13(const std::shared_ptr<TlsAgent>& server, + const std::shared_ptr<TlsAgent>& client, + VoidFunction before_finished) + : server_(server), + client_(client), + before_finished_(before_finished), + records_(0) {} + + protected: + virtual PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) { + switch (++records_) { + case 1: + // Packet 1 is the server's entire first flight. Drop it. + EXPECT_EQ(SECSuccess, + SSLInt_SetMTU(server_.lock()->ssl_fd(), input.len() - 1)); + return DROP; + + // Packet 2 is the first part of the server's retransmitted first + // flight. Keep that. + + case 3: + // Packet 3 is the second part of the server's retransmitted first + // flight. Before passing that on, make sure that the client processes + // packet 2, then call the before_finished_() callback. + client_.lock()->Handshake(); + before_finished_(); + break; + + default: + break; + } + return KEEP; + } + + private: + std::weak_ptr<TlsAgent> server_; + std::weak_ptr<TlsAgent> client_; + VoidFunction before_finished_; + size_t records_; +}; + +static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) { + return SECWouldBlock; +} + +// This test uses an AuthCertificateCallback that blocks. A filter is used to +// split the server's first flight into two pieces. Before the second piece is +// processed by the client, SSL_AuthCertificateComplete() is called. +TEST_F(TlsConnectDatagram13, AuthCompleteBeforeFinished) { + client_->SetAuthCertificateCallback(AuthCompleteBlock); + MakeTlsFilter<BeforeFinished13>(server_, client_, [this]() { + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + }); + Connect(); +} + +// This test uses a simple AuthCertificateCallback. Due to the way that the +// entire server flight is processed, the call to SSL_AuthCertificateComplete +// will trigger after the Finished message is processed. +TEST_P(TlsConnectTls13, AuthCompleteAfterFinished) { + SetDeferredAuthCertificateCallback(client_, 0); // 0 = success. + Connect(); +} + +TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) { + client_->EnableFalseStart(); + MakeTlsFilter<BeforeFinished>( + server_, client_, + [this]() { EXPECT_TRUE(client_->can_falsestart_hook_called()); }, + [this]() { + // Write something, which used to fail: bug 1235366. + client_->SendData(10); + }); + + Connect(); + server_->SendData(10); + Receive(10); +} + +TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) { + client_->EnableFalseStart(); + client_->SetAuthCertificateCallback(AuthCompleteBlock); + MakeTlsFilter<BeforeFinished>( + server_, client_, + []() { + // Do nothing before CCS + }, + [this]() { + EXPECT_FALSE(client_->can_falsestart_hook_called()); + // AuthComplete before Finished still enables false start. + EXPECT_EQ(SECSuccess, + SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + EXPECT_TRUE(client_->can_falsestart_hook_called()); + client_->SendData(10); + }); + + Connect(); + server_->SendData(10); + Receive(10); +} + +class EnforceNoActivity : public PacketFilter { + protected: + PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) override { + std::cerr << "Unexpected packet: " << input << std::endl; + EXPECT_TRUE(false) << "should not send anything"; + return KEEP; + } +}; + +// In this test, we want to make sure that the server completes its handshake, +// but the client does not. Because the AuthCertificate callback blocks and we +// never call SSL_AuthCertificateComplete(), the client should never report that +// it has completed the handshake. Manually call Handshake(), alternating sides +// between client and server, until the desired state is reached. +TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) { + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send ServerHello + client_->Handshake(); // Send ClientKeyExchange and Finished + server_->Handshake(); // Send Finished + // The server should now report that it is connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + // The client should send nothing from here on. + client_->SetFilter(std::make_shared<EnforceNoActivity>()); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + + // This should allow the handshake to complete now. + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + client_->Handshake(); // Transition to connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + // Remove filter before closing or the close_notify alert will trigger it. + client_->ClearFilter(); +} + +TEST_P(TlsConnectGenericPre13, AuthCompleteFailDelayed) { + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send ServerHello + client_->Handshake(); // Send ClientKeyExchange and Finished + server_->Handshake(); // Send Finished + // The server should now report that it is connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + // The client should send nothing from here on. + client_->SetFilter(std::make_shared<EnforceNoActivity>()); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + + // Report failure. + client_->ClearFilter(); + client_->ExpectSendAlert(kTlsAlertBadCertificate); + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), + SSL_ERROR_BAD_CERTIFICATE)); + client_->Handshake(); // Fail + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); +} + +// TLS 1.3 handles a delayed AuthComplete callback differently since the +// shape of the handshake is different. +TEST_P(TlsConnectTls13, AuthCompleteDelayed) { + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send ServerHello + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + + // The client will send nothing until AuthCertificateComplete is called. + client_->SetFilter(std::make_shared<EnforceNoActivity>()); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + + // This should allow the handshake to complete now. + client_->ClearFilter(); + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + client_->Handshake(); // Send Finished + server_->Handshake(); // Transition to connected and send NewSessionTicket + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); +} + +TEST_P(TlsConnectTls13, AuthCompleteFailDelayed) { + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send ServerHello + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + + // The client will send nothing until AuthCertificateComplete is called. + client_->SetFilter(std::make_shared<EnforceNoActivity>()); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + + // Report failure. + client_->ClearFilter(); + ExpectAlert(client_, kTlsAlertBadCertificate); + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), + SSL_ERROR_BAD_CERTIFICATE)); + client_->Handshake(); // This should now fail. + server_->Handshake(); // Get the error. + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state()); +} + +static SECStatus AuthCompleteFail(TlsAgent*, PRBool, PRBool) { + PORT_SetError(SSL_ERROR_BAD_CERTIFICATE); + return SECFailure; +} + +TEST_P(TlsConnectGeneric, AuthFailImmediate) { + client_->SetAuthCertificateCallback(AuthCompleteFail); + + StartConnect(); + ConnectExpectAlert(client_, kTlsAlertBadCertificate); + client_->CheckErrorCode(SSL_ERROR_BAD_CERTIFICATE); +} + +static const SSLExtraServerCertData ServerCertDataRsaPkcs1Decrypt = { + ssl_auth_rsa_decrypt, nullptr, nullptr, nullptr, nullptr, nullptr}; +static const SSLExtraServerCertData ServerCertDataRsaPkcs1Sign = { + ssl_auth_rsa_sign, nullptr, nullptr, nullptr, nullptr, nullptr}; +static const SSLExtraServerCertData ServerCertDataRsaPss = { + ssl_auth_rsa_pss, nullptr, nullptr, nullptr, nullptr, nullptr}; + +// Test RSA cert with usage=[signature, encipherment]. +TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPkcs1SignAndKEX) { + Reset(TlsAgent::kServerRsa); + + PRFileDesc* ssl_fd = agent_->ssl_fd(); + EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt)); + EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign)); + EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss)); + + // Configuring for only rsa_sign or rsa_decrypt should work. + EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsa, false, + &ServerCertDataRsaPkcs1Decrypt)); + EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsa, false, + &ServerCertDataRsaPkcs1Sign)); + EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsa, false, + &ServerCertDataRsaPss)); +} + +// Test RSA cert with usage=[signature]. +TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPkcs1Sign) { + Reset(TlsAgent::kServerRsaSign); + + PRFileDesc* ssl_fd = agent_->ssl_fd(); + EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt)); + EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign)); + EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss)); + + // Configuring for only rsa_decrypt should fail. + EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaSign, false, + &ServerCertDataRsaPkcs1Decrypt)); + + // Configuring for only rsa_sign should work. + EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsaSign, false, + &ServerCertDataRsaPkcs1Sign)); + EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaSign, false, + &ServerCertDataRsaPss)); +} + +// Test RSA cert with usage=[encipherment]. +TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPkcs1KEX) { + Reset(TlsAgent::kServerRsaDecrypt); + + PRFileDesc* ssl_fd = agent_->ssl_fd(); + EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt)); + EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign)); + EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss)); + + // Configuring for only rsa_sign or rsa_pss should fail. + EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaDecrypt, false, + &ServerCertDataRsaPkcs1Sign)); + EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaDecrypt, false, + &ServerCertDataRsaPss)); + + // Configuring for only rsa_decrypt should work. + EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsaDecrypt, false, + &ServerCertDataRsaPkcs1Decrypt)); +} + +// Test configuring an RSA-PSS cert. +TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPss) { + Reset(TlsAgent::kServerRsaPss); + + PRFileDesc* ssl_fd = agent_->ssl_fd(); + EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt)); + EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign)); + EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss)); + + // Configuring for only rsa_sign or rsa_decrypt should fail. + EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaPss, false, + &ServerCertDataRsaPkcs1Sign)); + EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaPss, false, + &ServerCertDataRsaPkcs1Decrypt)); + + // Configuring for only rsa_pss should work. + EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsaPss, false, + &ServerCertDataRsaPss)); +} + +// A server should refuse to even start a handshake with +// misconfigured certificate and signature scheme. +TEST_P(TlsConnectTls12Plus, MisconfiguredCertScheme) { + Reset(TlsAgent::kServerDsa); + static const SSLSignatureScheme kScheme[] = {ssl_sig_ecdsa_secp256r1_sha256}; + server_->SetSignatureSchemes(kScheme, PR_ARRAY_SIZE(kScheme)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + // TLS 1.2 disables cipher suites, which leads to a different error. + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + } else { + server_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM); + } + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// In TLS 1.2, disabling an EC group causes ECDSA to be invalid. +TEST_P(TlsConnectTls12, Tls12CertDisabledGroup) { + Reset(TlsAgent::kServerEcdsa256); + static const std::vector<SSLNamedGroup> k25519 = {ssl_grp_ec_curve25519}; + server_->ConfigNamedGroups(k25519); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// In TLS 1.3, ECDSA configuration only depends on the signature scheme. +TEST_P(TlsConnectTls13, Tls13CertDisabledGroup) { + Reset(TlsAgent::kServerEcdsa256); + static const std::vector<SSLNamedGroup> k25519 = {ssl_grp_ec_curve25519}; + server_->ConfigNamedGroups(k25519); + Connect(); +} + +// A client should refuse to even start a handshake with only DSA. +TEST_P(TlsConnectTls13, Tls13DsaOnlyClient) { + static const SSLSignatureScheme kDsa[] = {ssl_sig_dsa_sha256}; + client_->SetSignatureSchemes(kDsa, PR_ARRAY_SIZE(kDsa)); + client_->StartConnect(); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + client_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM); +} + +TEST_P(TlsConnectTls13, Tls13DsaOnlyServer) { + Reset(TlsAgent::kServerDsa); + static const SSLSignatureScheme kDsa[] = {ssl_sig_dsa_sha256}; + server_->SetSignatureSchemes(kDsa, PR_ARRAY_SIZE(kDsa)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +TEST_P(TlsConnectTls13, Tls13Pkcs1OnlyClient) { + static const SSLSignatureScheme kPkcs1[] = {ssl_sig_rsa_pkcs1_sha256}; + client_->SetSignatureSchemes(kPkcs1, PR_ARRAY_SIZE(kPkcs1)); + client_->StartConnect(); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + client_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM); +} + +TEST_P(TlsConnectTls13, Tls13Pkcs1OnlyServer) { + static const SSLSignatureScheme kPkcs1[] = {ssl_sig_rsa_pkcs1_sha256}; + server_->SetSignatureSchemes(kPkcs1, PR_ARRAY_SIZE(kPkcs1)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +TEST_P(TlsConnectTls13, Tls13DsaIsNotAdvertisedClient) { + EnsureTlsSetup(); + static const SSLSignatureScheme kSchemes[] = {ssl_sig_dsa_sha256, + ssl_sig_rsa_pss_rsae_sha256}; + client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + auto capture = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn); + Connect(); + // We should only have the one signature algorithm advertised. + static const uint8_t kExpectedExt[] = {0, 2, ssl_sig_rsa_pss_rsae_sha256 >> 8, + ssl_sig_rsa_pss_rsae_sha256 & 0xff}; + ASSERT_EQ(DataBuffer(kExpectedExt, sizeof(kExpectedExt)), + capture->extension()); +} + +TEST_P(TlsConnectTls13, Tls13DsaIsNotAdvertisedServer) { + EnsureTlsSetup(); + static const SSLSignatureScheme kSchemes[] = {ssl_sig_dsa_sha256, + ssl_sig_rsa_pss_rsae_sha256}; + server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + auto capture = MakeTlsFilter<TlsExtensionCapture>( + server_, ssl_signature_algorithms_xtn, true); + capture->SetHandshakeTypes({kTlsHandshakeCertificateRequest}); + capture->EnableDecryption(); + server_->RequestClientAuth(false); // So we get a CertificateRequest. + Connect(); + // We should only have the one signature algorithm advertised. + static const uint8_t kExpectedExt[] = {0, 2, ssl_sig_rsa_pss_rsae_sha256 >> 8, + ssl_sig_rsa_pss_rsae_sha256 & 0xff}; + ASSERT_EQ(DataBuffer(kExpectedExt, sizeof(kExpectedExt)), + capture->extension()); +} + +TEST_P(TlsConnectTls13, Tls13RsaPkcs1IsAdvertisedClient) { + EnsureTlsSetup(); + static const SSLSignatureScheme kSchemes[] = {ssl_sig_rsa_pkcs1_sha256, + ssl_sig_rsa_pss_rsae_sha256}; + client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + auto capture = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn); + Connect(); + // We should only have the one signature algorithm advertised. + static const uint8_t kExpectedExt[] = {0, + 4, + ssl_sig_rsa_pss_rsae_sha256 >> 8, + ssl_sig_rsa_pss_rsae_sha256 & 0xff, + ssl_sig_rsa_pkcs1_sha256 >> 8, + ssl_sig_rsa_pkcs1_sha256 & 0xff}; + ASSERT_EQ(DataBuffer(kExpectedExt, sizeof(kExpectedExt)), + capture->extension()); +} + +TEST_P(TlsConnectTls13, Tls13RsaPkcs1IsAdvertisedServer) { + EnsureTlsSetup(); + static const SSLSignatureScheme kSchemes[] = {ssl_sig_rsa_pkcs1_sha256, + ssl_sig_rsa_pss_rsae_sha256}; + server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + auto capture = MakeTlsFilter<TlsExtensionCapture>( + server_, ssl_signature_algorithms_xtn, true); + capture->SetHandshakeTypes({kTlsHandshakeCertificateRequest}); + capture->EnableDecryption(); + server_->RequestClientAuth(false); // So we get a CertificateRequest. + Connect(); + // We should only have the one signature algorithm advertised. + static const uint8_t kExpectedExt[] = {0, + 4, + ssl_sig_rsa_pss_rsae_sha256 >> 8, + ssl_sig_rsa_pss_rsae_sha256 & 0xff, + ssl_sig_rsa_pkcs1_sha256 >> 8, + ssl_sig_rsa_pkcs1_sha256 & 0xff}; + ASSERT_EQ(DataBuffer(kExpectedExt, sizeof(kExpectedExt)), + capture->extension()); +} + +// variant, version, certificate, auth type, signature scheme +typedef std::tuple<SSLProtocolVariant, uint16_t, std::string, SSLAuthType, + SSLSignatureScheme> + SignatureSchemeProfile; + +class TlsSignatureSchemeConfiguration + : public TlsConnectTestBase, + public ::testing::WithParamInterface<SignatureSchemeProfile> { + public: + TlsSignatureSchemeConfiguration() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())), + certificate_(std::get<2>(GetParam())), + auth_type_(std::get<3>(GetParam())), + signature_scheme_(std::get<4>(GetParam())) {} + + protected: + void TestSignatureSchemeConfig(std::shared_ptr<TlsAgent>& configPeer) { + EnsureTlsSetup(); + configPeer->SetSignatureSchemes(&signature_scheme_, 1); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, auth_type_, + signature_scheme_); + } + + std::string certificate_; + SSLAuthType auth_type_; + SSLSignatureScheme signature_scheme_; +}; + +TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigServer) { + Reset(certificate_); + TestSignatureSchemeConfig(server_); +} + +TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigClient) { + Reset(certificate_); + auto capture = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn); + TestSignatureSchemeConfig(client_); + + const DataBuffer& ext = capture->extension(); + ASSERT_EQ(2U + 2U, ext.len()); + uint32_t v = 0; + ASSERT_TRUE(ext.Read(0, 2, &v)); + EXPECT_EQ(2U, v); + ASSERT_TRUE(ext.Read(2, 2, &v)); + EXPECT_EQ(signature_scheme_, static_cast<SSLSignatureScheme>(v)); +} + +TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigBoth) { + Reset(certificate_); + EnsureTlsSetup(); + client_->SetSignatureSchemes(&signature_scheme_, 1); + server_->SetSignatureSchemes(&signature_scheme_, 1); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, auth_type_, signature_scheme_); +} + +class Tls12CertificateRequestReplacer : public TlsHandshakeFilter { + public: + Tls12CertificateRequestReplacer(const std::shared_ptr<TlsAgent>& a, + SSLSignatureScheme scheme) + : TlsHandshakeFilter(a, {kTlsHandshakeCertificateRequest}), + scheme_(scheme) {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + uint32_t offset = 0; + + if (header.handshake_type() != ssl_hs_certificate_request) { + return KEEP; + } + + *output = input; + + uint32_t types_len = 0; + if (!output->Read(offset, 1, &types_len)) { + ADD_FAILURE(); + return KEEP; + } + offset += 1 + types_len; + uint32_t scheme_len = 0; + if (!output->Read(offset, 2, &scheme_len)) { + ADD_FAILURE(); + return KEEP; + } + DataBuffer schemes; + schemes.Write(0, 2, 2); + schemes.Write(2, scheme_, 2); + output->Write(offset, 2, schemes.len()); + output->Splice(schemes, offset + 2, scheme_len); + + return CHANGE; + } + + private: + SSLSignatureScheme scheme_; +}; + +// +// Test how policy interacts with client auth connections +// + +// TLS/DTLS version algorithm policy +typedef std::tuple<SSLProtocolVariant, uint16_t, SECOidTag, PRUint32> + PolicySignatureSchemeProfile; + +// Only TLS 1.2 handles client auth schemes inside +// the certificate request packet, so our failure tests for +// those kinds of connections only occur here. +class TlsConnectAuthWithPolicyTls12 + : public TlsConnectTestBase, + public ::testing::WithParamInterface<PolicySignatureSchemeProfile> { + public: + TlsConnectAuthWithPolicyTls12() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { + alg_ = std::get<2>(GetParam()); + policy_ = std::get<3>(GetParam()); + // use the algorithm to select which single scheme to deploy + // We use these schemes to force servers sending schemes the client + // didn't advertise to make sure the client will still filter these + // by policy and detect that no valid schemes were presented, rather + // than sending an empty client auth message. + switch (alg_) { + case SEC_OID_SHA256: + case SEC_OID_PKCS1_RSA_PSS_SIGNATURE: + scheme_ = ssl_sig_rsa_pss_pss_sha256; + break; + case SEC_OID_PKCS1_RSA_ENCRYPTION: + scheme_ = ssl_sig_rsa_pkcs1_sha256; + break; + case SEC_OID_ANSIX962_EC_PUBLIC_KEY: + scheme_ = ssl_sig_ecdsa_secp256r1_sha256; + break; + default: + ADD_FAILURE() << "need to update algorithm table in " + "TlsConnectAuthWithPolicyTls12"; + scheme_ = ssl_sig_none; + break; + } + } + + protected: + SECOidTag alg_; + PRUint32 policy_; + SSLSignatureScheme scheme_; +}; + +// Only TLS 1.2 and greater looks at schemes extensions on client auth +class TlsConnectAuthWithPolicyTls12Plus + : public TlsConnectTestBase, + public ::testing::WithParamInterface<PolicySignatureSchemeProfile> { + public: + TlsConnectAuthWithPolicyTls12Plus() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { + alg_ = std::get<2>(GetParam()); + policy_ = std::get<3>(GetParam()); + } + + protected: + SECOidTag alg_; + PRUint32 policy_; +}; + +// make sure we can turn single algorithms off by policy an still connect +// this is basically testing that we are properly filtering our schemes +// by policy before communicating them to the server, and that the +// server is respecting our choices +TEST_P(TlsConnectAuthWithPolicyTls12Plus, PolicySuccessTest) { + // in TLS 1.3, RSA PKCS1 is restricted. If we are also + // restricting RSA PSS by policy, we can't use the default + // RSA certificate as the server cert, switch to ECDSA + if ((version_ >= SSL_LIBRARY_VERSION_TLS_1_3) && + (alg_ == SEC_OID_PKCS1_RSA_PSS_SIGNATURE)) { + Reset(TlsAgent::kServerEcdsa256); + } + client_->SetPolicy(alg_, 0, policy_); // Disable policy for client + client_->SetupClientAuth(); + server_->RequestClientAuth(false); + Connect(); +} + +// make sure we fail if the server ignores our policy preference and +// requests client auth with a scheme we don't support +TEST_P(TlsConnectAuthWithPolicyTls12, PolicyFailureTest) { + client_->SetPolicy(alg_, 0, policy_); + client_->SetupClientAuth(); + server_->RequestClientAuth(false); + MakeTlsFilter<Tls12CertificateRequestReplacer>(server_, scheme_); + ConnectExpectAlert(client_, kTlsAlertHandshakeFailure); + client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT); +} + +INSTANTIATE_TEST_SUITE_P( + SignaturesWithPolicyFail, TlsConnectAuthWithPolicyTls12, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12, + ::testing::Values(SEC_OID_SHA256, + SEC_OID_PKCS1_RSA_PSS_SIGNATURE, + SEC_OID_PKCS1_RSA_ENCRYPTION, + SEC_OID_ANSIX962_EC_PUBLIC_KEY), + ::testing::Values(NSS_USE_ALG_IN_SSL_KX, + NSS_USE_ALG_IN_ANY_SIGNATURE))); + +INSTANTIATE_TEST_SUITE_P( + SignaturesWithPolicySuccess, TlsConnectAuthWithPolicyTls12Plus, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12Plus, + ::testing::Values(SEC_OID_SHA256, + SEC_OID_PKCS1_RSA_PSS_SIGNATURE, + SEC_OID_PKCS1_RSA_ENCRYPTION, + SEC_OID_ANSIX962_EC_PUBLIC_KEY), + ::testing::Values(NSS_USE_ALG_IN_SSL_KX, + NSS_USE_ALG_IN_ANY_SIGNATURE))); + +INSTANTIATE_TEST_SUITE_P( + SignatureSchemeRsa, TlsSignatureSchemeConfiguration, + ::testing::Combine( + TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12, + ::testing::Values(TlsAgent::kServerRsaSign), + ::testing::Values(ssl_auth_rsa_sign), + ::testing::Values(ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384, + ssl_sig_rsa_pkcs1_sha512, ssl_sig_rsa_pss_rsae_sha256, + ssl_sig_rsa_pss_rsae_sha384))); +// RSASSA-PKCS1-v1_5 is not allowed to be used in TLS 1.3 +INSTANTIATE_TEST_SUITE_P( + SignatureSchemeRsaTls13, TlsSignatureSchemeConfiguration, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13, + ::testing::Values(TlsAgent::kServerRsaSign), + ::testing::Values(ssl_auth_rsa_sign), + ::testing::Values(ssl_sig_rsa_pss_rsae_sha256, + ssl_sig_rsa_pss_rsae_sha384))); +// PSS with SHA-512 needs a bigger key to work. +INSTANTIATE_TEST_SUITE_P( + SignatureSchemeBigRsa, TlsSignatureSchemeConfiguration, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12Plus, + ::testing::Values(TlsAgent::kRsa2048), + ::testing::Values(ssl_auth_rsa_sign), + ::testing::Values(ssl_sig_rsa_pss_rsae_sha512))); +INSTANTIATE_TEST_SUITE_P( + SignatureSchemeRsaSha1, TlsSignatureSchemeConfiguration, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12, + ::testing::Values(TlsAgent::kServerRsa), + ::testing::Values(ssl_auth_rsa_sign), + ::testing::Values(ssl_sig_rsa_pkcs1_sha1))); +INSTANTIATE_TEST_SUITE_P( + SignatureSchemeEcdsaP256, TlsSignatureSchemeConfiguration, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12Plus, + ::testing::Values(TlsAgent::kServerEcdsa256), + ::testing::Values(ssl_auth_ecdsa), + ::testing::Values(ssl_sig_ecdsa_secp256r1_sha256))); +INSTANTIATE_TEST_SUITE_P( + SignatureSchemeEcdsaP384, TlsSignatureSchemeConfiguration, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12Plus, + ::testing::Values(TlsAgent::kServerEcdsa384), + ::testing::Values(ssl_auth_ecdsa), + ::testing::Values(ssl_sig_ecdsa_secp384r1_sha384))); +INSTANTIATE_TEST_SUITE_P( + SignatureSchemeEcdsaP521, TlsSignatureSchemeConfiguration, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12Plus, + ::testing::Values(TlsAgent::kServerEcdsa521), + ::testing::Values(ssl_auth_ecdsa), + ::testing::Values(ssl_sig_ecdsa_secp521r1_sha512))); +INSTANTIATE_TEST_SUITE_P( + SignatureSchemeEcdsaSha1, TlsSignatureSchemeConfiguration, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12, + ::testing::Values(TlsAgent::kServerEcdsa256, + TlsAgent::kServerEcdsa384), + ::testing::Values(ssl_auth_ecdsa), + ::testing::Values(ssl_sig_ecdsa_sha1))); +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc new file mode 100644 index 0000000000..26e5fb5028 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc @@ -0,0 +1,246 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include <memory> + +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +// Tests for Certificate Transparency (RFC 6962) +// These don't work with TLS 1.3: see bug 1252745. + +// Helper class - stores signed certificate timestamps as provided +// by the relevant callbacks on the client. +class SignedCertificateTimestampsExtractor { + public: + SignedCertificateTimestampsExtractor(std::shared_ptr<TlsAgent>& client) + : client_(client) { + client->SetAuthCertificateCallback( + [this](TlsAgent* agent, bool checksig, bool isServer) -> SECStatus { + const SECItem* scts = SSL_PeerSignedCertTimestamps(agent->ssl_fd()); + EXPECT_TRUE(scts); + if (!scts) { + return SECFailure; + } + auth_timestamps_.reset(new DataBuffer(scts->data, scts->len)); + return SECSuccess; + }); + client->SetHandshakeCallback([this](TlsAgent* agent) { + const SECItem* scts = SSL_PeerSignedCertTimestamps(agent->ssl_fd()); + ASSERT_TRUE(scts); + handshake_timestamps_.reset(new DataBuffer(scts->data, scts->len)); + }); + } + + void assertTimestamps(const DataBuffer& timestamps) { + ASSERT_NE(nullptr, auth_timestamps_); + EXPECT_EQ(timestamps, *auth_timestamps_); + + ASSERT_NE(nullptr, handshake_timestamps_); + EXPECT_EQ(timestamps, *handshake_timestamps_); + + const SECItem* current = + SSL_PeerSignedCertTimestamps(client_.lock()->ssl_fd()); + EXPECT_EQ(timestamps, DataBuffer(current->data, current->len)); + } + + private: + std::weak_ptr<TlsAgent> client_; + std::unique_ptr<DataBuffer> auth_timestamps_; + std::unique_ptr<DataBuffer> handshake_timestamps_; +}; + +static const uint8_t kSctValue[] = {0x01, 0x23, 0x45, 0x67, 0x89}; +static const SECItem kSctItem = {siBuffer, const_cast<uint8_t*>(kSctValue), + sizeof(kSctValue)}; +static const DataBuffer kSctBuffer(kSctValue, sizeof(kSctValue)); +static const SSLExtraServerCertData kExtraSctData = { + ssl_auth_null, nullptr, nullptr, &kSctItem, nullptr, nullptr}; + +// Test timestamps extraction during a successful handshake. +TEST_P(TlsConnectGenericPre13, SignedCertificateTimestampsLegacy) { + EnsureTlsSetup(); + + // We have to use the legacy API consistently here for configuring certs. + // Also, this doesn't work in TLS 1.3 because this only configures the SCT for + // RSA decrypt and PKCS#1 signing, not PSS. + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsa, &cert, &priv)); + EXPECT_EQ(SECSuccess, SSL_ConfigSecureServerWithCertChain( + server_->ssl_fd(), cert.get(), nullptr, priv.get(), + ssl_kea_rsa)); + EXPECT_EQ(SECSuccess, SSL_SetSignedCertTimestamps(server_->ssl_fd(), + &kSctItem, ssl_kea_rsa)); + + client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); + SignedCertificateTimestampsExtractor timestamps_extractor(client_); + + Connect(); + + timestamps_extractor.assertTimestamps(kSctBuffer); +} + +TEST_P(TlsConnectGeneric, SignedCertificateTimestampsSuccess) { + EnsureTlsSetup(); + EXPECT_TRUE( + server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraSctData)); + client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); + SignedCertificateTimestampsExtractor timestamps_extractor(client_); + + Connect(); + + timestamps_extractor.assertTimestamps(kSctBuffer); +} + +// Test SSL_PeerSignedCertTimestamps returning zero-length SECItem +// when the client / the server / both have not enabled the feature. +TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveClient) { + EnsureTlsSetup(); + EXPECT_TRUE( + server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraSctData)); + SignedCertificateTimestampsExtractor timestamps_extractor(client_); + + Connect(); + timestamps_extractor.assertTimestamps(DataBuffer()); +} + +TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveServer) { + EnsureTlsSetup(); + client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); + SignedCertificateTimestampsExtractor timestamps_extractor(client_); + + Connect(); + timestamps_extractor.assertTimestamps(DataBuffer()); +} + +TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveBoth) { + EnsureTlsSetup(); + SignedCertificateTimestampsExtractor timestamps_extractor(client_); + + Connect(); + timestamps_extractor.assertTimestamps(DataBuffer()); +} + +// Check that the given agent doesn't have an OCSP response for its peer. +static SECStatus CheckNoOCSP(TlsAgent* agent, bool checksig, bool isServer) { + const SECItemArray* ocsp = SSL_PeerStapledOCSPResponses(agent->ssl_fd()); + EXPECT_TRUE(ocsp); + EXPECT_EQ(0U, ocsp->len); + return SECSuccess; +} + +static const uint8_t kOcspValue1[] = {1, 2, 3, 4, 5, 6}; +static const uint8_t kOcspValue2[] = {7, 8, 9}; +static const SECItem kOcspItems[] = { + {siBuffer, const_cast<uint8_t*>(kOcspValue1), sizeof(kOcspValue1)}, + {siBuffer, const_cast<uint8_t*>(kOcspValue2), sizeof(kOcspValue2)}}; +static const SECItemArray kOcspResponses = {const_cast<SECItem*>(kOcspItems), + PR_ARRAY_SIZE(kOcspItems)}; +const static SSLExtraServerCertData kOcspExtraData = { + ssl_auth_null, nullptr, &kOcspResponses, nullptr, nullptr, nullptr}; + +TEST_P(TlsConnectGeneric, NoOcsp) { + EnsureTlsSetup(); + client_->SetAuthCertificateCallback(CheckNoOCSP); + Connect(); +} + +// The client doesn't get OCSP stapling unless it asks. +TEST_P(TlsConnectGeneric, OcspNotRequested) { + EnsureTlsSetup(); + client_->SetAuthCertificateCallback(CheckNoOCSP); + EXPECT_TRUE( + server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData)); + Connect(); +} + +// Even if the client asks, the server has nothing unless it is configured. +TEST_P(TlsConnectGeneric, OcspNotProvided) { + EnsureTlsSetup(); + client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE); + client_->SetAuthCertificateCallback(CheckNoOCSP); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, OcspMangled) { + EnsureTlsSetup(); + client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE); + EXPECT_TRUE( + server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData)); + + static const uint8_t val[] = {1}; + auto replacer = MakeTlsFilter<TlsExtensionReplacer>( + server_, ssl_cert_status_xtn, DataBuffer(val, sizeof(val))); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectGeneric, OcspSuccess) { + EnsureTlsSetup(); + client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE); + auto capture_ocsp = + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_cert_status_xtn); + + // The value should be available during the AuthCertificateCallback + client_->SetAuthCertificateCallback([](TlsAgent* agent, bool checksig, + bool isServer) -> SECStatus { + const SECItemArray* ocsp = SSL_PeerStapledOCSPResponses(agent->ssl_fd()); + if (!ocsp) { + return SECFailure; + } + EXPECT_EQ(1U, ocsp->len) << "We only provide the first item"; + EXPECT_EQ(0, SECITEM_CompareItem(&kOcspItems[0], &ocsp->items[0])); + return SECSuccess; + }); + EXPECT_TRUE( + server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData)); + + Connect(); + // In TLS 1.3, the server doesn't provide a visible ServerHello extension. + // For earlier versions, the extension is just empty. + EXPECT_EQ(0U, capture_ocsp->extension().len()); +} + +TEST_P(TlsConnectGeneric, OcspHugeSuccess) { + EnsureTlsSetup(); + client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE); + + uint8_t hugeOcspValue[16385]; + memset(hugeOcspValue, 0xa1, sizeof(hugeOcspValue)); + const SECItem hugeOcspItems[] = { + {siBuffer, const_cast<uint8_t*>(hugeOcspValue), sizeof(hugeOcspValue)}}; + const SECItemArray hugeOcspResponses = {const_cast<SECItem*>(hugeOcspItems), + PR_ARRAY_SIZE(hugeOcspItems)}; + const SSLExtraServerCertData hugeOcspExtraData = { + ssl_auth_null, nullptr, &hugeOcspResponses, nullptr, nullptr, nullptr}; + + // The value should be available during the AuthCertificateCallback + client_->SetAuthCertificateCallback([&](TlsAgent* agent, bool checksig, + bool isServer) -> SECStatus { + const SECItemArray* ocsp = SSL_PeerStapledOCSPResponses(agent->ssl_fd()); + if (!ocsp) { + return SECFailure; + } + EXPECT_EQ(1U, ocsp->len) << "We only provide the first item"; + EXPECT_EQ(0, SECITEM_CompareItem(&hugeOcspItems[0], &ocsp->items[0])); + return SECSuccess; + }); + EXPECT_TRUE(server_->ConfigServerCert(TlsAgent::kServerRsa, true, + &hugeOcspExtraData)); + + Connect(); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_cipherorder_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_cipherorder_unittest.cc new file mode 100644 index 0000000000..1e4f817e95 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_cipherorder_unittest.cc @@ -0,0 +1,241 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include <memory> + +#include "tls_connect.h" +#include "tls_filter.h" + +namespace nss_test { + +class TlsCipherOrderTest : public TlsConnectTestBase { + protected: + virtual void ConfigureTLS() { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + } + + virtual SECStatus BuildTestLists(std::vector<uint16_t> &cs_initial_list, + std::vector<uint16_t> &cs_new_list) { + // This is the current CipherSuites order of enabled CipherSuites as defined + // in ssl3con.c + const PRUint16 *kCipherSuites = SSL_GetImplementedCiphers(); + + for (unsigned int i = 0; i < kNumImplementedCiphers; i++) { + PRBool pref = PR_FALSE, policy = PR_FALSE; + SECStatus rv; + rv = SSL_CipherPolicyGet(kCipherSuites[i], &policy); + if (rv != SECSuccess) { + return SECFailure; + } + rv = SSL_CipherPrefGetDefault(kCipherSuites[i], &pref); + if (rv != SECSuccess) { + return SECFailure; + } + if (pref && policy) { + cs_initial_list.push_back(kCipherSuites[i]); + } + } + + // We will test set function with the first 15 enabled ciphers. + const PRUint16 kNumCiphersToSet = 15; + for (unsigned int i = 0; i < kNumCiphersToSet; i++) { + cs_new_list.push_back(cs_initial_list[i]); + } + cs_new_list[0] = cs_initial_list[1]; + cs_new_list[1] = cs_initial_list[0]; + return SECSuccess; + } + + public: + TlsCipherOrderTest() : TlsConnectTestBase(ssl_variant_stream, 0) {} + const unsigned int kNumImplementedCiphers = SSL_GetNumImplementedCiphers(); +}; + +const PRUint16 kCSUnsupported[] = {20196, 10101}; +const PRUint16 kNumCSUnsupported = PR_ARRAY_SIZE(kCSUnsupported); +const PRUint16 kCSEmpty[] = {0}; + +// Get the active CipherSuites odered as they were compiled +TEST_F(TlsCipherOrderTest, CipherOrderGet) { + std::vector<uint16_t> initial_cs_order; + std::vector<uint16_t> new_cs_order; + SECStatus result = BuildTestLists(initial_cs_order, new_cs_order); + ASSERT_EQ(result, SECSuccess); + ConfigureTLS(); + + std::vector<uint16_t> current_cs_order(SSL_GetNumImplementedCiphers() + 1); + unsigned int current_num_active_cs = 0; + result = SSL_CipherSuiteOrderGet(client_->ssl_fd(), current_cs_order.data(), + ¤t_num_active_cs); + ASSERT_EQ(result, SECSuccess); + ASSERT_EQ(current_num_active_cs, initial_cs_order.size()); + for (unsigned int i = 0; i < initial_cs_order.size(); i++) { + EXPECT_EQ(initial_cs_order[i], current_cs_order[i]); + } + // Get the chosen CipherSuite during the Handshake without any modification. + Connect(); + SSLChannelInfo channel; + result = SSL_GetChannelInfo(client_->ssl_fd(), &channel, sizeof channel); + ASSERT_EQ(result, SECSuccess); + EXPECT_EQ(channel.cipherSuite, initial_cs_order[0]); +} + +// The "server" used for gtests honor only its ciphersuites order. +// So, we apply the new set for the server instead of client. +// This is enough to test the effect of SSL_CipherSuiteOrderSet function. +TEST_F(TlsCipherOrderTest, CipherOrderSet) { + std::vector<uint16_t> initial_cs_order; + std::vector<uint16_t> new_cs_order; + SECStatus result = BuildTestLists(initial_cs_order, new_cs_order); + ASSERT_EQ(result, SECSuccess); + ConfigureTLS(); + + // change the server_ ciphersuites order. + result = SSL_CipherSuiteOrderSet(server_->ssl_fd(), new_cs_order.data(), + new_cs_order.size()); + ASSERT_EQ(result, SECSuccess); + + // The function expect an array. We are using vector for VStudio + // compatibility. + std::vector<uint16_t> current_cs_order(SSL_GetNumImplementedCiphers() + 1); + unsigned int current_num_active_cs = 0; + result = SSL_CipherSuiteOrderGet(server_->ssl_fd(), current_cs_order.data(), + ¤t_num_active_cs); + ASSERT_EQ(result, SECSuccess); + ASSERT_EQ(current_num_active_cs, new_cs_order.size()); + for (unsigned int i = 0; i < new_cs_order.size(); i++) { + ASSERT_EQ(new_cs_order[i], current_cs_order[i]); + } + + Connect(); + SSLChannelInfo channel; + // changes in server_ order reflect in client chosen ciphersuite. + result = SSL_GetChannelInfo(client_->ssl_fd(), &channel, sizeof channel); + ASSERT_EQ(result, SECSuccess); + EXPECT_EQ(channel.cipherSuite, new_cs_order[0]); +} + +// Duplicate socket configuration from a model. +TEST_F(TlsCipherOrderTest, CipherOrderCopySocket) { + std::vector<uint16_t> initial_cs_order; + std::vector<uint16_t> new_cs_order; + SECStatus result = BuildTestLists(initial_cs_order, new_cs_order); + ASSERT_EQ(result, SECSuccess); + ConfigureTLS(); + + // Use the existing sockets for this test. + result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), new_cs_order.data(), + new_cs_order.size()); + ASSERT_EQ(result, SECSuccess); + + std::vector<uint16_t> current_cs_order(SSL_GetNumImplementedCiphers() + 1); + unsigned int current_num_active_cs = 0; + result = SSL_CipherSuiteOrderGet(server_->ssl_fd(), current_cs_order.data(), + ¤t_num_active_cs); + ASSERT_EQ(result, SECSuccess); + ASSERT_EQ(current_num_active_cs, initial_cs_order.size()); + for (unsigned int i = 0; i < current_num_active_cs; i++) { + ASSERT_EQ(initial_cs_order[i], current_cs_order[i]); + } + + // Import/Duplicate configurations from client_ to server_ + PRFileDesc *rv = SSL_ImportFD(client_->ssl_fd(), server_->ssl_fd()); + EXPECT_NE(nullptr, rv); + + result = SSL_CipherSuiteOrderGet(server_->ssl_fd(), current_cs_order.data(), + ¤t_num_active_cs); + ASSERT_EQ(result, SECSuccess); + ASSERT_EQ(current_num_active_cs, new_cs_order.size()); + for (unsigned int i = 0; i < new_cs_order.size(); i++) { + EXPECT_EQ(new_cs_order.data()[i], current_cs_order[i]); + } +} + +// If the infomed num of elements is lower than the actual list size, only the +// first "informed num" elements will be considered. The rest is ignored. +TEST_F(TlsCipherOrderTest, CipherOrderSetLower) { + std::vector<uint16_t> initial_cs_order; + std::vector<uint16_t> new_cs_order; + SECStatus result = BuildTestLists(initial_cs_order, new_cs_order); + ASSERT_EQ(result, SECSuccess); + ConfigureTLS(); + + result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), new_cs_order.data(), + new_cs_order.size() - 1); + ASSERT_EQ(result, SECSuccess); + + std::vector<uint16_t> current_cs_order(SSL_GetNumImplementedCiphers() + 1); + unsigned int current_num_active_cs = 0; + result = SSL_CipherSuiteOrderGet(client_->ssl_fd(), current_cs_order.data(), + ¤t_num_active_cs); + ASSERT_EQ(result, SECSuccess); + ASSERT_EQ(current_num_active_cs, new_cs_order.size() - 1); + for (unsigned int i = 0; i < new_cs_order.size() - 1; i++) { + ASSERT_EQ(new_cs_order.data()[i], current_cs_order[i]); + } +} + +// Testing Errors Controls +TEST_F(TlsCipherOrderTest, CipherOrderSetControls) { + std::vector<uint16_t> initial_cs_order; + std::vector<uint16_t> new_cs_order; + SECStatus result = BuildTestLists(initial_cs_order, new_cs_order); + ASSERT_EQ(result, SECSuccess); + ConfigureTLS(); + + // Create a new vector with diplicated entries + std::vector<uint16_t> repeated_cs_order(SSL_GetNumImplementedCiphers() + 1); + std::copy(initial_cs_order.begin(), initial_cs_order.end(), + repeated_cs_order.begin()); + repeated_cs_order[0] = repeated_cs_order[1]; + + // Repeated ciphersuites in the list + result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), repeated_cs_order.data(), + initial_cs_order.size()); + EXPECT_EQ(result, SECFailure); + + // Zero size for the sent list + result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), new_cs_order.data(), 0); + EXPECT_EQ(result, SECFailure); + + // Wrong size, greater than actual + result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), new_cs_order.data(), + SSL_GetNumImplementedCiphers() + 1); + EXPECT_EQ(result, SECFailure); + + // Wrong ciphersuites, not implemented + result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), kCSUnsupported, + kNumCSUnsupported); + EXPECT_EQ(result, SECFailure); + + // Null list + result = + SSL_CipherSuiteOrderSet(client_->ssl_fd(), nullptr, new_cs_order.size()); + EXPECT_EQ(result, SECFailure); + + // Empty list + result = + SSL_CipherSuiteOrderSet(client_->ssl_fd(), kCSEmpty, new_cs_order.size()); + EXPECT_EQ(result, SECFailure); + + // Confirm that the controls are working, as the current ciphersuites + // remained untouched + std::vector<uint16_t> current_cs_order(SSL_GetNumImplementedCiphers() + 1); + unsigned int current_num_active_cs = 0; + result = SSL_CipherSuiteOrderGet(client_->ssl_fd(), current_cs_order.data(), + ¤t_num_active_cs); + ASSERT_EQ(result, SECSuccess); + ASSERT_EQ(current_num_active_cs, initial_cs_order.size()); + for (unsigned int i = 0; i < initial_cs_order.size(); i++) { + ASSERT_EQ(initial_cs_order[i], current_cs_order[i]); + } +} +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc new file mode 100644 index 0000000000..db0618e042 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc @@ -0,0 +1,531 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "tls_connect.h" +#include "tls_parser.h" + +namespace nss_test { + +// variant, version, cipher suite +typedef std::tuple<SSLProtocolVariant, uint16_t, uint16_t, SSLNamedGroup, + SSLSignatureScheme> + CipherSuiteProfile; + +class TlsCipherSuiteTestBase : public TlsConnectTestBase { + public: + TlsCipherSuiteTestBase(SSLProtocolVariant variant, uint16_t version, + uint16_t cipher_suite, SSLNamedGroup group, + SSLSignatureScheme sig_scheme) + : TlsConnectTestBase(variant, version), + cipher_suite_(cipher_suite), + group_(group), + sig_scheme_(sig_scheme), + csinfo_({0}) { + SECStatus rv = + SSL_GetCipherSuiteInfo(cipher_suite_, &csinfo_, sizeof(csinfo_)); + EXPECT_EQ(SECSuccess, rv); + if (rv == SECSuccess) { + std::cerr << "Cipher suite: " << csinfo_.cipherSuiteName << std::endl; + } + auth_type_ = csinfo_.authType; + kea_type_ = csinfo_.keaType; + } + + protected: + void EnableSingleCipher() { + EnsureTlsSetup(); + // It doesn't matter which does this, but the test is better if both do it. + client_->EnableSingleCipher(cipher_suite_); + server_->EnableSingleCipher(cipher_suite_); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + std::vector<SSLNamedGroup> groups = {group_}; + if (cert_group_ != ssl_grp_none) { + groups.push_back(cert_group_); + } + client_->ConfigNamedGroups(groups); + server_->ConfigNamedGroups(groups); + kea_type_ = SSLInt_GetKEAType(group_); + + client_->SetSignatureSchemes(&sig_scheme_, 1); + server_->SetSignatureSchemes(&sig_scheme_, 1); + } + } + + virtual void SetupCertificate() { + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + switch (sig_scheme_) { + case ssl_sig_rsa_pss_rsae_sha256: + std::cerr << "Signature scheme: rsa_pss_rsae_sha256" << std::endl; + Reset(TlsAgent::kServerRsaSign); + auth_type_ = ssl_auth_rsa_sign; + break; + case ssl_sig_rsa_pss_rsae_sha384: + std::cerr << "Signature scheme: rsa_pss_rsae_sha384" << std::endl; + Reset(TlsAgent::kServerRsaSign); + auth_type_ = ssl_auth_rsa_sign; + break; + case ssl_sig_rsa_pss_rsae_sha512: + // You can't fit SHA-512 PSS in a 1024-bit key. + std::cerr << "Signature scheme: rsa_pss_rsae_sha512" << std::endl; + Reset(TlsAgent::kRsa2048); + auth_type_ = ssl_auth_rsa_sign; + break; + case ssl_sig_rsa_pss_pss_sha256: + std::cerr << "Signature scheme: rsa_pss_pss_sha256" << std::endl; + Reset(TlsAgent::kServerRsaPss); + auth_type_ = ssl_auth_rsa_pss; + break; + case ssl_sig_rsa_pss_pss_sha384: + std::cerr << "Signature scheme: rsa_pss_pss_sha384" << std::endl; + Reset("rsa_pss384"); + auth_type_ = ssl_auth_rsa_pss; + break; + case ssl_sig_rsa_pss_pss_sha512: + std::cerr << "Signature scheme: rsa_pss_pss_sha512" << std::endl; + Reset("rsa_pss512"); + auth_type_ = ssl_auth_rsa_pss; + break; + case ssl_sig_ecdsa_secp256r1_sha256: + std::cerr << "Signature scheme: ecdsa_secp256r1_sha256" << std::endl; + Reset(TlsAgent::kServerEcdsa256); + auth_type_ = ssl_auth_ecdsa; + cert_group_ = ssl_grp_ec_secp256r1; + break; + case ssl_sig_ecdsa_secp384r1_sha384: + std::cerr << "Signature scheme: ecdsa_secp384r1_sha384" << std::endl; + Reset(TlsAgent::kServerEcdsa384); + auth_type_ = ssl_auth_ecdsa; + cert_group_ = ssl_grp_ec_secp384r1; + break; + default: + ADD_FAILURE() << "Unsupported signature scheme: " << sig_scheme_; + break; + } + } else { + switch (csinfo_.authType) { + case ssl_auth_rsa_sign: + Reset(TlsAgent::kServerRsaSign); + break; + case ssl_auth_rsa_decrypt: + Reset(TlsAgent::kServerRsaDecrypt); + break; + case ssl_auth_ecdsa: + Reset(TlsAgent::kServerEcdsa256); + cert_group_ = ssl_grp_ec_secp256r1; + break; + case ssl_auth_ecdh_ecdsa: + Reset(TlsAgent::kServerEcdhEcdsa); + cert_group_ = ssl_grp_ec_secp256r1; + break; + case ssl_auth_ecdh_rsa: + Reset(TlsAgent::kServerEcdhRsa); + break; + case ssl_auth_dsa: + Reset(TlsAgent::kServerDsa); + break; + default: + ASSERT_TRUE(false) << "Unsupported cipher suite: " << cipher_suite_; + break; + } + } + } + + void ConnectAndCheckCipherSuite() { + Connect(); + SendReceive(); + + // Check that we used the right cipher suite, auth type and kea type. + uint16_t actual = TLS_NULL_WITH_NULL_NULL; + EXPECT_TRUE(client_->cipher_suite(&actual)); + EXPECT_EQ(cipher_suite_, actual); + EXPECT_TRUE(server_->cipher_suite(&actual)); + EXPECT_EQ(cipher_suite_, actual); + SSLAuthType auth = ssl_auth_size; + EXPECT_TRUE(client_->auth_type(&auth)); + EXPECT_EQ(auth_type_, auth); + EXPECT_TRUE(server_->auth_type(&auth)); + EXPECT_EQ(auth_type_, auth); + SSLKEAType kea = ssl_kea_size; + EXPECT_TRUE(client_->kea_type(&kea)); + EXPECT_EQ(kea_type_, kea); + EXPECT_TRUE(server_->kea_type(&kea)); + EXPECT_EQ(kea_type_, kea); + } + + // Get the expected limit on the number of records that can be sent for the + // cipher suite. + uint64_t record_limit() const { + switch (csinfo_.symCipher) { + case ssl_calg_rc4: + case ssl_calg_3des: + return 1ULL << 20; + case ssl_calg_aes: + case ssl_calg_aes_gcm: + return 0x5aULL << 28; + case ssl_calg_null: + case ssl_calg_chacha20: + return (1ULL << 48) - 1; + case ssl_calg_rc2: + case ssl_calg_des: + case ssl_calg_idea: + case ssl_calg_fortezza: + case ssl_calg_camellia: + case ssl_calg_seed: + break; + } + ADD_FAILURE() << "No limit for " << csinfo_.cipherSuiteName; + return 0; + } + + uint64_t last_safe_write() const { + uint64_t limit = record_limit() - 1; + if (version_ < SSL_LIBRARY_VERSION_TLS_1_1 && + (csinfo_.symCipher == ssl_calg_3des || + csinfo_.symCipher == ssl_calg_aes)) { + // 1/n-1 record splitting needs space for two records. + limit--; + } + return limit; + } + + protected: + uint16_t cipher_suite_; + SSLAuthType auth_type_; + SSLKEAType kea_type_; + SSLNamedGroup group_; + SSLNamedGroup cert_group_ = ssl_grp_none; + SSLSignatureScheme sig_scheme_; + SSLCipherSuiteInfo csinfo_; +}; + +class TlsCipherSuiteTest + : public TlsCipherSuiteTestBase, + public ::testing::WithParamInterface<CipherSuiteProfile> { + public: + TlsCipherSuiteTest() + : TlsCipherSuiteTestBase(std::get<0>(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam()), std::get<3>(GetParam()), + std::get<4>(GetParam())) {} + + protected: + bool SkipIfCipherSuiteIsDSA() { + bool isDSA = csinfo_.authType == ssl_auth_dsa; + if (isDSA) { + std::cerr << "Skipping DSA suite: " << csinfo_.cipherSuiteName + << std::endl; + } + return isDSA; + } +}; + +TEST_P(TlsCipherSuiteTest, SingleCipherSuite) { + SetupCertificate(); + EnableSingleCipher(); + ConnectAndCheckCipherSuite(); +} + +TEST_P(TlsCipherSuiteTest, ResumeCipherSuite) { + if (SkipIfCipherSuiteIsDSA()) { + GTEST_SKIP() << "Tickets not supported with DSA (bug 1174677)."; + } + + SetupCertificate(); // This is only needed once. + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + EnableSingleCipher(); + + ConnectAndCheckCipherSuite(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + EnableSingleCipher(); + ExpectResumption(RESUME_TICKET); + ConnectAndCheckCipherSuite(); +} + +TEST_P(TlsCipherSuiteTest, ReadLimit) { + SetupCertificate(); + EnableSingleCipher(); + TlsSendCipherSpecCapturer capturer(client_); + ConnectAndCheckCipherSuite(); + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + uint64_t last = last_safe_write(); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last)); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last)); + + client_->SendData(10, 10); + server_->ReadBytes(); // This should be OK. + server_->ReadBytes(); // Read twice to flush any 1,N-1 record splitting. + } else { + // In TLS 1.3, reading or writing triggers a KeyUpdate. That would mean + // that the sequence numbers would reset and we wouldn't hit the limit. So + // move the sequence number to the limit directly and don't test sending and + // receiving just before the limit. + uint64_t last = record_limit(); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last)); + } + + // The payload needs to be big enough to pass for encrypted. The code checks + // the limit before it tries to decrypt. + static const uint8_t payload[32] = {6}; + DataBuffer record; + uint64_t epoch; + if (variant_ == ssl_variant_datagram) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { + epoch = 3; // Application traffic keys. + } else { + epoch = 1; + } + } else { + epoch = 0; + } + + uint64_t seqno = (epoch << 48) | record_limit(); + + // DTLS 1.3 masks the sequence number + if (variant_ == ssl_variant_datagram && + version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + auto spec = capturer.spec(1); + ASSERT_NE(nullptr, spec.get()); + ASSERT_EQ(3, spec->epoch()); + + DataBuffer pt, ct; + uint8_t dtls13_ctype = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | + kCtDtlsCiphertextLengthPresent; + TlsRecordHeader hdr(variant_, version_, dtls13_ctype, seqno); + pt.Assign(payload, sizeof(payload)); + TlsRecordHeader out_hdr; + spec->Protect(hdr, pt, &ct, &out_hdr); + + auto rv = out_hdr.Write(&record, 0, ct); + EXPECT_EQ(out_hdr.header_length() + ct.len(), rv); + } else { + TlsAgentTestBase::MakeRecord(variant_, ssl_ct_application_data, version_, + payload, sizeof(payload), &record, seqno); + } + + client_->SendDirect(record); + server_->ExpectReadWriteError(); + server_->ReadBytes(); + EXPECT_EQ(SSL_ERROR_TOO_MANY_RECORDS, server_->error_code()); +} + +TEST_P(TlsCipherSuiteTest, WriteLimit) { + // This asserts in TLS 1.3 because we expect an automatic update. + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + GTEST_SKIP(); + } + SetupCertificate(); + EnableSingleCipher(); + ConnectAndCheckCipherSuite(); + EXPECT_EQ(SECSuccess, + SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last_safe_write())); + client_->SendData(10, 10); + client_->ExpectReadWriteError(); + client_->SendData(10, 10); + EXPECT_EQ(SSL_ERROR_TOO_MANY_RECORDS, client_->error_code()); +} + +// This awful macro makes the test instantiations easier to read. +#define INSTANTIATE_CIPHER_TEST_P(name, modes, versions, groups, sigalgs, ...) \ + static const uint16_t k##name##CiphersArr[] = {__VA_ARGS__}; \ + static const ::testing::internal::ParamGenerator<uint16_t> \ + k##name##Ciphers = ::testing::ValuesIn(k##name##CiphersArr); \ + INSTANTIATE_TEST_SUITE_P( \ + CipherSuite##name, TlsCipherSuiteTest, \ + ::testing::Combine(TlsConnectTestBase::kTlsVariants##modes, \ + TlsConnectTestBase::kTls##versions, k##name##Ciphers, \ + groups, sigalgs)); + +static const auto kDummyNamedGroupParams = ::testing::Values(ssl_grp_none); +static const auto kDummySignatureSchemesParams = + ::testing::Values(ssl_sig_none); + +static SSLSignatureScheme kSignatureSchemesParamsArr[] = { + ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384, + ssl_sig_rsa_pkcs1_sha512, ssl_sig_ecdsa_secp256r1_sha256, + ssl_sig_ecdsa_secp384r1_sha384, ssl_sig_rsa_pss_rsae_sha256, + ssl_sig_rsa_pss_rsae_sha384, ssl_sig_rsa_pss_rsae_sha512, + ssl_sig_rsa_pss_pss_sha256, ssl_sig_rsa_pss_pss_sha384, + ssl_sig_rsa_pss_pss_sha512}; + +static SSLSignatureScheme kSignatureSchemesParamsArrTls13[] = { + ssl_sig_ecdsa_secp256r1_sha256, ssl_sig_ecdsa_secp384r1_sha384, + ssl_sig_rsa_pss_rsae_sha256, ssl_sig_rsa_pss_rsae_sha384, + ssl_sig_rsa_pss_rsae_sha512, ssl_sig_rsa_pss_pss_sha256, + ssl_sig_rsa_pss_pss_sha384, ssl_sig_rsa_pss_pss_sha512}; + +INSTANTIATE_CIPHER_TEST_P(RC4, Stream, V10ToV12, kDummyNamedGroupParams, + kDummySignatureSchemesParams, + TLS_RSA_WITH_RC4_128_SHA, + TLS_ECDH_ECDSA_WITH_RC4_128_SHA, + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + TLS_ECDH_RSA_WITH_RC4_128_SHA, + TLS_ECDHE_RSA_WITH_RC4_128_SHA); +INSTANTIATE_CIPHER_TEST_P(AEAD12, All, V12, kDummyNamedGroupParams, + kDummySignatureSchemesParams, + TLS_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_256_GCM_SHA384, + TLS_DHE_DSS_WITH_AES_128_GCM_SHA256, + TLS_DHE_DSS_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384); +INSTANTIATE_CIPHER_TEST_P(AEAD, All, V12, kDummyNamedGroupParams, + kDummySignatureSchemesParams, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_DHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256); +INSTANTIATE_CIPHER_TEST_P( + CBC12, All, V12, kDummyNamedGroupParams, kDummySignatureSchemesParams, + TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, TLS_RSA_WITH_AES_256_CBC_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, TLS_DHE_DSS_WITH_AES_128_CBC_SHA256, + TLS_DHE_DSS_WITH_AES_256_CBC_SHA256); +INSTANTIATE_CIPHER_TEST_P( + CBCStream, Stream, V10ToV12, kDummyNamedGroupParams, + kDummySignatureSchemesParams, TLS_ECDH_ECDSA_WITH_NULL_SHA, + TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA, + TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_ECDSA_WITH_NULL_SHA, + TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDH_RSA_WITH_NULL_SHA, + TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDH_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDH_RSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_NULL_SHA, + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA); +INSTANTIATE_CIPHER_TEST_P( + CBCDatagram, Datagram, V11V12, kDummyNamedGroupParams, + kDummySignatureSchemesParams, TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, + TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_ECDH_RSA_WITH_AES_128_CBC_SHA, TLS_ECDH_RSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA); +INSTANTIATE_CIPHER_TEST_P( + TLS12SigSchemes, All, V12, ::testing::ValuesIn(kFasterDHEGroups), + ::testing::ValuesIn(kSignatureSchemesParamsArr), + TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, TLS_RSA_WITH_AES_256_CBC_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, TLS_DHE_DSS_WITH_AES_128_CBC_SHA256, + TLS_DHE_DSS_WITH_AES_256_CBC_SHA256); +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_CIPHER_TEST_P(TLS13, All, V13, + ::testing::ValuesIn(kFasterDHEGroups), + ::testing::ValuesIn(kSignatureSchemesParamsArrTls13), + TLS_AES_128_GCM_SHA256, TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_256_GCM_SHA384); +INSTANTIATE_CIPHER_TEST_P(TLS13AllGroups, All, V13, + ::testing::ValuesIn(kAllDHEGroups), + ::testing::Values(ssl_sig_ecdsa_secp384r1_sha384), + TLS_AES_256_GCM_SHA384); +#endif + +// Fields are: version, cipher suite, bulk cipher name, secretKeySize +struct SecStatusParams { + uint16_t version; + uint16_t cipher_suite; + std::string name; + int keySize; +}; + +inline std::ostream &operator<<(std::ostream &stream, + const SecStatusParams &vals) { + SSLCipherSuiteInfo csinfo; + SECStatus rv = + SSL_GetCipherSuiteInfo(vals.cipher_suite, &csinfo, sizeof(csinfo)); + if (rv != SECSuccess) { + return stream << "Error invoking SSL_GetCipherSuiteInfo()"; + } + + return stream << "TLS " << VersionString(vals.version) << ", " + << csinfo.cipherSuiteName << ", name = \"" << vals.name + << "\", key size = " << vals.keySize; +} + +class SecurityStatusTest + : public TlsCipherSuiteTestBase, + public ::testing::WithParamInterface<SecStatusParams> { + public: + SecurityStatusTest() + : TlsCipherSuiteTestBase(ssl_variant_stream, GetParam().version, + GetParam().cipher_suite, ssl_grp_none, + ssl_sig_none) {} +}; + +// SSL_SecurityStatus produces fairly useless output when compared to +// SSL_GetCipherSuiteInfo and SSL_GetChannelInfo, but we can't break it, so we +// need to check it. +TEST_P(SecurityStatusTest, CheckSecurityStatus) { + SetupCertificate(); + EnableSingleCipher(); + ConnectAndCheckCipherSuite(); + + int on; + char *cipher; + int keySize; + int secretKeySize; + char *issuer; + char *subject; + EXPECT_EQ(SECSuccess, + SSL_SecurityStatus(client_->ssl_fd(), &on, &cipher, &keySize, + &secretKeySize, &issuer, &subject)); + if (std::string(cipher) == "NULL") { + EXPECT_EQ(0, on); + } else { + EXPECT_NE(0, on); + } + EXPECT_EQ(GetParam().name, std::string(cipher)); + // All the ciphers we support have secret key size == key size. + EXPECT_EQ(GetParam().keySize, keySize); + EXPECT_EQ(GetParam().keySize, secretKeySize); + EXPECT_LT(0U, strlen(issuer)); + EXPECT_LT(0U, strlen(subject)); + + PORT_Free(cipher); + PORT_Free(issuer); + PORT_Free(subject); +} + +static const SecStatusParams kSecStatusTestValuesArr[] = { + {SSL_LIBRARY_VERSION_TLS_1_0, TLS_ECDHE_RSA_WITH_NULL_SHA, "NULL", 0}, + {SSL_LIBRARY_VERSION_TLS_1_0, TLS_RSA_WITH_RC4_128_SHA, "RC4", 128}, + {SSL_LIBRARY_VERSION_TLS_1_0, TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + "3DES-EDE-CBC", 168}, + {SSL_LIBRARY_VERSION_TLS_1_0, TLS_RSA_WITH_AES_128_CBC_SHA, "AES-128", 128}, + {SSL_LIBRARY_VERSION_TLS_1_2, TLS_RSA_WITH_AES_256_CBC_SHA256, "AES-256", + 256}, + {SSL_LIBRARY_VERSION_TLS_1_2, TLS_RSA_WITH_AES_128_GCM_SHA256, + "AES-128-GCM", 128}, + {SSL_LIBRARY_VERSION_TLS_1_2, TLS_RSA_WITH_AES_256_GCM_SHA384, + "AES-256-GCM", 256}, + {SSL_LIBRARY_VERSION_TLS_1_2, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + "ChaCha20-Poly1305", 256}}; +INSTANTIATE_TEST_SUITE_P(TestSecurityStatus, SecurityStatusTest, + ::testing::ValuesIn(kSecStatusTestValuesArr)); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc new file mode 100644 index 0000000000..7ed0e5d934 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc @@ -0,0 +1,499 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ssl.h" +#include "ssl3prot.h" +#include "sslerr.h" +#include "sslproto.h" +#include "sslexp.h" + +#include <memory> + +#include "tls_connect.h" + +namespace nss_test { + +static void IncrementCounterArg(void *arg) { + if (arg) { + auto *called = reinterpret_cast<size_t *>(arg); + ++*called; + } +} + +static PRBool NoopExtensionWriter(PRFileDesc *fd, SSLHandshakeType message, + PRUint8 *data, unsigned int *len, + unsigned int maxLen, void *arg) { + IncrementCounterArg(arg); + return PR_FALSE; +} + +static PRBool EmptyExtensionWriter(PRFileDesc *fd, SSLHandshakeType message, + PRUint8 *data, unsigned int *len, + unsigned int maxLen, void *arg) { + IncrementCounterArg(arg); + return PR_TRUE; +} + +static SECStatus NoopExtensionHandler(PRFileDesc *fd, SSLHandshakeType message, + const PRUint8 *data, unsigned int len, + SSLAlertDescription *alert, void *arg) { + return SECSuccess; +} + +// All of the (current) set of supported extensions, plus a few extra. +static const uint16_t kManyExtensions[] = { + ssl_server_name_xtn, + ssl_cert_status_xtn, + ssl_supported_groups_xtn, + ssl_ec_point_formats_xtn, + ssl_signature_algorithms_xtn, + ssl_signature_algorithms_cert_xtn, + ssl_use_srtp_xtn, + ssl_app_layer_protocol_xtn, + ssl_signed_cert_timestamp_xtn, + ssl_padding_xtn, + ssl_extended_master_secret_xtn, + ssl_session_ticket_xtn, + ssl_tls13_key_share_xtn, + ssl_tls13_pre_shared_key_xtn, + ssl_tls13_early_data_xtn, + ssl_tls13_supported_versions_xtn, + ssl_tls13_cookie_xtn, + ssl_tls13_psk_key_exchange_modes_xtn, + ssl_tls13_ticket_early_data_info_xtn, + ssl_tls13_certificate_authorities_xtn, + ssl_next_proto_nego_xtn, + ssl_renegotiation_info_xtn, + ssl_record_size_limit_xtn, + ssl_tls13_encrypted_client_hello_xtn, + 1, + 0xffff}; +// The list here includes all extensions we expect to use (SSL_MAX_EXTENSIONS), +// plus the deprecated values (see sslt.h), and two extra dummy values. +PR_STATIC_ASSERT((SSL_MAX_EXTENSIONS + 5) == PR_ARRAY_SIZE(kManyExtensions)); + +void InstallManyWriters(std::shared_ptr<TlsAgent> agent, + SSLExtensionWriter writer, size_t *installed = nullptr, + size_t *called = nullptr) { + for (size_t i = 0; i < PR_ARRAY_SIZE(kManyExtensions); ++i) { + SSLExtensionSupport support = ssl_ext_none; + SECStatus rv = SSL_GetExtensionSupport(kManyExtensions[i], &support); + ASSERT_EQ(SECSuccess, rv) << "SSL_GetExtensionSupport cannot fail"; + + rv = SSL_InstallExtensionHooks(agent->ssl_fd(), kManyExtensions[i], writer, + called, NoopExtensionHandler, nullptr); + if (support == ssl_ext_native_only) { + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + } else { + if (installed) { + ++*installed; + } + EXPECT_EQ(SECSuccess, rv); + } + } +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionAllNoopClient) { + EnsureTlsSetup(); + size_t installed = 0; + size_t called = 0; + InstallManyWriters(client_, NoopExtensionWriter, &installed, &called); + EXPECT_LT(0U, installed); + Connect(); + EXPECT_EQ(installed, called); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionAllNoopServer) { + EnsureTlsSetup(); + size_t installed = 0; + size_t called = 0; + InstallManyWriters(server_, NoopExtensionWriter, &installed, &called); + EXPECT_LT(0U, installed); + Connect(); + // Extension writers are all called for each of ServerHello, + // EncryptedExtensions, and Certificate. + EXPECT_EQ(installed * 3, called); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionEmptyWriterClient) { + EnsureTlsSetup(); + InstallManyWriters(client_, EmptyExtensionWriter); + InstallManyWriters(server_, EmptyExtensionWriter); + Connect(); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionEmptyWriterServer) { + EnsureTlsSetup(); + InstallManyWriters(server_, EmptyExtensionWriter); + // Sending extensions that the client doesn't expect leads to extensions + // appearing even if the client didn't send one, or in the wrong messages. + client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); +} + +// Install an writer to disable sending of a natively-supported extension. +TEST_F(TlsConnectStreamTls13, CustomExtensionWriterDisable) { + EnsureTlsSetup(); + + // This option enables sending the extension via the native support. + SECStatus rv = SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); + EXPECT_EQ(SECSuccess, rv); + + // This installs an override that doesn't do anything. You have to specify + // something; passing all nullptr values removes an existing handler. + rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NoopExtensionWriter, + nullptr, NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + auto capture = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_signed_cert_timestamp_xtn); + + Connect(); + // So nothing will be sent. + EXPECT_FALSE(capture->captured()); +} + +// An extension that is unlikely to be parsed as valid. +static uint8_t kNonsenseExtension[] = {91, 82, 73, 64, 55, 46, 37, 28, 19}; + +static PRBool NonsenseExtensionWriter(PRFileDesc *fd, SSLHandshakeType message, + PRUint8 *data, unsigned int *len, + unsigned int maxLen, void *arg) { + TlsAgent *agent = reinterpret_cast<TlsAgent *>(arg); + EXPECT_NE(nullptr, agent); + EXPECT_NE(nullptr, data); + EXPECT_NE(nullptr, len); + EXPECT_EQ(0U, *len); + EXPECT_LT(0U, maxLen); + EXPECT_EQ(agent->ssl_fd(), fd); + + if (message != ssl_hs_client_hello && message != ssl_hs_server_hello && + message != ssl_hs_encrypted_extensions) { + return PR_FALSE; + } + + *len = static_cast<unsigned int>(sizeof(kNonsenseExtension)); + EXPECT_GE(maxLen, *len); + if (maxLen < *len) { + return PR_FALSE; + } + PORT_Memcpy(data, kNonsenseExtension, *len); + return PR_TRUE; +} + +// Override the extension handler for an natively-supported and produce +// nonsense, which results in a handshake failure. +TEST_F(TlsConnectStreamTls13, CustomExtensionOverride) { + EnsureTlsSetup(); + + // This option enables sending the extension via the native support. + SECStatus rv = SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); + EXPECT_EQ(SECSuccess, rv); + + // This installs an override that sends nonsense. + rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NonsenseExtensionWriter, + client_.get(), NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture it to see what we got. + auto capture = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_signed_cert_timestamp_xtn); + + ConnectExpectAlert(server_, kTlsAlertDecodeError); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +static SECStatus NonsenseExtensionHandler(PRFileDesc *fd, + SSLHandshakeType message, + const PRUint8 *data, unsigned int len, + SSLAlertDescription *alert, + void *arg) { + TlsAgent *agent = reinterpret_cast<TlsAgent *>(arg); + EXPECT_EQ(agent->ssl_fd(), fd); + if (agent->role() == TlsAgent::SERVER) { + EXPECT_EQ(ssl_hs_client_hello, message); + } else { + EXPECT_TRUE(message == ssl_hs_server_hello || + message == ssl_hs_encrypted_extensions); + } + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + DataBuffer(data, len)); + EXPECT_NE(nullptr, alert); + return SECSuccess; +} + +// Send nonsense in an extension from client to server. +TEST_F(TlsConnectStreamTls13, CustomExtensionClientToServer) { + EnsureTlsSetup(); + + // This installs an override that sends nonsense. + const uint16_t extension_code = 0xffe5; + SECStatus rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), extension_code, NonsenseExtensionWriter, client_.get(), + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture it to see what we got. + auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, extension_code); + + // Handle it so that the handshake completes. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NoopExtensionWriter, nullptr, + NonsenseExtensionHandler, server_.get()); + EXPECT_EQ(SECSuccess, rv); + + Connect(); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +static PRBool NonsenseExtensionWriterSH(PRFileDesc *fd, + SSLHandshakeType message, PRUint8 *data, + unsigned int *len, unsigned int maxLen, + void *arg) { + if (message == ssl_hs_server_hello) { + return NonsenseExtensionWriter(fd, message, data, len, maxLen, arg); + } + return PR_FALSE; +} + +// Send nonsense in an extension from server to client, in ServerHello. +TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientSH) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + const uint16_t extension_code = 0xff5e; + SECStatus rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), extension_code, EmptyExtensionWriter, nullptr, + NonsenseExtensionHandler, client_.get()); + EXPECT_EQ(SECSuccess, rv); + + // Have the server send nonsense. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NonsenseExtensionWriterSH, server_.get(), + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture the extension from the ServerHello only and check it. + auto capture = MakeTlsFilter<TlsExtensionCapture>(server_, extension_code); + capture->SetHandshakeTypes({kTlsHandshakeServerHello}); + + Connect(); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +static PRBool NonsenseExtensionWriterEE(PRFileDesc *fd, + SSLHandshakeType message, PRUint8 *data, + unsigned int *len, unsigned int maxLen, + void *arg) { + if (message == ssl_hs_encrypted_extensions) { + return NonsenseExtensionWriter(fd, message, data, len, maxLen, arg); + } + return PR_FALSE; +} + +// Send nonsense in an extension from server to client, in EncryptedExtensions. +TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientEE) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + const uint16_t extension_code = 0xff5e; + SECStatus rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), extension_code, EmptyExtensionWriter, nullptr, + NonsenseExtensionHandler, client_.get()); + EXPECT_EQ(SECSuccess, rv); + + // Have the server send nonsense. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NonsenseExtensionWriterEE, server_.get(), + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture the extension from the EncryptedExtensions only and check it. + auto capture = MakeTlsFilter<TlsExtensionCapture>(server_, extension_code); + capture->SetHandshakeTypes({kTlsHandshakeEncryptedExtensions}); + capture->EnableDecryption(); + + Connect(); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionUnsolicitedServer) { + EnsureTlsSetup(); + + const uint16_t extension_code = 0xff5e; + SECStatus rv = SSL_InstallExtensionHooks( + server_->ssl_fd(), extension_code, NonsenseExtensionWriter, server_.get(), + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture it to see what we got. + auto capture = MakeTlsFilter<TlsExtensionCapture>(server_, extension_code); + + client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +SECStatus RejectExtensionHandler(PRFileDesc *fd, SSLHandshakeType message, + const PRUint8 *data, unsigned int len, + SSLAlertDescription *alert, void *arg) { + return SECFailure; +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionServerReject) { + EnsureTlsSetup(); + + // This installs an override that sends nonsense. + const uint16_t extension_code = 0xffe7; + SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Reject the extension for no good reason. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NoopExtensionWriter, nullptr, + RejectExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); +} + +// Send nonsense in an extension from client to server. +TEST_F(TlsConnectStreamTls13, CustomExtensionClientReject) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + const uint16_t extension_code = 0xff58; + SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + RejectExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Have the server send nonsense. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + client_->ExpectSendAlert(kTlsAlertHandshakeFailure); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); +} + +static const uint8_t kCustomAlert = 0xf6; + +SECStatus AlertExtensionHandler(PRFileDesc *fd, SSLHandshakeType message, + const PRUint8 *data, unsigned int len, + SSLAlertDescription *alert, void *arg) { + *alert = kCustomAlert; + return SECFailure; +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionServerRejectAlert) { + EnsureTlsSetup(); + + // This installs an override that sends nonsense. + const uint16_t extension_code = 0xffea; + SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Reject the extension for no good reason. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NoopExtensionWriter, nullptr, + AlertExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + ConnectExpectAlert(server_, kCustomAlert); +} + +// Send nonsense in an extension from client to server. +TEST_F(TlsConnectStreamTls13, CustomExtensionClientRejectAlert) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + const uint16_t extension_code = 0xff5a; + SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + AlertExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Have the server send nonsense. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + client_->ExpectSendAlert(kCustomAlert); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); +} + +// Configure a custom extension hook badly. +TEST_F(TlsConnectStreamTls13, CustomExtensionOnlyWriter) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + SECStatus rv = + SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff6c, EmptyExtensionWriter, + nullptr, nullptr, nullptr); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionOnlyHandler) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + SECStatus rv = + SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff6d, nullptr, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionOverrunBuffer) { + EnsureTlsSetup(); + // This doesn't actually overrun the buffer, but it says that it does. + auto overrun_writer = [](PRFileDesc *fd, SSLHandshakeType message, + PRUint8 *data, unsigned int *len, + unsigned int maxLen, void *arg) -> PRBool { + *len = maxLen + 1; + return PR_TRUE; + }; + SECStatus rv = + SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff71, overrun_writer, + nullptr, NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + client_->StartConnect(); + client_->Handshake(); + client_->CheckErrorCode(SEC_ERROR_APPLICATION_CALLBACK_ERROR); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc new file mode 100644 index 0000000000..9cbe9566f1 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc @@ -0,0 +1,104 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_F(TlsConnectTest, DamageSecretHandleClientFinished) { + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + std::cerr << "Damaging HS secret" << std::endl; + SSLInt_DamageClientHsTrafficSecret(server_->ssl_fd()); + client_->Handshake(); + // The client thinks it has connected. + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + + ExpectAlert(server_, kTlsAlertDecryptError); + server_->Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); +} + +TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) { + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + MakeTlsFilter<AfterRecordN>( + server_, client_, + 0, // ServerHello. + [this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); }); + ConnectExpectAlert(client_, kTlsAlertDecryptError); + client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); +} + +TEST_P(TlsConnectGenericPre13, DamageServerSignature) { + EnsureTlsSetup(); + auto filter = MakeTlsFilter<TlsLastByteDamager>( + server_, kTlsHandshakeServerKeyExchange); + ExpectAlert(client_, kTlsAlertDecryptError); + ConnectExpectFail(); + client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); + server_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); +} + +TEST_P(TlsConnectTls13, DamageServerSignature) { + EnsureTlsSetup(); + auto filter = MakeTlsFilter<TlsLastByteDamager>( + server_, kTlsHandshakeCertificateVerify); + filter->EnableDecryption(); + ConnectExpectAlert(client_, kTlsAlertDecryptError); + client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); +} + +TEST_P(TlsConnectGeneric, DamageClientSignature) { + EnsureTlsSetup(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + auto filter = MakeTlsFilter<TlsLastByteDamager>( + client_, kTlsHandshakeCertificateVerify); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + filter->EnableDecryption(); + } + server_->ExpectSendAlert(kTlsAlertDecryptError); + // Do these handshakes by hand to avoid race condition on + // the client processing the server's alert. + StartConnect(); + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + server_->Handshake(); + EXPECT_EQ(version_ >= SSL_LIBRARY_VERSION_TLS_1_3 + ? TlsAgent::STATE_CONNECTED + : TlsAgent::STATE_CONNECTING, + client_->state()); + server_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_debug_env_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_debug_env_unittest.cc new file mode 100644 index 0000000000..77b4d69afc --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_debug_env_unittest.cc @@ -0,0 +1,51 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <cstdlib> +#include <fstream> +#include <sstream> + +#include "gtest_utils.h" +#include "tls_connect.h" + +namespace nss_test { + +extern "C" { +extern FILE* ssl_trace_iob; + +#ifdef NSS_ALLOW_SSLKEYLOGFILE +extern FILE* ssl_keylog_iob; +#endif +} + +// These tests ensure that when the associated environment variables are unset +// that the lazily-initialized defaults are what they are supposed to be. + +#ifdef DEBUG +TEST_P(TlsConnectGeneric, DebugEnvTraceFileNotSet) { + char* ev = PR_GetEnvSecure("SSLDEBUGFILE"); + if (ev && ev[0]) { + GTEST_SKIP(); + } + + Connect(); + EXPECT_EQ(stderr, ssl_trace_iob); +} +#endif + +#ifdef NSS_ALLOW_SSLKEYLOGFILE +TEST_P(TlsConnectGeneric, DebugEnvKeylogFileNotSet) { + char* ev = PR_GetEnvSecure("SSLKEYLOGFILE"); + if (ev && ev[0]) { + GTEST_SKIP(); + } + + Connect(); + EXPECT_EQ(nullptr, ssl_keylog_iob); +} +#endif + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc new file mode 100644 index 0000000000..09beb2a6d9 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc @@ -0,0 +1,802 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include <set> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectGeneric, ConnectDhe) { + EnableOnlyDheCiphers(); + Connect(); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); +} + +TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) { + EnsureTlsSetup(); + client_->ConfigNamedGroups(kAllDHEGroups); + + auto groups_capture = + std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn); + auto shares_capture = + std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn); + std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture, + shares_capture}; + client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures)); + + Connect(); + + CheckKeys(); + + bool ec, dh; + auto track_group_type = [&ec, &dh](SSLNamedGroup group) { + if ((group & 0xff00U) == 0x100U) { + dh = true; + } else { + ec = true; + } + }; + CheckGroups(groups_capture->extension(), track_group_type); + CheckShares(shares_capture->extension(), track_group_type); + EXPECT_TRUE(ec) << "Should include an EC group and share"; + EXPECT_TRUE(dh) << "Should include an FFDHE group and share"; +} + +TEST_P(TlsConnectGeneric, ConnectFfdheClient) { + EnableOnlyDheCiphers(); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); + auto groups_capture = + std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn); + auto shares_capture = + std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn); + std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture, + shares_capture}; + client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures)); + + Connect(); + + CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign); + auto is_ffdhe = [](SSLNamedGroup group) { + // The group has to be in this range. + EXPECT_LE(ssl_grp_ffdhe_2048, group); + EXPECT_GE(ssl_grp_ffdhe_8192, group); + }; + CheckGroups(groups_capture->extension(), is_ffdhe); + if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { + CheckShares(shares_capture->extension(), is_ffdhe); + } else { + EXPECT_EQ(0U, shares_capture->extension().len()); + } +} + +// Requiring the FFDHE extension on the server alone means that clients won't be +// able to connect using a DHE suite. They should still connect in TLS 1.3, +// because the client automatically sends the supported groups extension. +TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) { + EnableOnlyDheCiphers(); + server_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + Connect(); + CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign); + } else { + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + } +} + +class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter { + public: + TlsDheServerKeyExchangeDamager(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, {kTlsHandshakeServerKeyExchange}) {} + virtual PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) { + // Damage the first octet of dh_p. Anything other than the known prime will + // be rejected as "weak" when we have SSL_REQUIRE_DH_NAMED_GROUPS enabled. + *output = input; + output->data()[3] ^= 73; + return CHANGE; + } +}; + +// Changing the prime in the server's key share results in an error. This will +// invalidate the signature over the ServerKeyShare. That's ok, NSS won't check +// the signature until everything else has been checked. +TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) { + EnableOnlyDheCiphers(); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); + MakeTlsFilter<TlsDheServerKeyExchangeDamager>(server_); + + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + + client_->CheckErrorCode(SSL_ERROR_WEAK_SERVER_EPHEMERAL_DH_KEY); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +class TlsDheSkeChangeY : public TlsHandshakeFilter { + public: + enum ChangeYTo { + kYZero, + kYOne, + kYPMinusOne, + kYGreaterThanP, + kYTooLarge, + kYZeroPad + }; + + TlsDheSkeChangeY(const std::shared_ptr<TlsAgent>& a, uint8_t handshake_type, + ChangeYTo change) + : TlsHandshakeFilter(a, {handshake_type}), change_Y_(change) {} + + protected: + void ChangeY(const DataBuffer& input, DataBuffer* output, size_t offset, + const DataBuffer& prime) { + static const uint8_t kExtraZero = 0; + static const uint8_t kTooLargeExtra = 1; + + uint32_t dh_Ys_len; + EXPECT_TRUE(input.Read(offset, 2, &dh_Ys_len)); + EXPECT_LT(offset + dh_Ys_len, input.len()); + offset += 2; + + // This isn't generally true, but our code pads. + EXPECT_EQ(prime.len(), dh_Ys_len) + << "Length of dh_Ys must equal length of dh_p"; + + *output = input; + switch (change_Y_) { + case kYZero: + memset(output->data() + offset, 0, prime.len()); + break; + + case kYOne: + memset(output->data() + offset, 0, prime.len() - 1); + output->Write(offset + prime.len() - 1, 1U, 1); + break; + + case kYPMinusOne: + output->Write(offset, prime); + EXPECT_TRUE(output->data()[offset + prime.len() - 1] & 0x01) + << "P must at least be odd"; + --output->data()[offset + prime.len() - 1]; + break; + + case kYGreaterThanP: + // Set the first 32 octets of Y to 0xff, except the first which we set + // to p[0]. This will make Y > p. That is, unless p is Mersenne, or + // improbably large (but still the same bit length). We currently only + // use a fixed prime that isn't a problem for this code. + EXPECT_LT(0, prime.data()[0]) << "dh_p should not be zero-padded"; + offset = output->Write(offset, prime.data()[0], 1); + memset(output->data() + offset, 0xff, 31); + break; + + case kYTooLarge: + // Increase the dh_Ys length. + output->Write(offset - 2, prime.len() + sizeof(kTooLargeExtra), 2); + // Then insert the octet. + output->Splice(&kTooLargeExtra, sizeof(kTooLargeExtra), offset); + break; + + case kYZeroPad: + output->Write(offset - 2, prime.len() + sizeof(kExtraZero), 2); + output->Splice(&kExtraZero, sizeof(kExtraZero), offset); + break; + } + } + + private: + ChangeYTo change_Y_; +}; + +class TlsDheSkeChangeYServer : public TlsDheSkeChangeY { + public: + TlsDheSkeChangeYServer(const std::shared_ptr<TlsAgent>& a, ChangeYTo change, + bool modify) + : TlsDheSkeChangeY(a, kTlsHandshakeServerKeyExchange, change), + modify_(modify), + p_() {} + + const DataBuffer& prime() const { return p_; } + + protected: + virtual PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) override { + size_t offset = 2; + // Read dh_p + uint32_t dh_len = 0; + EXPECT_TRUE(input.Read(0, 2, &dh_len)); + EXPECT_GT(input.len(), offset + dh_len); + p_.Assign(input.data() + offset, dh_len); + offset += dh_len; + + // Skip dh_g to find dh_Ys + EXPECT_TRUE(input.Read(offset, 2, &dh_len)); + offset += 2 + dh_len; + + if (modify_) { + ChangeY(input, output, offset, p_); + return CHANGE; + } + return KEEP; + } + + private: + bool modify_; + DataBuffer p_; +}; + +class TlsDheSkeChangeYClient : public TlsDheSkeChangeY { + public: + TlsDheSkeChangeYClient( + const std::shared_ptr<TlsAgent>& a, ChangeYTo change, + std::shared_ptr<const TlsDheSkeChangeYServer> server_filter) + : TlsDheSkeChangeY(a, kTlsHandshakeClientKeyExchange, change), + server_filter_(server_filter) {} + + protected: + virtual PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) override { + ChangeY(input, output, 0, server_filter_->prime()); + return CHANGE; + } + + private: + std::shared_ptr<const TlsDheSkeChangeYServer> server_filter_; +}; + +/* This matrix includes: variant (stream/datagram), TLS version, what change to + * make to dh_Ys, whether the client will be configured to require DH named + * groups. Test all combinations. */ +typedef std::tuple<SSLProtocolVariant, uint16_t, TlsDheSkeChangeY::ChangeYTo, + bool> + DamageDHYProfile; +class TlsDamageDHYTest + : public TlsConnectTestBase, + public ::testing::WithParamInterface<DamageDHYProfile> { + public: + TlsDamageDHYTest() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} +}; + +TEST_P(TlsDamageDHYTest, DamageServerY) { + EnableOnlyDheCiphers(); + if (std::get<3>(GetParam())) { + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); + } + TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam()); + MakeTlsFilter<TlsDheSkeChangeYServer>(server_, change, true); + + if (change == TlsDheSkeChangeY::kYZeroPad) { + ExpectAlert(client_, kTlsAlertDecryptError); + } else { + ExpectAlert(client_, kTlsAlertIllegalParameter); + } + ConnectExpectFail(); + if (change == TlsDheSkeChangeY::kYZeroPad) { + // Zero padding Y only manifests in a signature failure. + // In TLS 1.0 and 1.1, the client reports a device error. + if (version_ < SSL_LIBRARY_VERSION_TLS_1_2) { + client_->CheckErrorCode(SEC_ERROR_PKCS11_DEVICE_ERROR); + } else { + client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); + } + server_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + } else { + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + } +} + +TEST_P(TlsDamageDHYTest, DamageClientY) { + EnableOnlyDheCiphers(); + if (std::get<3>(GetParam())) { + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); + } + // The filter on the server is required to capture the prime. + auto server_filter = MakeTlsFilter<TlsDheSkeChangeYServer>( + server_, TlsDheSkeChangeY::kYZero, false); + + // The client filter does the damage. + TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam()); + MakeTlsFilter<TlsDheSkeChangeYClient>(client_, change, server_filter); + + if (change == TlsDheSkeChangeY::kYZeroPad) { + ExpectAlert(server_, kTlsAlertDecryptError); + } else { + ExpectAlert(server_, kTlsAlertHandshakeFailure); + } + ConnectExpectFail(); + if (change == TlsDheSkeChangeY::kYZeroPad) { + // Zero padding Y only manifests in a finished error. + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); + } else { + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE); + } +} + +static const TlsDheSkeChangeY::ChangeYTo kAllYArr[] = { + TlsDheSkeChangeY::kYZero, TlsDheSkeChangeY::kYOne, + TlsDheSkeChangeY::kYPMinusOne, TlsDheSkeChangeY::kYGreaterThanP, + TlsDheSkeChangeY::kYTooLarge, TlsDheSkeChangeY::kYZeroPad}; +static ::testing::internal::ParamGenerator<TlsDheSkeChangeY::ChangeYTo> kAllY = + ::testing::ValuesIn(kAllYArr); +static const bool kTrueFalseArr[] = {true, false}; +static ::testing::internal::ParamGenerator<bool> kTrueFalse = + ::testing::ValuesIn(kTrueFalseArr); + +INSTANTIATE_TEST_SUITE_P( + DamageYStream, TlsDamageDHYTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12, kAllY, kTrueFalse)); +INSTANTIATE_TEST_SUITE_P( + DamageYDatagram, TlsDamageDHYTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11V12, kAllY, kTrueFalse)); + +class TlsDheSkeMakePEven : public TlsHandshakeFilter { + public: + TlsDheSkeMakePEven(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, {kTlsHandshakeServerKeyExchange}) {} + + virtual PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) { + // Find the end of dh_p + uint32_t dh_len = 0; + EXPECT_TRUE(input.Read(0, 2, &dh_len)); + EXPECT_GT(input.len(), 2 + dh_len) << "enough space for dh_p"; + size_t offset = 2 + dh_len - 1; + EXPECT_TRUE((input.data()[offset] & 0x01) == 0x01) << "p should be odd"; + + *output = input; + output->data()[offset] &= 0xfe; + + return CHANGE; + } +}; + +// Even without requiring named groups, an even value for p is bad news. +TEST_P(TlsConnectGenericPre13, MakeDhePEven) { + EnableOnlyDheCiphers(); + MakeTlsFilter<TlsDheSkeMakePEven>(server_); + + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +class TlsDheSkeZeroPadP : public TlsHandshakeFilter { + public: + TlsDheSkeZeroPadP(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, {kTlsHandshakeServerKeyExchange}) {} + + virtual PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) { + *output = input; + uint32_t dh_len = 0; + EXPECT_TRUE(input.Read(0, 2, &dh_len)); + static const uint8_t kZeroPad = 0; + output->Write(0, dh_len + sizeof(kZeroPad), 2); // increment the length + output->Splice(&kZeroPad, sizeof(kZeroPad), 2); // insert a zero + + return CHANGE; + } +}; + +// Zero padding only causes signature failure. +TEST_P(TlsConnectGenericPre13, PadDheP) { + EnableOnlyDheCiphers(); + MakeTlsFilter<TlsDheSkeZeroPadP>(server_); + + ConnectExpectAlert(client_, kTlsAlertDecryptError); + + // In TLS 1.0 and 1.1, the client reports a device error. + if (version_ < SSL_LIBRARY_VERSION_TLS_1_2) { + client_->CheckErrorCode(SEC_ERROR_PKCS11_DEVICE_ERROR); + } else { + client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); + } + server_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); +} + +// The server should not pick the weak DH group if the client includes FFDHE +// named groups in the supported_groups extension. The server then picks a +// commonly-supported named DH group and this connects. +// +// Note: This test case can take ages to generate the weak DH key. +TEST_P(TlsConnectGenericPre13, WeakDHGroup) { + EnableOnlyDheCiphers(); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); + EXPECT_EQ(SECSuccess, + SSL_EnableWeakDHEPrimeGroup(server_->ssl_fd(), PR_TRUE)); + + Connect(); +} + +TEST_P(TlsConnectGeneric, Ffdhe3072) { + EnableOnlyDheCiphers(); + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ffdhe_3072}; + client_->ConfigNamedGroups(groups); + + Connect(); +} + +// Even though the client doesn't have DHE groups enabled the server assumes it +// does. Because the client doesn't require named groups it accepts FF3072 as +// custom group. +TEST_P(TlsConnectGenericPre13, NamedGroupMismatchPre13) { + EnableOnlyDheCiphers(); + static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072}; + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp256r1}; + server_->ConfigNamedGroups(server_groups); + client_->ConfigNamedGroups(client_groups); + + Connect(); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_custom, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); +} + +// Same test but for TLS 1.3. This has to fail. +TEST_P(TlsConnectTls13, NamedGroupMismatch13) { + EnableOnlyDheCiphers(); + static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072}; + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp256r1}; + server_->ConfigNamedGroups(server_groups); + client_->ConfigNamedGroups(client_groups); + + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// Replace the key share in the server key exchange message with one that's +// larger than 8192 bits. +class TooLongDHEServerKEXFilter : public TlsHandshakeFilter { + public: + TooLongDHEServerKEXFilter(const std::shared_ptr<TlsAgent>& server) + : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}) {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + // Replace the server key exchange message very large DH shares that are + // not supported by NSS. + const uint32_t share_len = 0x401; + const uint8_t zero_share[share_len] = {0x80}; + size_t offset = 0; + // Write dh_p. + offset = output->Write(offset, share_len, 2); + offset = output->Write(offset, zero_share, share_len); + // Write dh_g. + offset = output->Write(offset, share_len, 2); + offset = output->Write(offset, zero_share, share_len); + // Write dh_Y. + offset = output->Write(offset, share_len, 2); + offset = output->Write(offset, zero_share, share_len); + + return CHANGE; + } +}; + +TEST_P(TlsConnectGenericPre13, TooBigDHGroup) { + EnableOnlyDheCiphers(); + MakeTlsFilter<TooLongDHEServerKEXFilter>(server_); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_FALSE); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + client_->CheckErrorCode(SSL_ERROR_DH_KEY_TOO_LONG); +} + +// Even though the client doesn't have DHE groups enabled the server assumes it +// does. The client requires named groups and thus does not accept FF3072 as +// custom group in contrast to the previous test. +TEST_P(TlsConnectGenericPre13, RequireNamedGroupsMismatchPre13) { + EnableOnlyDheCiphers(); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); + static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072}; + static const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ffdhe_2048}; + server_->ConfigNamedGroups(server_groups); + client_->ConfigNamedGroups(client_groups); + + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +TEST_P(TlsConnectGenericPre13, PreferredFfdhe) { + EnableOnlyDheCiphers(); + static const SSLDHEGroupType groups[] = {ssl_ff_dhe_3072_group, + ssl_ff_dhe_2048_group}; + EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), groups, + PR_ARRAY_SIZE(groups))); + + Connect(); + client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); + server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); + client_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); + server_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); +} + +TEST_P(TlsConnectGenericPre13, MismatchDHE) { + EnableOnlyDheCiphers(); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); + static const SSLDHEGroupType serverGroups[] = {ssl_ff_dhe_3072_group}; + EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), serverGroups, + PR_ARRAY_SIZE(serverGroups))); + static const SSLDHEGroupType clientGroups[] = {ssl_ff_dhe_2048_group}; + EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(client_->ssl_fd(), clientGroups, + PR_ARRAY_SIZE(clientGroups))); + + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +TEST_P(TlsConnectTls13, ResumeFfdhe) { + EnableOnlyDheCiphers(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + EnableOnlyDheCiphers(); + auto clientCapture = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn); + auto serverCapture = + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_pre_shared_key_xtn); + ExpectResumption(RESUME_TICKET); + Connect(); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + ASSERT_LT(0UL, clientCapture->extension().len()); + ASSERT_LT(0UL, serverCapture->extension().len()); +} + +class TlsDheSkeChangeSignature : public TlsHandshakeFilter { + public: + TlsDheSkeChangeSignature(const std::shared_ptr<TlsAgent>& a, uint16_t version, + const uint8_t* data, size_t len) + : TlsHandshakeFilter(a, {kTlsHandshakeServerKeyExchange}), + version_(version), + data_(data), + len_(len) {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + TlsParser parser(input); + EXPECT_TRUE(parser.SkipVariable(2)); // dh_p + EXPECT_TRUE(parser.SkipVariable(2)); // dh_g + EXPECT_TRUE(parser.SkipVariable(2)); // dh_Ys + + // Copy DH params to output. + size_t offset = output->Write(0, input.data(), parser.consumed()); + + if (version_ == SSL_LIBRARY_VERSION_TLS_1_2) { + // Write signature algorithm. + offset = output->Write(offset, ssl_sig_dsa_sha256, 2); + } + + // Write new signature. + offset = output->Write(offset, len_, 2); + offset = output->Write(offset, data_, len_); + + return CHANGE; + } + + private: + uint16_t version_; + const uint8_t* data_; + size_t len_; +}; + +TEST_P(TlsConnectGenericPre13, InvalidDERSignatureFfdhe) { + const uint8_t kBogusDheSignature[] = { + 0x30, 0x69, 0x3c, 0x02, 0x1c, 0x7d, 0x0b, 0x2f, 0x64, 0x00, 0x27, + 0xae, 0xcf, 0x1e, 0x28, 0x08, 0x6a, 0x7f, 0xb1, 0xbd, 0x78, 0xb5, + 0x3b, 0x8c, 0x8f, 0x59, 0xed, 0x8f, 0xee, 0x78, 0xeb, 0x2c, 0xe9, + 0x02, 0x1c, 0x6d, 0x7f, 0x3c, 0x0f, 0xf4, 0x44, 0x35, 0x0b, 0xb2, + 0x6d, 0xdc, 0xb8, 0x21, 0x87, 0xdd, 0x0d, 0xb9, 0x46, 0x09, 0x3e, + 0xef, 0x81, 0x5b, 0x37, 0x09, 0x39, 0xeb}; + + Reset(TlsAgent::kServerDsa); + + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ffdhe_2048}; + client_->ConfigNamedGroups(client_groups); + + MakeTlsFilter<TlsDheSkeChangeSignature>(server_, version_, kBogusDheSignature, + sizeof(kBogusDheSignature)); + + ConnectExpectAlert(client_, kTlsAlertDecryptError); + client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); +} + +TEST_P(TlsConnectTls12, ConnectInconsistentSigAlgDHE) { + EnableOnlyDheCiphers(); + + MakeTlsFilter<DHEServerKEXSigAlgReplacer>(server_, + ssl_sig_ecdsa_secp256r1_sha256); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); +} + +static void CheckSkeSigScheme( + std::shared_ptr<TlsHandshakeRecorder>& capture_ske, + uint16_t expected_scheme) { + TlsParser parser(capture_ske->buffer()); + EXPECT_TRUE(parser.SkipVariable(2)) << " read dh_p"; + EXPECT_TRUE(parser.SkipVariable(2)) << " read dh_q"; + EXPECT_TRUE(parser.SkipVariable(2)) << " read dh_Ys"; + + uint32_t tmp; + EXPECT_TRUE(parser.Read(&tmp, 2)) << " read sig_scheme"; + EXPECT_EQ(expected_scheme, static_cast<uint16_t>(tmp)); +} + +TEST_P(TlsConnectTls12, ConnectSigAlgEnabledByPolicyDhe) { + EnableOnlyDheCiphers(); + + const std::vector<SSLSignatureScheme> schemes = {ssl_sig_rsa_pkcs1_sha1, + ssl_sig_rsa_pkcs1_sha384}; + + EnsureTlsSetup(); + client_->SetSignatureSchemes(schemes.data(), schemes.size()); + server_->SetSignatureSchemes(schemes.data(), schemes.size()); + auto capture_ske = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeServerKeyExchange); + + StartConnect(); + client_->Handshake(); // Send ClientHello + + // Enable SHA-1 by policy. + SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, NSS_USE_ALG_IN_SSL_KX, 0); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + Handshake(); // Remainder of handshake + // The server should now report that it is connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + CheckSkeSigScheme(capture_ske, ssl_sig_rsa_pkcs1_sha1); +} + +TEST_P(TlsConnectTls12, ConnectSigAlgDisabledByPolicyDhe) { + EnableOnlyDheCiphers(); + + const std::vector<SSLSignatureScheme> schemes = {ssl_sig_rsa_pkcs1_sha1, + ssl_sig_rsa_pkcs1_sha384}; + + EnsureTlsSetup(); + client_->SetSignatureSchemes(schemes.data(), schemes.size()); + server_->SetSignatureSchemes(schemes.data(), schemes.size()); + auto capture_ske = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeServerKeyExchange); + + StartConnect(); + client_->Handshake(); // Send ClientHello + + // Disable SHA-1 by policy after sending ClientHello so that CH + // includes SHA-1 signature scheme. + SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, 0, NSS_USE_ALG_IN_SSL_KX); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + Handshake(); // Remainder of handshake + // The server should now report that it is connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + CheckSkeSigScheme(capture_ske, ssl_sig_rsa_pkcs1_sha384); +} + +TEST_P(TlsConnectPre12, ConnectSigAlgDisabledWeakGroupByOption3072DhePre12) { + EnableOnlyDheCiphers(); + + // explicitly enable the weak groups + EXPECT_EQ(SECSuccess, + SSL_EnableWeakDHEPrimeGroup(server_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, + SSL_EnableWeakDHEPrimeGroup(client_->ssl_fd(), PR_TRUE)); + server_->SetNssOption(NSS_DH_MIN_KEY_SIZE, 3072); + Connect(); + client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); + server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); +} + +TEST_P(TlsConnectPre12, ConnectSigAlgDisabledWeakGroupByOption2048DhePre12) { + EnableOnlyDheCiphers(); + + // explicitly enable the weak groups + EXPECT_EQ(SECSuccess, + SSL_EnableWeakDHEPrimeGroup(server_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, + SSL_EnableWeakDHEPrimeGroup(client_->ssl_fd(), PR_TRUE)); + server_->SetNssOption(NSS_DH_MIN_KEY_SIZE, 2048); + Connect(); + client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_2048, 2048); + server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_2048, 2048); +} + +TEST_P(TlsConnectPre12, ConnectSigAlgDisabledByPolicyDhePre12) { + EnableOnlyDheCiphers(); + + EnsureTlsSetup(); + StartConnect(); + client_->Handshake(); // Send ClientHello + + // Disable SHA-1 by policy. This will cause the connection fail as + // TLS 1.1 or earlier uses combined SHA-1 + MD5 signature. + SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, 0, NSS_USE_ALG_IN_SSL_KX); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + server_->ExpectSendAlert(kTlsAlertHandshakeFailure); + client_->ExpectReceiveAlert(kTlsAlertHandshakeFailure); + + // Remainder of handshake + Handshake(); + + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_HASH_ALGORITHM); +} + +TEST_P(TlsConnectTls12, ConnectSigAlgDisablePreferredGroupByOption3072Dhe) { + EnableOnlyDheCiphers(); + static const SSLDHEGroupType dhe_groups[] = { + ssl_ff_dhe_2048_group, // first in the lists is the preferred group + ssl_ff_dhe_3072_group}; + + server_->SetNssOption(NSS_DH_MIN_KEY_SIZE, 3072); + EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), &dhe_groups[0], + PR_ARRAY_SIZE(dhe_groups))); + Connect(); + // our option size should override the preferred group + client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); + server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); +} + +TEST_P(TlsConnectTls12, ConnectSigAlgDisableGroupByOption3072Dhe) { + EnableOnlyDheCiphers(); + + server_->SetNssOption(NSS_DH_MIN_KEY_SIZE, 3072); + Connect(); + client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); + server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc new file mode 100644 index 0000000000..98b29921ea --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc @@ -0,0 +1,914 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslexp.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectDatagramPre13, DropClientFirstFlightOnce) { + client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x1)); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightOnce) { + server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x1)); + Connect(); + SendReceive(); +} + +// This drops the first transmission from both the client and server of all +// flights that they send. Note: In DTLS 1.3, the shorter handshake means that +// this will also drop some application data, so we can't call SendReceive(). +TEST_P(TlsConnectDatagramPre13, DropAllFirstTransmissions) { + client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x15)); + server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x5)); + Connect(); +} + +// This drops the server's first flight three times. +TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightThrice) { + server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x7)); + Connect(); +} + +// This drops the client's second flight once +TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightOnce) { + client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x2)); + Connect(); +} + +// This drops the client's second flight three times. +TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightThrice) { + client_->SetFilter(std::make_shared<SelectiveDropFilter>(0xe)); + Connect(); +} + +// This drops the server's second flight three times. +TEST_P(TlsConnectDatagramPre13, DropServerSecondFlightThrice) { + server_->SetFilter(std::make_shared<SelectiveDropFilter>(0xe)); + Connect(); +} + +static void CheckAcks(const std::shared_ptr<TlsRecordRecorder>& acks, + size_t index, std::vector<uint64_t> expected) { + ASSERT_LT(index, acks->count()); + const DataBuffer& buf = acks->record(index).buffer; + size_t offset = 2; + uint64_t len; + + EXPECT_EQ(2 + expected.size() * 8, buf.len()); + ASSERT_TRUE(buf.Read(0, 2, &len)); + ASSERT_EQ(static_cast<size_t>(len + 2), buf.len()); + if ((2 + expected.size() * 8) != buf.len()) { + while (offset < buf.len()) { + uint64_t ack; + ASSERT_TRUE(buf.Read(offset, 8, &ack)); + offset += 8; + std::cerr << "Ack=0x" << std::hex << ack << std::dec << std::endl; + } + return; + } + + for (size_t i = 0; i < expected.size(); ++i) { + uint64_t a = expected[i]; + uint64_t ack; + ASSERT_TRUE(buf.Read(offset, 8, &ack)); + offset += 8; + if (a != ack) { + ADD_FAILURE() << "Wrong ack " << i << " expected=0x" << std::hex << a + << " got=0x" << ack << std::dec; + } + } +} + +class TlsDropDatagram13 : public TlsConnectDatagram13, + public ::testing::WithParamInterface<bool> { + public: + TlsDropDatagram13() + : client_filters_(), + server_filters_(), + expected_client_acks_(0), + expected_server_acks_(1) {} + + void SetUp() override { + TlsConnectDatagram13::SetUp(); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + int short_header = GetParam() ? PR_TRUE : PR_FALSE; + client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, short_header); + server_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, short_header); + SetFilters(); + } + + void SetFilters() { + EnsureTlsSetup(); + client_filters_.Init(client_); + server_filters_.Init(server_); + } + + void HandshakeAndAck(const std::shared_ptr<TlsAgent>& agent) { + agent->Handshake(); // Read flight. + ShiftDtlsTimers(); + agent->Handshake(); // Generate ACK. + } + + void ShrinkPostServerHelloMtu() { + // Abuse the custom extension mechanism to modify the MTU so that the + // Certificate message is split into two pieces. + ASSERT_EQ( + SECSuccess, + SSL_InstallExtensionHooks( + server_->ssl_fd(), 1, + [](PRFileDesc* fd, SSLHandshakeType message, PRUint8* data, + unsigned int* len, unsigned int maxLen, void* arg) -> PRBool { + SSLInt_SetMTU(fd, 500); // Splits the certificate. + return PR_FALSE; + }, + nullptr, + [](PRFileDesc* fd, SSLHandshakeType message, const PRUint8* data, + unsigned int len, SSLAlertDescription* alert, + void* arg) -> SECStatus { return SECSuccess; }, + nullptr)); + } + + protected: + class DropAckChain { + public: + DropAckChain() + : records_(nullptr), ack_(nullptr), drop_(nullptr), chain_(nullptr) {} + + void Init(const std::shared_ptr<TlsAgent>& agent) { + records_ = std::make_shared<TlsRecordRecorder>(agent); + ack_ = std::make_shared<TlsRecordRecorder>(agent, ssl_ct_ack); + ack_->EnableDecryption(); + drop_ = std::make_shared<SelectiveRecordDropFilter>(agent, 0, false); + chain_ = std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit({records_, ack_, drop_})); + agent->SetFilter(chain_); + } + + const TlsRecord& record(size_t i) const { return records_->record(i); } + + std::shared_ptr<TlsRecordRecorder> records_; + std::shared_ptr<TlsRecordRecorder> ack_; + std::shared_ptr<SelectiveRecordDropFilter> drop_; + std::shared_ptr<PacketFilter> chain_; + }; + + void CheckedHandshakeSendReceive() { + Handshake(); + CheckPostHandshake(); + } + + void CheckPostHandshake() { + CheckConnected(); + SendReceive(); + EXPECT_EQ(expected_client_acks_, client_filters_.ack_->count()); + EXPECT_EQ(expected_server_acks_, server_filters_.ack_->count()); + } + + protected: + DropAckChain client_filters_; + DropAckChain server_filters_; + size_t expected_client_acks_; + size_t expected_server_acks_; +}; + +// All of these tests produce a minimum one ACK, from the server +// to the client upon receiving the client Finished. +// Dropping complete first and second flights does not produce +// ACKs +TEST_P(TlsDropDatagram13, DropClientFirstFlightOnce) { + client_filters_.drop_->Reset({0}); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + CheckedHandshakeSendReceive(); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); +} + +TEST_P(TlsDropDatagram13, DropServerFirstFlightOnce) { + server_filters_.drop_->Reset(0xff); + StartConnect(); + client_->Handshake(); + // Send the first flight, all dropped. + server_->Handshake(); + server_filters_.drop_->Disable(); + CheckedHandshakeSendReceive(); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); +} + +// Dropping the server's first record also does not produce +// an ACK because the next record is ignored. +// TODO(ekr@rtfm.com): We should generate an empty ACK. +TEST_P(TlsDropDatagram13, DropServerFirstRecordOnce) { + server_filters_.drop_->Reset({0}); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + Handshake(); + CheckedHandshakeSendReceive(); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); +} + +// Dropping the second packet of the server's flight should +// produce an ACK. +TEST_P(TlsDropDatagram13, DropServerSecondRecordOnce) { + server_filters_.drop_->Reset({1}); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + HandshakeAndAck(client_); + expected_client_acks_ = 1; + CheckedHandshakeSendReceive(); + CheckAcks(client_filters_.ack_, 0, {0}); // ServerHello + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); +} + +// Drop the server ACK and verify that the client retransmits +// the ClientHello. +TEST_P(TlsDropDatagram13, DropServerAckOnce) { + StartConnect(); + client_->Handshake(); + server_->Handshake(); + // At this point the server has sent it's first flight, + // so make it drop the ACK. + server_filters_.drop_->Reset({0}); + client_->Handshake(); // Send the client Finished. + server_->Handshake(); // Receive the Finished and send the ACK. + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + // Wait for the DTLS timeout to make sure we retransmit the + // Finished. + ShiftDtlsTimers(); + client_->Handshake(); // Retransmit the Finished. + server_->Handshake(); // Read the Finished and send an ACK. + uint8_t buf[1]; + PRInt32 rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf)); + expected_server_acks_ = 2; + EXPECT_GT(0, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + CheckPostHandshake(); + // There should be two copies of the finished ACK + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_.ack_, 1, {0x0002000000000000ULL}); +} + +// Drop the client certificate verify. +TEST_P(TlsDropDatagram13, DropClientCertVerify) { + StartConnect(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + client_->Handshake(); + server_->Handshake(); + // Have the client drop Cert Verify + client_filters_.drop_->Reset({1}); + expected_server_acks_ = 2; + CheckedHandshakeSendReceive(); + // Ack of the Cert. + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); + // Ack of the whole client handshake. + CheckAcks( + server_filters_.ack_, 1, + {0x0002000000000000ULL, // CH (we drop everything after this on client) + 0x0002000000000003ULL, // CT (2) + 0x0002000000000004ULL}); // FIN (2) +} + +// Shrink the MTU down so that certs get split and drop the first piece. +TEST_P(TlsDropDatagram13, DropFirstHalfOfServerCertificate) { + server_filters_.drop_->Reset({2}); + StartConnect(); + ShrinkPostServerHelloMtu(); + client_->Handshake(); + server_->Handshake(); + // Check that things got split. + EXPECT_EQ(6UL, + server_filters_.records_->count()); // SH, EE, CT1, CT2, CV, FIN + size_t ct1_size = server_filters_.record(2).buffer.len(); + server_filters_.records_->Clear(); + expected_client_acks_ = 1; + HandshakeAndAck(client_); + server_->Handshake(); // Retransmit + EXPECT_EQ(3UL, server_filters_.records_->count()); // CT2, CV, FIN + // Check that the first record is CT1 (which is identical to the same + // as the previous CT1). + EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len()); + CheckedHandshakeSendReceive(); + CheckAcks(client_filters_.ack_, 0, + {0, // SH + 0x0002000000000000ULL, // EE + 0x0002000000000002ULL}); // CT2 + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); +} + +// Shrink the MTU down so that certs get split and drop the second piece. +TEST_P(TlsDropDatagram13, DropSecondHalfOfServerCertificate) { + server_filters_.drop_->Reset({3}); + StartConnect(); + ShrinkPostServerHelloMtu(); + client_->Handshake(); + server_->Handshake(); + // Check that things got split. + EXPECT_EQ(6UL, + server_filters_.records_->count()); // SH, EE, CT1, CT2, CV, FIN + size_t ct1_size = server_filters_.record(3).buffer.len(); + server_filters_.records_->Clear(); + expected_client_acks_ = 1; + HandshakeAndAck(client_); + server_->Handshake(); // Retransmit + EXPECT_EQ(3UL, server_filters_.records_->count()); // CT1, CV, FIN + // Check that the first record is CT1 + EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len()); + CheckedHandshakeSendReceive(); + CheckAcks(client_filters_.ack_, 0, + { + 0, // SH + 0x0002000000000000ULL, // EE + 0x0002000000000001ULL, // CT1 + }); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); +} + +// In this test, the Certificate message is sent four times, we drop all or part +// of the first three attempts: +// 1. Without fragmentation so that we can see how big it is - we drop that. +// 2. In two pieces - we drop half AND the resulting ACK. +// 3. In three pieces - we drop the middle piece. +// +// After that we let all the ACKs through and allow the handshake to complete +// without further interference. +// +// This allows us to test that ranges of handshake messages are sent correctly +// even when there are overlapping acknowledgments; that ACKs with duplicate or +// overlapping message ranges are handled properly; and that extra +// retransmissions are handled properly. +class TlsFragmentationAndRecoveryTest : public TlsDropDatagram13 { + public: + TlsFragmentationAndRecoveryTest() : cert_len_(0) {} + + protected: + void RunTest(size_t dropped_half) { + FirstFlightDropCertificate(); + + SecondAttemptDropHalf(dropped_half); + size_t dropped_half_size = server_record_len(dropped_half); + size_t second_flight_count = server_filters_.records_->count(); + + ThirdAttemptDropMiddle(); + size_t repaired_third_size = server_record_len((dropped_half == 0) ? 0 : 2); + size_t third_flight_count = server_filters_.records_->count(); + + AckAndCompleteRetransmission(); + size_t final_server_flight_count = server_filters_.records_->count(); + EXPECT_LE(3U, final_server_flight_count); // CT(sixth), CV, Fin + CheckSizeOfSixth(dropped_half_size, repaired_third_size); + + SendDelayedAck(); + // Same number of messages as the last flight. + EXPECT_EQ(final_server_flight_count, server_filters_.records_->count()); + // Double check that the Certificate size is still correct. + CheckSizeOfSixth(dropped_half_size, repaired_third_size); + + CompleteHandshake(final_server_flight_count); + + // This is the ACK for the first attempt to send a whole certificate. + std::vector<uint64_t> client_acks = { + 0, // SH + 0x0002000000000000ULL // EE + }; + CheckAcks(client_filters_.ack_, 0, client_acks); + // And from the second attempt for the half was kept (we delayed this ACK). + client_acks.push_back(0x0002000000000000ULL + second_flight_count + + ~dropped_half % 2); + CheckAcks(client_filters_.ack_, 1, client_acks); + // And the third attempt where the first and last thirds got through. + client_acks.push_back(0x0002000000000000ULL + second_flight_count + + third_flight_count - 1); + client_acks.push_back(0x0002000000000000ULL + second_flight_count + + third_flight_count + 1); + CheckAcks(client_filters_.ack_, 2, client_acks); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); + } + + private: + void FirstFlightDropCertificate() { + StartConnect(); + client_->Handshake(); + + // Note: 1 << N is the Nth packet, starting from zero. + server_filters_.drop_->Reset(1 << 2); // Drop Cert0. + server_->Handshake(); + EXPECT_EQ(5U, server_filters_.records_->count()); // SH, EE, CT, CV, Fin + cert_len_ = server_filters_.records_->record(2).buffer.len(); + + HandshakeAndAck(client_); + EXPECT_EQ(2U, client_filters_.records_->count()); + } + + // Lower the MTU so that the server has to split the certificate in two + // pieces. The server resends Certificate (in two), plus CV and Fin. + void SecondAttemptDropHalf(size_t dropped_half) { + ASSERT_LE(0U, dropped_half); + ASSERT_GT(2U, dropped_half); + server_filters_.records_->Clear(); + server_filters_.drop_->Reset({dropped_half}); // Drop Cert1[half] + SplitServerMtu(2); + server_->Handshake(); + EXPECT_LE(4U, server_filters_.records_->count()); // CT x2, CV, Fin + + // Generate and capture the ACK from the client. + client_filters_.drop_->Reset({0}); + HandshakeAndAck(client_); + EXPECT_EQ(3U, client_filters_.records_->count()); + } + + // Lower the MTU again so that the server sends Certificate cut into three + // pieces. Drop the middle piece. + void ThirdAttemptDropMiddle() { + server_filters_.records_->Clear(); + server_filters_.drop_->Reset({1}); // Drop Cert2[1] (of 3) + SplitServerMtu(3); + // Because we dropped the client ACK, the server retransmits on a timer. + ShiftDtlsTimers(); + server_->Handshake(); + EXPECT_LE(5U, server_filters_.records_->count()); // CT x3, CV, Fin + } + + void AckAndCompleteRetransmission() { + // Generate ACKs. + HandshakeAndAck(client_); + // The server should send the final sixth of the certificate: the client has + // acknowledged the first half and the last third. Also send CV and Fin. + server_filters_.records_->Clear(); + server_->Handshake(); + } + + void CheckSizeOfSixth(size_t size_of_half, size_t size_of_third) { + // Work out if the final sixth is the right size. We get the records with + // overheads added, which obscures the length of the payload. We want to + // ensure that the server only sent the missing sixth of the Certificate. + // + // We captured |size_of_half + overhead| and |size_of_third + overhead| and + // want to calculate |size_of_third - size_of_third + overhead|. We can't + // calculate |overhead|, but it is is (currently) always a handshake message + // header, a content type, and an authentication tag: + static const size_t record_overhead = 12 + 1 + 16; + EXPECT_EQ(size_of_half - size_of_third + record_overhead, + server_filters_.records_->record(0).buffer.len()); + } + + void SendDelayedAck() { + // Send the ACK we held back. The reordered ACK doesn't add new + // information, + // but triggers an extra retransmission of the missing records again (even + // though the client has all that it needs). + client_->SendRecordDirect(client_filters_.records_->record(2)); + server_filters_.records_->Clear(); + server_->Handshake(); + } + + void CompleteHandshake(size_t extra_retransmissions) { + // All this messing around shouldn't cause a failure... + Handshake(); + // ...but it leaves a mess. Add an extra few calls to Handshake() for the + // client so that it absorbs the extra retransmissions. + for (size_t i = 0; i < extra_retransmissions; ++i) { + client_->Handshake(); + } + CheckConnected(); + } + + // Split the server MTU so that the Certificate is split into |count| pieces. + // The calculation doesn't need to be perfect as long as the Certificate + // message is split into the right number of pieces. + void SplitServerMtu(size_t count) { + // Set the MTU based on the formula: + // bare_size = cert_len_ - actual_overhead + // MTU = ceil(bare_size / count) + pessimistic_overhead + // + // actual_overhead is the amount of actual overhead on the record we + // captured, which is (note that our length doesn't include the header): + static const size_t actual_overhead = 12 + // handshake message header + 1 + // content type + 16; // authentication tag + size_t bare_size = cert_len_ - actual_overhead; + + // pessimistic_overhead is the amount of expansion that NSS assumes will be + // added to each handshake record. Right now, that is DTLS_MIN_FRAGMENT: + static const size_t pessimistic_overhead = + 12 + // handshake message header + 1 + // content type + 13 + // record header length + 64; // maximum record expansion: IV, MAC and block cipher expansion + + size_t mtu = (bare_size + count - 1) / count + pessimistic_overhead; + if (g_ssl_gtest_verbose) { + std::cerr << "server: set MTU to " << mtu << std::endl; + } + EXPECT_EQ(SECSuccess, SSLInt_SetMTU(server_->ssl_fd(), mtu)); + } + + size_t server_record_len(size_t index) const { + return server_filters_.records_->record(index).buffer.len(); + } + + size_t cert_len_; +}; + +TEST_P(TlsFragmentationAndRecoveryTest, DropFirstHalf) { RunTest(0); } + +TEST_P(TlsFragmentationAndRecoveryTest, DropSecondHalf) { RunTest(1); } + +TEST_P(TlsDropDatagram13, NoDropsDuringZeroRtt) { + SetupForZeroRtt(); + SetFilters(); + std::cerr << "Starting second handshake" << std::endl; + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + EXPECT_EQ(0U, client_filters_.ack_->count()); + CheckAcks(server_filters_.ack_, 0, + {0x0001000000000001ULL, // EOED + 0x0002000000000000ULL}); // Finished +} + +TEST_P(TlsDropDatagram13, DropEEDuringZeroRtt) { + SetupForZeroRtt(); + SetFilters(); + std::cerr << "Starting second handshake" << std::endl; + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + server_filters_.drop_->Reset({1}); + ZeroRttSendReceive(true, true); + HandshakeAndAck(client_); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + CheckAcks(client_filters_.ack_, 0, {0}); + CheckAcks(server_filters_.ack_, 0, + {0x0001000000000002ULL, // EOED + 0x0002000000000000ULL}); // Finished +} + +class TlsReorderDatagram13 : public TlsDropDatagram13 { + public: + TlsReorderDatagram13() {} + + // Send records from the records buffer in the given order. + void ReSend(TlsAgent::Role side, std::vector<size_t> indices) { + std::shared_ptr<TlsAgent> agent; + std::shared_ptr<TlsRecordRecorder> records; + + if (side == TlsAgent::CLIENT) { + agent = client_; + records = client_filters_.records_; + } else { + agent = server_; + records = server_filters_.records_; + } + + for (auto i : indices) { + agent->SendRecordDirect(records->record(i)); + } + } +}; + +// Reorder the server records so that EE comes at the end +// of the flight and will still produce an ACK. +TEST_P(TlsDropDatagram13, ReorderServerEE) { + server_filters_.drop_->Reset({1}); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + // We dropped EE, now reinject. + server_->SendRecordDirect(server_filters_.record(1)); + expected_client_acks_ = 1; + HandshakeAndAck(client_); + CheckedHandshakeSendReceive(); + CheckAcks(client_filters_.ack_, 0, + { + 0, // SH + 0x0002000000000000, // EE + }); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); +} + +// The client sends an out of order non-handshake message +// but with the handshake key. +TEST_F(TlsConnectDatagram13, SendOutOfOrderAppWithHandshakeKey) { + StartConnect(); + // Capturing secrets means that we can't use decrypting filters on the client. + TlsSendCipherSpecCapturer capturer(client_); + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + server_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + // After the client sends Finished, inject an app data record + // with the handshake key. This should produce an alert. + uint8_t buf[] = {'a', 'b', 'c'}; + auto spec = capturer.spec(0); + ASSERT_NE(nullptr, spec.get()); + ASSERT_EQ(2, spec->epoch()); + + uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | + kCtDtlsCiphertextLengthPresent; + ASSERT_TRUE(client_->SendEncryptedRecord(spec, 0x0002000000000002, dtls13_ct, + DataBuffer(buf, sizeof(buf)))); + + // Now have the server consume the bogus message. + server_->ExpectSendAlert(illegal_parameter, kTlsAlertFatal); + server_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state()); + EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, PORT_GetError()); +} + +TEST_F(TlsConnectDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) { + StartConnect(); + TlsSendCipherSpecCapturer capturer(client_); + auto acks = MakeTlsFilter<TlsRecordRecorder>(server_, ssl_ct_ack); + acks->EnableDecryption(); + + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + server_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + // Inject a new bogus handshake record, which the server responds + // to by just ACKing the original one (we ignore the contents). + uint8_t buf[] = {'a', 'b', 'c'}; + auto spec = capturer.spec(0); + ASSERT_NE(nullptr, spec.get()); + ASSERT_EQ(2, spec->epoch()); + ASSERT_TRUE(client_->SendEncryptedRecord(spec, 0x0002000000000002, + ssl_ct_handshake, + DataBuffer(buf, sizeof(buf)))); + server_->Handshake(); + EXPECT_EQ(2UL, acks->count()); + // The server acknowledges client Finished twice. + CheckAcks(acks, 0, {0x0002000000000000ULL}); + CheckAcks(acks, 1, {0x0002000000000000ULL}); +} + +// Shrink the MTU down so that certs get split and then swap the first and +// second pieces of the server certificate. +TEST_P(TlsReorderDatagram13, ReorderServerCertificate) { + StartConnect(); + ShrinkPostServerHelloMtu(); + client_->Handshake(); + // Drop the entire handshake flight so we can reorder. + server_filters_.drop_->Reset(0xff); + server_->Handshake(); + // Check that things got split. + EXPECT_EQ(6UL, + server_filters_.records_->count()); // CH, EE, CT1, CT2, CV, FIN + // Now re-send things in a different order. + ReSend(TlsAgent::SERVER, std::vector<size_t>{0, 1, 3, 2, 4, 5}); + // Clear. + server_filters_.drop_->Disable(); + server_filters_.records_->Clear(); + // Wait for client to send ACK. + ShiftDtlsTimers(); + CheckedHandshakeSendReceive(); + EXPECT_EQ(2UL, server_filters_.records_->count()); // ACK + Data + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); +} + +TEST_P(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) { + SetupForZeroRtt(); + SetFilters(); + std::cerr << "Starting second handshake" << std::endl; + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + // Send the client's first flight of zero RTT data. + ZeroRttSendReceive(true, true); + // Now send another client application data record but + // capture it. + client_filters_.records_->Clear(); + client_filters_.drop_->Reset(0xff); + const char* k0RttData = "123456"; + const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + PRInt32 rv = + PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write. + EXPECT_EQ(k0RttDataLen, rv); + EXPECT_EQ(1UL, client_filters_.records_->count()); // data + server_->Handshake(); + client_->Handshake(); + ExpectEarlyDataAccepted(true); + // The server still hasn't received anything at this point. + EXPECT_EQ(3UL, client_filters_.records_->count()); // data, EOED, FIN + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + // Now re-send the client's messages: EOED, data, FIN + ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 0, 2})); + server_->Handshake(); + CheckConnected(); + EXPECT_EQ(0U, client_filters_.ack_->count()); + // Acknowledgements for EOED and Finished. + CheckAcks(server_filters_.ack_, 0, + {0x0001000000000002ULL, 0x0002000000000000ULL}); + uint8_t buf[8]; + rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_EQ(-1, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +TEST_P(TlsReorderDatagram13, DataAfterFinDuringZeroRtt) { + SetupForZeroRtt(); + SetFilters(); + std::cerr << "Starting second handshake" << std::endl; + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + // Send the client's first flight of zero RTT data. + ZeroRttSendReceive(true, true); + // Now send another client application data record but + // capture it. + client_filters_.records_->Clear(); + client_filters_.drop_->Reset(0xff); + const char* k0RttData = "123456"; + const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + PRInt32 rv = + PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write. + EXPECT_EQ(k0RttDataLen, rv); + EXPECT_EQ(1UL, client_filters_.records_->count()); // data + server_->Handshake(); + client_->Handshake(); + ExpectEarlyDataAccepted(true); + // The server still hasn't received anything at this point. + EXPECT_EQ(3UL, client_filters_.records_->count()); // EOED, FIN, Data + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + // Now re-send the client's messages: EOED, FIN, Data + ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 2, 0})); + server_->Handshake(); + CheckConnected(); + EXPECT_EQ(0U, client_filters_.ack_->count()); + // Acknowledgements for EOED and Finished. + CheckAcks(server_filters_.ack_, 0, + {0x0001000000000002ULL, 0x0002000000000000ULL}); + uint8_t buf[8]; + rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_EQ(-1, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +static void GetCipherAndLimit(uint16_t version, uint16_t* cipher, + uint64_t* limit = nullptr) { + uint64_t l; + if (!limit) limit = &l; + + if (version < SSL_LIBRARY_VERSION_TLS_1_2) { + *cipher = TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA; + *limit = 0x5aULL << 28; + } else if (version == SSL_LIBRARY_VERSION_TLS_1_2) { + *cipher = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256; + *limit = (1ULL << 48) - 1; + } else { + // This test probably isn't especially useful for TLS 1.3, which has a much + // shorter sequence number encoding. That space can probably be searched in + // a reasonable amount of time. + *cipher = TLS_CHACHA20_POLY1305_SHA256; + // Assume that we are starting with an expected sequence number of 0. + *limit = (1ULL << 15) - 1; + } +} + +// This simulates a huge number of drops on one side. +// See Bug 12965514 where a large gap was handled very inefficiently. +TEST_P(TlsConnectDatagram, MissLotsOfPackets) { + uint16_t cipher; + uint64_t limit; + + GetCipherAndLimit(version_, &cipher, &limit); + + EnsureTlsSetup(); + server_->EnableSingleCipher(cipher); + Connect(); + + // Note that the limit for ChaCha is 2^48-1. + EXPECT_EQ(SECSuccess, + SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), limit - 10)); + SendReceive(); +} + +// Send a sequence number of 0xfffd and it should be interpreted as that +// (and not -3 or UINT64_MAX - 2). +TEST_F(TlsConnectDatagram13, UnderflowSequenceNumber) { + Connect(); + // This is only valid if short headers are disabled. + client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_FALSE); + EXPECT_EQ(SECSuccess, + SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), (1ULL << 16) - 3)); + SendReceive(); +} + +class TlsConnectDatagram12Plus : public TlsConnectDatagram { + public: + TlsConnectDatagram12Plus() : TlsConnectDatagram() {} +}; + +// This simulates missing a window's worth of packets. +TEST_P(TlsConnectDatagram12Plus, MissAWindow) { + EnsureTlsSetup(); + uint16_t cipher; + GetCipherAndLimit(version_, &cipher); + server_->EnableSingleCipher(cipher); + Connect(); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqByAWindow(client_->ssl_fd(), 0)); + SendReceive(); +} + +TEST_P(TlsConnectDatagram12Plus, MissAWindowAndOne) { + EnsureTlsSetup(); + uint16_t cipher; + GetCipherAndLimit(version_, &cipher); + server_->EnableSingleCipher(cipher); + Connect(); + + EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqByAWindow(client_->ssl_fd(), 1)); + SendReceive(); +} + +// This filter replaces the first record it sees with junk application data. +class TlsReplaceFirstRecordWithJunk : public TlsRecordFilter { + public: + TlsReplaceFirstRecordWithJunk(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a), replaced_(false) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& record, size_t* offset, + DataBuffer* output) override { + if (replaced_) { + return KEEP; + } + replaced_ = true; + + uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | + kCtDtlsCiphertextLengthPresent; + TlsRecordHeader out_header( + header.variant(), header.version(), + is_dtls13() ? dtls13_ct : ssl_ct_application_data, + header.sequence_number()); + + static const uint8_t junk[] = {1, 2, 3, 4}; + *offset = out_header.Write(output, *offset, DataBuffer(junk, sizeof(junk))); + return CHANGE; + } + + private: + bool replaced_; +}; + +// DTLS needs to discard application_data that it receives prior to handshake +// completion, not generate an error. +TEST_P(TlsConnectDatagram, ReplaceFirstServerRecordWithApplicationData) { + MakeTlsFilter<TlsReplaceFirstRecordWithJunk>(server_); + Connect(); +} + +TEST_P(TlsConnectDatagram, ReplaceFirstClientRecordWithApplicationData) { + MakeTlsFilter<TlsReplaceFirstRecordWithJunk>(client_); + Connect(); +} + +INSTANTIATE_TEST_SUITE_P(Datagram12Plus, TlsConnectDatagram12Plus, + TlsConnectTestBase::kTlsV12Plus); +INSTANTIATE_TEST_SUITE_P(DatagramPre13, TlsConnectDatagramPre13, + TlsConnectTestBase::kTlsV11V12); +INSTANTIATE_TEST_SUITE_P(DatagramDrop13, TlsDropDatagram13, + ::testing::Values(true, false)); +INSTANTIATE_TEST_SUITE_P(DatagramReorder13, TlsReorderDatagram13, + ::testing::Values(true, false)); +INSTANTIATE_TEST_SUITE_P(DatagramFragment13, TlsFragmentationAndRecoveryTest, + ::testing::Values(true, false)); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc new file mode 100644 index 0000000000..89bd0c679a --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc @@ -0,0 +1,728 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectGenericPre13, ConnectEcdh) { + SetExpectedVersion(std::get<1>(GetParam())); + Reset(TlsAgent::kServerEcdhEcdsa); + DisableAllCiphers(); + EnableSomeEcdhCiphers(); + + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_ecdh_ecdsa, + ssl_sig_none); +} + +TEST_P(TlsConnectGenericPre13, ConnectEcdhWithoutDisablingSuites) { + SetExpectedVersion(std::get<1>(GetParam())); + Reset(TlsAgent::kServerEcdhEcdsa); + EnableSomeEcdhCiphers(); + + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_ecdh_ecdsa, + ssl_sig_none); +} + +TEST_P(TlsConnectGeneric, ConnectEcdhe) { + Connect(); + CheckKeys(); +} + +// If we pick a 256-bit cipher suite and use a P-384 certificate, the server +// should choose P-384 for key exchange too. Only valid for TLS == 1.2 because +// we don't have 256-bit ciphers before then and 1.3 doesn't try to couple +// DHE size to symmetric size. +TEST_P(TlsConnectTls12, ConnectEcdheP384) { + Reset(TlsAgent::kServerEcdsa384); + ConnectWithCipherSuite(TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_ecdsa, + ssl_sig_ecdsa_secp256r1_sha256); +} + +TEST_P(TlsConnectGeneric, ConnectEcdheP384Client) { + EnsureTlsSetup(); + const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, + ssl_grp_ffdhe_2048}; + client_->ConfigNamedGroups(groups); + server_->ConfigNamedGroups(groups); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); +} + +// This causes a HelloRetryRequest in TLS 1.3. Earlier versions don't care. +TEST_P(TlsConnectGeneric, ConnectEcdheP384Server) { + EnsureTlsSetup(); + auto hrr_capture = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeHelloRetryRequest); + const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + EXPECT_EQ(version_ == SSL_LIBRARY_VERSION_TLS_1_3, + hrr_capture->buffer().len() != 0); +} + +// This enables only P-256 on the client and disables it on the server. +// This test will fail when we add other groups that identify as ECDHE. +TEST_P(TlsConnectGeneric, ConnectEcdheGroupMismatch) { + EnsureTlsSetup(); + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ffdhe_2048}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_2048}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign); +} + +TEST_P(TlsKeyExchangeTest, P384Priority) { + // P256, P384 and P521 are enabled. Both prefer P384. + const std::vector<SSLNamedGroup> groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1}; + EnsureKeyShareSetup(); + ConfigNamedGroups(groups); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + + std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1}; + CheckKEXDetails(groups, shares); +} + +TEST_P(TlsKeyExchangeTest, DuplicateGroupConfig) { + const std::vector<SSLNamedGroup> groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_secp384r1, ssl_grp_ec_secp384r1, + ssl_grp_ec_secp256r1, ssl_grp_ec_secp256r1}; + EnsureKeyShareSetup(); + ConfigNamedGroups(groups); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + + std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1}; + std::vector<SSLNamedGroup> expectedGroups = {ssl_grp_ec_secp384r1, + ssl_grp_ec_secp256r1}; + CheckKEXDetails(expectedGroups, shares); +} + +TEST_P(TlsKeyExchangeTest, P384PriorityDHEnabled) { + // P256, P384, P521, and FFDHE2048 are enabled. Both prefer P384. + const std::vector<SSLNamedGroup> groups = { + ssl_grp_ec_secp384r1, ssl_grp_ffdhe_2048, ssl_grp_ec_secp256r1, + ssl_grp_ec_secp521r1}; + EnsureKeyShareSetup(); + ConfigNamedGroups(groups); + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1}; + CheckKEXDetails(groups, shares); + } else { + std::vector<SSLNamedGroup> oldtlsgroups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1}; + CheckKEXDetails(oldtlsgroups, std::vector<SSLNamedGroup>()); + } +} + +TEST_P(TlsConnectGenericPre13, P384PriorityOnServer) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + // The server prefers P384. It has to win. + const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1}; + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); +} + +TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) { + EnsureModelSockets(); + + /* Both prefer P384, set on the model socket. */ + const std::vector<SSLNamedGroup> groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1, + ssl_grp_ffdhe_2048}; + client_model_->ConfigNamedGroups(groups); + server_model_->ConfigNamedGroups(groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); +} + +class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter { + public: + TlsKeyExchangeGroupCapture(const std::shared_ptr<TlsAgent> &a) + : TlsHandshakeFilter(a, {kTlsHandshakeServerKeyExchange}), + group_(ssl_grp_none) {} + + SSLNamedGroup group() const { return group_; } + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, + const DataBuffer &input, + DataBuffer *output) { + uint32_t value = 0; + EXPECT_TRUE(input.Read(0, 1, &value)); + EXPECT_EQ(3U, value) << "curve type has to be 3"; + + EXPECT_TRUE(input.Read(1, 2, &value)); + group_ = static_cast<SSLNamedGroup>(value); + + return KEEP; + } + + private: + SSLNamedGroup group_; +}; + +// If we strip the client's supported groups extension, the server should assume +// P-256 is supported by the client (<= 1.2 only). +TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) { + EnsureTlsSetup(); + MakeTlsFilter<TlsExtensionDropper>(client_, ssl_supported_groups_xtn); + auto group_capture = MakeTlsFilter<TlsKeyExchangeGroupCapture>(server_); + + ConnectExpectAlert(server_, kTlsAlertDecryptError); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); + + EXPECT_EQ(ssl_grp_ec_secp256r1, group_capture->group()); +} + +// Supported groups is mandatory in TLS 1.3. +TEST_P(TlsConnectTls13, DropSupportedGroupExtension) { + EnsureTlsSetup(); + MakeTlsFilter<TlsExtensionDropper>(client_, ssl_supported_groups_xtn); + ConnectExpectAlert(server_, kTlsAlertMissingExtension); + client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION); +} + +// If we only have a lame group, we fall back to static RSA. +TEST_P(TlsConnectGenericPre13, UseLameGroup) { + const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp192r1}; + client_->ConfigNamedGroups(groups); + server_->ConfigNamedGroups(groups); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); +} + +// In TLS 1.3, we can't generate the ClientHello. +TEST_P(TlsConnectTls13, UseLameGroup) { + const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_sect283k1}; + client_->ConfigNamedGroups(groups); + server_->ConfigNamedGroups(groups); + client_->StartConnect(); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_NO_CIPHERS_SUPPORTED); +} + +TEST_P(TlsConnectStreamPre13, ConfiguredGroupsRenegotiate) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ec_secp256r1}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + CheckConnected(); + + // The renegotiation has to use the same preferences as the original session. + server_->PrepareForRenegotiate(); + client_->StartRenegotiate(); + Handshake(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); +} + +TEST_P(TlsKeyExchangeTest, Curve25519) { + Reset(TlsAgent::kServerEcdsa256); + const std::vector<SSLNamedGroup> groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1}; + EnsureKeyShareSetup(); + ConfigNamedGroups(groups); + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_ecdsa, + ssl_sig_ecdsa_secp256r1_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; + CheckKEXDetails(groups, shares); +} + +TEST_P(TlsConnectGenericPre13, GroupPreferenceServerPriority) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + // The client prefers P256 while the server prefers 25519. + // The server's preference has to win. + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ec_curve25519}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_curve25519, + ssl_grp_ec_secp256r1}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); +} + +#ifndef NSS_DISABLE_TLS_1_3 +TEST_P(TlsKeyExchangeTest13, Curve25519P256EqualPriorityClient13) { + EnsureKeyShareSetup(); + + // The client sends a P256 key share while the server prefers 25519. + // We have to accept P256 without retry. + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ec_curve25519}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_curve25519, + ssl_grp_ec_secp256r1}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp256r1}; + CheckKEXDetails(client_groups, shares); +} + +TEST_P(TlsKeyExchangeTest13, Curve25519P256EqualPriorityServer13) { + EnsureKeyShareSetup(); + + // The client sends a 25519 key share while the server prefers P256. + // We have to accept 25519 without retry. + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519, + ssl_grp_ec_secp256r1}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; + CheckKEXDetails(client_groups, shares); +} + +TEST_P(TlsKeyExchangeTest13, EqualPriorityTestRetryECServer13) { + EnsureKeyShareSetup(); + + // The client sends a 25519 key share while the server prefers P256. + // The server prefers P-384 over x25519, so it must not consider P-256 and + // x25519 to be equivalent. It will therefore request a P-256 share + // with a HelloRetryRequest. + const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1}; + const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; + CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); +} + +TEST_P(TlsKeyExchangeTest13, NotEqualPriorityWithIntermediateGroup13) { + EnsureKeyShareSetup(); + + // The client sends a 25519 key share while the server prefers P256. + // The server prefers ffdhe_2048 over x25519, so it must not consider the + // P-256 and x25519 to be equivalent. It will therefore request a P-256 share + // with a HelloRetryRequest. + const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048}; + const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048, ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; + CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); +} + +TEST_P(TlsKeyExchangeTest13, + NotEqualPriorityWithUnsupportedFFIntermediateGroup13) { + EnsureKeyShareSetup(); + + // As in the previous test, the server prefers ffdhe_2048. Thus, even though + // the client doesn't support this group, the server must not regard x25519 as + // equivalent to P-256. + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519, + ssl_grp_ec_secp256r1}; + const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048, ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; + CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); +} + +TEST_P(TlsKeyExchangeTest13, + NotEqualPriorityWithUnsupportedECIntermediateGroup13) { + EnsureKeyShareSetup(); + + // As in the previous test, the server prefers P-384. Thus, even though + // the client doesn't support this group, the server must not regard x25519 as + // equivalent to P-256. The server sends a HelloRetryRequest. + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519, + ssl_grp_ec_secp256r1}; + const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; + CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); +} + +TEST_P(TlsKeyExchangeTest13, EqualPriority13) { + EnsureKeyShareSetup(); + + // The client sends a 25519 key share while the server prefers P256. + // We have to accept 25519 without retry because it's considered equivalent to + // P256 by the server. + const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_curve25519, ssl_grp_ffdhe_2048, ssl_grp_ec_secp256r1}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; + CheckKEXDetails(client_groups, shares); +} +#endif + +TEST_P(TlsConnectGeneric, P256ClientAndCurve25519Server) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + // The client sends a P256 key share while the server prefers 25519. + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_curve25519}; + + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +TEST_P(TlsKeyExchangeTest13, MultipleClientShares) { + EnsureKeyShareSetup(); + + // The client sends 25519 and P256 key shares. The server prefers P256, + // which must be chosen here. + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519, + ssl_grp_ec_secp256r1}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + // Generate a key share on the client for both curves. + EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); + + Connect(); + + // The server would accept 25519 but its preferred group (P256) has to win. + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519, + ssl_grp_ec_secp256r1}; + CheckKEXDetails(client_groups, shares); +} + +// Replace the point in the client key exchange message with an empty one +class ECCClientKEXFilter : public TlsHandshakeFilter { + public: + ECCClientKEXFilter(const std::shared_ptr<TlsAgent> &client) + : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}) {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, + const DataBuffer &input, + DataBuffer *output) { + // Replace the client key exchange message with an empty point + output->Allocate(1); + output->Write(0, 0U, 1); // set point length 0 + return CHANGE; + } +}; + +// Replace the point in the server key exchange message with an empty one +class ECCServerKEXFilter : public TlsHandshakeFilter { + public: + ECCServerKEXFilter(const std::shared_ptr<TlsAgent> &server) + : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}) {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, + const DataBuffer &input, + DataBuffer *output) { + // Replace the server key exchange message with an empty point + output->Allocate(4); + output->Write(0, 3U, 1); // named curve + uint32_t curve = 0; + EXPECT_TRUE(input.Read(1, 2, &curve)); // get curve id + output->Write(1, curve, 2); // write curve id + output->Write(3, 0U, 1); // point length 0 + return CHANGE; + } +}; + +TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyServerPoint) { + MakeTlsFilter<ECCServerKEXFilter>(server_); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_KEY_EXCH); +} + +TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyClientPoint) { + MakeTlsFilter<ECCClientKEXFilter>(client_); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_KEY_EXCH); +} + +// Damage ECParams/ECPoint of a SKE. +class ECCServerKEXDamager : public TlsHandshakeFilter { + public: + ECCServerKEXDamager(const std::shared_ptr<TlsAgent> &server, ECType ec_type, + SSLNamedGroup named_curve) + : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}), + ec_type_(ec_type), + named_curve_(named_curve) {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, + const DataBuffer &input, + DataBuffer *output) { + size_t offset = 0; + output->Allocate(5); + offset = output->Write(offset, ec_type_, 1); + offset = output->Write(offset, named_curve_, 2); + // Write a point with fmt != EC_POINT_FORM_UNCOMPRESSED. + offset = output->Write(offset, 1U, 1); + (void)output->Write(offset, 0x02, 1); // EC_POINT_FORM_COMPRESSED_Y0 + return CHANGE; + } + + private: + ECType ec_type_; + SSLNamedGroup named_curve_; +}; + +TEST_P(TlsConnectGenericPre13, ConnectUnsupportedCurveType) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + MakeTlsFilter<ECCServerKEXDamager>(server_, ec_type_explicitPrime, + ssl_grp_none); + ConnectExpectAlert(client_, kTlsAlertHandshakeFailure); + client_->CheckErrorCode(SEC_ERROR_UNSUPPORTED_ELLIPTIC_CURVE); +} + +TEST_P(TlsConnectGenericPre13, ConnectUnsupportedCurve) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + MakeTlsFilter<ECCServerKEXDamager>(server_, ec_type_named, + ssl_grp_ffdhe_2048); + ConnectExpectAlert(client_, kTlsAlertHandshakeFailure); + client_->CheckErrorCode(SEC_ERROR_UNSUPPORTED_ELLIPTIC_CURVE); +} + +TEST_P(TlsConnectGenericPre13, ConnectUnsupportedPointFormat) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + MakeTlsFilter<ECCServerKEXDamager>(server_, ec_type_named, + ssl_grp_ec_secp256r1); + ConnectExpectAlert(client_, kTlsAlertHandshakeFailure); + client_->CheckErrorCode(SEC_ERROR_UNSUPPORTED_EC_POINT_FORM); +} + +TEST_P(TlsConnectTls12, ConnectUnsupportedSigAlg) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + MakeTlsFilter<ECCServerKEXSigAlgReplacer>(server_, ssl_sig_none); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); +} + +TEST_P(TlsConnectTls12, ConnectIncorrectSigAlg) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + MakeTlsFilter<ECCServerKEXSigAlgReplacer>(server_, + ssl_sig_ecdsa_secp256r1_sha256); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM); +} + +static void CheckSkeSigScheme( + std::shared_ptr<TlsHandshakeRecorder> &capture_ske, + uint16_t expected_scheme) { + TlsParser parser(capture_ske->buffer()); + uint32_t tmp = 0; + EXPECT_TRUE(parser.Read(&tmp, 1)) << " read curve_type"; + EXPECT_EQ(3U, tmp) << "curve type has to be 3"; + EXPECT_TRUE(parser.Skip(2)) << " read namedcurve"; + EXPECT_TRUE(parser.SkipVariable(1)) << " read public"; + + EXPECT_TRUE(parser.Read(&tmp, 2)) << " read sig_scheme"; + EXPECT_EQ(expected_scheme, static_cast<uint16_t>(tmp)); +} + +TEST_P(TlsConnectTls12, ConnectSigAlgEnabledByPolicy) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + const std::vector<SSLSignatureScheme> schemes = {ssl_sig_rsa_pkcs1_sha1, + ssl_sig_rsa_pkcs1_sha384}; + + client_->SetSignatureSchemes(schemes.data(), schemes.size()); + server_->SetSignatureSchemes(schemes.data(), schemes.size()); + auto capture_ske = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeServerKeyExchange); + + StartConnect(); + client_->Handshake(); // Send ClientHello + + // Enable SHA-1 by policy. + SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, NSS_USE_ALG_IN_SSL_KX, 0); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + Handshake(); // Remainder of handshake + // The server should now report that it is connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + CheckSkeSigScheme(capture_ske, ssl_sig_rsa_pkcs1_sha1); +} + +TEST_P(TlsConnectTls12, ConnectSigAlgDisabledByPolicy) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + const std::vector<SSLSignatureScheme> schemes = {ssl_sig_rsa_pkcs1_sha1, + ssl_sig_rsa_pkcs1_sha384}; + + client_->SetSignatureSchemes(schemes.data(), schemes.size()); + server_->SetSignatureSchemes(schemes.data(), schemes.size()); + auto capture_ske = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeServerKeyExchange); + + StartConnect(); + client_->Handshake(); // Send ClientHello + + // Disable SHA-1 by policy. + SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, 0, NSS_USE_ALG_IN_SSL_KX); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + Handshake(); // Remainder of handshake + // The server should now report that it is connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + CheckSkeSigScheme(capture_ske, ssl_sig_rsa_pkcs1_sha384); +} + +INSTANTIATE_TEST_SUITE_P(KeyExchangeTest, TlsKeyExchangeTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV11Plus)); + +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_SUITE_P(KeyExchangeTest, TlsKeyExchangeTest13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13)); +#endif + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_ems_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ems_unittest.cc new file mode 100644 index 0000000000..39b2d58736 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_ems_unittest.cc @@ -0,0 +1,96 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecret) { + EnableExtendedMasterSecret(); + Connect(); + Reset(); + ExpectResumption(RESUME_SESSIONID); + EnableExtendedMasterSecret(); + Connect(); +} + +TEST_P(TlsConnectTls12, ConnectExtendedMasterSecretSha384) { + EnableExtendedMasterSecret(); + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384); + ConnectWithCipherSuite(TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384); +} + +TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretStaticRSA) { + EnableOnlyStaticRsaCiphers(); + EnableExtendedMasterSecret(); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretECDHE) { + EnableExtendedMasterSecret(); + Connect(); + + Reset(); + EnableExtendedMasterSecret(); + ExpectResumption(RESUME_SESSIONID); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretTicket) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + EnableExtendedMasterSecret(); + Connect(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + + EnableExtendedMasterSecret(); + ExpectResumption(RESUME_TICKET); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretClientOnly) { + client_->EnableExtendedMasterSecret(); + ExpectExtendedMasterSecret(false); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretServerOnly) { + server_->EnableExtendedMasterSecret(); + ExpectExtendedMasterSecret(false); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretResumeWithout) { + EnableExtendedMasterSecret(); + Connect(); + + Reset(); + server_->EnableExtendedMasterSecret(); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); +} + +TEST_P(TlsConnectGenericPre13, ConnectNormalResumeWithExtendedMasterSecret) { + ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); + ExpectExtendedMasterSecret(false); + Connect(); + + Reset(); + EnableExtendedMasterSecret(); + ExpectResumption(RESUME_NONE); + Connect(); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc new file mode 100644 index 0000000000..26ed6bc0ed --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc @@ -0,0 +1,188 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ssl.h" + +#include "gtest_utils.h" +#include "tls_connect.h" + +namespace nss_test { + +static const char* kExporterLabel = "EXPORTER-duck"; +static const uint8_t kExporterContext[] = {0x12, 0x34, 0x56}; + +static void ExportAndCompare(std::shared_ptr<TlsAgent>& client, + std::shared_ptr<TlsAgent>& server, bool context) { + static const size_t exporter_len = 10; + uint8_t client_value[exporter_len] = {0}; + EXPECT_EQ(SECSuccess, + SSL_ExportKeyingMaterial( + client->ssl_fd(), kExporterLabel, strlen(kExporterLabel), + context ? PR_TRUE : PR_FALSE, kExporterContext, + sizeof(kExporterContext), client_value, sizeof(client_value))); + uint8_t server_value[exporter_len] = {0xff}; + EXPECT_EQ(SECSuccess, + SSL_ExportKeyingMaterial( + server->ssl_fd(), kExporterLabel, strlen(kExporterLabel), + context ? PR_TRUE : PR_FALSE, kExporterContext, + sizeof(kExporterContext), server_value, sizeof(server_value))); + EXPECT_EQ(0, memcmp(client_value, server_value, sizeof(client_value))); +} + +TEST_P(TlsConnectGeneric, ExporterBasic) { + EnsureTlsSetup(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + } else { + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + } + Connect(); + CheckKeys(); + ExportAndCompare(client_, server_, false); +} + +TEST_P(TlsConnectGeneric, ExporterContext) { + EnsureTlsSetup(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + } else { + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + } + Connect(); + CheckKeys(); + ExportAndCompare(client_, server_, true); +} + +// Bug 1312976 - SHA-384 doesn't work in 1.2 right now. +TEST_P(TlsConnectTls13, ExporterSha384) { + EnsureTlsSetup(); + client_->EnableSingleCipher(TLS_AES_256_GCM_SHA384); + Connect(); + CheckKeys(); + ExportAndCompare(client_, server_, false); +} + +TEST_P(TlsConnectTls13, ExporterContextEmptyIsSameAsNone) { + EnsureTlsSetup(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + } else { + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + } + Connect(); + CheckKeys(); + ExportAndCompare(client_, server_, false); +} + +TEST_P(TlsConnectGenericPre13, ExporterContextLengthTooLong) { + static const uint8_t kExporterContextTooLong[PR_UINT16_MAX] = { + 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xFF}; + + EnsureTlsSetup(); + Connect(); + CheckKeys(); + + static const size_t exporter_len = 10; + uint8_t client_value[exporter_len] = {0}; + EXPECT_EQ(SECFailure, + SSL_ExportKeyingMaterial(client_->ssl_fd(), kExporterLabel, + strlen(kExporterLabel), PR_TRUE, + kExporterContextTooLong, + sizeof(kExporterContextTooLong), + client_value, sizeof(client_value))); + EXPECT_EQ(PORT_GetError(), SEC_ERROR_INVALID_ARGS); + uint8_t server_value[exporter_len] = {0xff}; + EXPECT_EQ(SECFailure, + SSL_ExportKeyingMaterial(server_->ssl_fd(), kExporterLabel, + strlen(kExporterLabel), PR_TRUE, + kExporterContextTooLong, + sizeof(kExporterContextTooLong), + server_value, sizeof(server_value))); + EXPECT_EQ(PORT_GetError(), SEC_ERROR_INVALID_ARGS); +} + +// This has a weird signature so that it can be passed to the SNI callback. +int32_t RegularExporterShouldFail(TlsAgent* agent, const SECItem* srvNameArr, + PRUint32 srvNameArrSize) { + uint8_t val[10]; + EXPECT_EQ(SECFailure, SSL_ExportKeyingMaterial( + agent->ssl_fd(), kExporterLabel, + strlen(kExporterLabel), PR_TRUE, kExporterContext, + sizeof(kExporterContext), val, sizeof(val))) + << "regular exporter should fail"; + return 0; +} + +TEST_P(TlsConnectTls13, EarlyExporter) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + + client_->Handshake(); // Send ClientHello. + uint8_t client_value[10] = {0}; + RegularExporterShouldFail(client_.get(), nullptr, 0); + + EXPECT_EQ(SECSuccess, + SSL_ExportEarlyKeyingMaterial( + client_->ssl_fd(), kExporterLabel, strlen(kExporterLabel), + kExporterContext, sizeof(kExporterContext), client_value, + sizeof(client_value))); + + server_->SetSniCallback(RegularExporterShouldFail); + server_->Handshake(); // Handle ClientHello. + uint8_t server_value[10] = {0}; + EXPECT_EQ(SECSuccess, + SSL_ExportEarlyKeyingMaterial( + server_->ssl_fd(), kExporterLabel, strlen(kExporterLabel), + kExporterContext, sizeof(kExporterContext), server_value, + sizeof(server_value))); + EXPECT_EQ(0, memcmp(client_value, server_value, sizeof(client_value))); + + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, EarlyExporterExternalPsk) { + RolloverAntiReplay(); + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + ASSERT_TRUE(!!slot); + ScopedPK11SymKey scoped_psk( + PK11_KeyGen(slot.get(), CKM_HKDF_KEY_GEN, nullptr, 16, nullptr)); + AddPsk(scoped_psk, std::string("foo"), ssl_hash_sha256, + TLS_CHACHA20_POLY1305_SHA256); + StartConnect(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + client_->Handshake(); // Send ClientHello. + uint8_t client_value[10] = {0}; + RegularExporterShouldFail(client_.get(), nullptr, 0); + + EXPECT_EQ(SECSuccess, + SSL_ExportEarlyKeyingMaterial( + client_->ssl_fd(), kExporterLabel, strlen(kExporterLabel), + kExporterContext, sizeof(kExporterContext), client_value, + sizeof(client_value))); + + server_->SetSniCallback(RegularExporterShouldFail); + server_->Handshake(); // Handle ClientHello. + uint8_t server_value[10] = {0}; + EXPECT_EQ(SECSuccess, + SSL_ExportEarlyKeyingMaterial( + server_->ssl_fd(), kExporterLabel, strlen(kExporterLabel), + kExporterContext, sizeof(kExporterContext), server_value, + sizeof(server_value))); + EXPECT_EQ(0, memcmp(client_value, server_value, sizeof(client_value))); + + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc new file mode 100644 index 0000000000..eb45e71422 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc @@ -0,0 +1,1513 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ssl.h" +#include "ssl3prot.h" +#include "sslerr.h" +#include "sslproto.h" + +// This is only to get DTLS_1_3_DRAFT_VERSION +#include "ssl3prot.h" + +#include <memory> + +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +class Dtls13LegacyCookieInjector : public TlsHandshakeFilter { + public: + Dtls13LegacyCookieInjector(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}) {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + const uint8_t cookie_bytes[] = {0x03, 0x0A, 0x0B, 0x0C}; + uint32_t offset = 2 /* version */ + 32 /* random */; + + if (agent()->variant() != ssl_variant_datagram) { + ADD_FAILURE(); + return KEEP; + } + + if (header.handshake_type() != ssl_hs_client_hello) { + return KEEP; + } + + DataBuffer cookie(cookie_bytes, sizeof(cookie_bytes)); + *output = input; + + // Add the SID length (if any) to locate the cookie. + uint32_t sid_len = 0; + if (!output->Read(offset, 1, &sid_len)) { + ADD_FAILURE(); + return KEEP; + } + offset += 1 + sid_len; + output->Splice(cookie, offset, 1); + + return CHANGE; + } + + private: + DataBuffer cookie_; +}; + +class TlsExtensionTruncator : public TlsExtensionFilter { + public: + TlsExtensionTruncator(const std::shared_ptr<TlsAgent>& a, uint16_t extension, + size_t length) + : TlsExtensionFilter(a), extension_(extension), length_(length) {} + virtual PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) { + if (extension_type != extension_) { + return KEEP; + } + if (input.len() <= length_) { + return KEEP; + } + + output->Assign(input.data(), length_); + return CHANGE; + } + + private: + uint16_t extension_; + size_t length_; +}; + +class TlsExtensionTestBase : public TlsConnectTestBase { + protected: + TlsExtensionTestBase(SSLProtocolVariant variant, uint16_t version) + : TlsConnectTestBase(variant, version) {} + + void ClientHelloErrorTest(std::shared_ptr<PacketFilter> filter, + uint8_t desc = kTlsAlertDecodeError) { + client_->SetFilter(filter); + ConnectExpectAlert(server_, desc); + } + + void ServerHelloErrorTest(std::shared_ptr<PacketFilter> filter, + uint8_t desc = kTlsAlertDecodeError) { + server_->SetFilter(filter); + ConnectExpectAlert(client_, desc); + } + + static void InitSimpleSni(DataBuffer* extension) { + const char* name = "host.name"; + const size_t namelen = PL_strlen(name); + extension->Allocate(namelen + 5); + extension->Write(0, namelen + 3, 2); + extension->Write(2, static_cast<uint32_t>(0), 1); // 0 == hostname + extension->Write(3, namelen, 2); + extension->Write(5, reinterpret_cast<const uint8_t*>(name), namelen); + } + + void HrrThenRemoveExtensionsTest(SSLExtensionType type, PRInt32 client_error, + PRInt32 server_error) { + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + static const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp384r1}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + EnsureTlsSetup(); + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send HRR. + MakeTlsFilter<TlsExtensionDropper>(client_, type); + Handshake(); + client_->CheckErrorCode(client_error); + server_->CheckErrorCode(server_error); + } +}; + +class TlsExtensionTestDtls : public TlsExtensionTestBase, + public ::testing::WithParamInterface<uint16_t> { + public: + TlsExtensionTestDtls() + : TlsExtensionTestBase(ssl_variant_datagram, GetParam()) {} +}; + +class TlsExtensionTest12Plus : public TlsExtensionTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsExtensionTest12Plus() + : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { + } +}; + +class TlsExtensionTest12 : public TlsExtensionTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsExtensionTest12() + : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { + } +}; + +class TlsExtensionTest13 + : public TlsExtensionTestBase, + public ::testing::WithParamInterface<SSLProtocolVariant> { + public: + TlsExtensionTest13() + : TlsExtensionTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} + + void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) { + DataBuffer versions_buf(buf, len); + MakeTlsFilter<TlsExtensionReplacer>( + client_, ssl_tls13_supported_versions_xtn, versions_buf); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); + } + + void ConnectWithReplacementVersionList(uint16_t version) { + // Convert the version encoding for DTLS, if needed. + if (variant_ == ssl_variant_datagram) { + switch (version) { + case SSL_LIBRARY_VERSION_TLS_1_3: +#ifdef DTLS_1_3_DRAFT_VERSION + version = 0x7f00 | DTLS_1_3_DRAFT_VERSION; +#else + version = SSL_LIBRARY_VERSION_DTLS_1_3_WIRE; +#endif + break; + case SSL_LIBRARY_VERSION_TLS_1_2: + version = SSL_LIBRARY_VERSION_DTLS_1_2_WIRE; + break; + case SSL_LIBRARY_VERSION_TLS_1_1: + /* TLS_1_1 maps to DTLS_1_0, see sslproto.h. */ + version = SSL_LIBRARY_VERSION_DTLS_1_0_WIRE; + break; + default: + PORT_Assert(0); + } + } + + DataBuffer versions_buf; + size_t index = versions_buf.Write(0, 2, 1); + versions_buf.Write(index, version, 2); + MakeTlsFilter<TlsExtensionReplacer>( + client_, ssl_tls13_supported_versions_xtn, versions_buf); + ConnectExpectFail(); + } +}; + +class TlsExtensionTest13Stream : public TlsExtensionTestBase { + public: + TlsExtensionTest13Stream() + : TlsExtensionTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {} +}; + +class TlsExtensionTestGeneric : public TlsExtensionTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsExtensionTestGeneric() + : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { + } +}; + +class TlsExtensionTestPre13 : public TlsExtensionTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsExtensionTestPre13() + : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { + } +}; + +TEST_P(TlsExtensionTestGeneric, DamageSniLength) { + ClientHelloErrorTest( + std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 1)); +} + +TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) { + ClientHelloErrorTest( + std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 4)); +} + +TEST_P(TlsExtensionTestGeneric, TruncateSni) { + ClientHelloErrorTest( + std::make_shared<TlsExtensionTruncator>(client_, ssl_server_name_xtn, 7)); +} + +// A valid extension that appears twice will be reported as unsupported. +TEST_P(TlsExtensionTestGeneric, RepeatSni) { + DataBuffer extension; + InitSimpleSni(&extension); + ClientHelloErrorTest(std::make_shared<TlsExtensionInjector>( + client_, ssl_server_name_xtn, extension), + kTlsAlertIllegalParameter); +} + +// An SNI entry with zero length is considered invalid (strangely, not if it is +// the last entry, which is probably a bug). +TEST_P(TlsExtensionTestGeneric, BadSni) { + DataBuffer simple; + InitSimpleSni(&simple); + DataBuffer extension; + extension.Allocate(simple.len() + 3); + extension.Write(0, static_cast<uint32_t>(0), 3); + extension.Write(3, simple); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_server_name_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, EmptySni) { + DataBuffer extension; + extension.Allocate(2); + extension.Write(0, static_cast<uint32_t>(0), 2); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_server_name_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) { + EnableAlpn(); + DataBuffer extension; + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_app_layer_protocol_xtn, extension), + kTlsAlertIllegalParameter); +} + +// An empty ALPN isn't considered bad, though it does lead to there being no +// protocol for the server to select. +TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_app_layer_protocol_xtn, extension), + kTlsAlertNoApplicationProtocol); +} + +TEST_P(TlsExtensionTestGeneric, OneByteAlpn) { + EnableAlpn(); + ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>( + client_, ssl_app_layer_protocol_xtn, 1)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) { + EnableAlpn(); + // This will leave the length of the second entry, but no value. + ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>( + client_, ssl_app_layer_protocol_xtn, 5)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x03, 0x01, 0x61, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnLengthOverflow) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x03, 0x01, 0x61, 0x01}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnMismatch) { + const uint8_t client_alpn[] = {0x01, 0x61}; + client_->EnableAlpn(client_alpn, sizeof(client_alpn)); + const uint8_t server_alpn[] = {0x02, 0x61, 0x62}; + server_->EnableAlpn(server_alpn, sizeof(server_alpn)); + + ClientHelloErrorTest(nullptr, kTlsAlertNoApplicationProtocol); + client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_NO_PROTOCOL); +} + +TEST_P(TlsExtensionTestGeneric, AlpnDisabledServer) { + const uint8_t client_alpn[] = {0x01, 0x61}; + client_->EnableAlpn(client_alpn, sizeof(client_alpn)); + server_->EnableAlpn(nullptr, 0); + + ClientHelloErrorTest(nullptr, kTlsAlertUnsupportedExtension); +} + +TEST_P(TlsConnectGeneric, AlpnDisabled) { + server_->EnableAlpn(nullptr, 0); + Connect(); + + SSLNextProtoState state; + uint8_t buf[255] = {0}; + unsigned int buf_len = 3; + EXPECT_EQ(SECSuccess, SSL_GetNextProto(client_->ssl_fd(), &state, buf, + &buf_len, sizeof(buf))); + EXPECT_EQ(SSL_NEXT_PROTO_NO_SUPPORT, state); + EXPECT_EQ(0U, buf_len); +} + +// Many of these tests fail in TLS 1.3 because the extension is encrypted, which +// prevents modification of the value from the ServerHello. +TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x00}; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + server_, ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x01, 0x00}; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + server_, ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00}; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + server_, ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62}; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + server_, ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00}; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + server_, ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x02, 0x99, 0x61}; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + server_, ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x02, 0x01, 0x67}; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + server_, ssl_app_layer_protocol_xtn, extension), + kTlsAlertIllegalParameter); +} + +TEST_P(TlsExtensionTestDtls, SrtpShort) { + EnableSrtp(); + ClientHelloErrorTest( + std::make_shared<TlsExtensionTruncator>(client_, ssl_use_srtp_xtn, 3)); +} + +TEST_P(TlsExtensionTestDtls, SrtpOdd) { + EnableSrtp(); + const uint8_t val[] = {0x00, 0x01, 0xff, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_use_srtp_xtn, extension)); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) { + const uint8_t val[] = {0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_signature_algorithms_xtn, extension)); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) { + // make sure the test uses an algorithm that is legal for + // tls 1.3 (or tls 1.3 will throw a handshake failure alert + // instead of a decode error alert) + const uint8_t val[] = {0x00, 0x02, 0x08, 0x09, 0x00}; // sha-256, rsa-pss-pss + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_signature_algorithms_xtn, extension)); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) { + const uint8_t val[] = {0x00, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_signature_algorithms_xtn, extension), + kTlsAlertHandshakeFailure); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsNoOverlap) { + const uint8_t val[] = {0x00, 0x02, 0xff, 0xff}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_signature_algorithms_xtn, extension), + kTlsAlertHandshakeFailure); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) { + const uint8_t val[] = {0x00, 0x01, 0x04}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_signature_algorithms_xtn, extension)); +} + +TEST_F(TlsExtensionTest13Stream, SignatureAlgorithmsPrecedingGarbage) { + // 31 unknown signature algorithms followed by sha-256, rsa-pss + const uint8_t val[] = { + 0x00, 0x40, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x08, 0x04}; + DataBuffer extension(val, sizeof(val)); + MakeTlsFilter<TlsExtensionReplacer>(client_, ssl_signature_algorithms_xtn, + extension); + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); +} + +TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) { + ClientHelloErrorTest( + std::make_shared<TlsExtensionDropper>(client_, ssl_supported_groups_xtn), + version_ < SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsAlertDecryptError + : kTlsAlertMissingExtension); +} + +TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) { + const uint8_t val[] = {0x00, 0x01, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_elliptic_curves_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) { + const uint8_t val[] = {0x09, 0x99, 0x00, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_elliptic_curves_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) { + const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_elliptic_curves_xtn, extension)); +} + +TEST_P(TlsExtensionTest12, SupportedCurvesDisableX25519) { + // Disable session resumption. + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + + // Ensure that we can enable its use in the key exchange. + SECStatus rv = + NSS_SetAlgorithmPolicy(SEC_OID_CURVE25519, NSS_USE_ALG_IN_SSL_KX, 0); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + auto capture1 = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_elliptic_curves_xtn); + Connect(); + + EXPECT_TRUE(capture1->captured()); + const DataBuffer& ext1 = capture1->extension(); + + uint32_t count; + ASSERT_TRUE(ext1.Read(0, 2, &count)); + + // Whether or not we've seen x25519 offered in this handshake. + bool seen1_x25519 = false; + for (size_t offset = 2; offset <= count; offset++) { + uint32_t val; + ASSERT_TRUE(ext1.Read(offset, 2, &val)); + if (val == ssl_grp_ec_curve25519) { + seen1_x25519 = true; + break; + } + } + ASSERT_TRUE(seen1_x25519); + + // Ensure that we can disable its use in the key exchange. + rv = NSS_SetAlgorithmPolicy(SEC_OID_CURVE25519, 0, NSS_USE_ALG_IN_SSL_KX); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + // Clean up after the last run. + Reset(); + auto capture2 = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_elliptic_curves_xtn); + Connect(); + + EXPECT_TRUE(capture2->captured()); + const DataBuffer& ext2 = capture2->extension(); + + ASSERT_TRUE(ext2.Read(0, 2, &count)); + + // Whether or not we've seen x25519 offered in this handshake. + bool seen2_x25519 = false; + for (size_t offset = 2; offset <= count; offset++) { + uint32_t val; + ASSERT_TRUE(ext2.Read(offset, 2, &val)); + + if (val == ssl_grp_ec_curve25519) { + seen2_x25519 = true; + break; + } + } + + ASSERT_FALSE(seen2_x25519); +} + +TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) { + const uint8_t val[] = {0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_ec_point_formats_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) { + const uint8_t val[] = {0x99, 0x00, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_ec_point_formats_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) { + const uint8_t val[] = {0x01, 0x00, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_ec_point_formats_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, SupportedPointsCompressed) { + const uint8_t val[] = {0x01, 0x02}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_ec_point_formats_xtn, extension), + kTlsAlertIllegalParameter); +} + +TEST_P(TlsExtensionTestPre13, SupportedPointsUndefined) { + const uint8_t val[] = {0x01, 0xAA}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_ec_point_formats_xtn, extension), + kTlsAlertIllegalParameter); +} + +TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) { + const uint8_t val[] = {0x99}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_renegotiation_info_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) { + const uint8_t val[] = {0x01, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_renegotiation_info_xtn, extension)); +} + +// The extension has to contain a length. +TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) { + DataBuffer extension; + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_renegotiation_info_xtn, extension)); +} + +// This only works on TLS 1.2, since it relies on static RSA; otherwise libssl +// picks the wrong cipher suite. +TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) { + const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_rsae_sha512, + ssl_sig_rsa_pss_rsae_sha384}; + + auto capture = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn); + client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes)); + EnableOnlyStaticRsaCiphers(); + Connect(); + + const DataBuffer& ext = capture->extension(); + EXPECT_EQ(2 + PR_ARRAY_SIZE(schemes) * 2, ext.len()); + for (size_t i = 0, cursor = 2; + i < PR_ARRAY_SIZE(schemes) && cursor < ext.len(); ++i) { + uint32_t v = 0; + EXPECT_TRUE(ext.Read(cursor, 2, &v)); + cursor += 2; + EXPECT_EQ(schemes[i], static_cast<SSLSignatureScheme>(v)); + } +} + +// This only works on TLS 1.2, since it relies on DSA. +TEST_P(TlsExtensionTest12, SignatureAlgorithmDisableDSA) { + const std::vector<SSLSignatureScheme> schemes = { + ssl_sig_dsa_sha1, ssl_sig_dsa_sha256, ssl_sig_dsa_sha384, + ssl_sig_dsa_sha512, ssl_sig_rsa_pss_rsae_sha256}; + + // Connect with DSA enabled by policy. + SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_ANSIX9_DSA_SIGNATURE, + NSS_USE_ALG_IN_SSL_KX, 0); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + Reset(TlsAgent::kServerDsa); + auto capture1 = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn); + client_->SetSignatureSchemes(schemes.data(), schemes.size()); + Connect(); + + // Check if all the signature algorithms are advertised. + EXPECT_TRUE(capture1->captured()); + const DataBuffer& ext1 = capture1->extension(); + EXPECT_EQ(2U + 2U * schemes.size(), ext1.len()); + + // Connect with DSA disabled by policy. + rv = NSS_SetAlgorithmPolicy(SEC_OID_ANSIX9_DSA_SIGNATURE, 0, + NSS_USE_ALG_IN_SSL_KX); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + Reset(TlsAgent::kServerDsa); + auto capture2 = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn); + client_->SetSignatureSchemes(schemes.data(), schemes.size()); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + + // Check if no DSA algorithms are advertised. + EXPECT_TRUE(capture2->captured()); + const DataBuffer& ext2 = capture2->extension(); + EXPECT_EQ(2U + 2U, ext2.len()); + uint32_t v = 0; + EXPECT_TRUE(ext2.Read(2, 2, &v)); + EXPECT_EQ(ssl_sig_rsa_pss_rsae_sha256, v); +} + +// Temporary test to verify that we choke on an empty ClientKeyShare. +// This test will fail when we implement HelloRetryRequest. +TEST_P(TlsExtensionTest13, EmptyClientKeyShare) { + ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>( + client_, ssl_tls13_key_share_xtn, 2), + kTlsAlertHandshakeFailure); +} + +// These tests only work in stream mode because the client sends a +// cleartext alert which causes a MAC error on the server. With +// stream this causes handshake failure but with datagram, the +// packet gets dropped. +TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) { + EnsureTlsSetup(); + MakeTlsFilter<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn); + client_->ExpectSendAlert(kTlsAlertMissingExtension); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_MISSING_KEY_SHARE, client_->error_code()); + EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code()); +} + +TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) { + const uint16_t wrong_group = ssl_grp_ec_secp384r1; + + static const uint8_t key_share[] = { + wrong_group >> 8, + wrong_group & 0xff, // Group we didn't offer. + 0x00, + 0x02, // length = 2 + 0x01, + 0x02}; + DataBuffer buf(key_share, sizeof(key_share)); + EnsureTlsSetup(); + MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_key_share_xtn, buf); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_KEY_SHARE, client_->error_code()); + EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code()); +} + +TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) { + const uint16_t wrong_group = 0xffff; + + static const uint8_t key_share[] = { + wrong_group >> 8, + wrong_group & 0xff, // Group we didn't offer. + 0x00, + 0x02, // length = 2 + 0x01, + 0x02}; + DataBuffer buf(key_share, sizeof(key_share)); + EnsureTlsSetup(); + MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_key_share_xtn, buf); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_KEY_SHARE, client_->error_code()); + EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code()); +} + +TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) { + SetupForResume(); + DataBuffer empty; + MakeTlsFilter<TlsExtensionInjector>(server_, ssl_signature_algorithms_xtn, + empty); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_EXTENSION_DISALLOWED_FOR_VERSION, client_->error_code()); + EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code()); +} + +struct PskIdentity { + DataBuffer identity; + uint32_t obfuscated_ticket_age; +}; + +class TlsPreSharedKeyReplacer; + +typedef std::function<void(TlsPreSharedKeyReplacer*)> + TlsPreSharedKeyReplacerFunc; + +class TlsPreSharedKeyReplacer : public TlsExtensionFilter { + public: + TlsPreSharedKeyReplacer(const std::shared_ptr<TlsAgent>& a, + TlsPreSharedKeyReplacerFunc function) + : TlsExtensionFilter(a), identities_(), binders_(), function_(function) {} + + static size_t CopyAndMaybeReplace(TlsParser* parser, size_t size, + const std::unique_ptr<DataBuffer>& replace, + size_t index, DataBuffer* output) { + DataBuffer tmp; + bool ret = parser->ReadVariable(&tmp, size); + EXPECT_EQ(true, ret); + if (!ret) return 0; + if (replace) { + tmp = *replace; + } + + return WriteVariable(output, index, tmp, size); + } + + PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) { + if (extension_type != ssl_tls13_pre_shared_key_xtn) { + return KEEP; + } + + if (!Decode(input)) { + return KEEP; + } + + // Call the function. + function_(this); + + Encode(output); + + return CHANGE; + } + + std::vector<PskIdentity> identities_; + std::vector<DataBuffer> binders_; + + private: + bool Decode(const DataBuffer& input) { + std::unique_ptr<TlsParser> parser(new TlsParser(input)); + DataBuffer identities; + + if (!parser->ReadVariable(&identities, 2)) { + ADD_FAILURE(); + return false; + } + + DataBuffer binders; + if (!parser->ReadVariable(&binders, 2)) { + ADD_FAILURE(); + return false; + } + EXPECT_EQ(0UL, parser->remaining()); + + // Now parse the inner sections. + parser.reset(new TlsParser(identities)); + while (parser->remaining()) { + PskIdentity identity; + + if (!parser->ReadVariable(&identity.identity, 2)) { + ADD_FAILURE(); + return false; + } + + if (!parser->Read(&identity.obfuscated_ticket_age, 4)) { + ADD_FAILURE(); + return false; + } + + identities_.push_back(identity); + } + + parser.reset(new TlsParser(binders)); + while (parser->remaining()) { + DataBuffer binder; + + if (!parser->ReadVariable(&binder, 1)) { + ADD_FAILURE(); + return false; + } + + binders_.push_back(binder); + } + + return true; + } + + void Encode(DataBuffer* output) { + DataBuffer identities; + size_t index = 0; + for (auto id : identities_) { + index = WriteVariable(&identities, index, id.identity, 2); + index = identities.Write(index, id.obfuscated_ticket_age, 4); + } + + DataBuffer binders; + index = 0; + for (auto binder : binders_) { + index = WriteVariable(&binders, index, binder, 1); + } + + output->Truncate(0); + index = 0; + index = WriteVariable(output, index, identities, 2); + index = WriteVariable(output, index, binders, 2); + } + + TlsPreSharedKeyReplacerFunc function_; +}; + +TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) { + SetupForResume(); + + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { + r->identities_[0].identity.Truncate(0); + }); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +// Flip the first byte of the binder. +TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) { + SetupForResume(); + + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { + r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1); + }); + ConnectExpectAlert(server_, kTlsAlertDecryptError); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); +} + +// Do the same with an External PSK. +TEST_P(TlsConnectTls13, TestTls13PskInvalidBinderValue) { + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + ASSERT_TRUE(!!slot); + ScopedPK11SymKey key( + PK11_KeyGen(slot.get(), CKM_HKDF_KEY_GEN, nullptr, 16, nullptr)); + ASSERT_TRUE(!!key); + AddPsk(key, std::string("foo"), ssl_hash_sha256); + StartConnect(); + ASSERT_TRUE(client_->MaybeSetResumptionToken()); + + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { + r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1); + }); + ConnectExpectAlert(server_, kTlsAlertDecryptError); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); +} + +// Extend the binder by one. +TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) { + SetupForResume(); + + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { + r->binders_[0].Write(r->binders_[0].len(), 0xff, 1); + }); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +// Binders must be at least 32 bytes. +TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) { + SetupForResume(); + + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); }); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +// Duplicate the identity and binder. This will fail with an error +// processing the binder (because we extended the identity list.) +TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) { + SetupForResume(); + + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { + r->identities_.push_back(r->identities_[0]); + r->binders_.push_back(r->binders_[0]); + }); + ConnectExpectAlert(server_, kTlsAlertDecryptError); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); +} + +// The next two tests have mismatches in the number of identities +// and binders. This generates an illegal parameter alert. +TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) { + SetupForResume(); + + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { + r->identities_.push_back(r->identities_[0]); + }); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) { + SetupForResume(); + + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { + r->binders_.push_back(r->binders_[0]); + }); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) { + SetupForResume(); + + const uint8_t empty_buf[] = {0}; + DataBuffer empty(empty_buf, 0); + // Inject an unused extension after the PSK extension. + MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello, 0xffff, + empty); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) { + SetupForResume(); + + DataBuffer empty; + MakeTlsFilter<TlsExtensionDropper>(client_, + ssl_tls13_psk_key_exchange_modes_xtn); + ConnectExpectAlert(server_, kTlsAlertMissingExtension); + client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES); +} + +// The following test contains valid but unacceptable PreSharedKey +// modes and therefore produces non-resumption followed by MAC +// errors. +TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) { + SetupForResume(); + const static uint8_t ke_modes[] = {1, // Length + kTls13PskKe}; + + DataBuffer modes(ke_modes, sizeof(ke_modes)); + MakeTlsFilter<TlsExtensionReplacer>( + client_, ssl_tls13_psk_key_exchange_modes_xtn, modes); + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); +} + +TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) { + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + auto capture = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_psk_key_exchange_modes_xtn); + Connect(); + EXPECT_FALSE(capture->captured()); +} + +// In these tests, we downgrade to TLS 1.2, causing the +// server to negotiate TLS 1.2. +// 1. Both sides only support TLS 1.3, so we get a cipher version +// error. +TEST_P(TlsExtensionTest13, RemoveTls13FromVersionList) { + ExpectAlert(server_, kTlsAlertProtocolVersion); + ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2); + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); +} + +// 2. Server supports 1.2 and 1.3, client supports 1.2, so we +// can't negotiate any ciphers. +TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListServerV12) { + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + ExpectAlert(server_, kTlsAlertHandshakeFailure); + ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// 3. Server supports 1.2 and 1.3, client supports 1.2 and 1.3 +// but advertises 1.2 (because we changed things). +TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListBothV12) { + client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_TRUE); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); +// The downgrade check is disabled in DTLS 1.3, so all that happens when we +// tamper with the supported versions is that the Finished check fails. +#ifdef DTLS_1_3_DRAFT_VERSION + if (variant_ == ssl_variant_datagram) { + ExpectAlert(server_, kTlsAlertDecryptError); + } else +#endif + { + ExpectAlert(client_, kTlsAlertIllegalParameter); + } + ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2); +#ifdef DTLS_1_3_DRAFT_VERSION + if (variant_ == ssl_variant_datagram) { + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); + } else +#endif + { + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + } +} + +TEST_P(TlsExtensionTest13, HrrThenRemoveSignatureAlgorithms) { + ExpectAlert(server_, kTlsAlertMissingExtension); + HrrThenRemoveExtensionsTest(ssl_signature_algorithms_xtn, + SSL_ERROR_MISSING_EXTENSION_ALERT, + SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION); +} + +TEST_P(TlsExtensionTest13, HrrThenRemoveKeyShare) { + ExpectAlert(server_, kTlsAlertIllegalParameter); + HrrThenRemoveExtensionsTest(ssl_tls13_key_share_xtn, + SSL_ERROR_ILLEGAL_PARAMETER_ALERT, + SSL_ERROR_BAD_2ND_CLIENT_HELLO); +} + +TEST_P(TlsExtensionTest13, HrrThenRemoveSupportedGroups) { + ExpectAlert(server_, kTlsAlertMissingExtension); + HrrThenRemoveExtensionsTest(ssl_supported_groups_xtn, + SSL_ERROR_MISSING_EXTENSION_ALERT, + SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION); +} + +TEST_P(TlsExtensionTest13, EmptyVersionList) { + static const uint8_t kExt[] = {0x00, 0x00}; + ConnectWithBogusVersionList(kExt, sizeof(kExt)); +} + +TEST_P(TlsExtensionTest13, OddVersionList) { + static const uint8_t kExt[] = {0x00, 0x01, 0x00}; + ConnectWithBogusVersionList(kExt, sizeof(kExt)); +} + +TEST_P(TlsExtensionTest13, SignatureAlgorithmsInvalidTls13) { + // testing the case where we ask for a invalid parameter for tls13 + const uint8_t val[] = {0x00, 0x02, 0x04, 0x01}; // sha-256, rsa-pkcs1 + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_signature_algorithms_xtn, extension), + kTlsAlertHandshakeFailure); +} + +// Use the stream version number for TLS 1.3 (0x0304) in DTLS. +TEST_F(TlsConnectDatagram13, TlsVersionInDtls) { + static const uint8_t kExt[] = {0x02, 0x03, 0x04}; + + DataBuffer versions_buf(kExt, sizeof(kExt)); + MakeTlsFilter<TlsExtensionReplacer>(client_, ssl_tls13_supported_versions_xtn, + versions_buf); + ConnectExpectAlert(server_, kTlsAlertProtocolVersion); + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); +} + +// TODO: this only tests extensions in server messages. The client can extend +// Certificate messages, which is not checked here. +class TlsBogusExtensionTest : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsBogusExtensionTest() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + + protected: + virtual void ConnectAndFail(uint8_t message) = 0; + + void AddFilter(uint8_t message, uint16_t extension) { + static uint8_t empty_buf[1] = {0}; + DataBuffer empty(empty_buf, 0); + auto filter = + MakeTlsFilter<TlsExtensionAppender>(server_, message, extension, empty); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + filter->EnableDecryption(); + } + } + + void Run(uint8_t message, uint16_t extension = 0xff) { + EnsureTlsSetup(); + AddFilter(message, extension); + ConnectAndFail(message); + } +}; + +class TlsBogusExtensionTestPre13 : public TlsBogusExtensionTest { + protected: + void ConnectAndFail(uint8_t) override { + ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension); + } +}; + +class TlsBogusExtensionTest13 : public TlsBogusExtensionTest { + protected: + void ConnectAndFail(uint8_t message) override { + if (message != kTlsHandshakeServerHello) { + ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension); + return; + } + + FailWithAlert(kTlsAlertUnsupportedExtension); + } + + void FailWithAlert(uint8_t alert) { + StartConnect(); + client_->Handshake(); // ClientHello + server_->Handshake(); // ServerHello + + client_->ExpectSendAlert(alert); + client_->Handshake(); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + } + server_->Handshake(); + } +}; + +TEST_P(TlsBogusExtensionTestPre13, AddBogusExtensionServerHello) { + Run(kTlsHandshakeServerHello); +} + +TEST_P(TlsBogusExtensionTest13, AddBogusExtensionServerHello) { + Run(kTlsHandshakeServerHello); +} + +TEST_P(TlsBogusExtensionTest13, AddBogusExtensionEncryptedExtensions) { + Run(kTlsHandshakeEncryptedExtensions); +} + +TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificate) { + Run(kTlsHandshakeCertificate); +} + +// It's perfectly valid to set unknown extensions in CertificateRequest. +TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificateRequest) { + server_->RequestClientAuth(false); + AddFilter(kTlsHandshakeCertificateRequest, 0xff); + ConnectExpectAlert(client_, kTlsAlertDecryptError); + client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); +} + +TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) { + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + Run(kTlsHandshakeHelloRetryRequest); +} + +// NewSessionTicket allows unknown extensions AND it isn't protected by the +// Finished. So adding an unknown extension doesn't cause an error. +TEST_P(TlsBogusExtensionTest13, AddBogusExtensionNewSessionTicket) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + + AddFilter(kTlsHandshakeNewSessionTicket, 0xff); + Connect(); + SendReceive(); + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); +} + +class TlsDisallowedExtensionTest13 : public TlsBogusExtensionTest { + protected: + void ConnectAndFail(uint8_t message) override { + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + } +}; + +TEST_P(TlsDisallowedExtensionTest13, AddVersionExtensionEncryptedExtensions) { + Run(kTlsHandshakeEncryptedExtensions, ssl_tls13_supported_versions_xtn); +} + +TEST_P(TlsDisallowedExtensionTest13, AddVersionExtensionCertificate) { + Run(kTlsHandshakeCertificate, ssl_tls13_supported_versions_xtn); +} + +TEST_P(TlsDisallowedExtensionTest13, AddVersionExtensionCertificateRequest) { + server_->RequestClientAuth(false); + Run(kTlsHandshakeCertificateRequest, ssl_tls13_supported_versions_xtn); +} + +/* For unadvertised disallowed extensions an unsupported_extension alert is + * thrown since NSS checks for unadvertised extensions before its disallowed + * extension check. */ +class TlsDisallowedUnadvertisedExtensionTest13 : public TlsBogusExtensionTest { + protected: + void ConnectAndFail(uint8_t message) override { + uint8_t alert = kTlsAlertUnsupportedExtension; + if (message == kTlsHandshakeCertificateRequest) { + alert = kTlsAlertIllegalParameter; + } + ConnectExpectAlert(client_, alert); + } +}; + +TEST_P(TlsDisallowedUnadvertisedExtensionTest13, + AddPSKExtensionEncryptedExtensions) { + Run(kTlsHandshakeEncryptedExtensions, ssl_tls13_pre_shared_key_xtn); +} + +TEST_P(TlsDisallowedUnadvertisedExtensionTest13, AddPSKExtensionCertificate) { + Run(kTlsHandshakeCertificate, ssl_tls13_pre_shared_key_xtn); +} + +TEST_P(TlsDisallowedUnadvertisedExtensionTest13, + AddPSKExtensionCertificateRequest) { + server_->RequestClientAuth(false); + Run(kTlsHandshakeCertificateRequest, ssl_tls13_pre_shared_key_xtn); +} + +TEST_P(TlsConnectStream, IncludePadding) { + EnsureTlsSetup(); + SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_FALSE); // Don't GREASE + + // This needs to be long enough to push a TLS 1.0 ClientHello over 255, but + // short enough not to push a TLS 1.3 ClientHello over 511. + static const char* long_name = + "chickenchickenchickenchickenchickenchickenchickenchicken." + "chickenchickenchickenchickenchickenchickenchickenchicken." + "chickenchickenchickenchickenchicken."; + SECStatus rv = SSL_SetURL(client_->ssl_fd(), long_name); + EXPECT_EQ(SECSuccess, rv); + + auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, ssl_padding_xtn); + client_->StartConnect(); + client_->Handshake(); + EXPECT_TRUE(capture->captured()); +} + +TEST_F(TlsConnectDatagram13, Dtls13RejectLegacyCookie) { + EnsureTlsSetup(); + MakeTlsFilter<Dtls13LegacyCookieInjector>(client_); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectGeneric, ClientHelloExtensionPermutation) { + EnsureTlsSetup(); + PR_ASSERT(SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_CH_EXTENSION_PERMUTATION, + PR_TRUE) == SECSuccess); + Connect(); +} + +TEST_F(TlsConnectStreamTls13, ClientHelloExtensionPermutationWithPSK) { + EnsureTlsSetup(); + + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + const uint8_t kPskDummyVal_[16] = {0x01, 0x02, 0x03, 0x04, 0x05, + 0x06, 0x07, 0x08, 0x09, 0x0a, + 0x0b, 0x0c, 0x0d, 0x0e, 0x0f}; + SECItem psk_item; + psk_item.type = siBuffer; + psk_item.len = sizeof(kPskDummyVal_); + psk_item.data = const_cast<uint8_t*>(kPskDummyVal_); + PK11SymKey* key = + PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN, PK11_OriginUnwrap, + CKA_DERIVE, &psk_item, NULL); + + ScopedPK11SymKey scoped_psk_(key); + const std::string kPskDummyLabel_ = "NSS PSK GTEST label"; + const SSLHashType kPskHash_ = ssl_hash_sha384; + AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + + PR_ASSERT(SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_CH_EXTENSION_PERMUTATION, + PR_TRUE) == SECSuccess); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none); +} + +/* This test checks that the ClientHello extension order is actually permuted + * if ss->opt.chXtnPermutation is set. It is asserted that at least one out of + * 10 extension orders differs from the others. + * + * This is a probabilistic test: The default TLS 1.3 ClientHello contains 8 + * extensions, leading to a 1/8! probability for any extension order and the + * same probability for two drawn extension orders to coincide. + * Since all sequences are compared against each other this leads to a false + * positive rate of (1/8!)^(n^2-n). + * To achieve a spurious failure rate << 1/2^64, we compare n=10 drawn orders. + * + * This test assures that randomisation is happening but does not check quality + * of the used Fisher-Yates shuffle. */ +TEST_F(TlsConnectStreamTls13, + ClientHelloExtensionPermutationProbabilisticTest) { + std::vector<std::vector<uint16_t>> orders; + + /* Capture the extension order of 10 ClientHello messages. */ + for (size_t i = 0; i < 10; i++) { + client_->StartConnect(); + /* Enable ClientHello extension permutation. */ + ASSERT_TRUE(SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_CH_EXTENSION_PERMUTATION, + PR_TRUE) == SECSuccess); + /* Capture extension order filter. */ + auto filter = MakeTlsFilter<TlsExtensionOrderCapture>( + client_, kTlsHandshakeClientHello); + /* Send ClientHello. */ + client_->Handshake(); + /* Remember extension order. */ + orders.push_back(filter->order); + /* Reset client / server state. */ + Reset(); + } + + /* Check for extension order inequality. */ + size_t inequal = 0; + for (auto& outerOrders : orders) { + for (auto& innerOrders : orders) { + if (outerOrders != innerOrders) { + inequal++; + } + } + } + ASSERT_TRUE(inequal >= 1); +} + +// The certificate_authorities xtn can be included in a ClientHello [RFC 8446, +// Section 4.2] +TEST_F(TlsConnectStreamTls13, ClientHelloCertAuthXtnToleration) { + EnsureTlsSetup(); + uint8_t bodyBuf[3] = {0x00, 0x01, 0xff}; + DataBuffer body(bodyBuf, sizeof(bodyBuf)); + auto ch = MakeTlsFilter<TlsExtensionAppender>( + client_, kTlsHandshakeClientHello, ssl_tls13_certificate_authorities_xtn, + body); + // The Connection will fail because the added extension isn't in the client's + // transcript not because the extension is unsupported (Bug 1815167). + server_->ExpectSendAlert(bad_record_mac); + client_->ExpectSendAlert(bad_record_mac); + ConnectExpectFail(); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); +} + +INSTANTIATE_TEST_SUITE_P( + ExtensionStream, TlsExtensionTestGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_SUITE_P( + ExtensionDatagram, TlsExtensionTestGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus)); +INSTANTIATE_TEST_SUITE_P(ExtensionDatagramOnly, TlsExtensionTestDtls, + TlsConnectTestBase::kTlsV11Plus); + +INSTANTIATE_TEST_SUITE_P(ExtensionTls12, TlsExtensionTest12, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12)); + +INSTANTIATE_TEST_SUITE_P(ExtensionTls12Plus, TlsExtensionTest12Plus, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12Plus)); + +INSTANTIATE_TEST_SUITE_P( + ExtensionPre13Stream, TlsExtensionTestPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); +INSTANTIATE_TEST_SUITE_P(ExtensionPre13Datagram, TlsExtensionTestPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV11V12)); + +INSTANTIATE_TEST_SUITE_P(ExtensionTls13, TlsExtensionTest13, + TlsConnectTestBase::kTlsVariantsAll); + +INSTANTIATE_TEST_SUITE_P( + BogusExtensionStream, TlsBogusExtensionTestPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); +INSTANTIATE_TEST_SUITE_P( + BogusExtensionDatagram, TlsBogusExtensionTestPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11V12)); + +INSTANTIATE_TEST_SUITE_P(BogusExtension13, TlsBogusExtensionTest13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13)); + +INSTANTIATE_TEST_SUITE_P(DisallowedExtension13, TlsDisallowedExtensionTest13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13)); + +INSTANTIATE_TEST_SUITE_P(DisallowedUnadvertisedExtension13, + TlsDisallowedUnadvertisedExtensionTest13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13)); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc new file mode 100644 index 0000000000..3752812633 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc @@ -0,0 +1,169 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +// This class cuts every unencrypted handshake record into two parts. +class RecordFragmenter : public PacketFilter { + public: + RecordFragmenter(bool is_dtls13) + : is_dtls13_(is_dtls13), sequence_number_(0), splitting_(true) {} + + private: + class HandshakeSplitter { + public: + HandshakeSplitter(bool is_dtls13, const DataBuffer& input, + DataBuffer* output, uint64_t* sequence_number) + : is_dtls13_(is_dtls13), + input_(input), + output_(output), + cursor_(0), + sequence_number_(sequence_number) {} + + private: + void WriteRecord(TlsRecordHeader& record_header, + DataBuffer& record_fragment) { + TlsRecordHeader fragment_header( + record_header.variant(), record_header.version(), + record_header.content_type(), *sequence_number_); + ++*sequence_number_; + if (::g_ssl_gtest_verbose) { + std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment + << std::endl; + } + cursor_ = fragment_header.Write(output_, cursor_, record_fragment); + } + + bool SplitRecord(TlsRecordHeader& record_header, DataBuffer& record) { + TlsParser parser(record); + while (parser.remaining()) { + TlsHandshakeFilter::HandshakeHeader handshake_header; + DataBuffer handshake_body; + bool complete = false; + if (!handshake_header.Parse(&parser, record_header, DataBuffer(), + &handshake_body, &complete)) { + ADD_FAILURE() << "couldn't parse handshake header"; + return false; + } + if (!complete) { + ADD_FAILURE() << "don't want to deal with fragmented messages"; + return false; + } + + DataBuffer record_fragment; + // We can't fragment handshake records that are too small. + if (handshake_body.len() < 2) { + handshake_header.Write(&record_fragment, 0U, handshake_body); + WriteRecord(record_header, record_fragment); + continue; + } + + size_t cut = handshake_body.len() / 2; + handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, 0U, + cut); + WriteRecord(record_header, record_fragment); + + handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, + cut, handshake_body.len() - cut); + WriteRecord(record_header, record_fragment); + } + return true; + } + + public: + bool Split() { + TlsParser parser(input_); + while (parser.remaining()) { + TlsRecordHeader header; + DataBuffer record; + if (!header.Parse(is_dtls13_, 0, &parser, &record)) { + ADD_FAILURE() << "bad record header"; + return false; + } + + if (::g_ssl_gtest_verbose) { + std::cerr << "Record: " << header << ' ' << record << std::endl; + } + + // Don't touch packets from a non-zero epoch. Leave these unmodified. + if ((header.sequence_number() >> 48) != 0ULL) { + cursor_ = header.Write(output_, cursor_, record); + continue; + } + + // Just rewrite the sequence number (CCS only). + if (header.content_type() != ssl_ct_handshake) { + EXPECT_EQ(ssl_ct_change_cipher_spec, header.content_type()); + WriteRecord(header, record); + continue; + } + + if (!SplitRecord(header, record)) { + return false; + } + } + return true; + } + + private: + bool is_dtls13_; + const DataBuffer& input_; + DataBuffer* output_; + size_t cursor_; + uint64_t* sequence_number_; + }; + + protected: + virtual PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) override { + if (!splitting_) { + return KEEP; + } + + output->Allocate(input.len()); + HandshakeSplitter splitter(is_dtls13_, input, output, &sequence_number_); + if (!splitter.Split()) { + // If splitting fails, we obviously reached encrypted packets. + // Stop splitting from that point onward. + splitting_ = false; + return KEEP; + } + + return CHANGE; + } + + private: + bool is_dtls13_; + uint64_t sequence_number_; + bool splitting_; +}; + +TEST_P(TlsConnectDatagram, FragmentClientPackets) { + bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3; + client_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13)); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectDatagram, FragmentServerPackets) { + bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3; + server_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13)); + Connect(); + SendReceive(); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc new file mode 100644 index 0000000000..ef6f7602cf --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc @@ -0,0 +1,252 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "blapi.h" +#include "ssl.h" +#include "sslimpl.h" +#include "tls_connect.h" + +#include "gtest/gtest.h" + +namespace nss_test { + +#ifdef UNSAFE_FUZZER_MODE +#define FUZZ_F(c, f) TEST_F(c, Fuzz_##f) +#define FUZZ_P(c, f) TEST_P(c, Fuzz_##f) +#else +#define FUZZ_F(c, f) TEST_F(c, DISABLED_Fuzz_##f) +#define FUZZ_P(c, f) TEST_P(c, DISABLED_Fuzz_##f) +#endif + +const uint8_t kShortEmptyFinished[8] = {0}; +const uint8_t kLongEmptyFinished[128] = {0}; + +class TlsFuzzTest : public TlsConnectGeneric {}; + +// Record the application data stream. +class TlsApplicationDataRecorder : public TlsRecordFilter { + public: + TlsApplicationDataRecorder(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a), buffer_() {} + + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& input, + DataBuffer* output) { + if (header.content_type() == ssl_ct_application_data) { + buffer_.Append(input); + } + + return KEEP; + } + + const DataBuffer& buffer() const { return buffer_; } + + private: + DataBuffer buffer_; +}; + +// Check that due to the deterministic PRNG we derive +// the same master secret in two consecutive TLS sessions. +FUZZ_P(TlsFuzzTest, DeterministicExporter) { + const char kLabel[] = "label"; + std::vector<unsigned char> out1(32), out2(32); + + // Make sure we have RSA blinding params. + Connect(); + + Reset(); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + + // Reset the RNG state. + EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0)); + Connect(); + + // Export a key derived from the MS and nonces. + SECStatus rv = + SSL_ExportKeyingMaterial(client_->ssl_fd(), kLabel, strlen(kLabel), false, + NULL, 0, out1.data(), out1.size()); + EXPECT_EQ(SECSuccess, rv); + + Reset(); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + + // Reset the RNG state. + EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0)); + Connect(); + + // Export another key derived from the MS and nonces. + rv = SSL_ExportKeyingMaterial(client_->ssl_fd(), kLabel, strlen(kLabel), + false, NULL, 0, out2.data(), out2.size()); + EXPECT_EQ(SECSuccess, rv); + + // The two exported keys should be the same. + EXPECT_EQ(out1, out2); +} + +// Check that due to the deterministic RNG two consecutive +// TLS sessions will have the exact same transcript. +FUZZ_P(TlsFuzzTest, DeterministicTranscript) { + // Make sure we have RSA blinding params. + Connect(); + + // Connect a few times and compare the transcripts byte-by-byte. + DataBuffer last; + for (size_t i = 0; i < 5; i++) { + Reset(); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + + DataBuffer buffer; + MakeTlsFilter<TlsConversationRecorder>(client_, buffer); + MakeTlsFilter<TlsConversationRecorder>(server_, buffer); + + // Reset the RNG state. + EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0)); + Connect(); + + // Ensure the filters go away before |buffer| does. + client_->ClearFilter(); + server_->ClearFilter(); + + if (last.len() > 0) { + EXPECT_EQ(last, buffer); + } + + last = buffer; + } +} + +// Check that we can establish and use a connection +// with all supported TLS versions, STREAM and DGRAM. +// Check that records are NOT encrypted. +// Check that records don't have a MAC. +FUZZ_P(TlsFuzzTest, ConnectSendReceive_NullCipher) { + // Set up app data filters. + auto client_recorder = MakeTlsFilter<TlsApplicationDataRecorder>(client_); + auto server_recorder = MakeTlsFilter<TlsApplicationDataRecorder>(server_); + + Connect(); + + // Construct the plaintext. + DataBuffer buf; + buf.Allocate(50); + for (size_t i = 0; i < buf.len(); ++i) { + buf.data()[i] = i & 0xff; + } + + // Send/Receive data. + client_->SendBuffer(buf); + server_->SendBuffer(buf); + Receive(buf.len()); + + // Check for plaintext on the wire. + EXPECT_EQ(buf, client_recorder->buffer()); + EXPECT_EQ(buf, server_recorder->buffer()); +} + +// Check that an invalid Finished message doesn't abort the connection. +FUZZ_P(TlsFuzzTest, BogusClientFinished) { + EnsureTlsSetup(); + + MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>( + client_, kTlsHandshakeFinished, + DataBuffer(kShortEmptyFinished, sizeof(kShortEmptyFinished))); + Connect(); + SendReceive(); +} + +// Check that an invalid Finished message doesn't abort the connection. +FUZZ_P(TlsFuzzTest, BogusServerFinished) { + EnsureTlsSetup(); + + MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>( + server_, kTlsHandshakeFinished, + DataBuffer(kLongEmptyFinished, sizeof(kLongEmptyFinished))); + Connect(); + SendReceive(); +} + +// Check that an invalid server auth signature doesn't abort the connection. +FUZZ_P(TlsFuzzTest, BogusServerAuthSignature) { + EnsureTlsSetup(); + uint8_t msg_type = version_ == SSL_LIBRARY_VERSION_TLS_1_3 + ? kTlsHandshakeCertificateVerify + : kTlsHandshakeServerKeyExchange; + MakeTlsFilter<TlsLastByteDamager>(server_, msg_type); + Connect(); + SendReceive(); +} + +// Check that an invalid client auth signature doesn't abort the connection. +FUZZ_P(TlsFuzzTest, BogusClientAuthSignature) { + EnsureTlsSetup(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + MakeTlsFilter<TlsLastByteDamager>(client_, kTlsHandshakeCertificateVerify); + Connect(); +} + +// Check that session ticket resumption works. +FUZZ_P(TlsFuzzTest, SessionTicketResumption) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); +} + +// Check that session tickets are not encrypted. +FUZZ_P(TlsFuzzTest, UnencryptedSessionTickets) { + ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); + + auto filter = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeNewSessionTicket); + Connect(); + + std::cerr << "ticket" << filter->buffer() << std::endl; + size_t offset = 4; // Skip lifetime. + + if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { + offset += 4; // Skip ticket_age_add. + uint32_t nonce_len = 0; + EXPECT_TRUE(filter->buffer().Read(offset, 1, &nonce_len)); + offset += 1 + nonce_len; + } + + offset += 2; // Skip the ticket length. + + // This bit parses the contents of the ticket, which would ordinarily be + // encrypted. Start by checking that we have the right version. This needs + // to be updated every time that TLS_EX_SESS_TICKET_VERSION is changed. But + // we don't use the #define. That way, any time that code is updated, this + // test will fail unless it is manually checked. + uint32_t ticket_version; + EXPECT_TRUE(filter->buffer().Read(offset, 2, &ticket_version)); + EXPECT_EQ(0x010aU, ticket_version); + offset += 2; + + // Check the protocol version number. + uint32_t tls_version = 0; + EXPECT_TRUE(filter->buffer().Read(offset, sizeof(version_), &tls_version)); + EXPECT_EQ(version_, static_cast<decltype(version_)>(tls_version)); + offset += sizeof(version_); + + // Check the cipher suite. + uint32_t suite = 0; + EXPECT_TRUE(filter->buffer().Read(offset, 2, &suite)); + client_->CheckCipherSuite(static_cast<uint16_t>(suite)); +} + +INSTANTIATE_TEST_SUITE_P( + FuzzStream, TlsFuzzTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_SUITE_P( + FuzzDatagram, TlsFuzzTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus)); +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_gather_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_gather_unittest.cc new file mode 100644 index 0000000000..2b0b722ae2 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_gather_unittest.cc @@ -0,0 +1,156 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "gtest_utils.h" +#include "tls_connect.h" + +namespace nss_test { + +class GatherV2ClientHelloTest : public TlsConnectTestBase { + public: + GatherV2ClientHelloTest() : TlsConnectTestBase(ssl_variant_stream, 0) {} + + void ConnectExpectMalformedClientHello(const DataBuffer &data) { + EnsureTlsSetup(); + server_->SetOption(SSL_ENABLE_V2_COMPATIBLE_HELLO, PR_TRUE); + server_->ExpectSendAlert(kTlsAlertIllegalParameter); + client_->SendDirect(data); + server_->StartConnect(); + server_->Handshake(); + ASSERT_TRUE_WAIT( + (server_->error_code() == SSL_ERROR_RX_MALFORMED_CLIENT_HELLO), 2000); + } +}; + +// Gather a 5-byte v3 record, with a fragment length exceeding the maximum. +TEST_F(TlsConnectTest, GatherExcessiveV3Record) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x16, 1); // handshake + idx = buffer.Write(idx, 0x0301, 2); // record_version + (void)buffer.Write(idx, MAX_FRAGMENT_LENGTH + 2048 + 1, 2); // length=max+1 + + EnsureTlsSetup(); + server_->ExpectSendAlert(kTlsAlertRecordOverflow); + client_->SendDirect(buffer); + server_->StartConnect(); + server_->Handshake(); + ASSERT_TRUE_WAIT((server_->error_code() == SSL_ERROR_RX_RECORD_TOO_LONG), + 2000); +} + +// Gather a 3-byte v2 header, with a fragment length of 2. +TEST_F(GatherV2ClientHelloTest, GatherV2RecordLongHeader) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x0002, 2); // length=2 (long header) + idx = buffer.Write(idx, 0U, 1); // padding=0 + (void)buffer.Write(idx, 0U, 2); // data + + ConnectExpectMalformedClientHello(buffer); +} + +// Gather a 3-byte v2 header, with a fragment length of 1. +TEST_F(GatherV2ClientHelloTest, GatherV2RecordLongHeader2) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x0001, 2); // length=1 (long header) + idx = buffer.Write(idx, 0U, 1); // padding=0 + idx = buffer.Write(idx, 0U, 1); // data + (void)buffer.Write(idx, 0U, 1); // surplus (need 5 bytes total) + + ConnectExpectMalformedClientHello(buffer); +} + +// Gather a 3-byte v2 header, with a zero fragment length. +TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordLongHeader) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0U, 2); // length=0 (long header) + idx = buffer.Write(idx, 0U, 1); // padding=0 + (void)buffer.Write(idx, 0U, 2); // surplus (need 5 bytes total) + + ConnectExpectMalformedClientHello(buffer); +} + +// Gather a 2-byte v2 header, with a fragment length of 3. +TEST_F(GatherV2ClientHelloTest, GatherV2RecordShortHeader) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x8003, 2); // length=3 (short header) + (void)buffer.Write(idx, 0U, 3); // data + + ConnectExpectMalformedClientHello(buffer); +} + +// Gather a 2-byte v2 header, with a fragment length of 2. +TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordShortHeader2) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x8002, 2); // length=2 (short header) + idx = buffer.Write(idx, 0U, 2); // data + (void)buffer.Write(idx, 0U, 1); // surplus (need 5 bytes total) + + ConnectExpectMalformedClientHello(buffer); +} + +// Gather a 2-byte v2 header, with a fragment length of 1. +TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordShortHeader3) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x8001, 2); // length=1 (short header) + idx = buffer.Write(idx, 0U, 1); // data + (void)buffer.Write(idx, 0U, 2); // surplus (need 5 bytes total) + + ConnectExpectMalformedClientHello(buffer); +} + +// Gather a 2-byte v2 header, with a zero fragment length. +TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordShortHeader) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x8000, 2); // length=0 (short header) + (void)buffer.Write(idx, 0U, 3); // surplus (need 5 bytes total) + + ConnectExpectMalformedClientHello(buffer); +} + +/* Test correct gather buffer clearing/freeing and (re-)allocation. + * + * Freeing and (re-)allocation of the gather buffers after reception of single + * records is only done in DEBUG builds. Normally they are created and + * destroyed with the SSL socket. + * + * TLS 1.0 record splitting leads to implicit complete read of the data. + * + * The NSS DTLS impelmentation does not allow partial reads + * (see sslsecur.c, line 535-543). */ +TEST_P(TlsConnectStream, GatherBufferPartialReadTest) { + EnsureTlsSetup(); + Connect(); + + client_->SendData(1000); + + if (version_ > SSL_LIBRARY_VERSION_TLS_1_0) { + for (unsigned i = 1; i <= 20; i++) { + server_->ReadBytes(50); + ASSERT_EQ(server_->received_bytes(), 50U * i); + } + } else { + server_->ReadBytes(50); + ASSERT_EQ(server_->received_bytes(), 1000U); + } +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.cc b/security/nss/gtests/ssl_gtest/ssl_gtest.cc new file mode 100644 index 0000000000..2fff9d7cbb --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_gtest.cc @@ -0,0 +1,52 @@ +#include "nspr.h" +#include "nss.h" +#include "prenv.h" +#include "ssl.h" + +#include <cstdlib> + +#include "test_io.h" +#include "databuffer.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" + +std::string g_working_dir_path; +bool g_ssl_gtest_verbose; + +int main(int argc, char** argv) { + // Start the tests + ::testing::InitGoogleTest(&argc, argv); + g_working_dir_path = "."; + g_ssl_gtest_verbose = false; + + char* workdir = PR_GetEnvSecure("NSS_GTEST_WORKDIR"); + if (workdir) g_working_dir_path = workdir; + + for (int i = 0; i < argc; i++) { + if (!strcmp(argv[i], "-d")) { + g_working_dir_path = argv[i + 1]; + ++i; + } else if (!strcmp(argv[i], "-v")) { + g_ssl_gtest_verbose = true; + nss_test::DataBuffer::SetLogLimit(16384); + } + } + + if (NSS_Initialize(g_working_dir_path.c_str(), "", "", SECMOD_DB, + NSS_INIT_READONLY) != SECSuccess) { + return 1; + } + if (NSS_SetDomesticPolicy() != SECSuccess) { + return 1; + } + int rv = RUN_ALL_TESTS(); + + if (NSS_Shutdown() != SECSuccess) { + return 1; + } + + nss_test::Poller::Shutdown(); + + return rv; +} diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.gyp b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp new file mode 100644 index 0000000000..d078ce2303 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp @@ -0,0 +1,135 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +{ + 'includes': [ + '../../coreconf/config.gypi', + '../common/gtest.gypi', + ], + 'targets': [ + { + 'target_name': 'ssl_gtest', + 'type': 'executable', + 'sources': [ + 'bloomfilter_unittest.cc', + 'libssl_internals.c', + 'selfencrypt_unittest.cc', + 'ssl_0rtt_unittest.cc', + 'ssl_aead_unittest.cc', + 'ssl_agent_unittest.cc', + 'ssl_auth_unittest.cc', + 'ssl_cert_ext_unittest.cc', + 'ssl_cipherorder_unittest.cc', + 'ssl_ciphersuite_unittest.cc', + 'ssl_custext_unittest.cc', + 'ssl_damage_unittest.cc', + 'ssl_debug_env_unittest.cc', + 'ssl_dhe_unittest.cc', + 'ssl_drop_unittest.cc', + 'ssl_ecdh_unittest.cc', + 'ssl_ems_unittest.cc', + 'ssl_exporter_unittest.cc', + 'ssl_extension_unittest.cc', + 'ssl_fuzz_unittest.cc', + 'ssl_fragment_unittest.cc', + 'ssl_gather_unittest.cc', + 'ssl_gtest.cc', + 'ssl_hrr_unittest.cc', + 'ssl_keyupdate_unittest.cc', + 'ssl_loopback_unittest.cc', + 'ssl_masking_unittest.cc', + 'ssl_misc_unittest.cc', + 'ssl_record_unittest.cc', + 'ssl_recordsep_unittest.cc', + 'ssl_recordsize_unittest.cc', + 'ssl_resumption_unittest.cc', + 'ssl_renegotiation_unittest.cc', + 'ssl_skip_unittest.cc', + 'ssl_staticrsa_unittest.cc', + 'ssl_tls13compat_unittest.cc', + 'ssl_v2_client_hello_unittest.cc', + 'ssl_version_unittest.cc', + 'ssl_versionpolicy_unittest.cc', + 'test_io.cc', + 'tls_agent.cc', + 'tls_connect.cc', + 'tls_filter.cc', + 'tls_hkdf_unittest.cc', + 'tls_ech_unittest.cc', + 'tls_protect.cc', + 'tls_psk_unittest.cc', + 'tls_subcerts_unittest.cc', + 'tls_grease_unittest.cc' + ], + 'dependencies': [ + '<(DEPTH)/exports.gyp:nss_exports', + '<(DEPTH)/lib/util/util.gyp:nssutil3', + '<(DEPTH)/gtests/google_test/google_test.gyp:gtest', + '<(DEPTH)/lib/smime/smime.gyp:smime', + '<(DEPTH)/lib/ssl/ssl.gyp:ssl', + '<(DEPTH)/lib/nss/nss.gyp:nss_static', + '<(DEPTH)/lib/pkcs12/pkcs12.gyp:pkcs12', + '<(DEPTH)/lib/pkcs7/pkcs7.gyp:pkcs7', + '<(DEPTH)/lib/certhigh/certhigh.gyp:certhi', + '<(DEPTH)/lib/cryptohi/cryptohi.gyp:cryptohi', + '<(DEPTH)/lib/certdb/certdb.gyp:certdb', + '<(DEPTH)/lib/pki/pki.gyp:nsspki', + '<(DEPTH)/lib/dev/dev.gyp:nssdev', + '<(DEPTH)/lib/base/base.gyp:nssb', + '<(DEPTH)/lib/zlib/zlib.gyp:nss_zlib', + '<(DEPTH)/cpputil/cpputil.gyp:cpputil', + '<(DEPTH)/lib/libpkix/libpkix.gyp:libpkix', + ], + 'conditions': [ + [ 'static_libs==1', { + 'dependencies': [ + '<(DEPTH)/lib/pk11wrap/pk11wrap.gyp:pk11wrap_static', + ], + }, { + 'dependencies': [ + '<(DEPTH)/lib/sqlite/sqlite.gyp:sqlite3', + '<(DEPTH)/lib/pk11wrap/pk11wrap.gyp:pk11wrap', + '<(DEPTH)/lib/softoken/softoken.gyp:softokn', + '<(DEPTH)/lib/freebl/freebl.gyp:freebl', + ], + }], + [ 'disable_dbm==0', { + 'dependencies': [ + '<(DEPTH)/lib/dbm/src/src.gyp:dbm', + ], + }], + [ 'enable_sslkeylogfile==1 and sanitizer_flags==0', { + 'sources': [ + 'ssl_keylog_unittest.cc', + ], + 'defines': [ + 'NSS_ALLOW_SSLKEYLOGFILE', + ], + }], + # ssl_gtest fuzz defines should only be determined by the 'fuzz_tls' + # flag (so as to match lib/ssl). If gtest.gypi added the define due + # to '--fuzz' only, remove it. + ['fuzz_tls==1', { + 'defines': [ + 'UNSAFE_FUZZER_MODE', + ], + }, { + 'defines!': [ + 'UNSAFE_FUZZER_MODE', + ], + }], + ], + } + ], + 'target_defaults': { + 'include_dirs': [ + '../../lib/ssl' + ], + 'defines': [ + 'NSS_USE_STATIC_LIBS' + ], + }, + 'variables': { + 'module': 'nss', + } +} diff --git a/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc new file mode 100644 index 0000000000..3b81278f47 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc @@ -0,0 +1,1364 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +// This is internal, just to get DTLS_1_3_DRAFT_VERSION. +#include "ssl3prot.h" + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) { + const char* k0RttData = "Such is life"; + const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + + SetupForZeroRtt(); // initial handshake as normal + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1}; + server_->ConfigNamedGroups(groups); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + + // Send first ClientHello and send 0-RTT data + auto capture_early_data = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_early_data_xtn); + client_->Handshake(); + EXPECT_EQ(k0RttDataLen, PR_Write(client_->ssl_fd(), k0RttData, + k0RttDataLen)); // 0-RTT write. + EXPECT_TRUE(capture_early_data->captured()); + + // Send the HelloRetryRequest + auto hrr_capture = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeHelloRetryRequest); + server_->Handshake(); + EXPECT_LT(0U, hrr_capture->buffer().len()); + + // The server can't read + std::vector<uint8_t> buf(k0RttDataLen); + EXPECT_EQ(SECFailure, PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen)); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Make a new capture for the early data. + capture_early_data = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_early_data_xtn); + + // Complete the handshake successfully + Handshake(); + ExpectEarlyDataAccepted(false); // The server should reject 0-RTT + CheckConnected(); + SendReceive(); + EXPECT_FALSE(capture_early_data->captured()); +} + +// This filter only works for DTLS 1.3 where there is exactly one handshake +// packet. If the record is split into two packets, or there are multiple +// handshake packets, this will break. +class CorrectMessageSeqAfterHrrFilter : public TlsRecordFilter { + public: + CorrectMessageSeqAfterHrrFilter(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& record, size_t* offset, + DataBuffer* output) { + if (filtered_packets() > 0 || header.content_type() != ssl_ct_handshake) { + return KEEP; + } + + DataBuffer buffer(record); + TlsRecordHeader new_header(header.variant(), header.version(), + header.content_type(), + header.sequence_number() + 1); + + // Correct message_seq. + buffer.Write(4, 1U, 2); + + *offset = new_header.Write(output, *offset, buffer); + return CHANGE; + } +}; + +TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) { + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1}; + + SetupForZeroRtt(); + ExpectResumption(RESUME_TICKET); + + client_->ConfigNamedGroups(groups); + server_->ConfigNamedGroups(groups); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + + // A new client that tries to resume with 0-RTT but doesn't send the + // correct key share(s). The server will respond with an HRR. + auto orig_client = + std::make_shared<TlsAgent>(client_->name(), TlsAgent::CLIENT, variant_); + client_.swap(orig_client); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + client_->ConfigureSessionCache(RESUME_BOTH); + client_->Set0RttEnabled(true); + client_->StartConnect(); + + // Swap in the new client. + client_->SetPeer(server_); + server_->SetPeer(client_); + + // Send the ClientHello. + client_->Handshake(); + // Process the CH, send an HRR. + server_->Handshake(); + + // Swap the client we created manually with the one that successfully + // received a PSK, and try to resume with 0-RTT. The client doesn't know + // about the HRR so it will send the early_data xtn as well as 0-RTT data. + client_.swap(orig_client); + orig_client.reset(); + + // Correct the DTLS message sequence number after an HRR. + if (variant_ == ssl_variant_datagram) { + MakeTlsFilter<CorrectMessageSeqAfterHrrFilter>(client_); + } + + server_->SetPeer(client_); + client_->Handshake(); + + // Send 0-RTT data. + const char* k0RttData = "ABCDEF"; + const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + PRInt32 rv = PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); + EXPECT_EQ(k0RttDataLen, rv); + + ExpectAlert(server_, kTlsAlertUnsupportedExtension); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_EXTENSION_ALERT); +} + +class KeyShareReplayer : public TlsExtensionFilter { + public: + KeyShareReplayer(const std::shared_ptr<TlsAgent>& a) + : TlsExtensionFilter(a) {} + + virtual PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) { + if (extension_type != ssl_tls13_key_share_xtn) { + return KEEP; + } + + if (!data_.len()) { + data_ = input; + return KEEP; + } + + *output = data_; + return CHANGE; + } + + private: + DataBuffer data_; +}; + +// This forces a HelloRetryRequest by disabling P-256 on the server. However, +// the second ClientHello is modified so that it omits the requested share. The +// server should reject this. +TEST_P(TlsConnectTls13, RetryWithSameKeyShare) { + EnsureTlsSetup(); + MakeTlsFilter<KeyShareReplayer>(client_); + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1}; + server_->ConfigNamedGroups(groups); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code()); + EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code()); +} + +// Here we modify the second ClientHello so that the client retries with the +// same shares, even though the server wanted something else. +TEST_P(TlsConnectTls13, RetryWithTwoShares) { + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); + MakeTlsFilter<KeyShareReplayer>(client_); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1}; + server_->ConfigNamedGroups(groups); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code()); + EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code()); +} + +TEST_P(TlsConnectTls13, RetryCallbackAccept) { + EnsureTlsSetup(); + + auto accept_hello = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast<bool*>(arg); + *called = true; + + EXPECT_TRUE(firstHello); + EXPECT_EQ(0U, clientTokenLen); + return ssl_hello_retry_accept; + }; + + bool cb_run = false; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + accept_hello, &cb_run)); + Connect(); + EXPECT_TRUE(cb_run); +} + +TEST_P(TlsConnectTls13, RetryCallbackAcceptGroupMismatch) { + EnsureTlsSetup(); + + auto accept_hello_twice = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, + unsigned int appTokenMax, void* arg) { + auto* called = reinterpret_cast<size_t*>(arg); + ++*called; + + EXPECT_EQ(0U, clientTokenLen); + return ssl_hello_retry_accept; + }; + + auto capture = + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_cookie_xtn); + capture->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + size_t cb_run = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), accept_hello_twice, &cb_run)); + Connect(); + EXPECT_EQ(2U, cb_run); + EXPECT_TRUE(capture->captured()) << "expected a cookie in HelloRetryRequest"; +} + +TEST_P(TlsConnectTls13, RetryCallbackFail) { + EnsureTlsSetup(); + + auto fail_hello = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast<bool*>(arg); + *called = true; + + EXPECT_TRUE(firstHello); + EXPECT_EQ(0U, clientTokenLen); + return ssl_hello_retry_fail; + }; + + bool cb_run = false; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + fail_hello, &cb_run)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_APPLICATION_ABORT); + EXPECT_TRUE(cb_run); +} + +// Asking for retry twice isn't allowed. +TEST_P(TlsConnectTls13, RetryCallbackRequestHrrTwice) { + EnsureTlsSetup(); + + auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) -> SSLHelloRetryRequestAction { + return ssl_hello_retry_request; + }; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + bad_callback, NULL)); + ConnectExpectAlert(server_, kTlsAlertInternalError); + server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR); +} + +// Accepting the CH and modifying the token isn't allowed. +TEST_P(TlsConnectTls13, RetryCallbackAcceptAndSetToken) { + EnsureTlsSetup(); + + auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) -> SSLHelloRetryRequestAction { + *appTokenLen = 1; + return ssl_hello_retry_accept; + }; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + bad_callback, NULL)); + ConnectExpectAlert(server_, kTlsAlertInternalError); + server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR); +} + +// As above, but with reject. +TEST_P(TlsConnectTls13, RetryCallbackRejectAndSetToken) { + EnsureTlsSetup(); + + auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) -> SSLHelloRetryRequestAction { + *appTokenLen = 1; + return ssl_hello_retry_fail; + }; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + bad_callback, NULL)); + ConnectExpectAlert(server_, kTlsAlertInternalError); + server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR); +} + +// This is a (pretend) buffer overflow. +TEST_P(TlsConnectTls13, RetryCallbackSetTooLargeToken) { + EnsureTlsSetup(); + + auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) -> SSLHelloRetryRequestAction { + *appTokenLen = appTokenMax + 1; + return ssl_hello_retry_accept; + }; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + bad_callback, NULL)); + ConnectExpectAlert(server_, kTlsAlertInternalError); + server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR); +} + +SSLHelloRetryRequestAction RetryHello(PRBool firstHello, + const PRUint8* clientToken, + unsigned int clientTokenLen, + PRUint8* appToken, + unsigned int* appTokenLen, + unsigned int appTokenMax, void* arg) { + auto* called = reinterpret_cast<size_t*>(arg); + ++*called; + + EXPECT_EQ(0U, clientTokenLen); + return firstHello ? ssl_hello_retry_request : ssl_hello_retry_accept; +} + +TEST_P(TlsConnectTls13, RetryCallbackRetry) { + EnsureTlsSetup(); + + auto capture_hrr = std::make_shared<TlsHandshakeRecorder>( + server_, ssl_hs_hello_retry_request); + auto capture_key_share = + std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn); + capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + std::vector<std::shared_ptr<PacketFilter>> chain = {capture_hrr, + capture_key_share}; + server_->SetFilter(std::make_shared<ChainedPacketFilter>(chain)); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + + // Do the first message exchange. + StartConnect(); + client_->Handshake(); + server_->Handshake(); + + EXPECT_EQ(1U, cb_called) << "callback should be called once here"; + EXPECT_LT(0U, capture_hrr->buffer().len()) << "HelloRetryRequest expected"; + EXPECT_FALSE(capture_key_share->captured()) + << "no key_share extension expected"; + + auto capture_cookie = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_cookie_xtn); + + Handshake(); + CheckConnected(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(capture_cookie->captured()) << "should have a cookie"; +} + +static size_t CountShares(const DataBuffer& key_share) { + size_t count = 0; + uint32_t len = 0; + size_t offset = 2; + + EXPECT_TRUE(key_share.Read(0, 2, &len)); + EXPECT_EQ(key_share.len() - 2, len); + while (offset < key_share.len()) { + offset += 2; // Skip KeyShareEntry.group + EXPECT_TRUE(key_share.Read(offset, 2, &len)); + offset += 2 + len; // Skip KeyShareEntry.key_exchange + ++count; + } + return count; +} + +TEST_P(TlsConnectTls13, RetryCallbackRetryWithAdditionalShares) { + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); + + auto capture_server = + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn); + capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + + // Do the first message exchange. + StartConnect(); + client_->Handshake(); + server_->Handshake(); + + EXPECT_EQ(1U, cb_called) << "callback should be called once here"; + EXPECT_FALSE(capture_server->captured()) + << "no key_share extension expected from server"; + + auto capture_client_2nd = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn); + + Handshake(); + CheckConnected(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(capture_client_2nd->captured()) << "client should send key_share"; + EXPECT_EQ(2U, CountShares(capture_client_2nd->extension())) + << "client should still send two shares"; +} + +// The callback should be run even if we have another reason to send +// HelloRetryRequest. In this case, the server sends HRR because the server +// wants a P-384 key share and the client didn't offer one. +TEST_P(TlsConnectTls13, RetryCallbackRetryWithGroupMismatch) { + EnsureTlsSetup(); + + auto capture_cookie = + std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_cookie_xtn); + capture_cookie->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + auto capture_key_share = + std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn); + capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + server_->SetFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit{capture_cookie, capture_key_share})); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + Connect(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(capture_cookie->captured()) << "cookie expected"; + EXPECT_TRUE(capture_key_share->captured()) << "key_share expected"; +} + +static const uint8_t kApplicationToken[] = {0x92, 0x44, 0x00}; + +SSLHelloRetryRequestAction RetryHelloWithToken( + PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen, + PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast<size_t*>(arg); + ++*called; + + if (firstHello) { + memcpy(appToken, kApplicationToken, sizeof(kApplicationToken)); + *appTokenLen = sizeof(kApplicationToken); + return ssl_hello_retry_request; + } + + EXPECT_EQ(DataBuffer(kApplicationToken, sizeof(kApplicationToken)), + DataBuffer(clientToken, static_cast<size_t>(clientTokenLen))); + return ssl_hello_retry_accept; +} + +TEST_P(TlsConnectTls13, RetryCallbackRetryWithToken) { + EnsureTlsSetup(); + + auto capture_key_share = + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn); + capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, + SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHelloWithToken, &cb_called)); + Connect(); + EXPECT_EQ(2U, cb_called); + EXPECT_FALSE(capture_key_share->captured()) << "no key share expected"; +} + +TEST_P(TlsConnectTls13, RetryCallbackRetryWithTokenAndGroupMismatch) { + EnsureTlsSetup(); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + auto capture_key_share = + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn); + capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, + SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHelloWithToken, &cb_called)); + Connect(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(capture_key_share->captured()) << "key share expected"; +} + +SSLHelloRetryRequestAction CheckTicketToken( + PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen, + PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast<bool*>(arg); + *called = true; + + EXPECT_TRUE(firstHello); + EXPECT_EQ(DataBuffer(kApplicationToken, sizeof(kApplicationToken)), + DataBuffer(clientToken, static_cast<size_t>(clientTokenLen))); + return ssl_hello_retry_accept; +} + +// Stream because SSL_SendSessionTicket only supports that. +TEST_F(TlsConnectStreamTls13, RetryCallbackWithSessionTicketToken) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + EXPECT_EQ(SECSuccess, + SSL_SendSessionTicket(server_->ssl_fd(), kApplicationToken, + sizeof(kApplicationToken))); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + + bool cb_run = false; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), CheckTicketToken, &cb_run)); + Connect(); + EXPECT_TRUE(cb_run); +} + +void TriggerHelloRetryRequest(std::shared_ptr<TlsAgent>& client, + std::shared_ptr<TlsAgent>& server) { + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server->ssl_fd(), + RetryHello, &cb_called)); + + // Start the handshake. + client->StartConnect(); + server->StartConnect(); + client->Handshake(); + server->Handshake(); + EXPECT_EQ(1U, cb_called); + // Stop the callback from being called in future handshakes. + EXPECT_EQ(SECSuccess, + SSL_HelloRetryRequestCallback(server->ssl_fd(), nullptr, nullptr)); +} + +TEST_P(TlsConnectTls13, VersionNumbersAfterRetry) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + auto r = MakeTlsFilter<TlsRecordRecorder>(client_); + TriggerHelloRetryRequest(client_, server_); + Handshake(); + ASSERT_GT(r->count(), 1UL); + auto ch1 = r->record(0); + if (ch1.header.is_dtls()) { + ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_1, ch1.header.version()); + } else { + ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_0, ch1.header.version()); + } + auto ch2 = r->record(1); + ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, ch2.header.version()); + + CheckConnected(); +} + +TEST_P(TlsConnectTls13, RetryStateless) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + Handshake(); + CheckConnected(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, RetryStatefulDropCookie) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeTlsFilter<TlsExtensionDropper>(client_, ssl_tls13_cookie_xtn); + + ExpectAlert(server_, kTlsAlertMissingExtension); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_MISSING_COOKIE_EXTENSION); +} + +class TruncateHrrCookie : public TlsExtensionFilter { + public: + TruncateHrrCookie(const std::shared_ptr<TlsAgent>& a) + : TlsExtensionFilter(a) {} + virtual PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) { + if (extension_type != ssl_tls13_cookie_xtn) { + return KEEP; + } + + // Claim a zero-length cookie. + output->Allocate(2); + output->Write(0, static_cast<uint32_t>(0), 2); + return CHANGE; + } +}; + +TEST_P(TlsConnectTls13, RetryCookieEmpty) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeTlsFilter<TruncateHrrCookie>(client_); + + ExpectAlert(server_, kTlsAlertHandshakeFailure); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +class AddJunkToCookie : public TlsExtensionFilter { + public: + AddJunkToCookie(const std::shared_ptr<TlsAgent>& a) : TlsExtensionFilter(a) {} + virtual PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) { + if (extension_type != ssl_tls13_cookie_xtn) { + return KEEP; + } + + *output = input; + // Add junk after the cookie. + static const uint8_t junk[2] = {1, 2}; + output->Append(DataBuffer(junk, sizeof(junk))); + return CHANGE; + } +}; + +TEST_P(TlsConnectTls13, RetryCookieWithExtras) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeTlsFilter<AddJunkToCookie>(client_); + + ExpectAlert(server_, kTlsAlertHandshakeFailure); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +// Stream only because DTLS drops bad packets. +TEST_F(TlsConnectStreamTls13, RetryStatelessDamageFirstClientHello) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + auto damage_ch = + MakeTlsFilter<TlsExtensionInjector>(client_, 0xfff3, DataBuffer()); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + // Key exchange fails when the handshake continues because client and server + // disagree about the transcript. + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); +} + +TEST_F(TlsConnectStreamTls13, RetryStatelessDamageSecondClientHello) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + auto damage_ch = + MakeTlsFilter<TlsExtensionInjector>(client_, 0xfff3, DataBuffer()); + + // Key exchange fails when the handshake continues because client and server + // disagree about the transcript. + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); +} + +// Stream because SSL_SendSessionTicket only supports that. +TEST_F(TlsConnectStreamTls13, SecondClientHelloSendSameTicket) { + // This simulates the scenario described at: + // https://bugzilla.mozilla.org/show_bug.cgi?id=1481271#c7 + // + // Here two connections are interleaved. Tickets are issued on one + // connection. A HelloRetryRequest is triggered on the second connection, + // meaning that there are two ClientHellos. We need to check that both + // ClientHellos have the same ticket, even if a new ticket is issued on the + // other connection in the meantime. + // + // Connection 1: <handshake> + // Connection 1: S->C: NST=X + // Connection 2: C->S: CH [PSK_ID=X] + // Connection 1: S->C: NST=Y + // Connection 2: S->C: HRR + // Connection 2: C->S: CH [PSK_ID=Y] + + // Connection 1, send a ticket after handshake is complete. + ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); + + Connect(); + + // Set this token so that RetryHelloWithToken() will check that this + // is the token that it receives in the HelloRetryRequest callback. + EXPECT_EQ(SECSuccess, + SSL_SendSessionTicket(server_->ssl_fd(), kApplicationToken, + sizeof(kApplicationToken))); + SendReceive(50); + + // Connection 2, trigger HRR. + auto client2 = + std::make_shared<TlsAgent>(client_->name(), TlsAgent::CLIENT, variant_); + auto server2 = + std::make_shared<TlsAgent>(server_->name(), TlsAgent::SERVER, variant_); + + client2->SetPeer(server2); + server2->SetPeer(client2); + + client_.swap(client2); + server_.swap(server2); + + ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); + + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + client_->StartConnect(); + server_->StartConnect(); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, + SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHelloWithToken, &cb_called)); + client_->Handshake(); // Send ClientHello. + server_->Handshake(); // Process ClientHello, send HelloRetryRequest. + + EXPECT_EQ(1U, cb_called) << "callback should be called once here"; + + // Connection 1, send another ticket. + client_.swap(client2); + server_.swap(server2); + + // If the client uses this token, RetryHelloWithToken() will fail the test. + const uint8_t kAnotherApplicationToken[] = {0x92, 0x44, 0x01}; + EXPECT_EQ(SECSuccess, + SSL_SendSessionTicket(server_->ssl_fd(), kAnotherApplicationToken, + sizeof(kAnotherApplicationToken))); + SendReceive(60); + + // Connection 2, continue the handshake. + // The client should use kApplicationToken, not kAnotherApplicationToken. + client_.swap(client2); + server_.swap(server2); + + client_->Handshake(); + server_->Handshake(); + + EXPECT_EQ(2U, cb_called) << "callback should be called twice here"; +} + +// Read the cipher suite from the HRR and disable it on the identified agent. +static void DisableSuiteFromHrr( + std::shared_ptr<TlsAgent>& agent, + std::shared_ptr<TlsHandshakeRecorder>& capture_hrr) { + uint32_t tmp; + size_t offset = 2 + 32; // skip version + server_random + ASSERT_TRUE( + capture_hrr->buffer().Read(offset, 1, &tmp)); // session_id length + EXPECT_EQ(0U, tmp); + offset += 1 + tmp; + ASSERT_TRUE(capture_hrr->buffer().Read(offset, 2, &tmp)); // suite + EXPECT_EQ( + SECSuccess, + SSL_CipherPrefSet(agent->ssl_fd(), static_cast<uint16_t>(tmp), PR_FALSE)); +} + +TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteClient) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + auto capture_hrr = + MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_hello_retry_request); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + DisableSuiteFromHrr(client_, capture_hrr); + + // The client thinks that the HelloRetryRequest is bad, even though its + // because it changed its mind about the cipher suite. + ExpectAlert(client_, kTlsAlertIllegalParameter); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteServer) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + auto capture_hrr = + MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_hello_retry_request); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + DisableSuiteFromHrr(server_, capture_hrr); + + ExpectAlert(server_, kTlsAlertIllegalParameter); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectTls13, RetryStatelessDisableGroupClient) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + client_->ConfigNamedGroups(groups); + + // We're into undefined behavior on the client side, but - at the point this + // test was written - the client here doesn't amend its key shares because the + // server doesn't ask it to. The server notices that the key share (x25519) + // doesn't match the negotiated group (P-384) and objects. + ExpectAlert(server_, kTlsAlertIllegalParameter); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectTls13, RetryStatelessDisableGroupServer) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + ExpectAlert(server_, kTlsAlertIllegalParameter); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectTls13, RetryStatelessBadCookie) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + + // Now replace the self-encrypt MAC key with a garbage key. + static const uint8_t bad_hmac_key[32] = {0}; + SECItem key_item = {siBuffer, const_cast<uint8_t*>(bad_hmac_key), + sizeof(bad_hmac_key)}; + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + PK11SymKey* hmac_key = + PK11_ImportSymKey(slot.get(), CKM_SHA256_HMAC, PK11_OriginUnwrap, + CKA_SIGN, &key_item, nullptr); + ASSERT_NE(nullptr, hmac_key); + SSLInt_SetSelfEncryptMacKey(hmac_key); // Passes ownership. + + MakeNewServer(); + + ExpectAlert(server_, kTlsAlertIllegalParameter); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Stream because the server doesn't consume the alert and terminate. +TEST_F(TlsConnectStreamTls13, RetryWithDifferentCipherSuite) { + EnsureTlsSetup(); + // Force a HelloRetryRequest. + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + // Then switch out the default suite (TLS_AES_128_GCM_SHA256). + MakeTlsFilter<SelectedCipherSuiteReplacer>(server_, + TLS_CHACHA20_POLY1305_SHA256); + + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code()); + EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code()); +} + +// This tests that the second attempt at sending a ClientHello (after receiving +// a HelloRetryRequest) is correctly retransmitted. +TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) { + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1}; + server_->ConfigNamedGroups(groups); + server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x2)); + Connect(); +} + +class TlsKeyExchange13 : public TlsKeyExchangeTest {}; + +// This should work, with an HRR, because the server prefers x25519 and the +// client generates a share for P-384 on the initial ClientHello. +TEST_P(TlsKeyExchange13, ConnectEcdhePreferenceMismatchHrr) { + EnsureKeyShareSetup(); + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + static const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp384r1}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + Connect(); + CheckKeys(); + static const std::vector<SSLNamedGroup> expectedShares = { + ssl_grp_ec_secp384r1}; + CheckKEXDetails(client_groups, expectedShares, ssl_grp_ec_curve25519); +} + +TEST_P(TlsKeyExchange13, SecondClientHelloPreambleMatches) { + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, + SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + + ConfigureSelfEncrypt(); + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + static const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + auto ch1 = MakeTlsFilter<ClientHelloPreambleCapture>(client_); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + + MakeNewServer(); + auto ch2 = MakeTlsFilter<ClientHelloPreambleCapture>(client_); + Handshake(); + + EXPECT_TRUE(ch1->captured()); + EXPECT_TRUE(ch2->captured()); + EXPECT_EQ(ch1->contents(), ch2->contents()); +} + +// This should work, but not use HRR because the key share for x25519 was +// pre-generated by the client. +TEST_P(TlsKeyExchange13, ConnectEcdhePreferenceMismatchHrrExtraShares) { + EnsureKeyShareSetup(); + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + static const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp384r1}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); + + Connect(); + CheckKeys(); + CheckKEXDetails(client_groups, client_groups); +} + +// The callback should be run even if we have another reason to send +// HelloRetryRequest. In this case, the server sends HRR because the server +// wants an X25519 key share and the client didn't offer one. +TEST_P(TlsKeyExchange13, + RetryCallbackRetryWithGroupMismatchAndAdditionalShares) { + EnsureKeyShareSetup(); + + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + static const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_curve25519}; + server_->ConfigNamedGroups(server_groups); + EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); + + auto capture_server = + std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn); + capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + server_->SetFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit{capture_hrr_, capture_server})); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + + // Do the first message exchange. + StartConnect(); + client_->Handshake(); + server_->Handshake(); + + EXPECT_EQ(1U, cb_called) << "callback should be called once here"; + EXPECT_TRUE(capture_server->captured()) << "key_share extension expected"; + + uint32_t server_group = 0; + EXPECT_TRUE(capture_server->extension().Read(0, 2, &server_group)); + EXPECT_EQ(ssl_grp_ec_curve25519, static_cast<SSLNamedGroup>(server_group)); + + Handshake(); + CheckConnected(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(shares_capture2_->captured()) << "client should send shares"; + + CheckKeys(); + static const std::vector<SSLNamedGroup> client_shares( + client_groups.begin(), client_groups.begin() + 2); + CheckKEXDetails(client_groups, client_shares, server_groups[0]); +} + +TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) { + EnsureTlsSetup(); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1}; + client_->ConfigNamedGroups(client_groups); + static const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1}; + server_->ConfigNamedGroups(server_groups); + StartConnect(); + + client_->Handshake(); + server_->Handshake(); + + // Here we replace the TLS server with one that does TLS 1.2 only. + // This will happily send the client a TLS 1.2 ServerHello. + server_.reset(new TlsAgent(server_->name(), TlsAgent::SERVER, variant_)); + client_->SetPeer(server_); + server_->SetPeer(client_); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + server_->StartConnect(); + ExpectAlert(client_, kTlsAlertIllegalParameter); + Handshake(); + EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, server_->error_code()); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code()); +} + +// This class increments the low byte of the first Handshake.message_seq +// field in every handshake record. +class MessageSeqIncrementer : public TlsRecordFilter { + public: + MessageSeqIncrementer(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + if (header.content_type() != ssl_ct_handshake) { + return KEEP; + } + + *changed = data; + // struct { uint8 msg_type; uint24 length; uint16 message_seq; ... } + // Handshake; + changed->data()[5]++; + EXPECT_NE(0, changed->data()[5]); // Check for overflow. + return CHANGE; + } +}; + +// A server that receives a ClientHello with message_seq == 1 +// assumes that this is after a stateless HelloRetryRequest. +// However, it should reject the ClientHello if it lacks a cookie. +TEST_F(TlsConnectDatagram13, MessageSeq1ClientHello) { + EnsureTlsSetup(); + MakeTlsFilter<MessageSeqIncrementer>(client_); + ConnectExpectAlert(server_, kTlsAlertMissingExtension); + EXPECT_EQ(SSL_ERROR_MISSING_COOKIE_EXTENSION, server_->error_code()); + EXPECT_EQ(SSL_ERROR_MISSING_EXTENSION_ALERT, client_->error_code()); +} + +class HelloRetryRequestAgentTest : public TlsAgentTestClient { + protected: + void SetUp() override { + TlsAgentTestClient::SetUp(); + EnsureInit(); + agent_->StartConnect(); + } + + void MakeCannedHrr(const uint8_t* body, size_t len, DataBuffer* hrr_record, + uint32_t seq_num = 0) const { + DataBuffer hrr_data; + const uint8_t ssl_hello_retry_random[] = { + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, + 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, + 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C}; + + hrr_data.Allocate(len + 6); + size_t i = 0; + i = hrr_data.Write(i, + variant_ == ssl_variant_datagram + ? SSL_LIBRARY_VERSION_DTLS_1_2_WIRE + : SSL_LIBRARY_VERSION_TLS_1_2, + 2); + i = hrr_data.Write(i, ssl_hello_retry_random, + sizeof(ssl_hello_retry_random)); + i = hrr_data.Write(i, static_cast<uint32_t>(0), 1); // session_id + i = hrr_data.Write(i, TLS_AES_128_GCM_SHA256, 2); + i = hrr_data.Write(i, ssl_compression_null, 1); + // Add extensions. First a length, which includes the supported version. + i = hrr_data.Write(i, static_cast<uint32_t>(len) + 6, 2); + // Now the supported version. + i = hrr_data.Write(i, ssl_tls13_supported_versions_xtn, 2); + i = hrr_data.Write(i, 2, 2); + i = hrr_data.Write(i, + (variant_ == ssl_variant_datagram) + ? (0x7f00 | DTLS_1_3_DRAFT_VERSION) + : SSL_LIBRARY_VERSION_TLS_1_3, + 2); + if (len) { + hrr_data.Write(i, body, len); + } + DataBuffer hrr; + MakeHandshakeMessage(kTlsHandshakeServerHello, hrr_data.data(), + hrr_data.len(), &hrr, seq_num); + MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3, hrr.data(), + hrr.len(), hrr_record, seq_num); + } + + void MakeGroupHrr(SSLNamedGroup group, DataBuffer* hrr_record, + uint32_t seq_num = 0) const { + const uint8_t group_hrr[] = { + static_cast<uint8_t>(ssl_tls13_key_share_xtn >> 8), + static_cast<uint8_t>(ssl_tls13_key_share_xtn), + 0, + 2, // length of key share extension + static_cast<uint8_t>(group >> 8), + static_cast<uint8_t>(group)}; + MakeCannedHrr(group_hrr, sizeof(group_hrr), hrr_record, seq_num); + } +}; + +// Send two HelloRetryRequest messages in response to the ClientHello. The are +// constructed to appear legitimate by asking for a new share in each, so that +// the client has to count to work out that the server is being unreasonable. +TEST_P(HelloRetryRequestAgentTest, SendSecondHelloRetryRequest) { + DataBuffer hrr; + MakeGroupHrr(ssl_grp_ec_secp384r1, &hrr, 0); + ProcessMessage(hrr, TlsAgent::STATE_CONNECTING); + MakeGroupHrr(ssl_grp_ec_secp521r1, &hrr, 1); + ExpectAlert(kTlsAlertUnexpectedMessage); + ProcessMessage(hrr, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_UNEXPECTED_HELLO_RETRY_REQUEST); +} + +// Here the client receives a HelloRetryRequest with a group that they already +// provided a share for. +TEST_P(HelloRetryRequestAgentTest, HandleBogusHelloRetryRequest) { + DataBuffer hrr; + MakeGroupHrr(ssl_grp_ec_curve25519, &hrr); + ExpectAlert(kTlsAlertIllegalParameter); + ProcessMessage(hrr, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST); +} + +TEST_P(HelloRetryRequestAgentTest, HandleNoopHelloRetryRequest) { + DataBuffer hrr; + MakeCannedHrr(nullptr, 0U, &hrr); + ExpectAlert(kTlsAlertDecodeError); + ProcessMessage(hrr, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST); +} + +class ReplaceRandom : public TlsHandshakeFilter { + public: + ReplaceRandom(const std::shared_ptr<TlsAgent>& a, const DataBuffer& r) + : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}), random_(r) {} + + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + output->Assign(input); + output->Write(2, random_); + return CHANGE; + } + + private: + DataBuffer random_; +}; + +// Make sure that the TLS 1.3 special value for the ServerHello.random +// is rejected by earlier versions. +TEST_P(TlsConnectStreamPre13, HrrRandomOnTls10) { + static const uint8_t hrr_random[] = { + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, + 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, + 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C}; + + EnsureTlsSetup(); + MakeTlsFilter<ReplaceRandom>(server_, + DataBuffer(hrr_random, sizeof(hrr_random))); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_F(TlsConnectStreamTls13, HrrThenTls12) { + StartConnect(); + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + + client_->Handshake(); // Send CH (1.3) + server_->Handshake(); // Send HRR. + EXPECT_EQ(1U, cb_called); + + // Replace the client with a new TLS 1.2 client. Don't call Init(), since + // it will artifically limit the server's vrange. + client_.reset( + new TlsAgent(client_->name(), TlsAgent::CLIENT, ssl_variant_stream)); + client_->SetPeer(server_); + server_->SetPeer(client_); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + + client_->StartConnect(); + client_->Handshake(); // Send CH (1.2) + ExpectAlert(server_, kTlsAlertProtocolVersion); + server_->Handshake(); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); +} + +TEST_F(TlsConnectStreamTls13, ZeroRttHrrThenTls12) { + SetupForZeroRtt(); + + client_->Set0RttEnabled(true); + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + + client_->Handshake(); // Send CH (1.3) + ZeroRttSendReceive(true, false); + server_->Handshake(); // Send HRR. + EXPECT_EQ(1U, cb_called); + + // Replace the client with a new TLS 1.2 client. Don't call Init(), since + // it will artifically limit the server's vrange. + client_.reset( + new TlsAgent(client_->name(), TlsAgent::CLIENT, ssl_variant_stream)); + client_->SetPeer(server_); + server_->SetPeer(client_); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + + client_->StartConnect(); + client_->Handshake(); // Send CH (1.2) + ExpectAlert(server_, kTlsAlertProtocolVersion); + server_->Handshake(); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); + + // Try to write something + server_->Handshake(); + client_->ExpectReadWriteError(); + client_->SendData(1); + uint8_t buf[1]; + EXPECT_EQ(-1, PR_Read(server_->ssl_fd(), buf, sizeof(buf))); + EXPECT_EQ(SSL_ERROR_HANDSHAKE_FAILED, PR_GetError()); +} + +TEST_F(TlsConnectStreamTls13, HrrThenTls12SupportedVersions) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + + client_->Handshake(); // Send CH (1.3) + ZeroRttSendReceive(true, false); + server_->Handshake(); // Send HRR. + EXPECT_EQ(1U, cb_called); + + // Replace the client with a new TLS 1.2 client. Don't call Init(), since + // it will artifically limit the server's vrange. + client_.reset( + new TlsAgent(client_->name(), TlsAgent::CLIENT, ssl_variant_stream)); + client_->SetPeer(server_); + server_->SetPeer(client_); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2); + // Negotiate via supported_versions + static const uint8_t tls12[] = {0x02, 0x03, 0x03}; + auto replacer = MakeTlsFilter<TlsExtensionInjector>( + client_, ssl_tls13_supported_versions_xtn, + DataBuffer(tls12, sizeof(tls12))); + + client_->StartConnect(); + client_->Handshake(); // Send CH (1.2) + ExpectAlert(server_, kTlsAlertProtocolVersion); + server_->Handshake(); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); +} + +INSTANTIATE_TEST_SUITE_P(HelloRetryRequestAgentTests, + HelloRetryRequestAgentTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13)); +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_SUITE_P(HelloRetryRequestKeyExchangeTests, TlsKeyExchange13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13)); +#endif + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc new file mode 100644 index 0000000000..b7f0351d11 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc @@ -0,0 +1,164 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <cstdlib> +#include <fstream> +#include <sstream> + +#include "gtest_utils.h" +#include "tls_connect.h" + +namespace nss_test { + +static const std::string kKeylogFilePath = "keylog.txt"; +static const std::string kKeylogBlankEnv = "SSLKEYLOGFILE="; +static const std::string kKeylogSetEnv = kKeylogBlankEnv + kKeylogFilePath; + +extern "C" { +extern FILE* ssl_keylog_iob; +} + +class KeyLogFileTestBase : public TlsConnectGeneric { + private: + std::string env_to_set_; + + public: + virtual void CheckKeyLog() = 0; + + KeyLogFileTestBase(std::string env) : env_to_set_(env) {} + + void SetUp() override { + TlsConnectGeneric::SetUp(); + // Remove previous results (if any). + (void)remove(kKeylogFilePath.c_str()); + PR_SetEnv(env_to_set_.c_str()); + } + + void ConnectAndCheck() { + // This is a child process, ensure that error messages immediately + // propagate or else it will not be visible. + ::testing::GTEST_FLAG(throw_on_failure) = true; + + if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + } else { + Connect(); + } + CheckKeyLog(); + _exit(0); + } +}; + +class KeyLogFileTest : public KeyLogFileTestBase { + public: + KeyLogFileTest() : KeyLogFileTestBase(kKeylogSetEnv) {} + + void CheckKeyLog() override { + std::ifstream f(kKeylogFilePath); + std::map<std::string, size_t> labels; + std::set<std::string> client_randoms; + for (std::string line; std::getline(f, line);) { + if (line[0] == '#') { + continue; + } + + std::istringstream iss(line); + std::string label, client_random, secret; + iss >> label >> client_random >> secret; + + ASSERT_EQ(64U, client_random.size()); + client_randoms.insert(client_random); + labels[label]++; + } + + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + ASSERT_EQ(1U, client_randoms.size()); + } else { + /* two handshakes for 0-RTT */ + ASSERT_EQ(2U, client_randoms.size()); + } + + // Every entry occurs twice (one log from server, one from client). + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + ASSERT_EQ(2U, labels["CLIENT_RANDOM"]); + } else { + ASSERT_EQ(2U, labels["CLIENT_EARLY_TRAFFIC_SECRET"]); + ASSERT_EQ(2U, labels["EARLY_EXPORTER_SECRET"]); + ASSERT_EQ(4U, labels["CLIENT_HANDSHAKE_TRAFFIC_SECRET"]); + ASSERT_EQ(4U, labels["SERVER_HANDSHAKE_TRAFFIC_SECRET"]); + ASSERT_EQ(4U, labels["CLIENT_TRAFFIC_SECRET_0"]); + ASSERT_EQ(4U, labels["SERVER_TRAFFIC_SECRET_0"]); + ASSERT_EQ(4U, labels["EXPORTER_SECRET"]); + } + } +}; + +// Tests are run in a separate process to ensure that NSS is not initialized yet +// and can process the SSLKEYLOGFILE environment variable. + +TEST_P(KeyLogFileTest, KeyLogFile) { + testing::GTEST_FLAG(death_test_style) = "threadsafe"; + + ASSERT_EXIT(ConnectAndCheck(), ::testing::ExitedWithCode(0), ""); +} + +INSTANTIATE_TEST_SUITE_P( + KeyLogFileDTLS12, KeyLogFileTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11V12)); +INSTANTIATE_TEST_SUITE_P( + KeyLogFileTLS12, KeyLogFileTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_SUITE_P( + KeyLogFileTLS13, KeyLogFileTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV13)); +#endif + +class KeyLogFileUnsetTest : public KeyLogFileTestBase { + public: + KeyLogFileUnsetTest() : KeyLogFileTestBase(kKeylogBlankEnv) {} + + void CheckKeyLog() override { + std::ifstream f(kKeylogFilePath); + EXPECT_FALSE(f.good()); + + EXPECT_EQ(nullptr, ssl_keylog_iob); + } +}; + +TEST_P(KeyLogFileUnsetTest, KeyLogFile) { + testing::GTEST_FLAG(death_test_style) = "threadsafe"; + + ASSERT_EXIT(ConnectAndCheck(), ::testing::ExitedWithCode(0), ""); +} + +INSTANTIATE_TEST_SUITE_P( + KeyLogFileDTLS12, KeyLogFileUnsetTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11V12)); +INSTANTIATE_TEST_SUITE_P( + KeyLogFileTLS12, KeyLogFileUnsetTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_SUITE_P( + KeyLogFileTLS13, KeyLogFileUnsetTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV13)); +#endif + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc new file mode 100644 index 0000000000..b921d2c1e6 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc @@ -0,0 +1,209 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +// All stream only tests; DTLS isn't supported yet. + +TEST_F(TlsConnectTest, KeyUpdateClient) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(4, 3); +} + +TEST_F(TlsConnectStreamTls13, KeyUpdateTooEarly_Client) { + StartConnect(); + auto filter = MakeTlsFilter<TlsEncryptedHandshakeMessageReplacer>( + server_, kTlsHandshakeFinished, kTlsHandshakeKeyUpdate); + filter->EnableDecryption(); + + client_->Handshake(); + server_->Handshake(); + ExpectAlert(client_, kTlsAlertUnexpectedMessage); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_KEY_UPDATE); + server_->Handshake(); + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +TEST_F(TlsConnectStreamTls13, KeyUpdateTooEarly_Server) { + StartConnect(); + auto filter = MakeTlsFilter<TlsEncryptedHandshakeMessageReplacer>( + client_, kTlsHandshakeFinished, kTlsHandshakeKeyUpdate); + filter->EnableDecryption(); + + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + server_->Handshake(); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_KEY_UPDATE); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +TEST_F(TlsConnectTest, KeyUpdateClientRequestUpdate) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE)); + // SendReceive() only gives each peer one chance to read. This isn't enough + // when the read on one side generates another handshake message. A second + // read gives each peer an extra chance to consume the KeyUpdate. + SendReceive(50); + SendReceive(60); // Cumulative count. + CheckEpochs(4, 4); +} + +TEST_F(TlsConnectTest, KeyUpdateServer) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(3, 4); +} + +TEST_F(TlsConnectTest, KeyUpdateServerRequestUpdate) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(4, 4); +} + +TEST_F(TlsConnectTest, KeyUpdateConsecutiveRequests) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + SendReceive(50); + SendReceive(60); + // The server should have updated twice, but the client should have declined + // to respond to the second request from the server, since it doesn't send + // anything in between those two requests. + CheckEpochs(4, 5); +} + +// Check that a local update can be immediately followed by a remotely triggered +// update even if there is no use of the keys. +TEST_F(TlsConnectTest, KeyUpdateLocalUpdateThenConsecutiveRequests) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + // This should trigger an update on the client. + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE)); + // The client should update for the first request. + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + // ...but not the second. + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + SendReceive(50); + SendReceive(60); + // Both should have updated twice. + CheckEpochs(5, 5); +} + +TEST_F(TlsConnectTest, KeyUpdateMultiple) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(5, 6); +} + +// Both ask the other for an update, and both should react. +TEST_F(TlsConnectTest, KeyUpdateBothRequest) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(5, 5); +} + +// If the sequence number exceeds the number of writes before an automatic +// update (currently 3/4 of the max records for the cipher suite), then the +// stack should send an update automatically (but not request one). +TEST_F(TlsConnectTest, KeyUpdateAutomaticOnWrite) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ConnectWithCipherSuite(TLS_AES_128_GCM_SHA256); + + // Set this to one below the write threshold. + uint64_t threshold = (0x5aULL << 28) * 3 / 4; + EXPECT_EQ(SECSuccess, + SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), threshold)); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), threshold)); + + // This should be OK. + client_->SendData(10); + server_->ReadBytes(); + + // This should cause the client to update. + client_->SendData(10); + server_->ReadBytes(); + + SendReceive(100); + CheckEpochs(4, 3); +} + +// If the sequence number exceeds a certain number of reads (currently 7/8 of +// the max records for the cipher suite), then the stack should send AND request +// an update automatically. However, the sender (client) will be above its +// automatic update threshold, so the KeyUpdate - that it sends with the old +// cipher spec - will exceed the receiver (server) automatic update threshold. +// The receiver gets a packet with a sequence number over its automatic read +// update threshold. Even though the sender has updated, the code that checks +// the sequence numbers at the receiver doesn't know this and it will request an +// update. This causes two updates: one from the sender (without requesting a +// response) and one from the receiver (which does request a response). +TEST_F(TlsConnectTest, KeyUpdateAutomaticOnRead) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ConnectWithCipherSuite(TLS_AES_128_GCM_SHA256); + + // Move to right at the read threshold. Unlike the write test, we can't send + // packets because that would cause the client to update, which would spoil + // the test. + uint64_t threshold = ((0x5aULL << 28) * 7 / 8) + 1; + EXPECT_EQ(SECSuccess, + SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), threshold)); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), threshold)); + + // This should cause the client to update, but not early enough to prevent the + // server from updating also. + client_->SendData(10); + server_->ReadBytes(); + + // Need two SendReceive() calls to ensure that the update that the server + // requested is properly generated and consumed. + SendReceive(70); + SendReceive(80); + CheckEpochs(5, 4); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc new file mode 100644 index 0000000000..491f50921f --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc @@ -0,0 +1,801 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include <vector> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectGeneric, SetupOnly) {} + +TEST_P(TlsConnectGeneric, Connect) { + SetExpectedVersion(std::get<1>(GetParam())); + Connect(); + CheckKeys(); +} + +TEST_P(TlsConnectGeneric, ConnectEcdsa) { + SetExpectedVersion(std::get<1>(GetParam())); + Reset(TlsAgent::kServerEcdsa256); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa); +} + +TEST_P(TlsConnectGeneric, CipherSuiteMismatch) { + EnsureTlsSetup(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + server_->EnableSingleCipher(TLS_AES_256_GCM_SHA384); + } else { + client_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA); + } + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +class TlsAlertRecorder : public TlsRecordFilter { + public: + TlsAlertRecorder(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a), level_(255), description_(255) {} + + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + if (level_ != 255) { // Already captured. + return KEEP; + } + if (header.content_type() != ssl_ct_alert) { + return KEEP; + } + + std::cerr << "Alert: " << input << std::endl; + + TlsParser parser(input); + EXPECT_TRUE(parser.Read(&level_)); + EXPECT_TRUE(parser.Read(&description_)); + return KEEP; + } + + uint8_t level() const { return level_; } + uint8_t description() const { return description_; } + + private: + uint8_t level_; + uint8_t description_; +}; + +class HelloTruncator : public TlsHandshakeFilter { + public: + HelloTruncator(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter( + a, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {} + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + output->Assign(input.data(), input.len() - 1); + return CHANGE; + } +}; + +// Verify that when NSS reports that an alert is sent, it is actually sent. +TEST_P(TlsConnectGeneric, CaptureAlertServer) { + MakeTlsFilter<HelloTruncator>(client_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(server_); + + ConnectExpectAlert(server_, kTlsAlertDecodeError); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description()); +} + +TEST_P(TlsConnectGenericPre13, CaptureAlertClient) { + MakeTlsFilter<HelloTruncator>(server_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_); + + ConnectExpectAlert(client_, kTlsAlertDecodeError); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description()); +} + +// In TLS 1.3, the server can't read the client alert. +TEST_P(TlsConnectTls13, CaptureAlertClient) { + MakeTlsFilter<HelloTruncator>(server_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_); + + StartConnect(); + + client_->Handshake(); + client_->ExpectSendAlert(kTlsAlertDecodeError); + server_->Handshake(); + client_->Handshake(); + if (variant_ == ssl_variant_stream) { + // DTLS just drops the alert it can't decrypt. + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + } + server_->Handshake(); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description()); +} + +TEST_P(TlsConnectGenericPre13, ConnectFalseStart) { + client_->EnableFalseStart(); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGeneric, ConnectAlpn) { + EnableAlpn(); + Connect(); + CheckAlpn("a"); +} + +TEST_P(TlsConnectGeneric, ConnectAlpnPriorityA) { + // "alpn" "npn" + // alpn is the fallback here. npn has the highest priority and should be + // picked. + const std::vector<uint8_t> alpn = {0x04, 0x61, 0x6c, 0x70, 0x6e, + 0x03, 0x6e, 0x70, 0x6e}; + EnableAlpn(alpn); + Connect(); + CheckAlpn("npn"); +} + +TEST_P(TlsConnectGeneric, ConnectAlpnPriorityB) { + // "alpn" "npn" "http" + // npn has the highest priority and should be picked. + const std::vector<uint8_t> alpn = {0x04, 0x61, 0x6c, 0x70, 0x6e, 0x03, 0x6e, + 0x70, 0x6e, 0x04, 0x68, 0x74, 0x74, 0x70}; + EnableAlpn(alpn); + Connect(); + CheckAlpn("npn"); +} + +TEST_P(TlsConnectGeneric, ConnectAlpnClone) { + EnsureModelSockets(); + client_model_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); + server_model_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); + Connect(); + CheckAlpn("a"); +} + +TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackA) { + // "ab" "alpn" + const std::vector<uint8_t> client_alpn = {0x02, 0x61, 0x62, 0x04, + 0x61, 0x6c, 0x70, 0x6e}; + EnableAlpnWithCallback(client_alpn, "alpn"); + Connect(); + CheckAlpn("alpn"); +} + +TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackB) { + // "ab" "alpn" + const std::vector<uint8_t> client_alpn = {0x02, 0x61, 0x62, 0x04, + 0x61, 0x6c, 0x70, 0x6e}; + EnableAlpnWithCallback(client_alpn, "ab"); + Connect(); + CheckAlpn("ab"); +} + +TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackC) { + // "cd" "npn" "alpn" + const std::vector<uint8_t> client_alpn = {0x02, 0x63, 0x64, 0x03, 0x6e, 0x70, + 0x6e, 0x04, 0x61, 0x6c, 0x70, 0x6e}; + EnableAlpnWithCallback(client_alpn, "npn"); + Connect(); + CheckAlpn("npn"); +} + +TEST_P(TlsConnectDatagram, ConnectSrtp) { + EnableSrtp(); + Connect(); + CheckSrtp(); + SendReceive(); +} + +TEST_P(TlsConnectGeneric, ConnectSendReceive) { + Connect(); + SendReceive(); +} + +class SaveTlsRecord : public TlsRecordFilter { + public: + SaveTlsRecord(const std::shared_ptr<TlsAgent>& a, size_t index) + : TlsRecordFilter(a), index_(index), count_(0), contents_() {} + + const DataBuffer& contents() const { return contents_; } + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + if (count_++ == index_) { + contents_ = data; + } + return KEEP; + } + + private: + const size_t index_; + size_t count_; + DataBuffer contents_; +}; + +// Check that decrypting filters work and can read any record. +// This test (currently) only works in TLS 1.3 where we can decrypt. +TEST_F(TlsConnectStreamTls13, DecryptRecordClient) { + EnsureTlsSetup(); + // 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer + auto saved = MakeTlsFilter<SaveTlsRecord>(client_, 3); + saved->EnableDecryption(); + Connect(); + SendReceive(); + + static const uint8_t data[] = {0xde, 0xad, 0xdc}; + DataBuffer buf(data, sizeof(data)); + client_->SendBuffer(buf); + EXPECT_EQ(buf, saved->contents()); +} + +TEST_F(TlsConnectStreamTls13, DecryptRecordServer) { + EnsureTlsSetup(); + // Disable tickets so that we are sure to not get NewSessionTicket. + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); + // 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer + auto saved = MakeTlsFilter<SaveTlsRecord>(server_, 3); + saved->EnableDecryption(); + Connect(); + SendReceive(); + + static const uint8_t data[] = {0xde, 0xad, 0xd5}; + DataBuffer buf(data, sizeof(data)); + server_->SendBuffer(buf); + EXPECT_EQ(buf, saved->contents()); +} + +class DropTlsRecord : public TlsRecordFilter { + public: + DropTlsRecord(const std::shared_ptr<TlsAgent>& a, size_t index) + : TlsRecordFilter(a), index_(index), count_(0) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + if (count_++ == index_) { + return DROP; + } + return KEEP; + } + + private: + const size_t index_; + size_t count_; +}; + +// Test that decrypting filters work correctly and are able to drop records. +TEST_F(TlsConnectStreamTls13, DropRecordServer) { + EnsureTlsSetup(); + // Disable session tickets so that the server doesn't send an extra record. + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); + + // 0 = ServerHello, 1 = other handshake, 2 = first write + auto filter = MakeTlsFilter<DropTlsRecord>(server_, 2); + filter->EnableDecryption(); + Connect(); + server_->SendData(23, 23); // This should be dropped, so it won't be counted. + server_->ResetSentBytes(); + SendReceive(); +} + +TEST_F(TlsConnectStreamTls13, DropRecordClient) { + EnsureTlsSetup(); + // 0 = ClientHello, 1 = Finished, 2 = first write + auto filter = MakeTlsFilter<DropTlsRecord>(client_, 2); + filter->EnableDecryption(); + Connect(); + client_->SendData(26, 26); // This should be dropped, so it won't be counted. + client_->ResetSentBytes(); + SendReceive(); +} + +// Check that a server can use 0.5 RTT if client authentication isn't enabled. +TEST_P(TlsConnectTls13, WriteBeforeClientFinished) { + EnsureTlsSetup(); + StartConnect(); + client_->Handshake(); // ClientHello + server_->Handshake(); // ServerHello + + server_->SendData(10); + client_->ReadBytes(10); // Client should emit the Finished as a side-effect. + server_->Handshake(); // Server consumes the Finished. + CheckConnected(); +} + +// We don't allow 0.5 RTT if client authentication is requested. +TEST_P(TlsConnectTls13, WriteBeforeClientFinishedClientAuth) { + client_->SetupClientAuth(); + server_->RequestClientAuth(false); + StartConnect(); + client_->Handshake(); // ClientHello + server_->Handshake(); // ServerHello + + static const uint8_t data[] = {1, 2, 3}; + EXPECT_GT(0, PR_Write(server_->ssl_fd(), data, sizeof(data))); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + Handshake(); + CheckConnected(); + SendReceive(); +} + +// 0.5 RTT should fail with client authentication required. +TEST_P(TlsConnectTls13, WriteBeforeClientFinishedClientAuthRequired) { + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + StartConnect(); + client_->Handshake(); // ClientHello + server_->Handshake(); // ServerHello + + static const uint8_t data[] = {1, 2, 3}; + EXPECT_GT(0, PR_Write(server_->ssl_fd(), data, sizeof(data))); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + Handshake(); + CheckConnected(); + SendReceive(); +} + +// The next two tests takes advantage of the fact that we +// automatically read the first 1024 bytes, so if +// we provide 1200 bytes, they overrun the read buffer +// provided by the calling test. + +// DTLS should return an error. +TEST_P(TlsConnectDatagram, ShortRead) { + Connect(); + client_->ExpectReadWriteError(); + server_->SendData(50, 50); + client_->ReadBytes(20); + EXPECT_EQ(0U, client_->received_bytes()); + EXPECT_EQ(SSL_ERROR_RX_SHORT_DTLS_READ, PORT_GetError()); + + // Now send and receive another packet. + server_->ResetSentBytes(); // Reset the counter. + SendReceive(); +} + +// TLS should get the write in two chunks. +TEST_P(TlsConnectStream, ShortRead) { + // This test behaves oddly with TLS 1.0 because of 1/n+1 splitting, + // so skip in that case. + if (version_ < SSL_LIBRARY_VERSION_TLS_1_1) GTEST_SKIP(); + + Connect(); + server_->SendData(50, 50); + // Read the first tranche. + client_->ReadBytes(20); + ASSERT_EQ(20U, client_->received_bytes()); + // The second tranche should now immediately be available. + client_->ReadBytes(); + ASSERT_EQ(50U, client_->received_bytes()); +} + +// We enable compression via the API but it's disabled internally, +// so we should never get it. +TEST_P(TlsConnectGeneric, ConnectWithCompressionEnabled) { + EnsureTlsSetup(); + client_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE); + server_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE); + Connect(); + EXPECT_FALSE(client_->is_compressed()); + SendReceive(); +} + +class TlsHolddownTest : public TlsConnectDatagram { + protected: + // This causes all timers to run to completion. It advances the clock and + // handshakes on both peers until both peers have no more timers pending, + // which should happen at the end of a handshake. This is necessary to ensure + // that the relatively long holddown timer expires, but that any other timers + // also expire and run correctly. + void RunAllTimersDown() { + while (true) { + PRIntervalTime time; + SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); + if (rv != SECSuccess) { + rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); + if (rv != SECSuccess) { + break; // Neither peer has an outstanding timer. + } + } + + if (g_ssl_gtest_verbose) { + std::cerr << "Shifting timers" << std::endl; + } + ShiftDtlsTimers(); + Handshake(); + } + } +}; + +TEST_P(TlsHolddownTest, TestDtlsHolddownExpiry) { + Connect(); + std::cerr << "Expiring holddown timer" << std::endl; + RunAllTimersDown(); + SendReceive(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + // One for send, one for receive. + EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd())); + } +} + +TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + Connect(); + RunAllTimersDown(); + SendReceive(); + // One for send, one for receive. + EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd())); +} + +class TlsPreCCSHeaderInjector : public TlsRecordFilter { + public: + TlsPreCCSHeaderInjector(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a) {} + virtual PacketFilter::Action FilterRecord( + const TlsRecordHeader& record_header, const DataBuffer& input, + size_t* offset, DataBuffer* output) override { + if (record_header.content_type() != ssl_ct_change_cipher_spec) { + return KEEP; + } + + std::cerr << "Injecting Finished header before CCS\n"; + const uint8_t hhdr[] = {kTlsHandshakeFinished, 0x00, 0x00, 0x0c}; + DataBuffer hhdr_buf(hhdr, sizeof(hhdr)); + TlsRecordHeader nhdr(record_header.variant(), record_header.version(), + ssl_ct_handshake, 0); + *offset = nhdr.Write(output, *offset, hhdr_buf); + *offset = record_header.Write(output, *offset, input); + return CHANGE; + } +}; + +TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) { + MakeTlsFilter<TlsPreCCSHeaderInjector>(client_); + ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); +} + +TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) { + MakeTlsFilter<TlsPreCCSHeaderInjector>(server_); + StartConnect(); + ExpectAlert(client_, kTlsAlertUnexpectedMessage); + Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + server_->Handshake(); // Make sure alert is consumed. +} + +TEST_P(TlsConnectTls13, UnknownAlert) { + Connect(); + server_->ExpectSendAlert(0xff, kTlsAlertWarning); + client_->ExpectReceiveAlert(0xff, kTlsAlertWarning); + SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning, + 0xff); // Unknown value. + client_->ExpectReadWriteError(); + client_->WaitForErrorCode(SSL_ERROR_RX_UNKNOWN_ALERT, 2000); +} + +TEST_P(TlsConnectTls13, AlertWrongLevel) { + Connect(); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage, kTlsAlertWarning); + client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage, kTlsAlertWarning); + SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning, + kTlsAlertUnexpectedMessage); + client_->ExpectReadWriteError(); + client_->WaitForErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT, 2000); +} + +TEST_P(TlsConnectTls13, UnknownRecord) { + static const uint8_t kUknownRecord[] = { + 0xff, SSL_LIBRARY_VERSION_TLS_1_2 >> 8, + SSL_LIBRARY_VERSION_TLS_1_2 & 0xff, 0, 0}; + + Connect(); + if (variant_ == ssl_variant_stream) { + // DTLS just drops the record with an invalid type. + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + } + client_->SendDirect(DataBuffer(kUknownRecord, sizeof(kUknownRecord))); + server_->ExpectReadWriteError(); + server_->ReadBytes(); + if (variant_ == ssl_variant_stream) { + EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code()); + } else { + EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, server_->error_code()); + } +} + +TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) { + EnsureTlsSetup(); + StartConnect(); + client_->Handshake(); + server_->Handshake(); // Send first flight. + client_->adapter()->SetWriteError(PR_IO_ERROR); + client_->Handshake(); // This will get an error, but shouldn't crash. + client_->CheckErrorCode(SSL_ERROR_SOCKET_WRITE_FAILURE); +} + +TEST_P(TlsConnectDatagram, BlockedWrite) { + Connect(); + + // Mark the socket as blocked. + client_->adapter()->SetWriteError(PR_WOULD_BLOCK_ERROR); + static const uint8_t data[] = {1, 2, 3}; + int32_t rv = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_GT(0, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Remove the write error and though the previous write failed, future reads + // and writes should just work as if it never happened. + client_->adapter()->SetWriteError(0); + SendReceive(); +} + +TEST_F(TlsConnectTest, ConnectSSLv3) { + ConfigureVersion(SSL_LIBRARY_VERSION_3_0); + EnableOnlyStaticRsaCiphers(); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); +} + +TEST_F(TlsConnectTest, ConnectSSLv3ClientAuth) { + ConfigureVersion(SSL_LIBRARY_VERSION_3_0); + EnableOnlyStaticRsaCiphers(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); +} + +static size_t ExpectedCbcLen(size_t in, size_t hmac = 20, size_t block = 16) { + // MAC-then-Encrypt expansion formula: + return ((in + hmac + (block - 1)) / block) * block; +} + +TEST_F(TlsConnectTest, OneNRecordSplitting) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0); + EnsureTlsSetup(); + ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA); + auto records = MakeTlsFilter<TlsRecordRecorder>(server_); + // This should be split into 1, 16384 and 20. + DataBuffer big_buffer; + big_buffer.Allocate(1 + 16384 + 20); + server_->SendBuffer(big_buffer); + ASSERT_EQ(3U, records->count()); + EXPECT_EQ(ExpectedCbcLen(1), records->record(0).buffer.len()); + EXPECT_EQ(ExpectedCbcLen(16384), records->record(1).buffer.len()); + EXPECT_EQ(ExpectedCbcLen(20), records->record(2).buffer.len()); +} + +// We can't test for randomness easily here, but we can test that we don't +// produce a zero value, or produce the same value twice. There are 5 values +// here: two ClientHello.random, two ServerHello.random, and one zero value. +// Matrix them and fail if any are the same. +TEST_P(TlsConnectGeneric, CheckRandoms) { + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + + static const size_t random_len = 32; + uint8_t crandom1[random_len], srandom1[random_len]; + uint8_t z[random_len] = {0}; + + auto ch = MakeTlsFilter<TlsHandshakeRecorder>(client_, ssl_hs_client_hello); + auto sh = MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_server_hello); + Connect(); + ASSERT_TRUE(ch->buffer().len() > (random_len + 2)); + ASSERT_TRUE(sh->buffer().len() > (random_len + 2)); + memcpy(crandom1, ch->buffer().data() + 2, random_len); + memcpy(srandom1, sh->buffer().data() + 2, random_len); + EXPECT_NE(0, memcmp(crandom1, srandom1, random_len)); + EXPECT_NE(0, memcmp(crandom1, z, random_len)); + EXPECT_NE(0, memcmp(srandom1, z, random_len)); + + Reset(); + ch = MakeTlsFilter<TlsHandshakeRecorder>(client_, ssl_hs_client_hello); + sh = MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_server_hello); + Connect(); + ASSERT_TRUE(ch->buffer().len() > (random_len + 2)); + ASSERT_TRUE(sh->buffer().len() > (random_len + 2)); + const uint8_t* crandom2 = ch->buffer().data() + 2; + const uint8_t* srandom2 = sh->buffer().data() + 2; + + EXPECT_NE(0, memcmp(crandom2, srandom2, random_len)); + EXPECT_NE(0, memcmp(crandom2, z, random_len)); + EXPECT_NE(0, memcmp(srandom2, z, random_len)); + + EXPECT_NE(0, memcmp(crandom1, crandom2, random_len)); + EXPECT_NE(0, memcmp(crandom1, srandom2, random_len)); + EXPECT_NE(0, memcmp(srandom1, crandom2, random_len)); + EXPECT_NE(0, memcmp(srandom1, srandom2, random_len)); +} + +void FailOnCloseNotify(const PRFileDesc* fd, void* arg, const SSLAlert* alert) { + ADD_FAILURE() << "received alert " << alert->description; +} + +void CheckCloseNotify(const PRFileDesc* fd, void* arg, const SSLAlert* alert) { + *reinterpret_cast<bool*>(arg) = true; + EXPECT_EQ(close_notify, alert->description); + EXPECT_EQ(alert_warning, alert->level); +} + +TEST_P(TlsConnectGeneric, ShutdownOneSide) { + Connect(); + + // Setup to check alerts. + EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(server_->ssl_fd(), + FailOnCloseNotify, nullptr)); + EXPECT_EQ(SECSuccess, SSL_AlertReceivedCallback(client_->ssl_fd(), + FailOnCloseNotify, nullptr)); + + bool client_sent = false; + EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(client_->ssl_fd(), + CheckCloseNotify, &client_sent)); + bool server_received = false; + EXPECT_EQ(SECSuccess, + SSL_AlertReceivedCallback(server_->ssl_fd(), CheckCloseNotify, + &server_received)); + EXPECT_EQ(PR_SUCCESS, PR_Shutdown(client_->ssl_fd(), PR_SHUTDOWN_SEND)); + + // Make sure that the server reads out the close_notify. + uint8_t buf[10]; + EXPECT_EQ(0, PR_Read(server_->ssl_fd(), buf, sizeof(buf))); + + // Reading and writing should still work in the one open direction. + EXPECT_TRUE(client_sent); + EXPECT_TRUE(server_received); + server_->SendData(10, 10); + client_->ReadBytes(10); + + // Now close the other side and do the same checks. + bool server_sent = false; + EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(server_->ssl_fd(), + CheckCloseNotify, &server_sent)); + bool client_received = false; + EXPECT_EQ(SECSuccess, + SSL_AlertReceivedCallback(client_->ssl_fd(), CheckCloseNotify, + &client_received)); + EXPECT_EQ(PR_SUCCESS, PR_Shutdown(server_->ssl_fd(), PR_SHUTDOWN_SEND)); + + EXPECT_EQ(0, PR_Read(client_->ssl_fd(), buf, sizeof(buf))); + EXPECT_TRUE(server_sent); + EXPECT_TRUE(client_received); +} + +TEST_P(TlsConnectGeneric, ShutdownOneSideThenCloseTcp) { + Connect(); + + bool client_sent = false; + EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(client_->ssl_fd(), + CheckCloseNotify, &client_sent)); + bool server_received = false; + EXPECT_EQ(SECSuccess, + SSL_AlertReceivedCallback(server_->ssl_fd(), CheckCloseNotify, + &server_received)); + EXPECT_EQ(PR_SUCCESS, PR_Shutdown(client_->ssl_fd(), PR_SHUTDOWN_SEND)); + + // Make sure that the server reads out the close_notify. + uint8_t buf[10]; + EXPECT_EQ(0, PR_Read(server_->ssl_fd(), buf, sizeof(buf))); + + // Now simulate the underlying connection closing. + client_->adapter()->Reset(); + + // Now close the other side and see that things don't explode. + EXPECT_EQ(PR_SUCCESS, PR_Shutdown(server_->ssl_fd(), PR_SHUTDOWN_SEND)); + + EXPECT_GT(0, PR_Read(client_->ssl_fd(), buf, sizeof(buf))); + EXPECT_EQ(PR_NOT_CONNECTED_ERROR, PR_GetError()); +} + +INSTANTIATE_TEST_SUITE_P( + GenericStream, TlsConnectGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_SUITE_P( + GenericDatagram, TlsConnectGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus)); + +INSTANTIATE_TEST_SUITE_P(StreamOnly, TlsConnectStream, + TlsConnectTestBase::kTlsVAll); +INSTANTIATE_TEST_SUITE_P(DatagramOnly, TlsConnectDatagram, + TlsConnectTestBase::kTlsV11Plus); +INSTANTIATE_TEST_SUITE_P(DatagramHolddown, TlsHolddownTest, + TlsConnectTestBase::kTlsV11Plus); + +INSTANTIATE_TEST_SUITE_P( + Pre12Stream, TlsConnectPre12, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10V11)); +INSTANTIATE_TEST_SUITE_P( + Pre12Datagram, TlsConnectPre12, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11)); + +INSTANTIATE_TEST_SUITE_P(Version12Only, TlsConnectTls12, + TlsConnectTestBase::kTlsVariantsAll); +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_SUITE_P(Version13Only, TlsConnectTls13, + TlsConnectTestBase::kTlsVariantsAll); +#endif + +INSTANTIATE_TEST_SUITE_P( + Pre13Stream, TlsConnectGenericPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); +INSTANTIATE_TEST_SUITE_P( + Pre13Datagram, TlsConnectGenericPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11V12)); +INSTANTIATE_TEST_SUITE_P(Pre13StreamOnly, TlsConnectStreamPre13, + TlsConnectTestBase::kTlsV10ToV12); + +INSTANTIATE_TEST_SUITE_P(Version12Plus, TlsConnectTls12Plus, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12Plus)); + +INSTANTIATE_TEST_SUITE_P( + GenericStream, TlsConnectGenericResumption, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll, + ::testing::Values(true, false))); +INSTANTIATE_TEST_SUITE_P( + GenericDatagram, TlsConnectGenericResumption, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus, + ::testing::Values(true, false))); + +INSTANTIATE_TEST_SUITE_P( + GenericStream, TlsConnectGenericResumptionToken, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_SUITE_P( + GenericDatagram, TlsConnectGenericResumptionToken, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus)); + +INSTANTIATE_TEST_SUITE_P(GenericDatagram, TlsConnectTls13ResumptionToken, + TlsConnectTestBase::kTlsVariantsAll); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_masking_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_masking_unittest.cc new file mode 100644 index 0000000000..8209a6e4e0 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_masking_unittest.cc @@ -0,0 +1,350 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <memory> + +#include "keyhi.h" +#include "pk11pub.h" +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslexp.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "scoped_ptrs_ssl.h" +#include "tls_connect.h" + +namespace nss_test { + +// From tls_hkdf_unittest.cc: +extern size_t GetHashLength(SSLHashType ht); + +const std::string kLabel = "sn"; + +class MaskingTest : public ::testing::Test { + public: + MaskingTest() : slot_(PK11_GetInternalSlot()) {} + + void InitSecret(SSLHashType hash_type) { + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + PK11SymKey *s = PK11_KeyGen(slot_.get(), CKM_GENERIC_SECRET_KEY_GEN, + nullptr, AES_128_KEY_LENGTH, nullptr); + ASSERT_NE(nullptr, s); + secret_.reset(s); + } + + void SetUp() override { + InitSecret(ssl_hash_sha256); + PORT_SetError(0); + } + + protected: + ScopedPK11SymKey secret_; + ScopedPK11SlotInfo slot_; + // Should have 4B ctr, 12B nonce for ChaCha, or >=16B ciphertext for AES. + // Use the same default size for mask output. + static const int kSampleSize = 16; + static const int kMaskSize = 16; + void CreateMask(PRUint16 ciphersuite, SSLProtocolVariant variant, + std::string label, const std::vector<uint8_t> &sample, + std::vector<uint8_t> *out_mask) { + ASSERT_NE(nullptr, out_mask); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite, variant, + secret_.get(), label.c_str(), label.size(), &ctx_init)); + ASSERT_NE(nullptr, ctx_init); + ScopedSSLMaskingContext ctx(ctx_init); + + EXPECT_EQ(SECSuccess, + SSL_CreateMask(ctx.get(), sample.data(), sample.size(), + out_mask->data(), out_mask->size())); + bool all_zeros = std::all_of(out_mask->begin(), out_mask->end(), + [](uint8_t v) { return v == 0; }); + + // If out_mask is short, |all_zeros| will be (expectedly) true often enough + // to fail tests. + // In this case, just retry to make sure we're not outputting zeros + // continuously. + if (all_zeros && out_mask->size() < 3) { + unsigned int tries = 2; + std::vector<uint8_t> tmp_sample = sample; + std::vector<uint8_t> tmp_mask(out_mask->size()); + while (tries--) { + tmp_sample.data()[0]++; // Tweak something to get a new mask. + EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), tmp_sample.data(), + tmp_sample.size(), tmp_mask.data(), + tmp_mask.size())); + bool retry_zero = std::all_of(tmp_mask.begin(), tmp_mask.end(), + [](uint8_t v) { return v == 0; }); + if (!retry_zero) { + all_zeros = false; + break; + } + } + } + EXPECT_FALSE(all_zeros); + } +}; + +class SuiteTest : public MaskingTest, + public ::testing::WithParamInterface<uint16_t> { + public: + SuiteTest() : ciphersuite_(GetParam()) {} + void CreateMask(std::string label, const std::vector<uint8_t> &sample, + std::vector<uint8_t> *out_mask) { + MaskingTest::CreateMask(ciphersuite_, ssl_variant_datagram, label, sample, + out_mask); + } + + protected: + const uint16_t ciphersuite_; +}; + +class VariantTest : public MaskingTest, + public ::testing::WithParamInterface<SSLProtocolVariant> { + public: + VariantTest() : variant_(GetParam()) {} + void CreateMask(uint16_t ciphersuite, std::string label, + const std::vector<uint8_t> &sample, + std::vector<uint8_t> *out_mask) { + MaskingTest::CreateMask(ciphersuite, variant_, label, sample, out_mask); + } + + protected: + const SSLProtocolVariant variant_; +}; + +class VariantSuiteTest : public MaskingTest, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + VariantSuiteTest() + : variant_(std::get<0>(GetParam())), + ciphersuite_(std::get<1>(GetParam())) {} + void CreateMask(std::string label, const std::vector<uint8_t> &sample, + std::vector<uint8_t> *out_mask) { + MaskingTest::CreateMask(ciphersuite_, variant_, label, sample, out_mask); + } + + protected: + const SSLProtocolVariant variant_; + const uint16_t ciphersuite_; +}; + +TEST_P(VariantSuiteTest, MaskContextNoLabel) { + std::vector<uint8_t> sample(kSampleSize); + std::vector<uint8_t> mask(kMaskSize); + CreateMask(std::string(""), sample, &mask); +} + +TEST_P(VariantSuiteTest, MaskNoSample) { + std::vector<uint8_t> mask(kMaskSize); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_, + secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init)); + ASSERT_NE(nullptr, ctx_init); + ScopedSSLMaskingContext ctx(ctx_init); + + EXPECT_EQ(SECFailure, + SSL_CreateMask(ctx.get(), nullptr, 0, mask.data(), mask.size())); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), nullptr, mask.size(), + mask.data(), mask.size())); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_P(VariantSuiteTest, MaskShortSample) { + std::vector<uint8_t> sample(kSampleSize); + std::vector<uint8_t> mask(kMaskSize); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_, + secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init)); + ASSERT_NE(nullptr, ctx_init); + ScopedSSLMaskingContext ctx(ctx_init); + + EXPECT_EQ(SECFailure, + SSL_CreateMask(ctx.get(), sample.data(), sample.size() - 1, + mask.data(), mask.size())); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_P(VariantSuiteTest, MaskContextUnsupportedMech) { + std::vector<uint8_t> sample(kSampleSize); + std::vector<uint8_t> mask(kMaskSize); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECFailure, + SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_AES_128_CBC_SHA256, + variant_, secret_.get(), nullptr, 0, &ctx_init)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(nullptr, ctx_init); +} + +TEST_P(VariantSuiteTest, MaskContextUnsupportedVersion) { + std::vector<uint8_t> sample(kSampleSize); + std::vector<uint8_t> mask(kMaskSize); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECFailure, SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_2, ciphersuite_, variant_, + secret_.get(), nullptr, 0, &ctx_init)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(nullptr, ctx_init); +} + +TEST_P(VariantSuiteTest, MaskMaxLength) { + uint32_t max_mask_len = kMaskSize; + if (ciphersuite_ == TLS_CHACHA20_POLY1305_SHA256) { + // Internal limitation for ChaCha20 masks. + max_mask_len = 128; + } + + std::vector<uint8_t> sample(kSampleSize); + std::vector<uint8_t> mask(max_mask_len + 1); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_, + secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init)); + ASSERT_NE(nullptr, ctx_init); + ScopedSSLMaskingContext ctx(ctx_init); + + EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), + mask.data(), mask.size() - 1)); + EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), + mask.data(), mask.size())); + EXPECT_EQ(SEC_ERROR_OUTPUT_LEN, PORT_GetError()); +} + +TEST_P(VariantSuiteTest, MaskMinLength) { + std::vector<uint8_t> sample(kSampleSize); + std::vector<uint8_t> mask(1); // Don't pass a null + + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_, + secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init)); + ASSERT_NE(nullptr, ctx_init); + ScopedSSLMaskingContext ctx(ctx_init); + EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), + mask.data(), 0)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), + mask.data(), 1)); +} + +TEST_P(VariantSuiteTest, MaskRotateLabel) { + std::vector<uint8_t> sample(kSampleSize); + std::vector<uint8_t> mask1(kMaskSize); + std::vector<uint8_t> mask2(kMaskSize); + EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), + sample.size())); + + CreateMask(kLabel, sample, &mask1); + CreateMask(std::string("sn1"), sample, &mask2); + EXPECT_FALSE(mask1 == mask2); +} + +TEST_P(VariantSuiteTest, MaskRotateSample) { + std::vector<uint8_t> sample(kSampleSize); + std::vector<uint8_t> mask1(kMaskSize); + std::vector<uint8_t> mask2(kMaskSize); + + EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), + sample.size())); + CreateMask(kLabel, sample, &mask1); + + EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), + sample.size())); + CreateMask(kLabel, sample, &mask2); + EXPECT_FALSE(mask1 == mask2); +} + +TEST_P(VariantSuiteTest, MaskRederive) { + std::vector<uint8_t> sample(kSampleSize); + std::vector<uint8_t> mask1(kMaskSize); + std::vector<uint8_t> mask2(kMaskSize); + + SECStatus rv = + PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), sample.size()); + EXPECT_EQ(SECSuccess, rv); + + // Check that re-using inputs with a new context produces the same mask. + CreateMask(kLabel, sample, &mask1); + CreateMask(kLabel, sample, &mask2); + EXPECT_TRUE(mask1 == mask2); +} + +TEST_P(SuiteTest, MaskTlsVariantKeySeparation) { + std::vector<uint8_t> sample(kSampleSize); + std::vector<uint8_t> tls_mask(kMaskSize); + std::vector<uint8_t> dtls_mask(kMaskSize); + SSLMaskingContext *stream_ctx_init = nullptr; + SSLMaskingContext *datagram_ctx_init = nullptr; + + // Init + EXPECT_EQ(SECSuccess, SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, + ssl_variant_stream, secret_.get(), kLabel.c_str(), + kLabel.size(), &stream_ctx_init)); + ASSERT_NE(nullptr, stream_ctx_init); + EXPECT_EQ(SECSuccess, SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, + ssl_variant_datagram, secret_.get(), kLabel.c_str(), + kLabel.size(), &datagram_ctx_init)); + ASSERT_NE(nullptr, datagram_ctx_init); + ScopedSSLMaskingContext tls_ctx(stream_ctx_init); + ScopedSSLMaskingContext dtls_ctx(datagram_ctx_init); + + // Derive + EXPECT_EQ(SECSuccess, + SSL_CreateMask(tls_ctx.get(), sample.data(), sample.size(), + tls_mask.data(), tls_mask.size())); + + EXPECT_EQ(SECSuccess, + SSL_CreateMask(dtls_ctx.get(), sample.data(), sample.size(), + dtls_mask.data(), dtls_mask.size())); + EXPECT_NE(tls_mask, dtls_mask); +} + +TEST_P(VariantTest, MaskChaChaRederiveOddSizes) { + // Non-block-aligned. + std::vector<uint8_t> sample(27); + std::vector<uint8_t> mask1(26); + std::vector<uint8_t> mask2(25); + EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), + sample.size())); + CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask1); + CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask2); + mask1.pop_back(); + EXPECT_TRUE(mask1 == mask2); +} + +static const uint16_t kMaskingCiphersuites[] = {TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384}; +::testing::internal::ParamGenerator<uint16_t> kMaskingCiphersuiteParams = + ::testing::ValuesIn(kMaskingCiphersuites); + +INSTANTIATE_TEST_SUITE_P(GenericMasking, SuiteTest, kMaskingCiphersuiteParams); + +INSTANTIATE_TEST_SUITE_P(GenericMasking, VariantTest, + TlsConnectTestBase::kTlsVariantsAll); + +INSTANTIATE_TEST_SUITE_P(GenericMasking, VariantSuiteTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + kMaskingCiphersuiteParams)); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc new file mode 100644 index 0000000000..2b1b92dcd8 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc @@ -0,0 +1,20 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "sslexp.h" + +#include "gtest_utils.h" + +namespace nss_test { + +class MiscTest : public ::testing::Test {}; + +TEST_F(MiscTest, NonExistentExperimentalAPI) { + EXPECT_EQ(nullptr, SSL_GetExperimentalAPI("blah")); + EXPECT_EQ(SSL_ERROR_UNSUPPORTED_EXPERIMENTAL_API, PORT_GetError()); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc new file mode 100644 index 0000000000..5378d67af8 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc @@ -0,0 +1,826 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nss.h" +#include "ssl.h" +#include "sslimpl.h" + +#include "databuffer.h" +#include "gtest_utils.h" +#include "tls_connect.h" +#include "tls_filter.h" + +namespace nss_test { + +const static size_t kMacSize = 20; + +class TlsPaddingTest + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<size_t, bool>> { + public: + TlsPaddingTest() : plaintext_len_(std::get<0>(GetParam())) { + size_t extra = + (plaintext_len_ + 1) % 16; // Bytes past a block (1 == pad len) + // Minimal padding. + pad_len_ = extra ? 16 - extra : 0; + if (std::get<1>(GetParam())) { + // Maximal padding. + pad_len_ += 240; + } + MakePaddedPlaintext(); + } + + // Makes a plaintext record with correct padding. + void MakePaddedPlaintext() { + EXPECT_EQ(0UL, (plaintext_len_ + pad_len_ + 1) % 16); + size_t i = 0; + plaintext_.Allocate(plaintext_len_ + pad_len_ + 1); + for (; i < plaintext_len_; ++i) { + plaintext_.Write(i, 'A', 1); + } + + for (; i < plaintext_len_ + pad_len_ + 1; ++i) { + plaintext_.Write(i, pad_len_, 1); + } + } + + void Unpad(bool expect_success) { + std::cerr << "Content length=" << plaintext_len_ + << " padding length=" << pad_len_ + << " total length=" << plaintext_.len() << std::endl; + std::cerr << "Plaintext: " << plaintext_ << std::endl; + sslBuffer s; + s.buf = const_cast<unsigned char*>( + static_cast<const unsigned char*>(plaintext_.data())); + s.len = plaintext_.len(); + SECStatus rv = ssl_RemoveTLSCBCPadding(&s, kMacSize); + if (expect_success) { + EXPECT_EQ(SECSuccess, rv); + EXPECT_EQ(plaintext_len_, static_cast<size_t>(s.len)); + } else { + EXPECT_EQ(SECFailure, rv); + } + } + + protected: + size_t plaintext_len_; + size_t pad_len_; + DataBuffer plaintext_; +}; + +TEST_P(TlsPaddingTest, Correct) { + if (plaintext_len_ >= kMacSize) { + Unpad(true); + } else { + Unpad(false); + } +} + +TEST_P(TlsPaddingTest, PadTooLong) { + if (plaintext_.len() < 255) { + plaintext_.Write(plaintext_.len() - 1, plaintext_.len(), 1); + Unpad(false); + } +} + +TEST_P(TlsPaddingTest, FirstByteOfPadWrong) { + if (pad_len_) { + plaintext_.Write(plaintext_len_, plaintext_.data()[plaintext_len_] + 1, 1); + Unpad(false); + } +} + +TEST_P(TlsPaddingTest, LastByteOfPadWrong) { + if (pad_len_) { + plaintext_.Write(plaintext_.len() - 2, + plaintext_.data()[plaintext_.len() - 1] + 1, 1); + Unpad(false); + } +} + +class RecordReplacer : public TlsRecordFilter { + public: + RecordReplacer(const std::shared_ptr<TlsAgent>& a, size_t size) + : TlsRecordFilter(a), size_(size) { + Disable(); + } + + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + EXPECT_EQ(ssl_ct_application_data, header.content_type()); + changed->Allocate(size_); + + for (size_t i = 0; i < size_; ++i) { + changed->data()[i] = i & 0xff; + } + + Disable(); + return CHANGE; + } + + private: + size_t size_; +}; + +TEST_P(TlsConnectStream, BadRecordMac) { + EnsureTlsSetup(); + Connect(); + client_->SetFilter(std::make_shared<TlsRecordLastByteDamager>(client_)); + ExpectAlert(server_, kTlsAlertBadRecordMac); + client_->SendData(10); + + // Read from the client, get error. + uint8_t buf[10]; + PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_GT(0, rv); + EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, PORT_GetError()); + + // Read the server alert. + rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf)); + EXPECT_GT(0, rv); + EXPECT_EQ(SSL_ERROR_BAD_MAC_ALERT, PORT_GetError()); +} + +TEST_F(TlsConnectStreamTls13, LargeRecord) { + EnsureTlsSetup(); + + const size_t record_limit = 16384; + auto replacer = MakeTlsFilter<RecordReplacer>(client_, record_limit); + replacer->EnableDecryption(); + Connect(); + + replacer->Enable(); + client_->SendData(10); + WAIT_(server_->received_bytes() == record_limit, 2000); + ASSERT_EQ(record_limit, server_->received_bytes()); +} + +TEST_F(TlsConnectStreamTls13, TooLargeRecord) { + EnsureTlsSetup(); + + const size_t record_limit = 16384; + auto replacer = MakeTlsFilter<RecordReplacer>(client_, record_limit + 1); + replacer->EnableDecryption(); + Connect(); + + replacer->Enable(); + ExpectAlert(server_, kTlsAlertRecordOverflow); + client_->SendData(10); // This is expanded. + + uint8_t buf[record_limit + 2]; + PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_GT(0, rv); + EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, PORT_GetError()); + + // Read the server alert. + rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf)); + EXPECT_GT(0, rv); + EXPECT_EQ(SSL_ERROR_RECORD_OVERFLOW_ALERT, PORT_GetError()); +} + +class ShortHeaderChecker : public PacketFilter { + public: + PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output) { + // The first octet should be 0b001000xx. + EXPECT_EQ(kCtDtlsCiphertext, (input.data()[0] & ~0x3)); + return KEEP; + } +}; + +TEST_F(TlsConnectDatagram13, AeadLimit) { + Connect(); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceDtls13DecryptFailures(server_->ssl_fd(), + (1ULL << 36) - 2)); + SendReceive(50); + + // Expect this to increment the counter. We should still be able to talk. + client_->SetFilter(std::make_shared<TlsRecordLastByteDamager>(client_)); + client_->SendData(10); + server_->ReadBytes(10); + client_->ClearFilter(); + client_->ResetSentBytes(50); + SendReceive(60); + + // Expect alert when the limit is hit. + client_->SetFilter(std::make_shared<TlsRecordLastByteDamager>(client_)); + client_->SendData(10); + ExpectAlert(server_, kTlsAlertBadRecordMac); + + // Check the error on both endpoints. + uint8_t buf[10]; + PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_EQ(-1, rv); + EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, PORT_GetError()); + + rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf)); + EXPECT_EQ(-1, rv); + EXPECT_EQ(SSL_ERROR_BAD_MAC_ALERT, PORT_GetError()); +} + +TEST_F(TlsConnectDatagram13, ShortHeadersClient) { + Connect(); + client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_TRUE); + client_->SetFilter(std::make_shared<ShortHeaderChecker>()); + SendReceive(); +} + +TEST_F(TlsConnectDatagram13, ShortHeadersServer) { + Connect(); + server_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_TRUE); + server_->SetFilter(std::make_shared<ShortHeaderChecker>()); + SendReceive(); +} + +// Send a DTLSCiphertext header with a 2B sequence number, and no length. +TEST_F(TlsConnectDatagram13, DtlsAlternateShortHeader) { + StartConnect(); + TlsSendCipherSpecCapturer capturer(client_); + Connect(); + SendReceive(50); + + uint8_t buf[] = {0x32, 0x33, 0x34}; + auto spec = capturer.spec(1); + ASSERT_NE(nullptr, spec.get()); + ASSERT_EQ(3, spec->epoch()); + + uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno; + TlsRecordHeader header(variant_, SSL_LIBRARY_VERSION_TLS_1_3, dtls13_ct, + 0x0003000000000001); + TlsRecordHeader out_header(header); + DataBuffer msg(buf, sizeof(buf)); + msg.Write(msg.len(), ssl_ct_application_data, 1); + DataBuffer ciphertext; + EXPECT_TRUE(spec->Protect(header, msg, &ciphertext, &out_header)); + + DataBuffer record; + auto rv = out_header.Write(&record, 0, ciphertext); + EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv); + client_->SendDirect(record); + + server_->ReadBytes(3); +} + +TEST_F(TlsConnectStreamTls13, UnencryptedFinishedMessage) { + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send first server flight + + // Record and drop the first record, which is the Finished. + auto recorder = std::make_shared<TlsRecordRecorder>(client_); + recorder->EnableDecryption(); + auto dropper = std::make_shared<SelectiveDropFilter>(1); + client_->SetFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit({recorder, dropper}))); + client_->Handshake(); // Save and drop CFIN. + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + + ASSERT_EQ(1U, recorder->count()); + auto& finished = recorder->record(0); + + DataBuffer d; + size_t offset = d.Write(0, ssl_ct_handshake, 1); + offset = d.Write(offset, SSL_LIBRARY_VERSION_TLS_1_2, 2); + offset = d.Write(offset, finished.buffer.len(), 2); + d.Append(finished.buffer); + client_->SendDirect(d); + + // Now process the message. + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + // The server should generate an alert. + server_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state()); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE); + // Have the client consume the alert. + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +const static size_t kContentSizesArr[] = { + 1, kMacSize - 1, kMacSize, 30, 31, 32, 36, 256, 257, 287, 288}; + +auto kContentSizes = ::testing::ValuesIn(kContentSizesArr); +const static bool kTrueFalseArr[] = {true, false}; +auto kTrueFalse = ::testing::ValuesIn(kTrueFalseArr); + +INSTANTIATE_TEST_SUITE_P(TlsPadding, TlsPaddingTest, + ::testing::Combine(kContentSizes, kTrueFalse)); + +/* Filter to modify record header and content */ +class Tls13RecordModifier : public TlsRecordFilter { + public: + Tls13RecordModifier(const std::shared_ptr<TlsAgent>& a, + uint8_t contentType = ssl_ct_handshake, size_t size = 0, + size_t padding = 0) + : TlsRecordFilter(a), + contentType_(contentType), + size_(size), + padding_(padding) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& record, size_t* offset, + DataBuffer* output) override { + if (!header.is_protected()) { + return KEEP; + } + + uint16_t protection_epoch; + uint8_t inner_content_type; + DataBuffer plaintext; + TlsRecordHeader out_header; + if (!Unprotect(header, record, &protection_epoch, &inner_content_type, + &plaintext, &out_header)) { + return KEEP; + } + + if (decrypting() && inner_content_type != ssl_ct_application_data) { + return KEEP; + } + + DataBuffer ciphertext; + bool ok = Protect(spec(protection_epoch), out_header, contentType_, + DataBuffer(size_), &ciphertext, &out_header, padding_); + EXPECT_TRUE(ok); + if (!ok) { + return KEEP; + } + + *offset = out_header.Write(output, *offset, ciphertext); + return CHANGE; + } + + private: + uint8_t contentType_; + size_t size_; + size_t padding_; +}; + +/* Zero-length InnerPlaintext test class + * + * Parameter = Tuple of: + * - TLS variant (datagram/stream) + * - Content type to be set in zero-length inner plaintext record + * - Padding of record plaintext + */ +class ZeroLengthInnerPlaintextSetupTls13 + : public TlsConnectTestBase, + public testing::WithParamInterface< + std::tuple<SSLProtocolVariant, SSLContentType, size_t>> { + public: + ZeroLengthInnerPlaintextSetupTls13() + : TlsConnectTestBase(std::get<0>(GetParam()), + SSL_LIBRARY_VERSION_TLS_1_3), + contentType_(std::get<1>(GetParam())), + padding_(std::get<2>(GetParam())){}; + + protected: + SSLContentType contentType_; + size_t padding_; +}; + +/* Test correct rejection of TLS 1.3 encrypted handshake/alert records with + * zero-length inner plaintext content length with and without padding. + * + * Implementations MUST NOT send Handshake and Alert records that have a + * zero-length TLSInnerPlaintext.content; if such a message is received, + * the receiving implementation MUST terminate the connection with an + * "unexpected_message" alert [RFC8446, Section 5.4]. */ +TEST_P(ZeroLengthInnerPlaintextSetupTls13, ZeroLengthInnerPlaintextRun) { + EnsureTlsSetup(); + + // Filter modifies record to be zero-length + auto filter = + MakeTlsFilter<Tls13RecordModifier>(client_, contentType_, 0, padding_); + filter->EnableDecryption(); + filter->Disable(); + + Connect(); + + filter->Enable(); + + // Record will be overwritten + client_->SendData(0xf); + + // Receive corrupt record + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + // 22B = 16B MAC + 1B innerContentType + 5B Header + server_->ReadBytes(22); + // Process alert at peer + client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); + client_->Handshake(); + } else { /* DTLS */ + size_t received = server_->received_bytes(); + // 22B = 16B MAC + 1B innerContentType + 5B Header + server_->ReadBytes(22); + // Check that no bytes were received => packet was dropped + ASSERT_EQ(received, server_->received_bytes()); + // Check that we are still connected / not in error state + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + } +} + +// Test for TLS and DTLS +const SSLProtocolVariant kZeroLengthInnerPlaintextVariants[] = { + ssl_variant_stream, ssl_variant_datagram}; +// Test for handshake and alert fragments +const SSLContentType kZeroLengthInnerPlaintextContentTypes[] = { + ssl_ct_handshake, ssl_ct_alert}; +// Test with 0,1 and 100 octets of padding +const size_t kZeroLengthInnerPlaintextPadding[] = {0, 1, 100}; + +INSTANTIATE_TEST_SUITE_P( + ZeroLengthInnerPlaintextTest, ZeroLengthInnerPlaintextSetupTls13, + testing::Combine(testing::ValuesIn(kZeroLengthInnerPlaintextVariants), + testing::ValuesIn(kZeroLengthInnerPlaintextContentTypes), + testing::ValuesIn(kZeroLengthInnerPlaintextPadding)), + [](const testing::TestParamInfo< + ZeroLengthInnerPlaintextSetupTls13::ParamType>& inf) { + return std::string(std::get<0>(inf.param) == ssl_variant_stream + ? "Tls" + : "Dtls") + + "ZeroLengthInnerPlaintext" + + (std::get<1>(inf.param) == ssl_ct_handshake ? "Handshake" + : "Alert") + + (std::get<2>(inf.param) + ? "Padding" + std::to_string(std::get<2>(inf.param)) + "B" + : "") + + "Test"; + }); + +/* Zero-length record test class + * + * Parameter = Tuple of: + * - TLS variant (datagram/stream) + * - TLS version + * - Content type to be set in zero-length record + */ +class ZeroLengthRecordSetup + : public TlsConnectTestBase, + public testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t, SSLContentType>> { + public: + ZeroLengthRecordSetup() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())), + variant_(std::get<0>(GetParam())), + contentType_(std::get<2>(GetParam())){}; + + void createZeroLengthRecord(DataBuffer& buffer, unsigned epoch = 0, + unsigned seqn = 0) { + size_t idx = 0; + // Set header content type + idx = buffer.Write(idx, contentType_, 1); + // The record version is not checked during record layer handling + idx = buffer.Write(idx, 0xDEAD, 2); + // DTLS (version always < TLS 1.3) + if (variant_ == ssl_variant_datagram) { + // Set epoch (Should be 0 before handshake) + idx = buffer.Write(idx, 0U, 2); + // Set 6B sequence number (0 if send as first message) + idx = buffer.Write(idx, 0U, 2); + idx = buffer.Write(idx, 0U, 4); + } + // Set fragment to be of zero-length + (void)buffer.Write(idx, 0U, 2); + } + + protected: + SSLProtocolVariant variant_; + SSLContentType contentType_; +}; + +/* Test handling of zero-length (ciphertext/fragment) records before handshake. + * + * This is only tested before the first handshake, since after it all of these + * messages are expected to be encrypted which is impossible for a content + * length of zero, always leading to a bad record mac. For TLS 1.3 only + * records of application data content type is legal after the handshake. + * + * Handshake records of length zero will be ignored in the record layer since + * the RFC does only specify that such records MUST NOT be sent but it does not + * state that an alert should be sent or the connection be terminated + * [RFC8446, Section 5.1]. + * + * Even though only handshake messages are handled (ignored) in the record + * layer handling, this test covers zero-length records of all content types + * for complete coverage of cases. + * + * !!! Expected TLS (Stream) behavior !!! + * - Handshake records of zero length are ignored. + * - Alert and ChangeCipherSpec records of zero-length lead to illegal + * parameter alerts due to the malformed record content. + * - ApplicationData before the handshake leads to an unexpected message alert. + * + * !!! Expected DTLS (Datagram) behavior !!! + * - Handshake message of zero length are ignored. + * - Alert messages lead to an illegal parameter alert due to malformed record + * content. + * - ChangeCipherSpec records before the first handshake are not expected and + * ignored (see ssl3con.c, line 3276). + * - ApplicationData before the handshake is ignored since it could be a packet + * received in incorrect order (see ssl3con.c, line 13353). + */ +TEST_P(ZeroLengthRecordSetup, ZeroLengthRecordRun) { + EnsureTlsSetup(); + + // Send zero-length record + DataBuffer buffer; + createZeroLengthRecord(buffer); + client_->SendDirect(buffer); + // This must be set, otherwise handshake completness assertions might fail + server_->StartConnect(); + + SSLAlertDescription alert = close_notify; + + switch (variant_) { + case ssl_variant_datagram: + switch (contentType_) { + case ssl_ct_alert: + // Should actually be ignored, see bug 1829391. + alert = illegal_parameter; + break; + case ssl_ct_ack: + if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { + // Skipped due to bug 1829391. + GTEST_SKIP(); + } + // DTLS versions < 1.3 correctly ignore the invalid record + // so we fall through. + case ssl_ct_change_cipher_spec: + case ssl_ct_application_data: + case ssl_ct_handshake: + server_->Handshake(); + Connect(); + return; + } + break; + case ssl_variant_stream: + switch (contentType_) { + case ssl_ct_alert: + case ssl_ct_change_cipher_spec: + alert = illegal_parameter; + break; + case ssl_ct_application_data: + case ssl_ct_ack: + alert = unexpected_message; + break; + case ssl_ct_handshake: + // TLS ignores unprotected zero-length handshake records + server_->Handshake(); + Connect(); + return; + } + break; + } + + // Assert alert is send for TLS and DTLS alert records + server_->ExpectSendAlert(alert); + server_->Handshake(); + + // Consume alert at peer, expect alert for TLS and DTLS alert records + client_->StartConnect(); + client_->ExpectReceiveAlert(alert); + client_->Handshake(); +} + +// Test for handshake, alert, change_cipher_spec and application data fragments +const SSLContentType kZeroLengthRecordContentTypes[] = { + ssl_ct_handshake, ssl_ct_alert, ssl_ct_change_cipher_spec, + ssl_ct_application_data, ssl_ct_ack}; + +INSTANTIATE_TEST_SUITE_P( + ZeroLengthRecordTest, ZeroLengthRecordSetup, + testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV11Plus, + testing::ValuesIn(kZeroLengthRecordContentTypes)), + [](const testing::TestParamInfo<ZeroLengthRecordSetup::ParamType>& inf) { + std::string variant = + (std::get<0>(inf.param) == ssl_variant_stream) ? "Tls" : "Dtls"; + std::string version = VersionString(std::get<1>(inf.param)); + std::replace(version.begin(), version.end(), '.', '_'); + std::string contentType; + switch (std::get<2>(inf.param)) { + case ssl_ct_handshake: + contentType = "Handshake"; + break; + case ssl_ct_alert: + contentType = "Alert"; + break; + case ssl_ct_application_data: + contentType = "ApplicationData"; + break; + case ssl_ct_change_cipher_spec: + contentType = "ChangeCipherSpec"; + break; + case ssl_ct_ack: + contentType = "Ack"; + break; + } + return variant + version + "ZeroLength" + contentType + "Test"; + }); + +/* Test correct handling of records with invalid content types. + * + * TLS: + * If a TLS implementation receives an unexpected record type, it MUST + * terminate the connection with an "unexpected_message" alert + * [RFC8446, Section 5]. + * + * DTLS: + * In general, invalid records SHOULD be silently discarded... + * [RFC6347, Section 4.1.2.7]. */ +class UndefinedContentTypeSetup : public TlsConnectGeneric { + public: + UndefinedContentTypeSetup() : TlsConnectGeneric() { StartConnect(); }; + + void createUndefinedContentTypeRecord(DataBuffer& buffer, unsigned epoch = 0, + unsigned seqn = 0) { + // dummy data + uint8_t data[] = {0xAA, 0xBB, 0xCC, 0xDD, 0xEE}; + + size_t idx = 0; + // Set undefined content type + idx = buffer.Write(idx, 0xFF, 1); + // The record version is not checked during record layer handling + idx = buffer.Write(idx, 0xDEAD, 2); + // DTLS (version always < TLS 1.3) + if (variant_ == ssl_variant_datagram) { + // Set epoch (Should be 0 before/during handshake) + idx = buffer.Write(idx, epoch, 2); + // Set 6B sequence number (0 if send as first message) + idx = buffer.Write(idx, 0U, 2); + idx = buffer.Write(idx, seqn, 4); + } + // Set fragment length + idx = buffer.Write(idx, 5U, 2); + // Add data to record + (void)buffer.Write(idx, data, 5); + } + + void checkUndefinedContentTypeHandling(std::shared_ptr<TlsAgent> sender, + std::shared_ptr<TlsAgent> receiver) { + if (variant_ == ssl_variant_stream) { + // Handle record and expect alert to be sent + receiver->ExpectSendAlert(kTlsAlertUnexpectedMessage); + receiver->ReadBytes(); + /* Digest and assert that the correct alert was received at peer + * + * The 1.3 server expects all messages other than the ClientHello to be + * encrypted and responds with an unexpected message alert to alerts. */ + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3 && sender == server_) { + sender->ExpectSendAlert(kTlsAlertUnexpectedMessage); + } else { + sender->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); + } + sender->ReadBytes(); + } else { // DTLS drops invalid records silently + size_t received = receiver->received_bytes(); + receiver->ReadBytes(); + // Ensure no bytes were received/record was dropped + ASSERT_EQ(received, receiver->received_bytes()); + } + } + + protected: + DataBuffer buffer_; +}; + +INSTANTIATE_TEST_SUITE_P( + UndefinedContentTypePreHandshakeStream, UndefinedContentTypeSetup, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); + +INSTANTIATE_TEST_SUITE_P( + UndefinedContentTypePreHandshakeDatagram, UndefinedContentTypeSetup, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus)); + +TEST_P(UndefinedContentTypeSetup, + ServerReceiveUndefinedContentTypePreClientHello) { + createUndefinedContentTypeRecord(buffer_); + + // Send undefined content type record + client_->SendDirect(buffer_); + + checkUndefinedContentTypeHandling(client_, server_); +} + +TEST_P(UndefinedContentTypeSetup, + ServerReceiveUndefinedContentTypePostClientHello) { + // Set epoch to 0 (handshake), and sequence number to 1 since hello is sent + createUndefinedContentTypeRecord(buffer_, 0, 1); + + // Send ClientHello + client_->Handshake(); + // Send undefined content type record + client_->SendDirect(buffer_); + + checkUndefinedContentTypeHandling(client_, server_); +} + +TEST_P(UndefinedContentTypeSetup, + ClientReceiveUndefinedContentTypePreClientHello) { + createUndefinedContentTypeRecord(buffer_); + + // Send undefined content type record + server_->SendDirect(buffer_); + + checkUndefinedContentTypeHandling(server_, client_); +} + +TEST_P(UndefinedContentTypeSetup, + ClientReceiveUndefinedContentTypePostClientHello) { + // Set epoch to 0 (handshake), and sequence number to 1 since hello is sent + createUndefinedContentTypeRecord(buffer_, 0, 1); + + // Send ClientHello + client_->Handshake(); + // Send undefined content type record + server_->SendDirect(buffer_); + + checkUndefinedContentTypeHandling(server_, client_); +} + +class RecordOuterContentTypeSetter : public TlsRecordFilter { + public: + RecordOuterContentTypeSetter(const std::shared_ptr<TlsAgent>& a, + uint8_t contentType = ssl_ct_handshake) + : TlsRecordFilter(a), contentType_(contentType) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& record, size_t* offset, + DataBuffer* output) override { + TlsRecordHeader hdr(header.variant(), header.version(), contentType_, + header.sequence_number()); + + *offset = hdr.Write(output, *offset, record); + return CHANGE; + } + + private: + uint8_t contentType_; +}; + +/* Test correct handling of invalid inner and outer record content type. + * This is only possible for TLS 1.3, since only for this version decryption + * and encryption of manipulated records is supported by the test suite. */ +TEST_P(TlsConnectTls13, UndefinedOuterContentType13) { + EnsureTlsSetup(); + Connect(); + + // Manipulate record: set invalid content type 0xff + MakeTlsFilter<RecordOuterContentTypeSetter>(client_, 0xff); + client_->SendData(50); + + if (variant_ == ssl_variant_stream) { + // Handle invalid record + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + server_->ReadBytes(); + // Handle alert at peer + client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); + client_->ReadBytes(); + } else { + // Make sure DTLS drops invalid record silently + size_t received = server_->received_bytes(); + server_->ReadBytes(); + ASSERT_EQ(received, server_->received_bytes()); + } +} + +TEST_P(TlsConnectTls13, UndefinedInnerContentType13) { + EnsureTlsSetup(); + + // Manipulate record: set invalid content type 0xff and length to 50. + auto filter = MakeTlsFilter<Tls13RecordModifier>(client_, 0xff, 50, 0); + filter->EnableDecryption(); + filter->Disable(); + + Connect(); + + filter->Enable(); + // Send manipulate record with invalid content type + client_->SendData(50); + + if (variant_ == ssl_variant_stream) { + // Handle invalid record + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + server_->ReadBytes(); + // Handle alert at peer + client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); + client_->ReadBytes(); + } else { + // Make sure DTLS drops invalid record silently + size_t received = server_->received_bytes(); + server_->ReadBytes(); + ASSERT_EQ(received, server_->received_bytes()); + } +} + +} // namespace nss_test
\ No newline at end of file diff --git a/security/nss/gtests/ssl_gtest/ssl_recordsep_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_recordsep_unittest.cc new file mode 100644 index 0000000000..8051b58d01 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_recordsep_unittest.cc @@ -0,0 +1,679 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include <queue> +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +class HandshakeSecretTracker { + public: + HandshakeSecretTracker(const std::shared_ptr<TlsAgent>& agent, + uint16_t first_read_epoch, uint16_t first_write_epoch) + : agent_(agent), + next_read_epoch_(first_read_epoch), + next_write_epoch_(first_write_epoch) { + EXPECT_EQ(SECSuccess, + SSL_SecretCallback(agent_->ssl_fd(), + HandshakeSecretTracker::SecretCb, this)); + } + + void CheckComplete() const { + EXPECT_EQ(0, next_read_epoch_); + EXPECT_EQ(0, next_write_epoch_); + } + + private: + static void SecretCb(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir, + PK11SymKey* secret, void* arg) { + HandshakeSecretTracker* t = reinterpret_cast<HandshakeSecretTracker*>(arg); + t->SecretUpdated(epoch, dir, secret); + } + + void SecretUpdated(PRUint16 epoch, SSLSecretDirection dir, + PK11SymKey* secret) { + if (g_ssl_gtest_verbose) { + std::cerr << agent_->role_str() << ": secret callback for " << dir + << " epoch " << epoch << std::endl; + } + + EXPECT_TRUE(secret); + uint16_t* p; + if (dir == ssl_secret_read) { + p = &next_read_epoch_; + } else { + ASSERT_EQ(ssl_secret_write, dir); + p = &next_write_epoch_; + } + EXPECT_EQ(*p, epoch); + switch (*p) { + case 1: // 1 == 0-RTT, next should be handshake. + case 2: // 2 == handshake, next should be application data. + (*p)++; + break; + + case 3: // 3 == application data, there should be no more. + // Use 0 as a sentinel value. + *p = 0; + break; + + default: + ADD_FAILURE() << "Unexpected next epoch: " << *p; + } + } + + std::shared_ptr<TlsAgent> agent_; + uint16_t next_read_epoch_; + uint16_t next_write_epoch_; +}; + +TEST_F(TlsConnectTest, HandshakeSecrets) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + EnsureTlsSetup(); + + HandshakeSecretTracker c(client_, 2, 2); + HandshakeSecretTracker s(server_, 2, 2); + + Connect(); + SendReceive(); + + c.CheckComplete(); + s.CheckComplete(); +} + +TEST_F(TlsConnectTest, ZeroRttSecrets) { + SetupForZeroRtt(); + + HandshakeSecretTracker c(client_, 2, 1); + HandshakeSecretTracker s(server_, 1, 2); + + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + + c.CheckComplete(); + s.CheckComplete(); +} + +class KeyUpdateTracker { + public: + KeyUpdateTracker(const std::shared_ptr<TlsAgent>& agent, + bool expect_read_secret) + : agent_(agent), expect_read_secret_(expect_read_secret), called_(false) { + EXPECT_EQ(SECSuccess, SSL_SecretCallback(agent_->ssl_fd(), + KeyUpdateTracker::SecretCb, this)); + } + + void CheckCalled() const { EXPECT_TRUE(called_); } + + private: + static void SecretCb(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir, + PK11SymKey* secret, void* arg) { + KeyUpdateTracker* t = reinterpret_cast<KeyUpdateTracker*>(arg); + t->SecretUpdated(epoch, dir, secret); + } + + void SecretUpdated(PRUint16 epoch, SSLSecretDirection dir, + PK11SymKey* secret) { + EXPECT_EQ(4U, epoch); + EXPECT_EQ(expect_read_secret_, dir == ssl_secret_read); + EXPECT_TRUE(secret); + called_ = true; + } + + std::shared_ptr<TlsAgent> agent_; + bool expect_read_secret_; + bool called_; +}; + +TEST_F(TlsConnectTest, KeyUpdateSecrets) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + // The update is to the client write secret; the server read secret. + KeyUpdateTracker c(client_, false); + KeyUpdateTracker s(server_, true); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(4, 3); + c.CheckCalled(); + s.CheckCalled(); +} + +// BadPrSocket is an instance of a PR IO layer that crashes the test if it is +// ever used for reading or writing. It does that by failing to overwrite any +// of the DummyIOLayerMethods, which all crash when invoked. +class BadPrSocket : public DummyIOLayerMethods { + public: + BadPrSocket(std::shared_ptr<TlsAgent>& agent) : DummyIOLayerMethods() { + static PRDescIdentity bad_identity = PR_GetUniqueIdentity("bad NSPR id"); + fd_ = DummyIOLayerMethods::CreateFD(bad_identity, this); + + // This is terrible, but NSPR doesn't provide an easy way to replace the + // bottom layer of an IO stack. Take the DummyPrSocket and replace its + // NSPR method vtable with the ones from this object. + dummy_layer_ = + PR_GetIdentitiesLayer(agent->ssl_fd(), DummyPrSocket::LayerId()); + EXPECT_TRUE(dummy_layer_); + original_methods_ = dummy_layer_->methods; + original_secret_ = dummy_layer_->secret; + dummy_layer_->methods = fd_->methods; + dummy_layer_->secret = reinterpret_cast<PRFilePrivate*>(this); + } + + // This will be destroyed before the agent, so we need to restore the state + // before we tampered with it. + virtual ~BadPrSocket() { + dummy_layer_->methods = original_methods_; + dummy_layer_->secret = original_secret_; + } + + private: + ScopedPRFileDesc fd_; + PRFileDesc* dummy_layer_; + const PRIOMethods* original_methods_; + PRFilePrivate* original_secret_; +}; + +class StagedRecords { + public: + StagedRecords(std::shared_ptr<TlsAgent>& agent) : agent_(agent), records_() { + EXPECT_EQ(SECSuccess, + SSL_RecordLayerWriteCallback( + agent_->ssl_fd(), StagedRecords::StageRecordData, this)); + } + + virtual ~StagedRecords() { + // Uninstall so that the callback doesn't fire during cleanup. + EXPECT_EQ(SECSuccess, + SSL_RecordLayerWriteCallback(agent_->ssl_fd(), nullptr, nullptr)); + } + + bool empty() const { return records_.empty(); } + + void ForwardAll(std::shared_ptr<TlsAgent>& peer) { + EXPECT_NE(agent_, peer) << "can't forward to self"; + for (auto r : records_) { + r.Forward(peer); + } + records_.clear(); + } + + // This forwards all saved data and checks the resulting state. + void ForwardAll(std::shared_ptr<TlsAgent>& peer, + TlsAgent::State expected_state) { + ForwardAll(peer); + switch (expected_state) { + case TlsAgent::STATE_CONNECTED: + // The handshake callback should have been called, so check that before + // checking that SSL_ForceHandshake succeeds. + EXPECT_EQ(expected_state, peer->state()); + EXPECT_EQ(SECSuccess, SSL_ForceHandshake(peer->ssl_fd())); + break; + + case TlsAgent::STATE_CONNECTING: + // Check that SSL_ForceHandshake() blocks. + EXPECT_EQ(SECFailure, SSL_ForceHandshake(peer->ssl_fd())); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + // Update and check the state. + peer->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, peer->state()); + break; + + default: + ADD_FAILURE() << "No idea how to handle this state"; + } + } + + void ForwardPartial(std::shared_ptr<TlsAgent>& peer) { + if (records_.empty()) { + ADD_FAILURE() << "No records to slice"; + return; + } + auto& last = records_.back(); + auto tail = last.SliceTail(); + ForwardAll(peer, TlsAgent::STATE_CONNECTING); + records_.push_back(tail); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, peer->state()); + } + + private: + // A single record. + class StagedRecord { + public: + StagedRecord(const std::string role, uint16_t epoch, SSLContentType ct, + const uint8_t* data, size_t len) + : role_(role), epoch_(epoch), content_type_(ct), data_(data, len) { + if (g_ssl_gtest_verbose) { + std::cerr << role_ << ": staged epoch " << epoch_ << " " + << content_type_ << ": " << data_ << std::endl; + } + } + + // This forwards staged data to the identified agent. + void Forward(std::shared_ptr<TlsAgent>& peer) { + // Now there should be staged data. + EXPECT_FALSE(data_.empty()); + if (g_ssl_gtest_verbose) { + std::cerr << role_ << ": forward epoch " << epoch_ << " " << data_ + << std::endl; + } + EXPECT_EQ(SECSuccess, + SSL_RecordLayerData(peer->ssl_fd(), epoch_, content_type_, + data_.data(), + static_cast<unsigned int>(data_.len()))); + } + + // Slices the tail off this record and returns it. + StagedRecord SliceTail() { + size_t slice = 1; + if (data_.len() <= slice) { + ADD_FAILURE() << "record too small to slice in two"; + slice = 0; + } + size_t keep = data_.len() - slice; + StagedRecord tail(role_, epoch_, content_type_, data_.data() + keep, + slice); + data_.Truncate(keep); + return tail; + } + + private: + std::string role_; + uint16_t epoch_; + SSLContentType content_type_; + DataBuffer data_; + }; + + // This is an SSLRecordWriteCallback that stages data. + static SECStatus StageRecordData(PRFileDesc* fd, PRUint16 epoch, + SSLContentType content_type, + const PRUint8* data, unsigned int len, + void* arg) { + auto stage = reinterpret_cast<StagedRecords*>(arg); + stage->records_.push_back(StagedRecord(stage->agent_->role_str(), epoch, + content_type, data, + static_cast<size_t>(len))); + return SECSuccess; + } + + std::shared_ptr<TlsAgent>& agent_; + std::deque<StagedRecord> records_; +}; + +// Attempting to feed application data in before the handshake is complete +// should be caught. +static void RefuseApplicationData(std::shared_ptr<TlsAgent>& peer, + uint16_t epoch) { + static const uint8_t d[] = {1, 2, 3}; + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(peer->ssl_fd(), epoch, ssl_ct_application_data, + d, static_cast<unsigned int>(sizeof(d)))); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +static void SendForwardReceive(std::shared_ptr<TlsAgent>& sender, + StagedRecords& sender_stage, + std::shared_ptr<TlsAgent>& receiver) { + const size_t count = 10; + sender->SendData(count, count); + sender_stage.ForwardAll(receiver); + receiver->ReadBytes(count); +} + +TEST_P(TlsConnectStream, ReplaceRecordLayer) { + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + + // BadPrSocket installs an IO layer that crashes when the SSL layer attempts + // to read or write. + BadPrSocket bad_layer_client(client_); + BadPrSocket bad_layer_server(server_); + + // StagedRecords installs a handler for unprotected data from the socket, and + // captures that data. + StagedRecords client_stage(client_); + StagedRecords server_stage(server_); + + // Both peers should refuse application data from epoch 0. + RefuseApplicationData(client_, 0); + RefuseApplicationData(server_, 0); + + // This first call forwards nothing, but it causes the client to handshake, + // which starts things off. This stages the ClientHello as a result. + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + // This processes the ClientHello and stages the first server flight. + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); + + // In TLS 1.3, this is 0-RTT; in <TLS 1.3, this is application data. + // Neither is acceptable. + RefuseApplicationData(client_, 1); + RefuseApplicationData(server_, 1); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + // Application data in handshake is never acceptable. + RefuseApplicationData(client_, 2); + RefuseApplicationData(server_, 2); + // Don't accept real data until the handshake is done. + RefuseApplicationData(client_, 3); + RefuseApplicationData(server_, 3); + // Process the server flight and the client is done. + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + } else { + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + } + CheckKeys(); + + // Reading and writing application data should work. + SendForwardReceive(client_, client_stage, server_); + SendForwardReceive(server_, server_stage, client_); +} + +TEST_F(TlsConnectStreamTls13, ReplaceRecordLayerZeroRtt) { + SetupForZeroRtt(); + + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + + BadPrSocket bad_layer_client(client_); + BadPrSocket bad_layer_server(server_); + + StagedRecords client_stage(client_); + StagedRecords server_stage(server_); + + ExpectResumption(RESUME_TICKET); + + // Send ClientHello + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + + // The client can never accept 0-RTT. + RefuseApplicationData(client_, 1); + + // Send some 0-RTT data, which get staged in `client_stage`. + const char* kMsg = "EarlyData"; + const PRInt32 kMsgLen = static_cast<PRInt32>(strlen(kMsg)); + PRInt32 rv = PR_Write(client_->ssl_fd(), kMsg, kMsgLen); + EXPECT_EQ(kMsgLen, rv); + + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); + + // The server should now have 0-RTT to read. + std::vector<uint8_t> buf(kMsgLen); + rv = PR_Read(server_->ssl_fd(), buf.data(), kMsgLen); + EXPECT_EQ(kMsgLen, rv); + + // The handshake should happily finish. + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + ExpectEarlyDataAccepted(true); + CheckConnected(); + + // Reading and writing application data should work. + SendForwardReceive(client_, client_stage, server_); + SendForwardReceive(server_, server_stage, client_); +} + +static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) { + return SECWouldBlock; +} + +TEST_P(TlsConnectStream, ReplaceRecordLayerAsyncLateAuth) { + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + + BadPrSocket bad_layer_client(client_); + BadPrSocket bad_layer_server(server_); + StagedRecords client_stage(client_); + StagedRecords server_stage(server_); + + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + + // Prior to TLS 1.3, the client sends its second flight immediately. But in + // TLS 1.3, a client won't send a Finished until it is happy with the server + // certificate. So blocking certificate validation causes the client to send + // nothing. + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + ASSERT_TRUE(client_stage.empty()); + + // Client should have stopped reading when it saw the Certificate message, + // so it will be reading handshake epoch, and writing cleartext. + client_->CheckEpochs(2, 0); + // Server should be reading handshake, and writing application data. + server_->CheckEpochs(2, 3); + + // Handshake again and the client will read the remainder of the server's + // flight, but it will remain blocked. + client_->Handshake(); + ASSERT_TRUE(client_stage.empty()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + } else { + // In prior versions, the client's second flight is always sent. + ASSERT_FALSE(client_stage.empty()); + } + + // Now declare the certificate good. + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + client_->Handshake(); + ASSERT_FALSE(client_stage.empty()); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + } else { + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + } + CheckKeys(); + + // Reading and writing application data should work. + SendForwardReceive(client_, client_stage, server_); +} + +TEST_F(TlsConnectStreamTls13, ReplaceRecordLayerAsyncPostHandshake) { + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + + BadPrSocket bad_layer_client(client_); + BadPrSocket bad_layer_server(server_); + StagedRecords client_stage(client_); + StagedRecords server_stage(server_); + + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + + ASSERT_TRUE(client_stage.empty()); + client_->Handshake(); + ASSERT_TRUE(client_stage.empty()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + + // Now declare the certificate good. + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + client_->Handshake(); + ASSERT_FALSE(client_stage.empty()); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + } else { + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + } + CheckKeys(); + + // Reading and writing application data should work. + SendForwardReceive(client_, client_stage, server_); + + // Post-handshake messages should work here. + EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), nullptr, 0)); + SendForwardReceive(server_, server_stage, client_); +} + +// This test ensures that data is correctly forwarded when the handshake is +// resumed after asynchronous server certificate authentication, when +// SSL_AuthCertificateComplete() is called. The logic for resuming the +// handshake involves a different code path than the usual one, so this test +// exercises that code fully. +TEST_F(TlsConnectStreamTls13, ReplaceRecordLayerAsyncEarlyAuth) { + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + + BadPrSocket bad_layer_client(client_); + BadPrSocket bad_layer_server(server_); + StagedRecords client_stage(client_); + StagedRecords server_stage(server_); + + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); + + // Send a partial flight on to the client. + // This includes enough to trigger the certificate callback. + server_stage.ForwardPartial(client_); + EXPECT_TRUE(client_stage.empty()); + + // Declare the certificate good. + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + client_->Handshake(); + EXPECT_TRUE(client_stage.empty()); + + // Send the remainder of the server flight. + PRBool pending = PR_FALSE; + EXPECT_EQ(SECSuccess, + SSLInt_HasPendingHandshakeData(client_->ssl_fd(), &pending)); + EXPECT_EQ(PR_TRUE, pending); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + CheckKeys(); + + SendForwardReceive(server_, server_stage, client_); +} + +TEST_P(TlsConnectStream, ForwardDataFromWrongEpoch) { + const uint8_t data[] = {1}; + Connect(); + uint16_t next_epoch; + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(client_->ssl_fd(), 2, ssl_ct_application_data, + data, sizeof(data))); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()) + << "Passing data from an old epoch is rejected"; + next_epoch = 4; + } else { + // Prior to TLS 1.3, the epoch is only updated once during the handshake. + next_epoch = 2; + } + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(client_->ssl_fd(), next_epoch, + ssl_ct_application_data, data, sizeof(data))); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()) + << "Passing data from a future epoch blocks"; +} + +TEST_F(TlsConnectStreamTls13, ForwardInvalidData) { + const uint8_t data[1] = {0}; + + EnsureTlsSetup(); + // Zero-length data. + EXPECT_EQ(SECFailure, SSL_RecordLayerData(client_->ssl_fd(), 0, + ssl_ct_application_data, data, 0)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // NULL data. + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(client_->ssl_fd(), 0, ssl_ct_application_data, + nullptr, 1)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectDatagram13, ForwardDataDtls) { + EnsureTlsSetup(); + const uint8_t data[1] = {0}; + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(client_->ssl_fd(), 0, ssl_ct_application_data, + data, sizeof(data))); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectStreamTls13, SuppressEndOfEarlyData) { + SetupForZeroRtt(); + + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + client_->SetOption(SSL_SUPPRESS_END_OF_EARLY_DATA, true); + server_->SetOption(SSL_SUPPRESS_END_OF_EARLY_DATA, true); + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + + BadPrSocket bad_layer_client(client_); + BadPrSocket bad_layer_server(server_); + + StagedRecords client_stage(client_); + StagedRecords server_stage(server_); + + ExpectResumption(RESUME_TICKET); + + // Send ClientHello + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + + // Send some 0-RTT data, which get staged in `client_stage`. + const char* kMsg = "ABCDEF"; + const PRInt32 kMsgLen = static_cast<PRInt32>(strlen(kMsg)); + PRInt32 rv = PR_Write(client_->ssl_fd(), kMsg, kMsgLen); + EXPECT_EQ(kMsgLen, rv); + + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); + + // The server should now have 0-RTT to read. + std::vector<uint8_t> buf(kMsgLen); + rv = PR_Read(server_->ssl_fd(), buf.data(), kMsgLen); + EXPECT_EQ(kMsgLen, rv); + + // The handshake should happily finish, without the end of the early data. + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + ExpectEarlyDataAccepted(true); + CheckConnected(); + + // Reading and writing application data should work. + SendForwardReceive(client_, client_stage, server_); + SendForwardReceive(server_, server_stage, client_); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_recordsize_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_recordsize_unittest.cc new file mode 100644 index 0000000000..8a84db5749 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_recordsize_unittest.cc @@ -0,0 +1,726 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +// This class tracks the maximum size of record that was sent, both cleartext +// and plain. It only tracks records that have an outer type of +// application_data or DTLSCiphertext. In TLS 1.3, this includes handshake +// messages. +class TlsRecordMaximum : public TlsRecordFilter { + public: + TlsRecordMaximum(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a), max_ciphertext_(0), max_plaintext_(0) {} + + size_t max_ciphertext() const { return max_ciphertext_; } + size_t max_plaintext() const { return max_plaintext_; } + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& record, size_t* offset, + DataBuffer* output) override { + std::cerr << "max: " << record << std::endl; + // Ignore unprotected packets. + if (!header.is_protected()) { + return KEEP; + } + + max_ciphertext_ = (std::max)(max_ciphertext_, record.len()); + return TlsRecordFilter::FilterRecord(header, record, offset, output); + } + + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + max_plaintext_ = (std::max)(max_plaintext_, data.len()); + return KEEP; + } + + private: + size_t max_ciphertext_; + size_t max_plaintext_; +}; + +void CheckRecordSizes(const std::shared_ptr<TlsAgent>& agent, + const std::shared_ptr<TlsRecordMaximum>& record_max, + size_t config) { + uint16_t cipher_suite; + ASSERT_TRUE(agent->cipher_suite(&cipher_suite)); + + size_t expansion; + size_t iv; + switch (cipher_suite) { + case TLS_AES_128_GCM_SHA256: + case TLS_AES_256_GCM_SHA384: + case TLS_CHACHA20_POLY1305_SHA256: + expansion = 16; + iv = 0; + break; + + case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + expansion = 16; + iv = 8; + break; + + case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: + // Expansion is 20 for the MAC. Maximum block padding is 16. Maximum + // padding is added when the input plus the MAC is an exact multiple of + // the block size. + expansion = 20 + 16 - ((config + 20) % 16); + iv = 16; + break; + + default: + ADD_FAILURE() << "No expansion set for ciphersuite " + << agent->cipher_suite_name(); + return; + } + + switch (agent->version()) { + case SSL_LIBRARY_VERSION_TLS_1_3: + EXPECT_EQ(0U, iv) << "No IV for TLS 1.3"; + // We only have decryption in TLS 1.3. + EXPECT_EQ(config - 1, record_max->max_plaintext()) + << "bad plaintext length for " << agent->role_str(); + break; + + case SSL_LIBRARY_VERSION_TLS_1_2: + case SSL_LIBRARY_VERSION_TLS_1_1: + expansion += iv; + break; + + case SSL_LIBRARY_VERSION_TLS_1_0: + break; + + default: + ADD_FAILURE() << "Unexpected version " << agent->version(); + return; + } + + EXPECT_EQ(config + expansion, record_max->max_ciphertext()) + << "bad ciphertext length for " << agent->role_str(); +} + +TEST_P(TlsConnectGeneric, RecordSizeMaximum) { + uint16_t max_record_size = + (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) ? 16385 : 16384; + size_t send_size = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) + ? max_record_size + : max_record_size + 1; + + EnsureTlsSetup(); + auto client_max = MakeTlsFilter<TlsRecordMaximum>(client_); + auto server_max = MakeTlsFilter<TlsRecordMaximum>(server_); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + client_max->EnableDecryption(); + server_max->EnableDecryption(); + } + + Connect(); + client_->SendData(send_size, send_size); + server_->SendData(send_size, send_size); + server_->ReadBytes(send_size); + client_->ReadBytes(send_size); + + CheckRecordSizes(client_, client_max, max_record_size); + CheckRecordSizes(server_, server_max, max_record_size); +} + +TEST_P(TlsConnectGeneric, RecordSizeMinimumClient) { + EnsureTlsSetup(); + auto server_max = MakeTlsFilter<TlsRecordMaximum>(server_); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + server_max->EnableDecryption(); + } + + client_->SetOption(SSL_RECORD_SIZE_LIMIT, 64); + Connect(); + SendReceive(127); // Big enough for one record, allowing for 1+N splitting. + + CheckRecordSizes(server_, server_max, 64); +} + +TEST_P(TlsConnectGeneric, RecordSizeMinimumServer) { + EnsureTlsSetup(); + auto client_max = MakeTlsFilter<TlsRecordMaximum>(client_); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + client_max->EnableDecryption(); + } + + server_->SetOption(SSL_RECORD_SIZE_LIMIT, 64); + Connect(); + SendReceive(127); + + CheckRecordSizes(client_, client_max, 64); +} + +TEST_P(TlsConnectGeneric, RecordSizeAsymmetric) { + EnsureTlsSetup(); + auto client_max = MakeTlsFilter<TlsRecordMaximum>(client_); + auto server_max = MakeTlsFilter<TlsRecordMaximum>(server_); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + client_max->EnableDecryption(); + server_max->EnableDecryption(); + } + + client_->SetOption(SSL_RECORD_SIZE_LIMIT, 64); + server_->SetOption(SSL_RECORD_SIZE_LIMIT, 100); + Connect(); + SendReceive(127); + + CheckRecordSizes(client_, client_max, 100); + CheckRecordSizes(server_, server_max, 64); +} + +// This just modifies the encrypted payload so to include a few extra zeros. +class TlsRecordExpander : public TlsRecordFilter { + public: + TlsRecordExpander(const std::shared_ptr<TlsAgent>& a, size_t expansion) + : TlsRecordFilter(a), expansion_(expansion) {} + + protected: + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) { + if (!header.is_protected()) { + // We're targeting application_data records. If the record is + // |!is_protected()|, we have two possibilities: + if (!decrypting()) { + // 1) We're not decrypting, in which this case this is truly an + // unencrypted record (Keep). + return KEEP; + } + if (header.content_type() != ssl_ct_application_data) { + // 2) We are decrypting, so is_protected() read the internal + // content_type. If the internal ct IS NOT application_data, then + // it's not our target (Keep). + return KEEP; + } + // Otherwise, the the internal ct IS application_data (Change). + } + + changed->Allocate(data.len() + expansion_); + changed->Write(0, data.data(), data.len()); + return CHANGE; + } + + private: + size_t expansion_; +}; + +// Tweak the plaintext of server records so that they exceed the client's limit. +TEST_F(TlsConnectStreamTls13, RecordSizePlaintextExceed) { + EnsureTlsSetup(); + auto server_expand = MakeTlsFilter<TlsRecordExpander>(server_, 1); + server_expand->EnableDecryption(); + + client_->SetOption(SSL_RECORD_SIZE_LIMIT, 64); + Connect(); + + server_->SendData(100); + + client_->ExpectReadWriteError(); + ExpectAlert(client_, kTlsAlertRecordOverflow); + client_->ReadBytes(100); + EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, client_->error_code()); + + // Consume the alert at the server. + server_->Handshake(); + server_->CheckErrorCode(SSL_ERROR_RECORD_OVERFLOW_ALERT); +} + +// Tweak the ciphertext of server records so that they greatly exceed the limit. +// This requires a much larger expansion than for plaintext to trigger the +// guard, which runs before decryption (current allowance is 320 octets, +// see MAX_EXPANSION in ssl3con.c). +TEST_F(TlsConnectStreamTls13, RecordSizeCiphertextExceed) { + EnsureTlsSetup(); + + client_->SetOption(SSL_RECORD_SIZE_LIMIT, 64); + Connect(); + + auto server_expand = MakeTlsFilter<TlsRecordExpander>(server_, 336); + server_->SendData(100); + + client_->ExpectReadWriteError(); + ExpectAlert(client_, kTlsAlertRecordOverflow); + client_->ReadBytes(100); + EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, client_->error_code()); + + // Consume the alert at the server. + server_->Handshake(); + server_->CheckErrorCode(SSL_ERROR_RECORD_OVERFLOW_ALERT); +} + +TEST_F(TlsConnectStreamTls13, ClientHelloF5Padding) { + EnsureTlsSetup(); + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + ScopedPK11SymKey key( + PK11_KeyGen(slot.get(), CKM_NSS_CHACHA20_POLY1305, nullptr, 32, nullptr)); + + auto filter = + MakeTlsFilter<TlsHandshakeRecorder>(client_, kTlsHandshakeClientHello); + + // Add PSK with label long enough to push CH length into [256, 511]. + std::vector<uint8_t> label(100); + EXPECT_EQ(SECSuccess, + SSL_AddExternalPsk(client_->ssl_fd(), key.get(), label.data(), + label.size(), ssl_hash_sha256)); + StartConnect(); + client_->Handshake(); + + // Filter removes the 4B handshake header. + EXPECT_EQ(508UL, filter->buffer().len()); +} + +// This indiscriminately adds padding to application data records. +class TlsRecordPadder : public TlsRecordFilter { + public: + TlsRecordPadder(const std::shared_ptr<TlsAgent>& a, size_t padding) + : TlsRecordFilter(a), padding_(padding) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& record, size_t* offset, + DataBuffer* output) override { + if (!header.is_protected()) { + return KEEP; + } + + uint16_t protection_epoch; + uint8_t inner_content_type; + DataBuffer plaintext; + TlsRecordHeader out_header; + if (!Unprotect(header, record, &protection_epoch, &inner_content_type, + &plaintext, &out_header)) { + return KEEP; + } + + if (decrypting() && inner_content_type != ssl_ct_application_data) { + return KEEP; + } + + DataBuffer ciphertext; + bool ok = Protect(spec(protection_epoch), out_header, inner_content_type, + plaintext, &ciphertext, &out_header, padding_); + EXPECT_TRUE(ok); + if (!ok) { + return KEEP; + } + *offset = out_header.Write(output, *offset, ciphertext); + return CHANGE; + } + + private: + size_t padding_; +}; + +TEST_F(TlsConnectStreamTls13, RecordSizeExceedPad) { + EnsureTlsSetup(); + auto server_max = std::make_shared<TlsRecordMaximum>(server_); + auto server_expand = std::make_shared<TlsRecordPadder>(server_, 1); + server_->SetFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit({server_max, server_expand}))); + server_expand->EnableDecryption(); + + client_->SetOption(SSL_RECORD_SIZE_LIMIT, 64); + Connect(); + + server_->SendData(100); + + client_->ExpectReadWriteError(); + ExpectAlert(client_, kTlsAlertRecordOverflow); + client_->ReadBytes(100); + EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, client_->error_code()); + + // Consume the alert at the server. + server_->Handshake(); + server_->CheckErrorCode(SSL_ERROR_RECORD_OVERFLOW_ALERT); +} + +TEST_P(TlsConnectGeneric, RecordSizeBadValues) { + EnsureTlsSetup(); + EXPECT_EQ(SECFailure, + SSL_OptionSet(client_->ssl_fd(), SSL_RECORD_SIZE_LIMIT, 63)); + EXPECT_EQ(SECFailure, + SSL_OptionSet(client_->ssl_fd(), SSL_RECORD_SIZE_LIMIT, -1)); + EXPECT_EQ(SECFailure, + SSL_OptionSet(server_->ssl_fd(), SSL_RECORD_SIZE_LIMIT, 16386)); + Connect(); +} + +TEST_P(TlsConnectGeneric, RecordSizeGetValues) { + EnsureTlsSetup(); + int v; + EXPECT_EQ(SECSuccess, + SSL_OptionGet(client_->ssl_fd(), SSL_RECORD_SIZE_LIMIT, &v)); + EXPECT_EQ(16385, v); + client_->SetOption(SSL_RECORD_SIZE_LIMIT, 300); + EXPECT_EQ(SECSuccess, + SSL_OptionGet(client_->ssl_fd(), SSL_RECORD_SIZE_LIMIT, &v)); + EXPECT_EQ(300, v); + Connect(); +} + +// The value of the extension is capped by the maximum version of the client. +TEST_P(TlsConnectGeneric, RecordSizeCapExtensionClient) { + EnsureTlsSetup(); + client_->SetOption(SSL_RECORD_SIZE_LIMIT, 16385); + auto capture = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_record_size_limit_xtn); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + capture->EnableDecryption(); + } + Connect(); + + uint64_t val = 0; + EXPECT_TRUE(capture->extension().Read(0, 2, &val)); + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_EQ(16384U, val) << "Extension should be capped"; + } else { + EXPECT_EQ(16385U, val); + } +} + +// The value of the extension is capped by the maximum version of the server. +TEST_P(TlsConnectGeneric, RecordSizeCapExtensionServer) { + EnsureTlsSetup(); + server_->SetOption(SSL_RECORD_SIZE_LIMIT, 16385); + auto capture = + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_record_size_limit_xtn); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + capture->EnableDecryption(); + } + Connect(); + + uint64_t val = 0; + EXPECT_TRUE(capture->extension().Read(0, 2, &val)); + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_EQ(16384U, val) << "Extension should be capped"; + } else { + EXPECT_EQ(16385U, val); + } +} + +// Damage the client extension and the handshake fails, but the server +// doesn't generate a validation error. +TEST_P(TlsConnectGenericPre13, RecordSizeClientExtensionInvalid) { + EnsureTlsSetup(); + client_->SetOption(SSL_RECORD_SIZE_LIMIT, 1000); + static const uint8_t v[] = {0xf4, 0x1f}; + MakeTlsFilter<TlsExtensionReplacer>(client_, ssl_record_size_limit_xtn, + DataBuffer(v, sizeof(v))); + ConnectExpectAlert(server_, kTlsAlertDecryptError); +} + +// Special handling for TLS 1.3, where the alert isn't read. +TEST_F(TlsConnectStreamTls13, RecordSizeClientExtensionInvalid) { + EnsureTlsSetup(); + client_->SetOption(SSL_RECORD_SIZE_LIMIT, 1000); + static const uint8_t v[] = {0xf4, 0x1f}; + MakeTlsFilter<TlsExtensionReplacer>(client_, ssl_record_size_limit_xtn, + DataBuffer(v, sizeof(v))); + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); +} + +TEST_P(TlsConnectGeneric, RecordSizeServerExtensionInvalid) { + EnsureTlsSetup(); + server_->SetOption(SSL_RECORD_SIZE_LIMIT, 1000); + static const uint8_t v[] = {0xf4, 0x1f}; + auto replace = MakeTlsFilter<TlsExtensionReplacer>( + server_, ssl_record_size_limit_xtn, DataBuffer(v, sizeof(v))); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + replace->EnableDecryption(); + } + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); +} + +TEST_P(TlsConnectGeneric, RecordSizeServerExtensionExtra) { + EnsureTlsSetup(); + server_->SetOption(SSL_RECORD_SIZE_LIMIT, 1000); + static const uint8_t v[] = {0x01, 0x00, 0x00}; + auto replace = MakeTlsFilter<TlsExtensionReplacer>( + server_, ssl_record_size_limit_xtn, DataBuffer(v, sizeof(v))); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + replace->EnableDecryption(); + } + ConnectExpectAlert(client_, kTlsAlertDecodeError); +} + +class RecordSizeDefaultsTest : public ::testing::Test { + public: + void SetUp() { + EXPECT_EQ(SECSuccess, + SSL_OptionGetDefault(SSL_RECORD_SIZE_LIMIT, &default_)); + } + void TearDown() { + // Make sure to restore the default value at the end. + EXPECT_EQ(SECSuccess, + SSL_OptionSetDefault(SSL_RECORD_SIZE_LIMIT, default_)); + } + + private: + PRIntn default_ = 0; +}; + +TEST_F(RecordSizeDefaultsTest, RecordSizeBadValues) { + EXPECT_EQ(SECFailure, SSL_OptionSetDefault(SSL_RECORD_SIZE_LIMIT, 63)); + EXPECT_EQ(SECFailure, SSL_OptionSetDefault(SSL_RECORD_SIZE_LIMIT, -1)); + EXPECT_EQ(SECFailure, SSL_OptionSetDefault(SSL_RECORD_SIZE_LIMIT, 16386)); +} + +TEST_F(RecordSizeDefaultsTest, RecordSizeGetValue) { + int v; + EXPECT_EQ(SECSuccess, SSL_OptionGetDefault(SSL_RECORD_SIZE_LIMIT, &v)); + EXPECT_EQ(16385, v); + EXPECT_EQ(SECSuccess, SSL_OptionSetDefault(SSL_RECORD_SIZE_LIMIT, 3000)); + EXPECT_EQ(SECSuccess, SSL_OptionGetDefault(SSL_RECORD_SIZE_LIMIT, &v)); + EXPECT_EQ(3000, v); +} + +class TlsCtextResizer : public TlsRecordFilter { + public: + TlsCtextResizer(const std::shared_ptr<TlsAgent>& a, size_t size) + : TlsRecordFilter(a), size_(size) {} + + protected: + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) { + // allocate and initialise buffer + changed->Allocate(size_); + + // copy record data (partially) + changed->Write(0, data.data(), + ((data.len() >= size_) ? size_ : data.len())); + + return CHANGE; + } + + private: + size_t size_; +}; + +/* (D)TLS overlong record test for maximum default record size of + * 2^14 + (256 (TLS 1.3) OR 2048 (TLS <= 1.2) + * [RFC8446, Section 5.2; RFC5246 , Section 6.2.3]. + * This should fail the first size check in ssl3gthr.c/ssl3_GatherData(). + * DTLS Record errors are dropped silently. [RFC6347, Section 4.1.2.7]. */ +TEST_P(TlsConnectGeneric, RecordGatherOverlong) { + EnsureTlsSetup(); + + size_t max_ctext = MAX_FRAGMENT_LENGTH; + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + max_ctext += TLS_1_3_MAX_EXPANSION; + } else { + max_ctext += TLS_1_2_MAX_EXPANSION; + } + + Connect(); + + MakeTlsFilter<TlsCtextResizer>(server_, max_ctext + 1); + // Dummy record will be overwritten + server_->SendData(0xf0); + + /* Drop DTLS Record Errors silently [RFC6347, Section 4.1.2.7]. */ + if (variant_ == ssl_variant_datagram) { + size_t received = client_->received_bytes(); + client_->ReadBytes(max_ctext + 1); + ASSERT_EQ(received, client_->received_bytes()); + } else { + client_->ExpectSendAlert(kTlsAlertRecordOverflow); + client_->ReadBytes(max_ctext + 1); + server_->ExpectReceiveAlert(kTlsAlertRecordOverflow); + server_->Handshake(); + } +} + +/* (D)TLS overlong record test with recordSizeLimit Extension and plus RFC + * specified maximum Expansion: 2^14 + (256 (TLS 1.3) OR 2048 (TLS <= 1.2) + * [RFC8446, Section 5.2; RFC5246 , Section 6.2.3]. + * DTLS Record errors are dropped silently. [RFC6347, Section 4.1.2.7]. */ +TEST_P(TlsConnectGeneric, RecordSizeExtensionOverlong) { + EnsureTlsSetup(); + + // Set some boundary + size_t max_ctext = 1000; + + client_->SetOption(SSL_RECORD_SIZE_LIMIT, max_ctext); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + // The record size limit includes the inner content type byte + max_ctext += TLS_1_3_MAX_EXPANSION - 1; + } else { + max_ctext += TLS_1_2_MAX_EXPANSION; + } + + Connect(); + + MakeTlsFilter<TlsCtextResizer>(server_, max_ctext + 1); + // Dummy record will be overwritten + server_->SendData(0xf); + + /* Drop DTLS Record Errors silently [RFC6347, Section 4.1.2.7]. + * For DTLS 1.0 and 1.2 the package is dropped before the size check because + * of the modification. This just tests that no error is thrown as required. + */ + if (variant_ == ssl_variant_datagram) { + size_t received = client_->received_bytes(); + client_->ReadBytes(max_ctext + 1); + ASSERT_EQ(received, client_->received_bytes()); + } else { + client_->ExpectSendAlert(kTlsAlertRecordOverflow); + client_->ReadBytes(max_ctext + 1); + server_->ExpectReceiveAlert(kTlsAlertRecordOverflow); + server_->Handshake(); + } +} + +/* For TLS <= 1.2: + * MAX_EXPANSION is the amount by which a record might plausibly be expanded + * when protected. It's the worst case estimate, so the sum of block cipher + * padding (up to 256 octets), HMAC (48 octets for SHA-384), and IV (16 + * octets for AES). */ +#define MAX_EXPANSION (256 + 48 + 16) + +/* (D)TLS overlong record test for specific ciphersuite expansion. + * Testing the smallest illegal record. + * This check is performed in ssl3con.c/ssl3_UnprotectRecord() OR + * tls13con.c/tls13_UnprotectRecord() and enforces stricter size limitations, + * dependent on the implemented cipher suites, than the RFC. + * DTLS Record errors are dropped silently. [RFC6347, Section 4.1.2.7]. */ +TEST_P(TlsConnectGeneric, RecordExpansionOverlong) { + EnsureTlsSetup(); + + // Set some boundary + size_t max_ctext = 1000; + + client_->SetOption(SSL_RECORD_SIZE_LIMIT, max_ctext); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + // For TLS1.3 all ciphers expand the cipherext by 16B + // The inner content type byte is included in the record size limit + max_ctext += 16; + } else { + // For TLS<=1.2 the max possible expansion in the NSS implementation is 320 + max_ctext += MAX_EXPANSION; + } + + Connect(); + + MakeTlsFilter<TlsCtextResizer>(server_, max_ctext + 1); + // Dummy record will be overwritten + server_->SendData(0xf); + + /* Drop DTLS Record Errors silently [RFC6347, Section 4.1.2.7]. + * For DTLS 1.0 and 1.2 the package is dropped before the size check because + * of the modification. This just tests that no error is thrown as required/ + * no bytes are received. */ + if (variant_ == ssl_variant_datagram) { + size_t received = client_->received_bytes(); + client_->ReadBytes(max_ctext + 1); + ASSERT_EQ(received, client_->received_bytes()); + } else { + client_->ExpectSendAlert(kTlsAlertRecordOverflow); + client_->ReadBytes(max_ctext + 1); + server_->ExpectReceiveAlert(kTlsAlertRecordOverflow); + server_->Handshake(); + } +} + +/* (D)TLS longest allowed record default size test. */ +TEST_P(TlsConnectGeneric, RecordSizeDefaultLong) { + EnsureTlsSetup(); + Connect(); + + // Maximum allowed plaintext size + size_t max = MAX_FRAGMENT_LENGTH; + + /* For TLS 1.0 the first byte of application data is sent in a single record + * as explained in the documentation of SSL_CBC_RANDOM_IV in ssl.h. + * Because of that we use TlsCTextResizer to send a record of max size. + * A bad record mac alert is expected since we modify the record. */ + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0 && + variant_ == ssl_variant_stream) { + // Set size to maxi plaintext + max allowed expansion + MakeTlsFilter<TlsCtextResizer>(server_, max + MAX_EXPANSION); + // Dummy record will be overwritten + server_->SendData(0xF); + // Expect alert + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + // Receive record + client_->ReadBytes(max); + // Handle alert on server side + server_->ExpectReceiveAlert(kTlsAlertBadRecordMac); + server_->Handshake(); + } else { // Everything but TLS 1.0 + // Send largest legal plaintext as single record + // by setting SendData() block size to max. + server_->SendData(max, max); + // Receive record + client_->ReadBytes(max); + // Assert that data was received successfully + ASSERT_EQ(client_->received_bytes(), max); + } +} + +/* (D)TLS longest allowed record size limit extension test. */ +TEST_P(TlsConnectGeneric, RecordSizeLimitLong) { + EnsureTlsSetup(); + + // Set some boundary + size_t max = 1000; + client_->SetOption(SSL_RECORD_SIZE_LIMIT, max); + + Connect(); + + // For TLS 1.3 the InnerContentType byte is included in the record size limit + if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { + max--; + } + + /* For TLS 1.0 the first byte of application data is sent in a single record + * as explained in the documentation of SSL_CBC_RANDOM_IV in ssl.h. + * Because of that we use TlsCTextResizer to send a record of max size. + * A bad record mac alert is expected since we modify the record. */ + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0 && + variant_ == ssl_variant_stream) { + // Set size to maxi plaintext + max allowed expansion + MakeTlsFilter<TlsCtextResizer>(server_, max + MAX_EXPANSION); + // Dummy record will be overwritten + server_->SendData(0xF); + // Expect alert + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + // Receive record + client_->ReadBytes(max); + // Handle alert on server side + server_->ExpectReceiveAlert(kTlsAlertBadRecordMac); + server_->Handshake(); + } else { // Everything but TLS 1.0 + // Send largest legal plaintext as single record + // by setting SendData() block size to max. + server_->SendData(max, max); + // Receive record + client_->ReadBytes(max); + // Assert that data was received successfully + ASSERT_EQ(client_->received_bytes(), max); + } +} + +} // namespace nss_test
\ No newline at end of file diff --git a/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc new file mode 100644 index 0000000000..3f7074a096 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc @@ -0,0 +1,235 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "tls_connect.h" + +namespace nss_test { + +// 1.3 is disabled in the next few tests because we don't +// presently support resumption in 1.3. +TEST_P(TlsConnectStreamPre13, RenegotiateClient) { + Connect(); + server_->PrepareForRenegotiate(); + client_->StartRenegotiate(); + Handshake(); + CheckConnected(); +} + +TEST_P(TlsConnectStreamPre13, RenegotiateServer) { + Connect(); + client_->PrepareForRenegotiate(); + server_->StartRenegotiate(); + Handshake(); + CheckConnected(); +} + +TEST_P(TlsConnectStreamPre13, RenegotiateRandoms) { + SSL3Random crand1, crand2, srand1, srand2; + Connect(); + EXPECT_EQ(SECSuccess, + SSLInt_GetHandshakeRandoms(client_->ssl_fd(), crand1, srand1)); + + // Renegotiate and check that both randoms have changed. + client_->PrepareForRenegotiate(); + server_->StartRenegotiate(); + Handshake(); + CheckConnected(); + EXPECT_EQ(SECSuccess, + SSLInt_GetHandshakeRandoms(client_->ssl_fd(), crand2, srand2)); + + EXPECT_NE(0, memcmp(crand1, crand2, sizeof(SSL3Random))); + EXPECT_NE(0, memcmp(srand1, srand2, sizeof(SSL3Random))); +} + +// The renegotiation options shouldn't cause an error if TLS 1.3 is chosen. +TEST_F(TlsConnectTest, RenegotiationConfigTls13) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetOption(SSL_ENABLE_RENEGOTIATION, SSL_RENEGOTIATE_UNRESTRICTED); + server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE); + Connect(); + SendReceive(); + CheckKeys(); +} + +TEST_P(TlsConnectStream, ConnectTls10AndServerRenegotiateHigher) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + GTEST_SKIP(); + } + // Set the client so it will accept any version from 1.0 + // to |version_|. + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + // Reset version so that the checks succeed. + uint16_t test_version = version_; + version_ = SSL_LIBRARY_VERSION_TLS_1_0; + Connect(); + + // Now renegotiate, with the server being set to do + // |version_|. + client_->PrepareForRenegotiate(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version); + // Reset version and cipher suite so that the preinfo callback + // doesn't fail. + server_->ResetPreliminaryInfo(); + server_->StartRenegotiate(); + + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + } else { + ExpectAlert(server_, kTlsAlertProtocolVersion); + } + + Handshake(); + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + // In TLS 1.3, the server detects this problem. + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED); + } else { + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); + } +} + +TEST_P(TlsConnectStream, ConnectTls10AndClientRenegotiateHigher) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + GTEST_SKIP(); + } + // Set the client so it will accept any version from 1.0 + // to |version_|. + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + // Reset version so that the checks succeed. + uint16_t test_version = version_; + version_ = SSL_LIBRARY_VERSION_TLS_1_0; + Connect(); + + // Now renegotiate, with the server being set to do + // |version_|. + server_->PrepareForRenegotiate(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version); + // Reset version and cipher suite so that the preinfo callback + // doesn't fail. + server_->ResetPreliminaryInfo(); + client_->StartRenegotiate(); + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + } else { + ExpectAlert(server_, kTlsAlertProtocolVersion); + } + Handshake(); + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + // In TLS 1.3, the server detects this problem. + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED); + } else { + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); + } +} + +TEST_P(TlsConnectStream, ConnectAndServerRenegotiateLower) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + GTEST_SKIP(); + } + Connect(); + + // Now renegotiate with the server set to TLS 1.0. + client_->PrepareForRenegotiate(); + server_->PrepareForRenegotiate(); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); + // Reset version and cipher suite so that the preinfo callback + // doesn't fail. + server_->ResetPreliminaryInfo(); + + SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_EQ(SECFailure, rv); + return; + } + ASSERT_EQ(SECSuccess, rv); + + // Now, before handshaking, tweak the server configuration. + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + + // The server should catch the own error. + ExpectAlert(server_, kTlsAlertProtocolVersion); + + Handshake(); + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); +} + +TEST_P(TlsConnectStream, ConnectAndServerWontRenegotiateLower) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + GTEST_SKIP(); + } + Connect(); + + // Now renegotiate with the server set to TLS 1.0. + client_->PrepareForRenegotiate(); + server_->PrepareForRenegotiate(); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + // Reset version and cipher suite so that the preinfo callback + // doesn't fail. + server_->ResetPreliminaryInfo(); + + EXPECT_EQ(SECFailure, SSL_ReHandshake(server_->ssl_fd(), PR_TRUE)); +} + +TEST_P(TlsConnectStream, ConnectAndClientWontRenegotiateLower) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + GTEST_SKIP(); + } + Connect(); + + // Now renegotiate with the client set to TLS 1.0. + client_->PrepareForRenegotiate(); + server_->PrepareForRenegotiate(); + server_->ResetPreliminaryInfo(); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + // The client will refuse to renegotiate down. + EXPECT_EQ(SECFailure, SSL_ReHandshake(client_->ssl_fd(), PR_TRUE)); +} + +TEST_F(TlsConnectTest, Tls13RejectsRehandshakeClient) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + SECStatus rv = SSL_ReHandshake(client_->ssl_fd(), PR_TRUE); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError()); +} + +TEST_F(TlsConnectTest, Tls13RejectsRehandshakeServer) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError()); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc new file mode 100644 index 0000000000..2e23fc096a --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc @@ -0,0 +1,1522 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslexp.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "scoped_ptrs_ssl.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" +#include "tls_protect.h" + +namespace nss_test { + +class TlsServerKeyExchangeEcdhe { + public: + bool Parse(const DataBuffer& buffer) { + TlsParser parser(buffer); + + uint8_t curve_type; + if (!parser.Read(&curve_type)) { + return false; + } + + if (curve_type != 3) { // named_curve + return false; + } + + uint32_t named_curve; + if (!parser.Read(&named_curve, 2)) { + return false; + } + + return parser.ReadVariable(&public_key_, 1); + } + + DataBuffer public_key_; +}; + +TEST_P(TlsConnectGenericPre13, ConnectResumed) { + ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); + Connect(); + + Reset(); + ExpectResumption(RESUME_SESSIONID); + Connect(); +} + +TEST_P(TlsConnectGenericResumption, ConnectClientCacheDisabled) { + ConfigureSessionCache(RESUME_NONE, RESUME_SESSIONID); + Connect(); + SendReceive(); + + Reset(); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGenericResumption, ConnectServerCacheDisabled) { + ConfigureSessionCache(RESUME_SESSIONID, RESUME_NONE); + Connect(); + SendReceive(); + + Reset(); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGenericResumption, ConnectSessionCacheDisabled) { + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + Connect(); + SendReceive(); + + Reset(); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGenericResumption, ConnectResumeSupportBoth) { + // This prefers tickets. + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGenericResumption, ConnectResumeClientTicketServerBoth) { + // This causes no resumption because the client needs the + // session cache to resume even with tickets. + ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGenericResumption, ConnectResumeClientBothTicketServerTicket) { + // This causes a ticket resumption. + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGenericResumption, ConnectResumeClientServerTicketOnly) { + // This causes no resumption because the client needs the + // session cache to resume even with tickets. + ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGenericResumption, ConnectResumeClientBothServerNone) { + ConfigureSessionCache(RESUME_BOTH, RESUME_NONE); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_NONE); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGenericResumption, ConnectResumeClientNoneServerBoth) { + ConfigureSessionCache(RESUME_NONE, RESUME_BOTH); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_NONE, RESUME_BOTH); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGenericPre13, ResumeWithHigherVersionTls13) { + uint16_t lower_version = version_; + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + SendReceive(); + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + EnsureTlsSetup(); + auto psk_ext = std::make_shared<TlsExtensionCapture>( + client_, ssl_tls13_pre_shared_key_xtn); + auto ticket_ext = + std::make_shared<TlsExtensionCapture>(client_, ssl_session_ticket_xtn); + client_->SetFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit({psk_ext, ticket_ext}))); + SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(lower_version, SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(lower_version, SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_NONE); + Connect(); + + // The client shouldn't have sent a PSK, though it will send a ticket. + EXPECT_FALSE(psk_ext->captured()); + EXPECT_TRUE(ticket_ext->captured()); +} + +class CaptureSessionId : public TlsHandshakeFilter { + public: + CaptureSessionId(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter( + a, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}), + sid_() {} + + const DataBuffer& sid() const { return sid_; } + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + // The session_id is in the same place in both Hello messages: + size_t offset = 2 + 32; // Version(2) + Random(32) + uint32_t len = 0; + EXPECT_TRUE(input.Read(offset, 1, &len)); + offset++; + if (input.len() < offset + len) { + ADD_FAILURE() << "session_id overflows the Hello message"; + return KEEP; + } + sid_.Assign(input.data() + offset, len); + return KEEP; + } + + private: + DataBuffer sid_; +}; + +// Attempting to resume from TLS 1.2 when 1.3 is possible should not result in +// resumption, though it will appear to be TLS 1.3 compatibility mode if the +// server uses a session ID. +TEST_P(TlsConnectGenericPre13, ResumeWithHigherVersionTls13SessionId) { + uint16_t lower_version = version_; + ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); + auto original_sid = MakeTlsFilter<CaptureSessionId>(server_); + Connect(); + CheckKeys(); + EXPECT_EQ(32U, original_sid->sid().len()); + + // The client should now attempt to resume with the session ID from the last + // connection. This looks like compatibility mode, we just want to ensure + // that we get TLS 1.3 rather than 1.2 (and no resumption). + Reset(); + auto client_sid = MakeTlsFilter<CaptureSessionId>(client_); + auto server_sid = MakeTlsFilter<CaptureSessionId>(server_); + ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); + SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(lower_version, SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(lower_version, SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_NONE); + + Connect(); + SendReceive(); + + EXPECT_EQ(client_sid->sid(), original_sid->sid()); + if (variant_ == ssl_variant_stream) { + EXPECT_EQ(client_sid->sid(), server_sid->sid()); + } else { + // DTLS servers don't echo the session ID. + EXPECT_EQ(0U, server_sid->sid().len()); + } +} + +TEST_P(TlsConnectPre12, ResumeWithHigherVersionTls12) { + uint16_t lower_version = version_; + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + EnsureTlsSetup(); + SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(lower_version, SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(lower_version, SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_NONE); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, ResumeWithLowerVersionFromTls13) { + uint16_t original_version = version_; + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + SendReceive(); + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ConfigureVersion(original_version); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectPre12, ResumeWithLowerVersionFromTls12) { + uint16_t original_version = version_; + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_2); + Connect(); + SendReceive(); + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ConfigureVersion(original_version); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicketForget) { + // This causes a ticket resumption. + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ClearServerCache(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +// Tickets last two days maximum; this is a time longer than that. +static const PRTime kLongerThanTicketLifetime = + 3LL * 24 * 60 * 60 * PR_USEC_PER_SEC; + +TEST_P(TlsConnectGenericResumption, ConnectWithExpiredTicketAtClient) { + // This causes a ticket resumption. + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + AdvanceTime(kLongerThanTicketLifetime); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_NONE); + + // TLS 1.3 uses the pre-shared key extension instead. + SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) + ? ssl_tls13_pre_shared_key_xtn + : ssl_session_ticket_xtn; + auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, xtn); + Connect(); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_FALSE(capture->captured()); + } else { + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(0U, capture->extension().len()); + } +} + +TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtServer) { + // This causes a ticket resumption. + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_NONE); + + SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) + ? ssl_tls13_pre_shared_key_xtn + : ssl_session_ticket_xtn; + auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, xtn); + StartConnect(); + client_->Handshake(); + EXPECT_TRUE(capture->captured()); + EXPECT_LT(0U, capture->extension().len()); + + AdvanceTime(kLongerThanTicketLifetime); + + Handshake(); + CheckConnected(); +} + +TEST_P(TlsConnectGeneric, ConnectResumeCorruptTicket) { + // This causes a ticket resumption. + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + static const uint8_t kHmacKey1Buf[32] = {0}; + static const DataBuffer kHmacKey1(kHmacKey1Buf, sizeof(kHmacKey1Buf)); + + SECItem key_item = {siBuffer, const_cast<uint8_t*>(kHmacKey1Buf), + sizeof(kHmacKey1Buf)}; + + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + PK11SymKey* hmac_key = + PK11_ImportSymKey(slot.get(), CKM_SHA256_HMAC, PK11_OriginUnwrap, + CKA_SIGN, &key_item, nullptr); + ASSERT_NE(nullptr, hmac_key); + SSLInt_SetSelfEncryptMacKey(hmac_key); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + ExpectResumption(RESUME_NONE); + Connect(); + } else { + ConnectExpectAlert(server_, illegal_parameter); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); + } +} + +// This callback switches out the "server" cert used on the server with +// the "client" certificate, which should be the same type. +static int32_t SwitchCertificates(TlsAgent* agent, const SECItem* srvNameArr, + uint32_t srvNameArrSize) { + bool ok = agent->ConfigServerCert("client"); + if (!ok) return SSL_SNI_SEND_ALERT; + + return 0; // first config +}; + +TEST_P(TlsConnectGeneric, ServerSNICertSwitch) { + Connect(); + ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); + + Reset(); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + + server_->SetSniCallback(SwitchCertificates); + + Connect(); + ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); + CheckKeys(); + EXPECT_FALSE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); +} + +TEST_P(TlsConnectGeneric, ServerSNICertTypeSwitch) { + Reset(TlsAgent::kServerEcdsa256); + Connect(); + ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); + + Reset(); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + + // Because we configure an RSA certificate here, it only adds a second, unused + // certificate, which has no effect on what the server uses. + server_->SetSniCallback(SwitchCertificates); + + Connect(); + ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); + CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); +} + +TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) { + auto filter = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeServerKeyExchange); + EnableECDHEServerKeyReuse(); + Connect(); + CheckKeys(); + TlsServerKeyExchangeEcdhe dhe1; + EXPECT_TRUE(dhe1.Parse(filter->buffer())); + + // Restart + Reset(); + EnableECDHEServerKeyReuse(); + auto filter2 = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeServerKeyExchange); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + Connect(); + CheckKeys(); + + TlsServerKeyExchangeEcdhe dhe2; + EXPECT_TRUE(dhe2.Parse(filter2->buffer())); + + // Make sure they are the same. + EXPECT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len()); + EXPECT_TRUE(!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(), + dhe1.public_key_.len())); +} + +// This test parses the ServerKeyExchange, which isn't in 1.3 +TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) { + auto filter = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeServerKeyExchange); + Connect(); + CheckKeys(); + TlsServerKeyExchangeEcdhe dhe1; + EXPECT_TRUE(dhe1.Parse(filter->buffer())); + + // Restart + Reset(); + auto filter2 = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeServerKeyExchange); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + Connect(); + CheckKeys(); + + TlsServerKeyExchangeEcdhe dhe2; + EXPECT_TRUE(dhe2.Parse(filter2->buffer())); + + // Make sure they are different. + EXPECT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) && + (!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(), + dhe1.public_key_.len()))); +} + +// Verify that TLS 1.3 reports an accurate group on resumption. +TEST_P(TlsConnectTls13, TestTls13ResumeDifferentGroup) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + client_->ConfigNamedGroups(kFFDHEGroups); + server_->ConfigNamedGroups(kFFDHEGroups); + Connect(); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); +} + +// Verify that TLS 1.3 server doesn't request certificate in the main +// handshake, after resumption. +TEST_P(TlsConnectTls13, TestTls13ResumeNoCertificateRequest) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + ScopedCERTCertificate cert1(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + server_->RequestClientAuth(false); + auto cr_capture = + MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_certificate_request); + cr_capture->EnableDecryption(); + Connect(); + SendReceive(); + EXPECT_EQ(0U, cr_capture->buffer().len()) << "expect nothing captured yet"; + + // Sanity check whether the client certificate matches the one + // decrypted from ticket. + ScopedCERTCertificate cert2(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); +} + +// Here we test that 0.5 RTT is available at the server when resuming, even if +// configured to request a client certificate. The resumed handshake relies on +// the authentication from the original handshake, so no certificate is +// requested this time around. The server can write before the handshake +// completes because the PSK binder is sufficient authentication for the client. +TEST_P(TlsConnectTls13, WriteBeforeHandshakeCompleteOnResumption) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + SendReceive(); // Absorb the session ticket. + ScopedCERTCertificate cert1(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + server_->RequestClientAuth(false); + StartConnect(); + client_->Handshake(); // ClientHello + server_->Handshake(); // ServerHello + + server_->SendData(10); + client_->ReadBytes(10); // Client should emit the Finished as a side-effect. + server_->Handshake(); // Server consumes the Finished. + CheckConnected(); + + // Check whether the client certificate matches the one from the ticket. + ScopedCERTCertificate cert2(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); +} + +// We need to enable different cipher suites at different times in the following +// tests. Those cipher suites need to be suited to the version. +static uint16_t ChooseOneCipher(uint16_t version) { + if (version >= SSL_LIBRARY_VERSION_TLS_1_3) { + return TLS_AES_128_GCM_SHA256; + } + return TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA; +} + +static uint16_t ChooseIncompatibleCipher(uint16_t version) { + if (version >= SSL_LIBRARY_VERSION_TLS_1_3) { + return TLS_AES_256_GCM_SHA384; + } + return TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA; +} + +// Test that we don't resume when we can't negotiate the same cipher. Note that +// for TLS 1.3, resumption is allowed between compatible ciphers, that is those +// with the same KDF hash, but we choose an incompatible one here. +TEST_P(TlsConnectGenericResumption, ResumeClientIncompatibleCipher) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + client_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_NONE); + client_->EnableSingleCipher(ChooseIncompatibleCipher(version_)); + uint16_t ticket_extension; + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + ticket_extension = ssl_tls13_pre_shared_key_xtn; + } else { + ticket_extension = ssl_session_ticket_xtn; + } + auto ticket_capture = + MakeTlsFilter<TlsExtensionCapture>(client_, ticket_extension); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); + EXPECT_EQ(0U, ticket_capture->extension().len()); +} + +// Test that we don't resume when we can't negotiate the same cipher. +TEST_P(TlsConnectGenericResumption, ResumeServerIncompatibleCipher) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + server_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); // Absorb the session ticket. + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_NONE); + server_->EnableSingleCipher(ChooseIncompatibleCipher(version_)); + Connect(); + CheckKeys(); +} + +// Test that the client doesn't tolerate the server picking a different cipher +// suite for resumption. +TEST_P(TlsConnectStream, ResumptionOverrideCipher) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + server_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + MakeTlsFilter<SelectedCipherSuiteReplacer>( + server_, ChooseIncompatibleCipher(version_)); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + } else { + ExpectAlert(client_, kTlsAlertHandshakeFailure); + } + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + // The reason this test is stream only: the server is unable to decrypt + // the alert that the client sends, see bug 1304603. + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE); + } else { + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT); + } +} + +// In TLS 1.3, it is possible to resume with a different cipher if it has the +// same hash. +TEST_P(TlsConnectTls13, ResumeClientCompatibleCipher) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + Connect(); + SendReceive(); // Absorb the session ticket. + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + client_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256); + Connect(); + CheckKeys(); +} + +TEST_P(TlsConnectTls13, ResumeServerCompatibleCipher) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + Connect(); + SendReceive(); // Absorb the session ticket. + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + server_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256); + Connect(); + CheckKeys(); +} + +class SelectedVersionReplacer : public TlsHandshakeFilter { + public: + SelectedVersionReplacer(const std::shared_ptr<TlsAgent>& a, uint16_t version) + : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}), version_(version) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + *output = input; + output->Write(0, static_cast<uint32_t>(version_), 2); + return CHANGE; + } + + private: + uint16_t version_; +}; + +// Test how the client handles the case where the server picks a +// lower version number on resumption. +TEST_P(TlsConnectGenericPre13, TestResumptionOverrideVersion) { + uint16_t override_version = 0; + if (variant_ == ssl_variant_stream) { + switch (version_) { + case SSL_LIBRARY_VERSION_TLS_1_0: + GTEST_SKIP(); + case SSL_LIBRARY_VERSION_TLS_1_1: + override_version = SSL_LIBRARY_VERSION_TLS_1_0; + break; + case SSL_LIBRARY_VERSION_TLS_1_2: + override_version = SSL_LIBRARY_VERSION_TLS_1_1; + break; + default: + ASSERT_TRUE(false) << "unknown version"; + } + } else { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_2) { + override_version = SSL_LIBRARY_VERSION_DTLS_1_0_WIRE; + } else { + ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_1, version_); + GTEST_SKIP(); + } + } + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + // Need to use a cipher that is plausible for the lower version. + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + // Enable the lower version on the client. + client_->SetVersionRange(version_ - 1, version_); + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + MakeTlsFilter<SelectedVersionReplacer>(server_, override_version); + + ConnectExpectAlert(client_, kTlsAlertHandshakeFailure); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT); +} + +// Test that two TLS resumptions work and produce the same ticket. +// This will change after bug 1257047 is fixed. +TEST_F(TlsConnectTest, TestTls13ResumptionTwice) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + uint16_t original_suite; + EXPECT_TRUE(client_->cipher_suite(&original_suite)); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_TICKET); + auto c1 = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + // The filter will go away when we reset, so save the captured extension. + DataBuffer initialTicket(c1->extension()); + ASSERT_LT(0U, initialTicket.len()); + + ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); + + Reset(); + ClearStats(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + auto c2 = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + ASSERT_LT(0U, c2->extension().len()); + + ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); + + // Check that the cipher suite is reported the same on both sides, though in + // TLS 1.3 resumption actually negotiates a different cipher suite. + uint16_t resumed_suite; + EXPECT_TRUE(server_->cipher_suite(&resumed_suite)); + EXPECT_EQ(original_suite, resumed_suite); + EXPECT_TRUE(client_->cipher_suite(&resumed_suite)); + EXPECT_EQ(original_suite, resumed_suite); + + ASSERT_NE(initialTicket, c2->extension()); +} + +// Check that resumption works after receiving two NST messages. +TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNST) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + + // Clear the session ticket keys to invalidate the old ticket. + ClearServerCache(); + EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + + // Resume the connection. + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); +} + +// Check that the value captured in a NewSessionTicket message matches the value +// captured from a pre_shared_key extension. +void NstTicketMatchesPskIdentity(const DataBuffer& nst, const DataBuffer& psk) { + uint32_t len; + + size_t offset = 4 + 4; // Skip ticket_lifetime and ticket_age_add. + ASSERT_TRUE(nst.Read(offset, 1, &len)); + offset += 1 + len; // Skip ticket_nonce. + + ASSERT_TRUE(nst.Read(offset, 2, &len)); + offset += 2; // Skip the ticket length. + ASSERT_LE(offset + len, nst.len()); + DataBuffer nst_ticket(nst.data() + offset, static_cast<size_t>(len)); + + offset = 2; // Skip the identities length. + ASSERT_TRUE(psk.Read(offset, 2, &len)); + offset += 2; // Skip the identity length. + ASSERT_LE(offset + len, psk.len()); + DataBuffer psk_ticket(psk.data() + offset, static_cast<size_t>(len)); + + EXPECT_EQ(nst_ticket, psk_ticket); +} + +TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNSTWithToken) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + auto nst_capture = + MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_new_session_ticket); + nst_capture->EnableDecryption(); + Connect(); + + // Clear the session ticket keys to invalidate the old ticket. + ClearServerCache(); + nst_capture->Reset(); + uint8_t token[] = {0x20, 0x20, 0xff, 0x00}; + EXPECT_EQ(SECSuccess, + SSL_SendSessionTicket(server_->ssl_fd(), token, sizeof(token))); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + EXPECT_LT(0U, nst_capture->buffer().len()); + + // Resume the connection. + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_TICKET); + + auto psk_capture = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn); + Connect(); + SendReceive(); + + NstTicketMatchesPskIdentity(nst_capture->buffer(), psk_capture->extension()); +} + +// Disable SSL_ENABLE_SESSION_TICKETS but ensure that tickets can still be sent +// by invoking SSL_SendSessionTicket directly (and that the ticket is usable). +TEST_F(TlsConnectTest, SendSessionTicketWithTicketsDisabled) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + server_->SetOption(SSL_ENABLE_SESSION_TICKETS, PR_FALSE); + + auto nst_capture = + MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_new_session_ticket); + nst_capture->EnableDecryption(); + Connect(); + + EXPECT_EQ(0U, nst_capture->buffer().len()) << "expect nothing captured yet"; + + EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)); + EXPECT_LT(0U, nst_capture->buffer().len()) << "should capture now"; + + SendReceive(); // Ensure that the client reads the ticket. + + // Resume the connection. + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_TICKET); + + auto psk_capture = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn); + Connect(); + SendReceive(); + + NstTicketMatchesPskIdentity(nst_capture->buffer(), psk_capture->extension()); +} + +// Successfully send a session ticket after resuming and then use it. +TEST_F(TlsConnectTest, SendTicketAfterResumption) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + + // Resume the connection. + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_TICKET); + + // We need to capture just one ticket, so + // disable automatic sending of tickets at the server. + // ConfigureSessionCache enables this option, so revert that. + server_->SetOption(SSL_ENABLE_SESSION_TICKETS, PR_FALSE); + auto nst_capture = + MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_new_session_ticket); + nst_capture->EnableDecryption(); + Connect(); + + ClearServerCache(); + EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)); + SendReceive(); + + // Reset stats so that the counters for resumptions match up. + ClearStats(); + // Resume again and ensure that we get the same ticket. + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_TICKET); + + auto psk_capture = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn); + Connect(); + SendReceive(); + + NstTicketMatchesPskIdentity(nst_capture->buffer(), psk_capture->extension()); +} + +// Test calling SSL_SendSessionTicket in inappropriate conditions. +TEST_F(TlsConnectTest, SendSessionTicketInappropriate) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_2); + + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(client_->ssl_fd(), NULL, 0)) + << "clients can't send tickets"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + StartConnect(); + + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)) + << "no ticket before the handshake has started"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + Handshake(); + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)) + << "no special tickets in TLS 1.2"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectTest, SendSessionTicketMassiveToken) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + // It should be safe to set length with a NULL token because the length should + // be checked before reading token. + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0x1ffff)) + << "this is clearly too big"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + static const uint8_t big_token[0xffff] = {1}; + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), big_token, + sizeof(big_token))) + << "this is too big, but that's not immediately obvious"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectDatagram13, SendSessionTicketDtls) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)) + << "no extra tickets in DTLS until we have Ack support"; + EXPECT_EQ(SSL_ERROR_FEATURE_NOT_SUPPORTED_FOR_VERSION, PORT_GetError()); +} + +TEST_F(TlsConnectStreamTls13, ExternalResumptionUseSecondTicket) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + struct ResumptionTicketState { + std::vector<uint8_t> ticket; + size_t invoked = 0; + } ticket_state; + auto cb = [](PRFileDesc* fd, const PRUint8* ticket, unsigned int ticket_len, + void* arg) -> SECStatus { + auto state = reinterpret_cast<ResumptionTicketState*>(arg); + state->ticket.assign(ticket, ticket + ticket_len); + state->invoked++; + return SECSuccess; + }; + EXPECT_EQ(SECSuccess, SSL_SetResumptionTokenCallback(client_->ssl_fd(), cb, + &ticket_state)); + + Connect(); + EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), nullptr, 0)); + SendReceive(); + EXPECT_EQ(2U, ticket_state.invoked); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + client_->SetResumptionToken(ticket_state.ticket); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); +} + +TEST_F(TlsConnectTest, TestTls13ResumptionDowngrade) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + + // Try resuming the connection. This will fail resuming the 1.3 session + // from before, but will successfully establish a 1.2 connection. + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + Connect(); + + // Renegotiate to ensure we don't carryover any state + // from the 1.3 resumption attempt. + client_->SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_2); + client_->PrepareForRenegotiate(); + server_->StartRenegotiate(); + Handshake(); + + SendReceive(); + CheckKeys(); +} + +TEST_F(TlsConnectTest, TestTls13ResumptionForcedDowngrade) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + + // Try resuming the connection. + Reset(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + // Enable the lower version on the client. + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + + // Add filters that set downgrade SH.version to 1.2 and the cipher suite + // to one that works with 1.2, so that we don't run into early sanity checks. + // We will eventually fail the (sid.version == SH.version) check. + std::vector<std::shared_ptr<PacketFilter>> filters; + filters.push_back(std::make_shared<SelectedCipherSuiteReplacer>( + server_, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256)); + filters.push_back(std::make_shared<SelectedVersionReplacer>( + server_, SSL_LIBRARY_VERSION_TLS_1_2)); + + // Drop a bunch of extensions so that we get past the SH processing. The + // version extension says TLS 1.3, which is counter to our goal, the others + // are not permitted in TLS 1.2 handshakes. + filters.push_back(std::make_shared<TlsExtensionDropper>( + server_, ssl_tls13_supported_versions_xtn)); + filters.push_back( + std::make_shared<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn)); + filters.push_back(std::make_shared<TlsExtensionDropper>( + server_, ssl_tls13_pre_shared_key_xtn)); + server_->SetFilter(std::make_shared<ChainedPacketFilter>(filters)); + + // The client here generates an unexpected_message alert when it receives an + // encrypted handshake message from the server (EncryptedExtension). The + // client expects to receive an unencrypted TLS 1.2 Certificate message. + // The server can't decrypt the alert. + client_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); // Server can't read + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE); +} + +TEST_P(TlsConnectGenericResumption, ReConnectTicket) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + server_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + // Resume + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + Connect(); + // Only the client knows this. + CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, + ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); +} + +TEST_P(TlsConnectGenericPre13, ReConnectCache) { + ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); + server_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + // Resume + Reset(); + ExpectResumption(RESUME_SESSIONID); + Connect(); + CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, + ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); +} + +TEST_P(TlsConnectGenericResumption, ReConnectAgainTicket) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + server_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + // Resume + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + Connect(); + // Only the client knows this. + CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, + ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); + // Resume connection again + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET, 2); + Connect(); + // Only the client knows this. + CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, + ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); +} + +void CheckGetInfoResult(PRTime now, uint32_t alpnSize, uint32_t earlyDataSize, + ScopedCERTCertificate& cert, + ScopedSSLResumptionTokenInfo& token) { + ASSERT_TRUE(cert); + ASSERT_TRUE(token->peerCert); + + // Check that the server cert is the correct one. + ASSERT_EQ(cert->derCert.len, token->peerCert->derCert.len); + EXPECT_EQ(0, memcmp(cert->derCert.data, token->peerCert->derCert.data, + cert->derCert.len)); + + ASSERT_EQ(alpnSize, token->alpnSelectionLen); + EXPECT_EQ(0, memcmp("a", token->alpnSelection, token->alpnSelectionLen)); + + ASSERT_EQ(earlyDataSize, token->maxEarlyDataSize); + + ASSERT_LT(now, token->expirationTime); +} + +// The client should generate a new, randomized session_id +// when resuming using an external token. +TEST_P(TlsConnectGenericResumptionToken, CheckSessionId) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + auto original_sid = MakeTlsFilter<CaptureSessionId>(client_); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + + StartConnect(); + ASSERT_TRUE(client_->MaybeSetResumptionToken()); + auto resumed_sid = MakeTlsFilter<CaptureSessionId>(client_); + + Handshake(); + CheckConnected(); + SendReceive(); + + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_NE(resumed_sid->sid(), original_sid->sid()); + EXPECT_EQ(32U, resumed_sid->sid().len()); + } else { + EXPECT_EQ(0U, resumed_sid->sid().len()); + } +} + +TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfo) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + + StartConnect(); + ASSERT_TRUE(client_->MaybeSetResumptionToken()); + + // Get resumption token infos + SSLResumptionTokenInfo tokenInfo = {0}; + ScopedSSLResumptionTokenInfo token(&tokenInfo); + client_->GetTokenInfo(token); + ScopedCERTCertificate cert( + PK11_FindCertFromNickname(server_->name().c_str(), nullptr)); + ASSERT_NE(nullptr, cert.get()); + + CheckGetInfoResult(now(), 0, 0, cert, token); + + Handshake(); + CheckConnected(); + + SendReceive(); +} + +TEST_P(TlsConnectGenericResumptionToken, RefuseExpiredTicketClient) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + SendReceive(); + + // Move the clock to the expiration time of the ticket. + SSLResumptionTokenInfo tokenInfo = {0}; + ScopedSSLResumptionTokenInfo token(&tokenInfo); + client_->GetTokenInfo(token); + AdvanceTime(token->expirationTime - now()); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + + StartConnect(); + ASSERT_EQ(SECFailure, + SSL_SetResumptionToken(client_->ssl_fd(), + client_->GetResumptionToken().data(), + client_->GetResumptionToken().size())); + EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError()); +} + +TEST_P(TlsConnectGenericResumptionToken, RefuseExpiredTicketServer) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_NONE); + + // Start the handshake and send the ClientHello. + StartConnect(); + ASSERT_EQ(SECSuccess, + SSL_SetResumptionToken(client_->ssl_fd(), + client_->GetResumptionToken().data(), + client_->GetResumptionToken().size())); + client_->Handshake(); + + // Move the clock to the expiration time of the ticket. + SSLResumptionTokenInfo tokenInfo = {0}; + ScopedSSLResumptionTokenInfo token(&tokenInfo); + client_->GetTokenInfo(token); + AdvanceTime(token->expirationTime - now()); + + Handshake(); + CheckConnected(); +} + +TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfoAlpn) { + EnableAlpn(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + CheckAlpn("a"); + SendReceive(); + + Reset(); + EnableAlpn(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + + StartConnect(); + ASSERT_TRUE(client_->MaybeSetResumptionToken()); + + // Get resumption token infos + SSLResumptionTokenInfo tokenInfo = {0}; + ScopedSSLResumptionTokenInfo token(&tokenInfo); + client_->GetTokenInfo(token); + ScopedCERTCertificate cert( + PK11_FindCertFromNickname(server_->name().c_str(), nullptr)); + ASSERT_NE(nullptr, cert.get()); + + CheckGetInfoResult(now(), 1, 0, cert, token); + + Handshake(); + CheckConnected(); + CheckAlpn("a"); + + SendReceive(); +} + +TEST_P(TlsConnectTls13ResumptionToken, ConnectResumeGetInfoZeroRtt) { + EnableAlpn(); + RolloverAntiReplay(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + server_->Set0RttEnabled(true); + Connect(); + CheckAlpn("a"); + SendReceive(); + + Reset(); + EnableAlpn(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + + StartConnect(); + server_->Set0RttEnabled(true); + client_->Set0RttEnabled(true); + ASSERT_TRUE(client_->MaybeSetResumptionToken()); + + // Get resumption token infos + SSLResumptionTokenInfo tokenInfo = {0}; + ScopedSSLResumptionTokenInfo token(&tokenInfo); + client_->GetTokenInfo(token); + ScopedCERTCertificate cert( + PK11_FindCertFromNickname(server_->name().c_str(), nullptr)); + ASSERT_NE(nullptr, cert.get()); + CheckGetInfoResult(now(), 1, 1024, cert, token); + + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + CheckAlpn("a"); + + SendReceive(); +} + +// Resumption on sessions with client authentication only works with internal +// caching. +TEST_P(TlsConnectGenericResumption, ConnectResumeClientAuth) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + SendReceive(); + EXPECT_FALSE(client_->resumption_callback_called()); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + if (use_external_cache()) { + ExpectResumption(RESUME_NONE); + } else { + ExpectResumption(RESUME_TICKET); + } + Connect(); + SendReceive(); +} + +// Check that resumption is blocked if the server requires client auth. +TEST_P(TlsConnectGenericResumption, ClientAuthRequiredOnResumption) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + server_->RequestClientAuth(false); + Connect(); + SendReceive(); + + Reset(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +// Check that resumption is blocked if the server requires client auth and +// the client fails to provide a certificate. +TEST_P(TlsConnectGenericResumption, ClientAuthRequiredOnResumptionNoCert) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + server_->RequestClientAuth(false); + Connect(); + SendReceive(); + + Reset(); + server_->RequestClientAuth(true); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + // Drive handshake manually because TLS 1.3 needs it. + StartConnect(); + client_->Handshake(); // CH + server_->Handshake(); // SH.. (no resumption) + client_->Handshake(); // ... + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + // In TLS 1.3, the client thinks that everything is OK here. + ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + ExpectAlert(server_, kTlsAlertCertificateRequired); + server_->Handshake(); // Alert + client_->Handshake(); // Receive Alert + client_->CheckErrorCode(SSL_ERROR_RX_CERTIFICATE_REQUIRED_ALERT); + } else { + ExpectAlert(server_, kTlsAlertBadCertificate); + server_->Handshake(); // Alert + client_->Handshake(); // Receive Alert + client_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT); + } + server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE); +} + +TEST_F(TlsConnectStreamTls13, ExternalTokenAfterHrr) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1}; + server_->ConfigNamedGroups(groups); + + StartConnect(); + ASSERT_TRUE(client_->MaybeSetResumptionToken()); + + client_->Handshake(); // Send ClientHello. + server_->Handshake(); // Process ClientHello, send HelloRetryRequest. + + auto& token = client_->GetResumptionToken(); + SECStatus rv = + SSL_SetResumptionToken(client_->ssl_fd(), token.data(), token.size()); + ASSERT_EQ(SECFailure, rv); + ASSERT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + Handshake(); + CheckConnected(); + SendReceive(); +} + +TEST_F(TlsConnectStreamTls13, ExternalTokenWithPeerId) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + EXPECT_EQ(SECSuccess, SSL_SetSockPeerID(client_->ssl_fd(), "testPeerId")); + std::vector<uint8_t> ticket_state; + auto cb = [](PRFileDesc* fd, const PRUint8* ticket, unsigned int ticket_len, + void* arg) -> SECStatus { + EXPECT_NE(0U, ticket_len); + EXPECT_NE(nullptr, ticket); + auto ticket_state_ = reinterpret_cast<std::vector<uint8_t>*>(arg); + ticket_state_->assign(ticket, ticket + ticket_len); + return SECSuccess; + }; + EXPECT_EQ(SECSuccess, SSL_SetResumptionTokenCallback(client_->ssl_fd(), cb, + &ticket_state)); + + Connect(); + SendReceive(); + EXPECT_NE(0U, ticket_state.size()); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + EXPECT_EQ(SECSuccess, SSL_SetSockPeerID(client_->ssl_fd(), "testPeerId")); + client_->SetResumptionToken(ticket_state); + ASSERT_TRUE(client_->MaybeSetResumptionToken()); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc new file mode 100644 index 0000000000..606e731033 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc @@ -0,0 +1,246 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "sslerr.h" + +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +/* + * The tests in this file test that the TLS state machine is robust against + * attacks that alter the order of handshake messages. + * + * See <https://www.smacktls.com/smack.pdf> for a description of the problems + * that this sort of attack can enable. + */ +namespace nss_test { + +class TlsHandshakeSkipFilter : public TlsRecordFilter { + public: + // A TLS record filter that skips handshake messages of the identified type. + TlsHandshakeSkipFilter(const std::shared_ptr<TlsAgent>& a, + uint8_t handshake_type) + : TlsRecordFilter(a), handshake_type_(handshake_type), skipped_(false) {} + + protected: + // Takes a record; if it is a handshake record, it removes the first handshake + // message that is of handshake_type_ type. + virtual PacketFilter::Action FilterRecord( + const TlsRecordHeader& record_header, const DataBuffer& input, + DataBuffer* output) { + if (record_header.content_type() != ssl_ct_handshake) { + return KEEP; + } + + size_t output_offset = 0U; + output->Allocate(input.len()); + + TlsParser parser(input); + while (parser.remaining()) { + size_t start = parser.consumed(); + TlsHandshakeFilter::HandshakeHeader header; + DataBuffer ignored; + bool complete = false; + if (!header.Parse(&parser, record_header, DataBuffer(), &ignored, + &complete)) { + ADD_FAILURE() << "Error parsing handshake header"; + return KEEP; + } + if (!complete) { + ADD_FAILURE() << "Don't want to deal with fragmented input"; + return KEEP; + } + + if (skipped_ || header.handshake_type() != handshake_type_) { + size_t entire_length = parser.consumed() - start; + output->Write(output_offset, input.data() + start, entire_length); + // DTLS sequence numbers need to be rewritten + if (skipped_ && header.is_dtls()) { + output->data()[start + 5] -= 1; + } + output_offset += entire_length; + } else { + std::cerr << "Dropping handshake: " + << static_cast<unsigned>(handshake_type_) << std::endl; + // We only need to report that the output contains changed data if we + // drop a handshake message. But once we've skipped one message, we + // have to modify all subsequent handshake messages so that they include + // the correct DTLS sequence numbers. + skipped_ = true; + } + } + output->Truncate(output_offset); + return skipped_ ? CHANGE : KEEP; + } + + private: + // The type of handshake message to drop. + uint8_t handshake_type_; + // Whether this filter has ever skipped a handshake message. Track this so + // that sequence numbers on DTLS handshake messages can be rewritten in + // subsequent calls. + bool skipped_; +}; + +class TlsSkipTest : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + protected: + TlsSkipTest() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + + void SetUp() override { + TlsConnectTestBase::SetUp(); + EnsureTlsSetup(); + } + + void ServerSkipTest(std::shared_ptr<PacketFilter> filter, + uint8_t alert = kTlsAlertUnexpectedMessage) { + server_->SetFilter(filter); + ConnectExpectAlert(client_, alert); + } +}; + +class Tls13SkipTest : public TlsConnectTestBase, + public ::testing::WithParamInterface<SSLProtocolVariant> { + protected: + Tls13SkipTest() + : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} + + void SetUp() override { + TlsConnectTestBase::SetUp(); + EnsureTlsSetup(); + } + + void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { + filter->EnableDecryption(); + server_->SetFilter(filter); + ExpectAlert(client_, kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + client_->CheckErrorCode(error); + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + } + + void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { + filter->EnableDecryption(); + client_->SetFilter(filter); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFailOneSide(TlsAgent::SERVER); + + server_->CheckErrorCode(error); + ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + + client_->Handshake(); // Make sure to consume the alert the server sends. + } +}; + +TEST_P(TlsSkipTest, SkipCertificateRsa) { + EnableOnlyStaticRsaCiphers(); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); +} + +TEST_P(TlsSkipTest, SkipCertificateDhe) { + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); +} + +TEST_P(TlsSkipTest, SkipCertificateEcdhe) { + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); +} + +TEST_P(TlsSkipTest, SkipCertificateEcdsa) { + Reset(TlsAgent::kServerEcdsa256); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); +} + +TEST_P(TlsSkipTest, SkipServerKeyExchange) { + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); +} + +TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) { + Reset(TlsAgent::kServerEcdsa256); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); +} + +TEST_P(TlsSkipTest, SkipCertAndKeyExch) { + auto chain = std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit{std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate), + std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)}); + ServerSkipTest(chain); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); +} + +TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) { + Reset(TlsAgent::kServerEcdsa256); + auto chain = std::make_shared<ChainedPacketFilter>(); + chain->Add(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); + chain->Add(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)); + ServerSkipTest(chain); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); +} + +TEST_P(Tls13SkipTest, SkipEncryptedExtensions) { + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeEncryptedExtensions), + SSL_ERROR_RX_UNEXPECTED_CERTIFICATE); +} + +TEST_P(Tls13SkipTest, SkipServerCertificate) { + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate), + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); +} + +TEST_P(Tls13SkipTest, SkipServerCertificateVerify) { + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificateVerify), + SSL_ERROR_RX_UNEXPECTED_FINISHED); +} + +TEST_P(Tls13SkipTest, SkipClientCertificate) { + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); + ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + client_, kTlsHandshakeCertificate), + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); +} + +TEST_P(Tls13SkipTest, SkipClientCertificateVerify) { + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); + ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + client_, kTlsHandshakeCertificateVerify), + SSL_ERROR_RX_UNEXPECTED_FINISHED); +} + +INSTANTIATE_TEST_SUITE_P( + SkipTls10, TlsSkipTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10)); +INSTANTIATE_TEST_SUITE_P(SkipVariants, TlsSkipTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV11V12)); +INSTANTIATE_TEST_SUITE_P(Skip13Variants, Tls13SkipTest, + TlsConnectTestBase::kTlsVariantsAll); +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc new file mode 100644 index 0000000000..abddaa5b61 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc @@ -0,0 +1,139 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" +#include "rsa8193.h" + +namespace nss_test { + +const uint8_t kBogusClientKeyExchange[] = { + 0x01, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, +}; + +TEST_P(TlsConnectGenericPre13, ConnectStaticRSA) { + EnableOnlyStaticRsaCiphers(); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); +} + +// Test that a totally bogus EPMS is handled correctly. +// This test is stream so we can catch the bad_record_mac alert. +TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) { + EnableOnlyStaticRsaCiphers(); + MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>( + client_, kTlsHandshakeClientKeyExchange, + DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))); + ConnectExpectAlert(server_, kTlsAlertBadRecordMac); +} + +// Test that a PMS with a bogus version number is handled correctly. +// This test is stream so we can catch the bad_record_mac alert. +TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) { + EnableOnlyStaticRsaCiphers(); + MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_); + ConnectExpectAlert(server_, kTlsAlertBadRecordMac); +} + +// Test that a PMS with a bogus version number is ignored when +// rollback detection is disabled. This is a positive control for +// ConnectStaticRSABogusPMSVersionDetect. +TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) { + EnableOnlyStaticRsaCiphers(); + MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_); + server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE); + Connect(); +} + +// This test is stream so we can catch the bad_record_mac alert. +TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusCKE) { + EnableOnlyStaticRsaCiphers(); + EnableExtendedMasterSecret(); + MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>( + client_, kTlsHandshakeClientKeyExchange, + DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))); + ConnectExpectAlert(server_, kTlsAlertBadRecordMac); +} + +// This test is stream so we can catch the bad_record_mac alert. +TEST_P(TlsConnectStreamPre13, + ConnectExtendedMasterSecretStaticRSABogusPMSVersionDetect) { + EnableOnlyStaticRsaCiphers(); + EnableExtendedMasterSecret(); + MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_); + ConnectExpectAlert(server_, kTlsAlertBadRecordMac); +} + +TEST_P(TlsConnectStreamPre13, + ConnectExtendedMasterSecretStaticRSABogusPMSVersionIgnore) { + EnableOnlyStaticRsaCiphers(); + EnableExtendedMasterSecret(); + MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_); + server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE); + Connect(); +} + +// Replace the server certificate with one that uses 8193-bit RSA. +class TooLargeRSACertFilter : public TlsHandshakeFilter { + public: + TooLargeRSACertFilter(const std::shared_ptr<TlsAgent> &server) + : TlsHandshakeFilter(server, {kTlsHandshakeCertificate}) {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, + const DataBuffer &input, + DataBuffer *output) { + const uint32_t cert_len = sizeof(rsa8193); + const uint32_t outer_len = cert_len + 3; + size_t offset = 0; + offset = output->Write(offset, outer_len, 3); + offset = output->Write(offset, cert_len, 3); + offset = output->Write(offset, rsa8193, cert_len); + + return CHANGE; + } +}; + +TEST_P(TlsConnectGenericPre13, TooLargeRSAKeyInCert) { + EnableOnlyStaticRsaCiphers(); + MakeTlsFilter<TooLargeRSACertFilter>(server_); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_CLIENT_KEY_EXCHANGE_FAILURE); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectGeneric, ServerAuthBiggestRsa) { + Reset(TlsAgent::kRsa8192); + Connect(); + CheckKeys(); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc new file mode 100644 index 0000000000..2421470a4f --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc @@ -0,0 +1,573 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <memory> +#include <vector> +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +class Tls13CompatTest : public TlsConnectStreamTls13 { + protected: + void EnableCompatMode() { + client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE); + } + + void InstallFilters() { + EnsureTlsSetup(); + client_recorders_.Install(client_); + server_recorders_.Install(server_); + } + + void CheckRecordVersions() { + ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_0, + client_recorders_.records_->record(0).header.version()); + CheckRecordsAreTls12("client", client_recorders_.records_, 1); + CheckRecordsAreTls12("server", server_recorders_.records_, 0); + } + + void CheckHelloVersions() { + uint32_t ver; + ASSERT_TRUE(server_recorders_.hello_->buffer().Read(0, 2, &ver)); + ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, static_cast<uint16_t>(ver)); + ASSERT_TRUE(client_recorders_.hello_->buffer().Read(0, 2, &ver)); + ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, static_cast<uint16_t>(ver)); + } + + void CheckForCCS(bool expected_client, bool expected_server) { + client_recorders_.CheckForCCS(expected_client); + server_recorders_.CheckForCCS(expected_server); + } + + void CheckForRegularHandshake() { + CheckRecordVersions(); + CheckHelloVersions(); + EXPECT_EQ(0U, client_recorders_.session_id_length()); + EXPECT_EQ(0U, server_recorders_.session_id_length()); + CheckForCCS(false, false); + } + + void CheckForCompatHandshake() { + CheckRecordVersions(); + CheckHelloVersions(); + EXPECT_EQ(32U, client_recorders_.session_id_length()); + EXPECT_EQ(32U, server_recorders_.session_id_length()); + CheckForCCS(true, true); + } + + private: + struct Recorders { + Recorders() : records_(nullptr), hello_(nullptr) {} + + uint8_t session_id_length() const { + // session_id is always after version (2) and random (32). + uint32_t len = 0; + EXPECT_TRUE(hello_->buffer().Read(2 + 32, 1, &len)); + return static_cast<uint8_t>(len); + } + + void CheckForCCS(bool expected) const { + EXPECT_LT(0U, records_->count()); + for (size_t i = 0; i < records_->count(); ++i) { + // Only the second record can be a CCS. + bool expected_match = expected && (i == 1); + EXPECT_EQ(expected_match, + ssl_ct_change_cipher_spec == + records_->record(i).header.content_type()); + } + } + + void Install(std::shared_ptr<TlsAgent>& agent) { + if (records_ && records_->agent() == agent) { + // Avoid replacing the filters if they are already installed on this + // agent. This ensures that InstallFilters() can be used after + // MakeNewServer() without losing state on the client filters. + return; + } + records_.reset(new TlsRecordRecorder(agent)); + hello_.reset(new TlsHandshakeRecorder( + agent, std::set<uint8_t>( + {kTlsHandshakeClientHello, kTlsHandshakeServerHello}))); + agent->SetFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit({records_, hello_}))); + } + + std::shared_ptr<TlsRecordRecorder> records_; + std::shared_ptr<TlsHandshakeRecorder> hello_; + }; + + void CheckRecordsAreTls12(const std::string& agent, + const std::shared_ptr<TlsRecordRecorder>& records, + size_t start) { + EXPECT_LE(start, records->count()); + for (size_t i = start; i < records->count(); ++i) { + EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, + records->record(i).header.version()) + << agent << ": record " << i << " has wrong version"; + } + } + + Recorders client_recorders_; + Recorders server_recorders_; +}; + +TEST_F(Tls13CompatTest, Disabled) { + InstallFilters(); + Connect(); + CheckForRegularHandshake(); +} + +TEST_F(Tls13CompatTest, Enabled) { + EnableCompatMode(); + InstallFilters(); + Connect(); + CheckForCompatHandshake(); +} + +TEST_F(Tls13CompatTest, EnabledZeroRtt) { + SetupForZeroRtt(); + EnableCompatMode(); + InstallFilters(); + + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + CheckForCCS(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + + CheckForCompatHandshake(); +} + +TEST_F(Tls13CompatTest, EnabledHrr) { + EnableCompatMode(); + InstallFilters(); + + // Force a HelloRetryRequest. The server sends CCS immediately. + server_->ConfigNamedGroups({ssl_grp_ec_secp384r1}); + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + CheckForCCS(false, true); + + Handshake(); + CheckConnected(); + CheckForCompatHandshake(); +} + +TEST_F(Tls13CompatTest, EnabledStatelessHrr) { + EnableCompatMode(); + InstallFilters(); + + // Force a HelloRetryRequest + server_->ConfigNamedGroups({ssl_grp_ec_secp384r1}); + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + + // The server should send CCS before HRR. + CheckForCCS(false, true); + + // A new server should complete the handshake, and not send CCS. + MakeNewServer(); + InstallFilters(); + server_->ConfigNamedGroups({ssl_grp_ec_secp384r1}); + + Handshake(); + CheckConnected(); + CheckRecordVersions(); + CheckHelloVersions(); + CheckForCCS(true, false); +} + +TEST_F(Tls13CompatTest, EnabledHrrZeroRtt) { + SetupForZeroRtt(); + EnableCompatMode(); + InstallFilters(); + server_->ConfigNamedGroups({ssl_grp_ec_secp384r1}); + + // With 0-RTT, the client sends CCS immediately. With HRR, the server sends + // CCS immediately too. + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false); + CheckForCCS(true, true); + + Handshake(); + ExpectEarlyDataAccepted(false); + CheckConnected(); + CheckForCompatHandshake(); +} + +TEST_F(Tls13CompatTest, EnabledAcceptedEch) { + EnsureTlsSetup(); + SetupEch(client_, server_); + EnableCompatMode(); + InstallFilters(); + Connect(); + CheckForCompatHandshake(); +} + +TEST_F(Tls13CompatTest, EnabledRejectedEch) { + EnsureTlsSetup(); + // Configure ECH on the client only, and expect CCS. + SetupEch(client_, server_, HpkeDhKemX25519Sha256, false, true, false); + EnableCompatMode(); + InstallFilters(); + ExpectAlert(client_, kTlsAlertEchRequired); + ConnectExpectFailOneSide(TlsAgent::CLIENT); + client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITHOUT_ECH); + CheckForCompatHandshake(); + // Reset expectations for the TlsAgent dtor. + server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning); +} + +class TlsSessionIDEchoFilter : public TlsHandshakeFilter { + public: + TlsSessionIDEchoFilter(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter( + a, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + TlsParser parser(input); + + // Skip version + random. + EXPECT_TRUE(parser.Skip(2 + 32)); + + // Capture CH.legacy_session_id. + if (header.handshake_type() == kTlsHandshakeClientHello) { + EXPECT_TRUE(parser.ReadVariable(&sid_, 1)); + return KEEP; + } + + // Check that server sends one too. + uint32_t sid_len = 0; + EXPECT_TRUE(parser.Read(&sid_len, 1)); + EXPECT_EQ(sid_len, sid_.len()); + + // Echo the one we captured. + *output = input; + output->Write(parser.consumed(), sid_.data(), sid_.len()); + + return CHANGE; + } + + private: + DataBuffer sid_; +}; + +TEST_F(TlsConnectTest, EchoTLS13CompatibilitySessionID) { + ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); + + client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE); + + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + + server_->SetFilter(MakeTlsFilter<TlsSessionIDEchoFilter>(client_)); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +class TlsSessionIDInjectFilter : public TlsHandshakeFilter { + public: + TlsSessionIDInjectFilter(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}) {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + TlsParser parser(input); + + // Skip version + random. + EXPECT_TRUE(parser.Skip(2 + 32)); + + *output = input; + + // Inject a Session ID. + const uint8_t fake_sid[SSL3_SESSIONID_BYTES] = {0xff}; + output->Write(parser.consumed(), sizeof(fake_sid), 1); + output->Splice(fake_sid, sizeof(fake_sid), parser.consumed() + 1, 0); + + return CHANGE; + } +}; + +TEST_F(TlsConnectTest, TLS13NonCompatModeSessionID) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + MakeTlsFilter<TlsSessionIDInjectFilter>(server_); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE); +} + +static const uint8_t kCannedCcs[] = { + ssl_ct_change_cipher_spec, + SSL_LIBRARY_VERSION_TLS_1_2 >> 8, + SSL_LIBRARY_VERSION_TLS_1_2 & 0xff, + 0, + 1, // length + 1 // change_cipher_spec_choice +}; + +// A ChangeCipherSpec is ignored by a server because we have to tolerate it for +// compatibility mode. That doesn't mean that we have to tolerate it +// unconditionally. If we negotiate 1.3, we expect to see a cookie extension. +TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello13) { + EnsureTlsSetup(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + // Client sends CCS before starting the handshake. + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +// A ChangeCipherSpec is ignored by a server because we have to tolerate it for +// compatibility mode. That doesn't mean that we have to tolerate it +// unconditionally. If we negotiate 1.3, we expect to see a cookie extension. +TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHelloTwice) { + EnsureTlsSetup(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + // Client sends CCS before starting the handshake. + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +// The server accepts a ChangeCipherSpec even if the client advertises +// an empty session ID. +TEST_F(TlsConnectStreamTls13, ChangeCipherSpecAfterClientHelloEmptySid) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + StartConnect(); + client_->Handshake(); // Send ClientHello + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); // Send CCS + + Handshake(); + CheckConnected(); +} + +// The server rejects multiple ChangeCipherSpec even if the client +// indicates compatibility mode with non-empty session ID. +TEST_F(Tls13CompatTest, ChangeCipherSpecAfterClientHelloTwice) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + EnableCompatMode(); + + StartConnect(); + client_->Handshake(); // Send ClientHello + // Send CCS twice in a row + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + server_->Handshake(); // Consume ClientHello and CCS. + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CHANGE_CIPHER); +} + +// The client accepts a ChangeCipherSpec even if it advertises an empty +// session ID. +TEST_F(TlsConnectStreamTls13, ChangeCipherSpecAfterServerHelloEmptySid) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + // To replace Finished with a CCS below + auto filter = MakeTlsFilter<TlsHandshakeDropper>(server_); + filter->SetHandshakeTypes({kTlsHandshakeFinished}); + filter->EnableDecryption(); + + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Consume ClientHello, and + // send ServerHello..CertificateVerify + // Send CCS + server_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + + // No alert is sent from the client. As Finished is dropped, we + // can't use Handshake() and CheckConnected(). + client_->Handshake(); +} + +// The client rejects multiple ChangeCipherSpec in a row even if the +// client indicates compatibility mode with non-empty session ID. +TEST_F(Tls13CompatTest, ChangeCipherSpecAfterServerHelloTwice) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + EnableCompatMode(); + + // To replace Finished with a CCS below + auto filter = MakeTlsFilter<TlsHandshakeDropper>(server_); + filter->SetHandshakeTypes({kTlsHandshakeFinished}); + filter->EnableDecryption(); + + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Consume ClientHello, and + // send ServerHello..CertificateVerify + // the ServerHello is followed by CCS + // Send another CCS + server_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + client_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + client_->Handshake(); // Consume ClientHello and CCS + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CHANGE_CIPHER); +} + +// If we negotiate 1.2, we abort. +TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello12) { + EnsureTlsSetup(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + // Client sends CCS before starting the handshake. + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +TEST_F(TlsConnectStreamTls13, ChangeCipherSpecAfterFinished13) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + SendReceive(10); + // Client sends CCS after the handshake. + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + server_->ExpectReadWriteError(); + server_->ReadBytes(); + EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code()); +} + +TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) { + EnsureTlsSetup(); + client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE); + auto client_records = MakeTlsFilter<TlsRecordRecorder>(client_); + auto server_records = MakeTlsFilter<TlsRecordRecorder>(server_); + Connect(); + + ASSERT_EQ(2U, client_records->count()); // CH, Fin + EXPECT_EQ(ssl_ct_handshake, client_records->record(0).header.content_type()); + EXPECT_EQ(kCtDtlsCiphertext, + (client_records->record(1).header.content_type() & + kCtDtlsCiphertextMask)); + + ASSERT_EQ(6U, server_records->count()); // SH, EE, CT, CV, Fin, Ack + EXPECT_EQ(ssl_ct_handshake, server_records->record(0).header.content_type()); + for (size_t i = 1; i < server_records->count(); ++i) { + EXPECT_EQ(kCtDtlsCiphertext, + (server_records->record(i).header.content_type() & + kCtDtlsCiphertextMask)); + } +} + +class AddSessionIdFilter : public TlsHandshakeFilter { + public: + AddSessionIdFilter(const std::shared_ptr<TlsAgent>& client) + : TlsHandshakeFilter(client, {ssl_hs_client_hello}) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + uint32_t session_id_len = 0; + EXPECT_TRUE(input.Read(2 + 32, 1, &session_id_len)); + EXPECT_EQ(0U, session_id_len); + uint8_t session_id[33] = {32}; // 32 for length, the rest zero. + *output = input; + output->Splice(session_id, sizeof(session_id), 34, 1); + return CHANGE; + } +}; + +// Adding a session ID to a DTLS ClientHello should not trigger compatibility +// mode. It should be ignored instead. +TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) { + EnsureTlsSetup(); + auto client_records = std::make_shared<TlsRecordRecorder>(client_); + client_->SetFilter( + std::make_shared<ChainedPacketFilter>(ChainedPacketFilterInit( + {client_records, std::make_shared<AddSessionIdFilter>(client_)}))); + auto server_hello = + std::make_shared<TlsHandshakeRecorder>(server_, kTlsHandshakeServerHello); + auto server_records = std::make_shared<TlsRecordRecorder>(server_); + server_->SetFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit({server_records, server_hello}))); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + // The client will consume the ServerHello, but discard everything else + // because it doesn't decrypt. And don't wait around for the client to ACK. + client_->Handshake(); + + ASSERT_EQ(1U, client_records->count()); + EXPECT_EQ(ssl_ct_handshake, client_records->record(0).header.content_type()); + + ASSERT_EQ(5U, server_records->count()); // SH, EE, CT, CV, Fin + EXPECT_EQ(ssl_ct_handshake, server_records->record(0).header.content_type()); + for (size_t i = 1; i < server_records->count(); ++i) { + EXPECT_EQ(kCtDtlsCiphertext, + (server_records->record(i).header.content_type() & + kCtDtlsCiphertextMask)); + } + + uint32_t session_id_len = 0; + EXPECT_TRUE(server_hello->buffer().Read(2 + 32, 1, &session_id_len)); + EXPECT_EQ(0U, session_id_len); +} + +TEST_F(Tls13CompatTest, ConnectWith12ThenAttemptToResume13CompatMode) { + ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_2); + Connect(); + + Reset(); + ExpectResumption(RESUME_NONE); + version_ = SSL_LIBRARY_VERSION_TLS_1_3; + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + EnableCompatMode(); + Connect(); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc new file mode 100644 index 0000000000..9aa6542d66 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc @@ -0,0 +1,414 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "pk11pub.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "tls_connect.h" +#include "tls_filter.h" + +namespace nss_test { + +// Replaces the client hello with an SSLv2 version once. +class SSLv2ClientHelloFilter : public PacketFilter { + public: + SSLv2ClientHelloFilter(const std::shared_ptr<TlsAgent>& client, + uint16_t version) + : replaced_(false), + client_(client), + version_(version), + pad_len_(0), + reported_pad_len_(0), + client_random_len_(16), + ciphers_(0), + send_escape_(false) {} + + void SetVersion(uint16_t version) { version_ = version; } + + void SetCipherSuites(const std::vector<uint16_t>& ciphers) { + ciphers_ = ciphers; + } + + // Set a padding length and announce it correctly. + void SetPadding(uint8_t pad_len) { SetPadding(pad_len, pad_len); } + + // Set a padding length and allow to lie about its length. + void SetPadding(uint8_t pad_len, uint8_t reported_pad_len) { + pad_len_ = pad_len; + reported_pad_len_ = reported_pad_len; + } + + void SetClientRandomLength(uint16_t client_random_len) { + client_random_len_ = client_random_len; + } + + void SetSendEscape(bool send_escape) { send_escape_ = send_escape; } + + protected: + virtual PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) { + if (replaced_) { + return KEEP; + } + + // Replace only the very first packet. + replaced_ = true; + + // The SSLv2 client hello size. + size_t packet_len = SSL_HL_CLIENT_HELLO_HBYTES + (ciphers_.size() * 3) + + client_random_len_ + pad_len_; + + size_t idx = 0; + *output = input; + output->Allocate(packet_len); + output->Truncate(packet_len); + + // Write record length. + if (pad_len_ > 0) { + size_t masked_len = 0x3fff & packet_len; + if (send_escape_) { + masked_len |= 0x4000; + } + + idx = output->Write(idx, masked_len, 2); + idx = output->Write(idx, reported_pad_len_, 1); + } else { + PR_ASSERT(!send_escape_); + idx = output->Write(idx, 0x8000 | packet_len, 2); + } + + // Remember header length. + size_t hdr_len = idx; + + // Write client hello. + idx = output->Write(idx, SSL_MT_CLIENT_HELLO, 1); + idx = output->Write(idx, version_, 2); + + // Cipher list length. + idx = output->Write(idx, (ciphers_.size() * 3), 2); + + // Session ID length. + idx = output->Write(idx, static_cast<uint32_t>(0), 2); + + // ClientRandom length. + idx = output->Write(idx, client_random_len_, 2); + + // Cipher suites. + for (auto cipher : ciphers_) { + idx = output->Write(idx, static_cast<uint32_t>(cipher), 3); + } + + // Challenge. + std::vector<uint8_t> challenge(client_random_len_); + PK11_GenerateRandom(challenge.data(), challenge.size()); + idx = output->Write(idx, challenge.data(), challenge.size()); + + // Add padding if any. + if (pad_len_ > 0) { + std::vector<uint8_t> pad(pad_len_); + idx = output->Write(idx, pad.data(), pad.size()); + } + + // Update the client random so that the handshake succeeds. + SECStatus rv = SSLInt_UpdateSSLv2ClientRandom( + client_.lock()->ssl_fd(), challenge.data(), challenge.size(), + output->data() + hdr_len, output->len() - hdr_len); + EXPECT_EQ(SECSuccess, rv); + + return CHANGE; + } + + private: + bool replaced_; + std::weak_ptr<TlsAgent> client_; + uint16_t version_; + uint8_t pad_len_; + uint8_t reported_pad_len_; + uint16_t client_random_len_; + std::vector<uint16_t> ciphers_; + bool send_escape_; +}; + +class SSLv2ClientHelloTestF : public TlsConnectTestBase { + public: + SSLv2ClientHelloTestF() + : TlsConnectTestBase(ssl_variant_stream, 0), filter_(nullptr) {} + + SSLv2ClientHelloTestF(SSLProtocolVariant variant, uint16_t version) + : TlsConnectTestBase(variant, version), filter_(nullptr) {} + + void SetUp() override { + TlsConnectTestBase::SetUp(); + filter_ = MakeTlsFilter<SSLv2ClientHelloFilter>(client_, version_); + server_->SetOption(SSL_ENABLE_V2_COMPATIBLE_HELLO, PR_TRUE); + } + + void SetExpectedVersion(uint16_t version) { + TlsConnectTestBase::SetExpectedVersion(version); + filter_->SetVersion(version); + } + + void SetAvailableCipherSuite(uint16_t cipher) { + filter_->SetCipherSuites(std::vector<uint16_t>(1, cipher)); + } + + void SetAvailableCipherSuites(const std::vector<uint16_t>& ciphers) { + filter_->SetCipherSuites(ciphers); + } + + void SetPadding(uint8_t pad_len) { filter_->SetPadding(pad_len); } + + void SetPadding(uint8_t pad_len, uint8_t reported_pad_len) { + filter_->SetPadding(pad_len, reported_pad_len); + } + + void SetClientRandomLength(uint16_t client_random_len) { + filter_->SetClientRandomLength(client_random_len); + } + + void SetSendEscape(bool send_escape) { filter_->SetSendEscape(send_escape); } + + private: + std::shared_ptr<SSLv2ClientHelloFilter> filter_; +}; + +// Parameterized version of SSLv2ClientHelloTestF we can +// use with TEST_P to test multiple TLS versions easily. +class SSLv2ClientHelloTest : public SSLv2ClientHelloTestF, + public ::testing::WithParamInterface<uint16_t> { + public: + SSLv2ClientHelloTest() + : SSLv2ClientHelloTestF(ssl_variant_stream, GetParam()) {} +}; + +// Test negotiating TLS 1.0 - 1.2. +TEST_P(SSLv2ClientHelloTest, Connect) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + Connect(); +} + +TEST_P(SSLv2ClientHelloTest, ConnectDisabled) { + server_->SetOption(SSL_ENABLE_V2_COMPATIBLE_HELLO, PR_FALSE); + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + + StartConnect(); + client_->Handshake(); // Send the modified ClientHello. + server_->Handshake(); // Read some. + // The problem here is that the v2 ClientHello puts the version where the v3 + // ClientHello puts a version number. So the version number (0x0301+) appears + // to be a length and server blocks waiting for that much data. + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // This is usually what happens with v2-compatible: the server hangs. + // But to be certain, feed in more data to see if an error comes out. + uint8_t zeros[SSL_LIBRARY_VERSION_TLS_1_2] = {0}; + client_->SendDirect(DataBuffer(zeros, sizeof(zeros))); + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + server_->Handshake(); + client_->Handshake(); +} + +// Sending a v2 ClientHello after a no-op v3 record must fail. +TEST_P(SSLv2ClientHelloTest, ConnectAfterEmptyV3Record) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x16, 1); // handshake + idx = buffer.Write(idx, 0x0301, 2); // record_version + (void)buffer.Write(idx, 0U, 2); // length=0 + + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + EnsureTlsSetup(); + client_->SendDirect(buffer); + + // Need padding so the connection doesn't just time out. With a v2 + // ClientHello parsed as a v3 record we will use the record version + // as the record length. + SetPadding(255); + + ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); + EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, server_->error_code()); +} + +// Test negotiating TLS 1.3. +TEST_F(SSLv2ClientHelloTestF, Connect13) { + EnsureTlsSetup(); + SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + std::vector<uint16_t> cipher_suites = {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}; + SetAvailableCipherSuites(cipher_suites); + + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); +} + +// Test negotiating an EC suite. +TEST_P(SSLv2ClientHelloTest, NegotiateECSuite) { + SetAvailableCipherSuite(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + Connect(); +} + +// Test negotiating TLS 1.0 - 1.2 with a padded client hello. +TEST_P(SSLv2ClientHelloTest, AddPadding) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + SetPadding(255); + Connect(); +} + +// Test that sending a security escape fails the handshake. +TEST_P(SSLv2ClientHelloTest, SendSecurityEscape) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + + // Send a security escape. + SetSendEscape(true); + + // Set a big padding so that the server fails instead of timing out. + SetPadding(255); + + ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); +} + +// Invalid SSLv2 client hello padding must fail the handshake. +TEST_P(SSLv2ClientHelloTest, AddErroneousPadding) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + + // Append 5 bytes of padding but say it's only 4. + SetPadding(5, 4); + + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); +} + +// Invalid SSLv2 client hello padding must fail the handshake. +TEST_P(SSLv2ClientHelloTest, AddErroneousPadding2) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + + // Append 5 bytes of padding but say it's 6. + SetPadding(5, 6); + + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); +} + +// Wrong amount of bytes for the ClientRandom must fail the handshake. +TEST_P(SSLv2ClientHelloTest, SmallClientRandom) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + + // Send a ClientRandom that's too small. + SetClientRandomLength(15); + + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); +} + +// Test sending the maximum accepted number of ClientRandom bytes. +TEST_P(SSLv2ClientHelloTest, MaxClientRandom) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + SetClientRandomLength(32); + Connect(); +} + +// Wrong amount of bytes for the ClientRandom must fail the handshake. +TEST_P(SSLv2ClientHelloTest, BigClientRandom) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + + // Send a ClientRandom that's too big. + SetClientRandomLength(33); + + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); +} + +// Connection must fail if we require safe renegotiation but the client doesn't +// include TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites. +TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiation) { + server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE); + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + EXPECT_EQ(SSL_ERROR_UNSAFE_NEGOTIATION, server_->error_code()); +} + +// Connection must succeed when requiring safe renegotiation and the client +// includes TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites. +TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiationWithSCSV) { + server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE); + std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA, + TLS_EMPTY_RENEGOTIATION_INFO_SCSV}; + SetAvailableCipherSuites(cipher_suites); + Connect(); +} + +TEST_P(SSLv2ClientHelloTest, CheckServerRandom) { + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + + static const size_t random_len = 32; + uint8_t srandom1[random_len]; + uint8_t z[random_len] = {0}; + + auto sh = MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_server_hello); + Connect(); + ASSERT_TRUE(sh->buffer().len() > (random_len + 2)); + memcpy(srandom1, sh->buffer().data() + 2, random_len); + EXPECT_NE(0, memcmp(srandom1, z, random_len)); + + Reset(); + sh = MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_server_hello); + Connect(); + ASSERT_TRUE(sh->buffer().len() > (random_len + 2)); + const uint8_t* srandom2 = sh->buffer().data() + 2; + + EXPECT_NE(0, memcmp(srandom2, z, random_len)); + EXPECT_NE(0, memcmp(srandom1, srandom2, random_len)); +} + +// Connect to the server with TLS 1.1, signalling that this is a fallback from +// a higher version. As the server doesn't support anything higher than TLS 1.1 +// it must accept the connection. +TEST_F(SSLv2ClientHelloTestF, FallbackSCSV) { + EnsureTlsSetup(); + SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_1); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_1); + + std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA, + TLS_FALLBACK_SCSV}; + SetAvailableCipherSuites(cipher_suites); + Connect(); +} + +// Connect to the server with TLS 1.1, signalling that this is a fallback from +// a higher version. As the server supports TLS 1.2 though it must reject the +// connection due to a possible downgrade attack. +TEST_F(SSLv2ClientHelloTestF, InappropriateFallbackSCSV) { + SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_1); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_1); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2); + + std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA, + TLS_FALLBACK_SCSV}; + SetAvailableCipherSuites(cipher_suites); + + ConnectExpectAlert(server_, kTlsAlertInappropriateFallback); + EXPECT_EQ(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT, server_->error_code()); +} + +INSTANTIATE_TEST_SUITE_P(VersionsStream10Pre13, SSLv2ClientHelloTest, + TlsConnectTestBase::kTlsV10); +INSTANTIATE_TEST_SUITE_P(VersionsStreamPre13, SSLv2ClientHelloTest, + TlsConnectTestBase::kTlsV11V12); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc new file mode 100644 index 0000000000..079d865a8f --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc @@ -0,0 +1,456 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "ssl3prot.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectStream, ServerNegotiateTls10) { + uint16_t minver, maxver; + client_->GetVersionRange(&minver, &maxver); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, maxver); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + Connect(); +} + +TEST_P(TlsConnectGeneric, ServerNegotiateTls11) { + if (version_ < SSL_LIBRARY_VERSION_TLS_1_1) GTEST_SKIP(); + + uint16_t minver, maxver; + client_->GetVersionRange(&minver, &maxver); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, maxver); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_1); + Connect(); +} + +TEST_P(TlsConnectGeneric, ServerNegotiateTls12) { + if (version_ < SSL_LIBRARY_VERSION_TLS_1_2) GTEST_SKIP(); + + uint16_t minver, maxver; + client_->GetVersionRange(&minver, &maxver); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, maxver); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + Connect(); +} + +// Test the ServerRandom version hack from +// [draft-ietf-tls-tls13-11 Section 6.3.1.1]. +// The first three tests test for active tampering. The next +// two validate that we can also detect fallback using the +// SSL_SetDowngradeCheckVersion() API. +TEST_F(TlsConnectTest, TestDowngradeDetectionToTls11) { + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_2); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_2); + client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_TRUE); + MakeTlsFilter<TlsMessageVersionSetter>(client_, kTlsHandshakeClientHello, + SSL_LIBRARY_VERSION_TLS_1_1); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Attempt to negotiate the bogus DTLS 1.1 version. +TEST_F(DtlsConnectTest, TestDtlsVersion11) { + MakeTlsFilter<TlsMessageVersionSetter>(client_, kTlsHandshakeClientHello, + ((~0x0101) & 0xffff)); + ConnectExpectAlert(server_, kTlsAlertProtocolVersion); + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); +} + +TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) { + client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_TRUE); + MakeTlsFilter<TlsExtensionDropper>(client_, ssl_tls13_supported_versions_xtn); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Disabling downgrade checks will be caught when the Finished MAC check fails. +TEST_F(TlsConnectTest, TestDisableDowngradeDetection) { + client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_FALSE); + MakeTlsFilter<TlsExtensionDropper>(client_, ssl_tls13_supported_versions_xtn); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + ConnectExpectAlert(server_, kTlsAlertDecryptError); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); +} + +typedef std::tuple<SSLProtocolVariant, + uint16_t, // client version + uint16_t> // server version + TlsDowngradeProfile; + +class TlsDowngradeTest + : public TlsConnectTestBase, + public ::testing::WithParamInterface<TlsDowngradeProfile> { + public: + TlsDowngradeTest() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())), + c_ver(std::get<1>(GetParam())), + s_ver(std::get<2>(GetParam())) {} + + protected: + const uint16_t c_ver; + const uint16_t s_ver; +}; + +TEST_P(TlsDowngradeTest, TlsDowngradeSentinelTest) { + static const uint8_t tls12_downgrade_random[] = {0x44, 0x4F, 0x57, 0x4E, + 0x47, 0x52, 0x44, 0x01}; + static const uint8_t tls1_downgrade_random[] = {0x44, 0x4F, 0x57, 0x4E, + 0x47, 0x52, 0x44, 0x00}; + static const size_t kRandomLen = 32; + + if (c_ver > s_ver) { + GTEST_SKIP(); + } + + client_->SetVersionRange(c_ver, c_ver); + server_->SetVersionRange(c_ver, s_ver); + + auto sh = MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_server_hello); + Connect(); + ASSERT_TRUE(sh->buffer().len() > (kRandomLen + 2)); + + const uint8_t* downgrade_sentinel = + sh->buffer().data() + 2 + kRandomLen - sizeof(tls1_downgrade_random); + if (c_ver < s_ver) { + if (c_ver == SSL_LIBRARY_VERSION_TLS_1_2) { + EXPECT_EQ(0, memcmp(downgrade_sentinel, tls12_downgrade_random, + sizeof(tls12_downgrade_random))); + } else { + EXPECT_EQ(0, memcmp(downgrade_sentinel, tls1_downgrade_random, + sizeof(tls1_downgrade_random))); + } + } else { + EXPECT_NE(0, memcmp(downgrade_sentinel, tls12_downgrade_random, + sizeof(tls12_downgrade_random))); + EXPECT_NE(0, memcmp(downgrade_sentinel, tls1_downgrade_random, + sizeof(tls1_downgrade_random))); + } +} + +// TLS 1.1 clients do not check the random values, so we should +// instead get a handshake failure alert from the server. +TEST_F(TlsConnectTest, TestDowngradeDetectionToTls10) { + // Setting the option here has no effect. + client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_TRUE); + MakeTlsFilter<TlsMessageVersionSetter>(client_, kTlsHandshakeClientHello, + SSL_LIBRARY_VERSION_TLS_1_0); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_1); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_2); + ConnectExpectAlert(server_, kTlsAlertDecryptError); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); +} + +TEST_F(TlsConnectTest, TestFallbackFromTls12) { + client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_TRUE); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_1); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2); + client_->SetDowngradeCheckVersion(SSL_LIBRARY_VERSION_TLS_1_2); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +static SECStatus AllowFalseStart(PRFileDesc* fd, void* arg, + PRBool* can_false_start) { + bool* false_start_attempted = reinterpret_cast<bool*>(arg); + *false_start_attempted = true; + *can_false_start = PR_TRUE; + return SECSuccess; +} + +// If we disable the downgrade check, the sentinel is still generated, and we +// disable false start instead. +TEST_F(TlsConnectTest, DisableFalseStartOnFallback) { + // Don't call client_->EnableFalseStart(), because that sets the client up for + // success, and we want false start to fail. + client_->SetOption(SSL_ENABLE_FALSE_START, PR_TRUE); + bool false_start_attempted = false; + EXPECT_EQ(SECSuccess, + SSL_SetCanFalseStartCallback(client_->ssl_fd(), AllowFalseStart, + &false_start_attempted)); + + client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_FALSE); + client_->SetDowngradeCheckVersion(SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_FALSE(false_start_attempted); +} + +TEST_F(TlsConnectTest, TestFallbackFromTls13) { + client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_TRUE); + client_->SetDowngradeCheckVersion(SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectGeneric, TestFallbackSCSVVersionMatch) { + client_->SetOption(SSL_ENABLE_FALLBACK_SCSV, PR_TRUE); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, TestFallbackSCSVVersionMismatch) { + client_->SetOption(SSL_ENABLE_FALLBACK_SCSV, PR_TRUE); + server_->SetVersionRange(version_, version_ + 1); + ConnectExpectAlert(server_, kTlsAlertInappropriateFallback); + client_->CheckErrorCode(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT); + server_->CheckErrorCode(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT); +} + +// The TLS v1.3 spec section C.4 states that 'Implementations MUST NOT send or +// accept any records with a version less than { 3, 0 }'. Thus we will not +// allow version ranges including both SSL v3 and TLS v1.3. +TEST_F(TlsConnectTest, DisallowSSLv3HelloWithTLSv13Enabled) { + SECStatus rv; + SSLVersionRange vrange = {SSL_LIBRARY_VERSION_3_0, + SSL_LIBRARY_VERSION_TLS_1_3}; + + EnsureTlsSetup(); + rv = SSL_VersionRangeSet(client_->ssl_fd(), &vrange); + EXPECT_EQ(SECFailure, rv); + + rv = SSL_VersionRangeSet(server_->ssl_fd(), &vrange); + EXPECT_EQ(SECFailure, rv); +} + +TEST_P(TlsConnectGeneric, AlertBeforeServerHello) { + EnsureTlsSetup(); + client_->ExpectReceiveAlert(kTlsAlertUnrecognizedName, kTlsAlertWarning); + StartConnect(); + client_->Handshake(); // Send ClientHello. + static const uint8_t kWarningAlert[] = {kTlsAlertWarning, + kTlsAlertUnrecognizedName}; + DataBuffer alert; + TlsAgentTestBase::MakeRecord(variant_, ssl_ct_alert, + SSL_LIBRARY_VERSION_TLS_1_0, kWarningAlert, + PR_ARRAY_SIZE(kWarningAlert), &alert); + client_->adapter()->PacketReceived(alert); + Handshake(); + CheckConnected(); +} + +class Tls13NoSupportedVersions : public TlsConnectStreamTls12 { + protected: + void Run(uint16_t overwritten_client_version, uint16_t max_server_version) { + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, max_server_version); + MakeTlsFilter<TlsMessageVersionSetter>(client_, kTlsHandshakeClientHello, + overwritten_client_version); + auto capture = + MakeTlsFilter<TlsHandshakeRecorder>(server_, kTlsHandshakeServerHello); + ConnectExpectAlert(server_, kTlsAlertDecryptError); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); + const DataBuffer& server_hello = capture->buffer(); + ASSERT_GT(server_hello.len(), 2U); + uint32_t ver; + ASSERT_TRUE(server_hello.Read(0, 2, &ver)); + ASSERT_EQ(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), ver); + } +}; + +// If we offer a 1.3 ClientHello w/o supported_versions, the server should +// negotiate 1.2. +TEST_F(Tls13NoSupportedVersions, + Tls13ClientHelloWithoutSupportedVersionsServer12) { + Run(SSL_LIBRARY_VERSION_TLS_1_3, SSL_LIBRARY_VERSION_TLS_1_2); +} + +TEST_F(Tls13NoSupportedVersions, + Tls13ClientHelloWithoutSupportedVersionsServer13) { + Run(SSL_LIBRARY_VERSION_TLS_1_3, SSL_LIBRARY_VERSION_TLS_1_3); +} + +TEST_F(Tls13NoSupportedVersions, + Tls14ClientHelloWithoutSupportedVersionsServer13) { + Run(SSL_LIBRARY_VERSION_TLS_1_3 + 1, SSL_LIBRARY_VERSION_TLS_1_3); +} + +// Offer 1.3 but with ClientHello.legacy_version == TLS 1.4. This +// causes a bad MAC error when we read EncryptedExtensions. +TEST_F(TlsConnectStreamTls13, Tls14ClientHelloWithSupportedVersions) { + MakeTlsFilter<TlsMessageVersionSetter>(client_, kTlsHandshakeClientHello, + SSL_LIBRARY_VERSION_TLS_1_3 + 1); + auto capture = MakeTlsFilter<TlsExtensionCapture>( + server_, ssl_tls13_supported_versions_xtn); + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + + ASSERT_EQ(2U, capture->extension().len()); + uint32_t version = 0; + ASSERT_TRUE(capture->extension().Read(0, 2, &version)); + // This way we don't need to change with new draft version. + ASSERT_LT(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), version); +} + +// Offer 1.3 but with Server/ClientHello.legacy_version == SSL 3.0. This +// causes a protocol version alert. See RFC 8446 Appendix D.5. +TEST_F(TlsConnectStreamTls13, Ssl30ClientHelloWithSupportedVersions) { + MakeTlsFilter<TlsMessageVersionSetter>(client_, kTlsHandshakeClientHello, + SSL_LIBRARY_VERSION_3_0); + ConnectExpectAlert(server_, kTlsAlertProtocolVersion); +} + +TEST_F(TlsConnectStreamTls13, Ssl30ServerHelloWithSupportedVersions) { + MakeTlsFilter<TlsMessageVersionSetter>(server_, kTlsHandshakeServerHello, + SSL_LIBRARY_VERSION_3_0); + StartConnect(); + client_->ExpectSendAlert(kTlsAlertProtocolVersion); + /* Since the handshake is not finished the client will send an unencrypted + * alert. The server is expected to close the connection with a unexpected + * message alert. */ + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + Handshake(); +} + +// Verify the client sends only DTLS versions in supported_versions +TEST_F(DtlsConnectTest, DtlsSupportedVersionsEncoding) { + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + auto capture = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_supported_versions_xtn); + Connect(); + + ASSERT_EQ(7U, capture->extension().len()); + uint32_t version = 0; + ASSERT_TRUE(capture->extension().Read(1, 2, &version)); + EXPECT_EQ(0x7f00 | DTLS_1_3_DRAFT_VERSION, static_cast<int>(version)); + ASSERT_TRUE(capture->extension().Read(3, 2, &version)); + EXPECT_EQ(SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, static_cast<int>(version)); + ASSERT_TRUE(capture->extension().Read(5, 2, &version)); + EXPECT_EQ(SSL_LIBRARY_VERSION_DTLS_1_0_WIRE, static_cast<int>(version)); +} + +// Verify the DTLS 1.3 supported_versions interop workaround. +TEST_F(DtlsConnectTest, Dtls13VersionWorkaround) { + static const uint16_t kExpectVersionsWorkaround[] = { + 0x7f00 | DTLS_1_3_DRAFT_VERSION, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, + SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_DTLS_1_0_WIRE, + SSL_LIBRARY_VERSION_TLS_1_1}; + const int min_ver = SSL_LIBRARY_VERSION_TLS_1_1, + max_ver = SSL_LIBRARY_VERSION_TLS_1_3; + + // Toggle the workaround, then verify both encodings are present. + EnsureTlsSetup(); + SSL_SetDtls13VersionWorkaround(client_->ssl_fd(), PR_TRUE); + SSL_SetDtls13VersionWorkaround(client_->ssl_fd(), PR_FALSE); + SSL_SetDtls13VersionWorkaround(client_->ssl_fd(), PR_TRUE); + client_->SetVersionRange(min_ver, max_ver); + server_->SetVersionRange(min_ver, max_ver); + auto capture = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_supported_versions_xtn); + Connect(); + + uint32_t version = 0; + size_t off = 1; + ASSERT_EQ(1 + sizeof(kExpectVersionsWorkaround), capture->extension().len()); + for (unsigned int i = 0; i < PR_ARRAY_SIZE(kExpectVersionsWorkaround); i++) { + ASSERT_TRUE(capture->extension().Read(off, 2, &version)); + EXPECT_EQ(kExpectVersionsWorkaround[i], static_cast<uint16_t>(version)); + off += 2; + } +} + +// Verify the client sends only TLS versions in supported_versions +TEST_F(TlsConnectTest, TlsSupportedVersionsEncoding) { + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_3); + auto capture = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_supported_versions_xtn); + Connect(); + + ASSERT_EQ(9U, capture->extension().len()); + uint32_t version = 0; + ASSERT_TRUE(capture->extension().Read(1, 2, &version)); + EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_3, static_cast<int>(version)); + ASSERT_TRUE(capture->extension().Read(3, 2, &version)); + EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, static_cast<int>(version)); + ASSERT_TRUE(capture->extension().Read(5, 2, &version)); + EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_1, static_cast<int>(version)); + ASSERT_TRUE(capture->extension().Read(7, 2, &version)); + EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_0, static_cast<int>(version)); +} + +/* Test that on reception of unsupported ClientHello.legacy_version the TLS 1.3 + * server sends the correct alert. + * + * If the "supported_versions" extension is absent and the server only supports + * versions greater than ClientHello.legacy_version, the server MUST abort the + * handshake with a "protocol_version" alert [RFC8446, Appendix D.2]. */ +TEST_P(TlsConnectGenericPre13, ClientHelloUnsupportedTlsVersion) { + StartConnect(); + + if (variant_ == ssl_variant_stream) { + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, + SSL_LIBRARY_VERSION_TLS_1_3); + } else { + server_->SetVersionRange(SSL_LIBRARY_VERSION_DTLS_1_3, + SSL_LIBRARY_VERSION_DTLS_1_3); + } + + // Try to handshake + client_->Handshake(); + // Expect protocol version alert + server_->ExpectSendAlert(kTlsAlertProtocolVersion); + server_->Handshake(); + // Digest alert at peer + client_->ExpectReceiveAlert(kTlsAlertProtocolVersion); + client_->ReadBytes(); +} + +INSTANTIATE_TEST_SUITE_P( + TlsDowngradeSentinelTest, TlsDowngradeTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll, + TlsConnectTestBase::kTlsV12Plus)); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc new file mode 100644 index 0000000000..91d8080377 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc @@ -0,0 +1,385 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nss.h" +#include "secerr.h" +#include "ssl.h" +#include "ssl3prot.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +#include <iostream> + +namespace nss_test { + +std::string GetSSLVersionString(uint16_t v) { + switch (v) { + case SSL_LIBRARY_VERSION_3_0: + return "ssl3"; + case SSL_LIBRARY_VERSION_TLS_1_0: + return "tls1.0"; + case SSL_LIBRARY_VERSION_TLS_1_1: + return "tls1.1"; + case SSL_LIBRARY_VERSION_TLS_1_2: + return "tls1.2"; + case SSL_LIBRARY_VERSION_TLS_1_3: + return "tls1.3"; + case SSL_LIBRARY_VERSION_NONE: + return "NONE"; + } + if (v < SSL_LIBRARY_VERSION_3_0) { + return "undefined-too-low"; + } + return "undefined-too-high"; +} + +inline std::ostream& operator<<(std::ostream& stream, + const SSLVersionRange& vr) { + return stream << GetSSLVersionString(vr.min) << "," + << GetSSLVersionString(vr.max); +} + +class VersionRangeWithLabel { + public: + VersionRangeWithLabel(const std::string& txt, const SSLVersionRange& vr) + : label_(txt), vr_(vr) {} + VersionRangeWithLabel(const std::string& txt, uint16_t start, uint16_t end) + : label_(txt) { + vr_.min = start; + vr_.max = end; + } + VersionRangeWithLabel(const std::string& label) : label_(label) { + vr_.min = vr_.max = SSL_LIBRARY_VERSION_NONE; + } + + void WriteStream(std::ostream& stream) const { + stream << " " << label_ << ": " << vr_; + } + + uint16_t min() const { return vr_.min; } + uint16_t max() const { return vr_.max; } + SSLVersionRange range() const { return vr_; } + + private: + std::string label_; + SSLVersionRange vr_; +}; + +inline std::ostream& operator<<(std::ostream& stream, + const VersionRangeWithLabel& vrwl) { + vrwl.WriteStream(stream); + return stream; +} + +typedef std::tuple<SSLProtocolVariant, // variant + uint16_t, // policy min + uint16_t, // policy max + uint16_t, // input min + uint16_t> // input max + PolicyVersionRangeInput; + +class TestPolicyVersionRange + : public TlsConnectTestBase, + public ::testing::WithParamInterface<PolicyVersionRangeInput> { + public: + TestPolicyVersionRange() + : TlsConnectTestBase(std::get<0>(GetParam()), 0), + variant_(std::get<0>(GetParam())), + policy_("policy", std::get<1>(GetParam()), std::get<2>(GetParam())), + input_("input", std::get<3>(GetParam()), std::get<4>(GetParam())), + library_("supported-by-library", + ((variant_ == ssl_variant_stream) + ? SSL_LIBRARY_VERSION_MIN_SUPPORTED_STREAM + : SSL_LIBRARY_VERSION_MIN_SUPPORTED_DATAGRAM), + SSL_LIBRARY_VERSION_MAX_SUPPORTED) { + TlsConnectTestBase::SkipVersionChecks(); + } + + void SetPolicy(const SSLVersionRange& policy) { + NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, 0); + + SECStatus rv; + rv = NSS_OptionSet(NSS_TLS_VERSION_MIN_POLICY, policy.min); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionSet(NSS_TLS_VERSION_MAX_POLICY, policy.max); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionSet(NSS_DTLS_VERSION_MIN_POLICY, policy.min); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionSet(NSS_DTLS_VERSION_MAX_POLICY, policy.max); + ASSERT_EQ(SECSuccess, rv); + } + + void CreateDummySocket(std::shared_ptr<DummyPrSocket>* dummy_socket, + ScopedPRFileDesc* ssl_fd) { + (*dummy_socket).reset(new DummyPrSocket("dummy", variant_)); + *ssl_fd = (*dummy_socket)->CreateFD(); + if (variant_ == ssl_variant_stream) { + SSL_ImportFD(nullptr, ssl_fd->get()); + } else { + DTLS_ImportFD(nullptr, ssl_fd->get()); + } + } + + bool GetOverlap(const SSLVersionRange& r1, const SSLVersionRange& r2, + SSLVersionRange* overlap) { + if (r1.min == SSL_LIBRARY_VERSION_NONE || + r1.max == SSL_LIBRARY_VERSION_NONE || + r2.min == SSL_LIBRARY_VERSION_NONE || + r2.max == SSL_LIBRARY_VERSION_NONE) { + return false; + } + + SSLVersionRange temp; + temp.min = PR_MAX(r1.min, r2.min); + temp.max = PR_MIN(r1.max, r2.max); + + if (temp.min > temp.max) { + return false; + } + + *overlap = temp; + return true; + } + + bool IsValidInputForVersionRangeSet(SSLVersionRange* expectedEffectiveRange) { + if (input_.min() <= SSL_LIBRARY_VERSION_3_0 && + input_.max() >= SSL_LIBRARY_VERSION_TLS_1_3) { + // This is always invalid input, independent of policy + return false; + } + + if (input_.min() < library_.min() || input_.max() > library_.max() || + input_.min() > input_.max()) { + // Asking for unsupported ranges is invalid input for VersionRangeSet + // APIs, regardless of overlap. + return false; + } + + SSLVersionRange overlap_with_library; + if (!GetOverlap(input_.range(), library_.range(), &overlap_with_library)) { + return false; + } + + SSLVersionRange overlap_with_library_and_policy; + if (!GetOverlap(overlap_with_library, policy_.range(), + &overlap_with_library_and_policy)) { + return false; + } + + RemoveConflictingVersions(variant_, &overlap_with_library_and_policy); + *expectedEffectiveRange = overlap_with_library_and_policy; + return true; + } + + void RemoveConflictingVersions(SSLProtocolVariant variant, + SSLVersionRange* r) { + ASSERT_TRUE(r != nullptr); + if (r->max >= SSL_LIBRARY_VERSION_TLS_1_3 && + r->min < SSL_LIBRARY_VERSION_TLS_1_0) { + r->min = SSL_LIBRARY_VERSION_TLS_1_0; + } + } + + void SetUp() override { + TlsConnectTestBase::SetUp(); + SetPolicy(policy_.range()); + } + + void TearDown() override { + TlsConnectTestBase::TearDown(); + saved_version_policy_.RestoreOriginalPolicy(); + } + + protected: + class VersionPolicy { + public: + VersionPolicy() { SaveOriginalPolicy(); } + + void RestoreOriginalPolicy() { + SECStatus rv; + rv = NSS_OptionSet(NSS_TLS_VERSION_MIN_POLICY, saved_min_tls_); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionSet(NSS_TLS_VERSION_MAX_POLICY, saved_max_tls_); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionSet(NSS_DTLS_VERSION_MIN_POLICY, saved_min_dtls_); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionSet(NSS_DTLS_VERSION_MAX_POLICY, saved_max_dtls_); + ASSERT_EQ(SECSuccess, rv); + } + + private: + void SaveOriginalPolicy() { + SECStatus rv; + rv = NSS_OptionGet(NSS_TLS_VERSION_MIN_POLICY, &saved_min_tls_); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionGet(NSS_TLS_VERSION_MAX_POLICY, &saved_max_tls_); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionGet(NSS_DTLS_VERSION_MIN_POLICY, &saved_min_dtls_); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionGet(NSS_DTLS_VERSION_MAX_POLICY, &saved_max_dtls_); + ASSERT_EQ(SECSuccess, rv); + } + + int32_t saved_min_tls_; + int32_t saved_max_tls_; + int32_t saved_min_dtls_; + int32_t saved_max_dtls_; + }; + + VersionPolicy saved_version_policy_; + + SSLProtocolVariant variant_; + const VersionRangeWithLabel policy_; + const VersionRangeWithLabel input_; + const VersionRangeWithLabel library_; +}; + +static const uint16_t kExpandedVersionsArr[] = { + /* clang-format off */ + SSL_LIBRARY_VERSION_3_0 - 1, + SSL_LIBRARY_VERSION_3_0, + SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2, +#ifndef NSS_DISABLE_TLS_1_3 + SSL_LIBRARY_VERSION_TLS_1_3, +#endif + SSL_LIBRARY_VERSION_MAX_SUPPORTED + 1 + /* clang-format on */ +}; +static ::testing::internal::ParamGenerator<uint16_t> kExpandedVersions = + ::testing::ValuesIn(kExpandedVersionsArr); + +TEST_P(TestPolicyVersionRange, TestAllTLSVersionsAndPolicyCombinations) { + ASSERT_TRUE(variant_ == ssl_variant_stream || + variant_ == ssl_variant_datagram) + << "testing unsupported ssl variant"; + + std::cerr << "testing: " << variant_ << policy_ << input_ << library_ + << std::endl; + + SSLVersionRange supported_range; + SECStatus rv = SSL_VersionRangeGetSupported(variant_, &supported_range); + VersionRangeWithLabel supported("SSL_VersionRangeGetSupported", + supported_range); + + std::cerr << supported << std::endl; + + std::shared_ptr<DummyPrSocket> dummy_socket; + ScopedPRFileDesc ssl_fd; + CreateDummySocket(&dummy_socket, &ssl_fd); + + SECStatus rv_socket; + SSLVersionRange overlap_policy_and_lib; + if (!GetOverlap(policy_.range(), library_.range(), &overlap_policy_and_lib)) { + EXPECT_EQ(SECFailure, rv) + << "expected SSL_VersionRangeGetSupported to fail with invalid policy"; + + SSLVersionRange enabled_range; + rv = SSL_VersionRangeGetDefault(variant_, &enabled_range); + EXPECT_EQ(SECFailure, rv) + << "expected SSL_VersionRangeGetDefault to fail with invalid policy"; + + SSLVersionRange enabled_range_on_socket; + rv_socket = SSL_VersionRangeGet(ssl_fd.get(), &enabled_range_on_socket); + EXPECT_EQ(SECFailure, rv_socket) + << "expected SSL_VersionRangeGet to fail with invalid policy"; + + ConnectExpectFail(); + return; + } + + EXPECT_EQ(SECSuccess, rv) + << "expected SSL_VersionRangeGetSupported to succeed with valid policy"; + + EXPECT_TRUE(supported_range.min != SSL_LIBRARY_VERSION_NONE && + supported_range.max != SSL_LIBRARY_VERSION_NONE) + << "expected SSL_VersionRangeGetSupported to return real values with " + "valid policy"; + + RemoveConflictingVersions(variant_, &overlap_policy_and_lib); + VersionRangeWithLabel overlap_info("overlap", overlap_policy_and_lib); + + EXPECT_TRUE(supported_range == overlap_policy_and_lib) + << "expected range from GetSupported to be identical with calculated " + "overlap " + << overlap_info; + + // We don't know which versions are "enabled by default" by the library, + // therefore we don't know if there's overlap between the default + // and the policy, and therefore, we don't if TLS connections should + // be successful or fail in this combination. + // Therefore we don't test if we can connect, without having configured a + // version range explicitly. + + // Now start testing with supplied input. + + SSLVersionRange expected_effective_range; + bool is_valid_input = + IsValidInputForVersionRangeSet(&expected_effective_range); + + SSLVersionRange temp_input = input_.range(); + rv = SSL_VersionRangeSetDefault(variant_, &temp_input); + rv_socket = SSL_VersionRangeSet(ssl_fd.get(), &temp_input); + + if (!is_valid_input) { + EXPECT_EQ(SECFailure, rv) + << "expected failure return from SSL_VersionRangeSetDefault"; + + EXPECT_EQ(SECFailure, rv_socket) + << "expected failure return from SSL_VersionRangeSet"; + return; + } + + EXPECT_EQ(SECSuccess, rv) + << "expected successful return from SSL_VersionRangeSetDefault"; + + EXPECT_EQ(SECSuccess, rv_socket) + << "expected successful return from SSL_VersionRangeSet"; + + SSLVersionRange effective; + SSLVersionRange effective_socket; + + rv = SSL_VersionRangeGetDefault(variant_, &effective); + EXPECT_EQ(SECSuccess, rv) + << "expected successful return from SSL_VersionRangeGetDefault"; + + rv_socket = SSL_VersionRangeGet(ssl_fd.get(), &effective_socket); + EXPECT_EQ(SECSuccess, rv_socket) + << "expected successful return from SSL_VersionRangeGet"; + + VersionRangeWithLabel expected_info("expectation", expected_effective_range); + VersionRangeWithLabel effective_info("effectively-enabled", effective); + + EXPECT_TRUE(expected_effective_range == effective) + << "range returned by SSL_VersionRangeGetDefault doesn't match " + "expectation: " + << expected_info << effective_info; + + EXPECT_TRUE(expected_effective_range == effective_socket) + << "range returned by SSL_VersionRangeGet doesn't match " + "expectation: " + << expected_info << effective_info; + + // Because we found overlap between policy and supported versions, + // and because we have used SetDefault to enable at least one version, + // it should be possible to execute an SSL/TLS connection. + Connect(); +} + +INSTANTIATE_TEST_SUITE_P(TLSVersionRanges, TestPolicyVersionRange, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + kExpandedVersions, + kExpandedVersions, + kExpandedVersions, + kExpandedVersions)); +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/test_io.cc b/security/nss/gtests/ssl_gtest/test_io.cc new file mode 100644 index 0000000000..e4651a2352 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/test_io.cc @@ -0,0 +1,278 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "test_io.h" + +#include <algorithm> +#include <cassert> +#include <iostream> +#include <memory> + +#include "prerror.h" +#include "prlog.h" +#include "prthread.h" + +extern bool g_ssl_gtest_verbose; + +namespace nss_test { + +#define LOG(a) std::cerr << name_ << ": " << a << std::endl +#define LOGV(a) \ + do { \ + if (g_ssl_gtest_verbose) LOG(a); \ + } while (false) + +PRDescIdentity DummyPrSocket::LayerId() { + static PRDescIdentity id = PR_GetUniqueIdentity("dummysocket"); + return id; +} + +ScopedPRFileDesc DummyPrSocket::CreateFD() { + return DummyIOLayerMethods::CreateFD(DummyPrSocket::LayerId(), this); +} + +void DummyPrSocket::Reset() { + auto p = peer_.lock(); + peer_.reset(); + if (p) { + p->peer_.reset(); + p->Reset(); + } + while (!input_.empty()) { + input_.pop(); + } + filter_ = nullptr; + write_error_ = 0; +} + +void DummyPrSocket::PacketReceived(const DataBuffer &packet) { + input_.push(Packet(packet)); +} + +int32_t DummyPrSocket::Read(PRFileDesc *f, void *data, int32_t len) { + PR_ASSERT(variant_ == ssl_variant_stream); + if (variant_ != ssl_variant_stream) { + PR_SetError(PR_INVALID_METHOD_ERROR, 0); + return -1; + } + + auto dst = peer_.lock(); + if (!dst) { + PR_SetError(PR_NOT_CONNECTED_ERROR, 0); + return -1; + } + + if (input_.empty()) { + LOGV("Read --> wouldblock " << len); + PR_SetError(PR_WOULD_BLOCK_ERROR, 0); + return -1; + } + + auto &front = input_.front(); + size_t to_read = + std::min(static_cast<size_t>(len), front.len() - front.offset()); + memcpy(data, static_cast<const void *>(front.data() + front.offset()), + to_read); + front.Advance(to_read); + + if (!front.remaining()) { + input_.pop(); + } + + return static_cast<int32_t>(to_read); +} + +int32_t DummyPrSocket::Recv(PRFileDesc *f, void *buf, int32_t buflen, + int32_t flags, PRIntervalTime to) { + PR_ASSERT(flags == 0); + if (flags != 0) { + PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); + return -1; + } + + if (variant() != ssl_variant_datagram) { + return Read(f, buf, buflen); + } + + auto dst = peer_.lock(); + if (!dst) { + PR_SetError(PR_NOT_CONNECTED_ERROR, 0); + return -1; + } + + if (input_.empty()) { + PR_SetError(PR_WOULD_BLOCK_ERROR, 0); + return -1; + } + + auto &front = input_.front(); + if (static_cast<size_t>(buflen) < front.len()) { + PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0); + return -1; + } + + size_t count = front.len(); + memcpy(buf, front.data(), count); + + input_.pop(); + return static_cast<int32_t>(count); +} + +int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) { + if (write_error_) { + PR_SetError(write_error_, 0); + return -1; + } + + auto dst = peer_.lock(); + if (!dst) { + PR_SetError(PR_NOT_CONNECTED_ERROR, 0); + return -1; + } + + DataBuffer packet(static_cast<const uint8_t *>(buf), + static_cast<size_t>(length)); + DataBuffer filtered; + PacketFilter::Action action = PacketFilter::KEEP; + if (filter_) { + LOGV("Original packet: " << packet); + action = filter_->Process(packet, &filtered); + } + switch (action) { + case PacketFilter::CHANGE: + LOG("Filtered packet: " << filtered); + dst->PacketReceived(filtered); + break; + case PacketFilter::DROP: + LOG("Drop packet"); + break; + case PacketFilter::KEEP: + dst->PacketReceived(packet); + break; + } + // libssl can't handle it if this reports something other than the length + // of what was passed in (or less, but we're not doing partial writes). + return static_cast<int32_t>(packet.len()); +} + +Poller *Poller::instance; + +Poller *Poller::Instance() { + if (!instance) instance = new Poller(); + + return instance; +} + +void Poller::Shutdown() { + delete instance; + instance = nullptr; +} + +void Poller::Wait(Event event, std::shared_ptr<DummyPrSocket> &adapter, + PollTarget *target, PollCallback cb) { + assert(event < TIMER_EVENT); + if (event >= TIMER_EVENT) return; + + std::unique_ptr<Waiter> waiter; + auto it = waiters_.find(adapter); + if (it == waiters_.end()) { + waiter.reset(new Waiter(adapter)); + } else { + waiter = std::move(it->second); + } + + waiter->targets_[event] = target; + waiter->callbacks_[event] = cb; + waiters_[adapter] = std::move(waiter); +} + +void Poller::Cancel(Event event, std::shared_ptr<DummyPrSocket> &adapter) { + auto it = waiters_.find(adapter); + if (it == waiters_.end()) { + return; + } + + auto &waiter = it->second; + waiter->targets_[event] = nullptr; + waiter->callbacks_[event] = nullptr; + + // Clean up if there are no callbacks. + for (size_t i = 0; i < TIMER_EVENT; ++i) { + if (waiter->callbacks_[i]) return; + } + + waiters_.erase(adapter); +} + +void Poller::SetTimer(uint32_t timer_ms, PollTarget *target, PollCallback cb, + std::shared_ptr<Timer> *timer) { + auto t = std::make_shared<Timer>(PR_Now() + timer_ms * 1000, target, cb); + timers_.push(t); + if (timer) *timer = t; +} + +bool Poller::Poll() { + if (g_ssl_gtest_verbose) { + std::cerr << "Poll() waiters = " << waiters_.size() + << " timers = " << timers_.size() << std::endl; + } + PRIntervalTime timeout = PR_INTERVAL_NO_TIMEOUT; + PRTime now = PR_Now(); + bool fired = false; + + // Figure out the timer for the select. + if (!timers_.empty()) { + auto first_timer = timers_.top(); + if (now >= first_timer->deadline_) { + // Timer expired. + timeout = PR_INTERVAL_NO_WAIT; + } else { + timeout = + PR_MillisecondsToInterval((first_timer->deadline_ - now) / 1000); + } + } + + for (auto it = waiters_.begin(); it != waiters_.end(); ++it) { + auto &waiter = it->second; + + if (waiter->callbacks_[READABLE_EVENT]) { + if (waiter->io_->readable()) { + PollCallback callback = waiter->callbacks_[READABLE_EVENT]; + PollTarget *target = waiter->targets_[READABLE_EVENT]; + waiter->callbacks_[READABLE_EVENT] = nullptr; + waiter->targets_[READABLE_EVENT] = nullptr; + callback(target, READABLE_EVENT); + fired = true; + } + } + } + + if (fired) timeout = PR_INTERVAL_NO_WAIT; + + // Can't wait forever and also have nothing readable now. + if (timeout == PR_INTERVAL_NO_TIMEOUT) return false; + + // Sleep. + if (timeout != PR_INTERVAL_NO_WAIT) { + PR_Sleep(timeout); + } + + // Now process anything that timed out. + now = PR_Now(); + while (!timers_.empty()) { + if (now < timers_.top()->deadline_) break; + + auto timer = timers_.top(); + timers_.pop(); + if (timer->callback_) { + timer->callback_(timer->target_, TIMER_EVENT); + } + } + + return true; +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/test_io.h b/security/nss/gtests/ssl_gtest/test_io.h new file mode 100644 index 0000000000..e262fb123e --- /dev/null +++ b/security/nss/gtests/ssl_gtest/test_io.h @@ -0,0 +1,187 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef test_io_h_ +#define test_io_h_ + +#include <string.h> +#include <map> +#include <memory> +#include <ostream> +#include <queue> +#include <string> + +#include "databuffer.h" +#include "dummy_io.h" +#include "prio.h" +#include "nss_scoped_ptrs.h" +#include "sslt.h" + +namespace nss_test { + +class DataBuffer; +class DummyPrSocket; // Fwd decl. + +// Allow us to inspect a packet before it is written. +class PacketFilter { + public: + enum Action { + KEEP, // keep the original packet unmodified + CHANGE, // change the packet to a different value + DROP // drop the packet + }; + explicit PacketFilter(bool on = true) : enabled_(on) {} + virtual ~PacketFilter() {} + + bool enabled() const { return enabled_; } + + virtual Action Process(const DataBuffer& input, DataBuffer* output) { + if (!enabled_) { + return KEEP; + } + return Filter(input, output); + } + void Enable() { enabled_ = true; } + void Disable() { enabled_ = false; } + + // The packet filter takes input and has the option of mutating it. + // + // A filter that modifies the data places the modified data in *output and + // returns CHANGE. A filter that does not modify data returns LEAVE, in which + // case the value in *output is ignored. A Filter can return DROP, in which + // case the packet is dropped (and *output is ignored). + virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0; + + private: + bool enabled_; +}; + +class DummyPrSocket : public DummyIOLayerMethods { + public: + DummyPrSocket(const std::string& name, SSLProtocolVariant var) + : name_(name), + variant_(var), + peer_(), + input_(), + filter_(nullptr), + write_error_(0) {} + virtual ~DummyPrSocket() {} + + static PRDescIdentity LayerId(); + + // Create a file descriptor that will reference this object. The fd must not + // live longer than this adapter; call PR_Close() before. + ScopedPRFileDesc CreateFD(); + + std::weak_ptr<DummyPrSocket>& peer() { return peer_; } + void SetPeer(const std::shared_ptr<DummyPrSocket>& p) { peer_ = p; } + void SetPacketFilter(const std::shared_ptr<PacketFilter>& filter) { + filter_ = filter; + } + // Drops peer, packet filter and any outstanding packets. + void Reset(); + + void PacketReceived(const DataBuffer& data); + int32_t Read(PRFileDesc* f, void* data, int32_t len) override; + int32_t Recv(PRFileDesc* f, void* buf, int32_t buflen, int32_t flags, + PRIntervalTime to) override; + int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override; + void SetWriteError(PRErrorCode code) { write_error_ = code; } + + SSLProtocolVariant variant() const { return variant_; } + bool readable() const { return !input_.empty(); } + + private: + class Packet : public DataBuffer { + public: + Packet(const DataBuffer& buf) : DataBuffer(buf), offset_(0) {} + + void Advance(size_t delta) { + PR_ASSERT(offset_ + delta <= len()); + offset_ = std::min(len(), offset_ + delta); + } + + size_t offset() const { return offset_; } + size_t remaining() const { return len() - offset_; } + + private: + size_t offset_; + }; + + const std::string name_; + SSLProtocolVariant variant_; + std::weak_ptr<DummyPrSocket> peer_; + std::queue<Packet> input_; + std::shared_ptr<PacketFilter> filter_; + PRErrorCode write_error_; +}; + +// Marker interface. +class PollTarget {}; + +enum Event { READABLE_EVENT, TIMER_EVENT /* Must be last */ }; + +typedef void (*PollCallback)(PollTarget*, Event); + +class Poller { + public: + static Poller* Instance(); // Get a singleton. + static void Shutdown(); // Shut it down. + + class Timer { + public: + Timer(PRTime deadline, PollTarget* target, PollCallback callback) + : deadline_(deadline), target_(target), callback_(callback) {} + void Cancel() { callback_ = nullptr; } + + PRTime deadline_; + PollTarget* target_; + PollCallback callback_; + }; + + void Wait(Event event, std::shared_ptr<DummyPrSocket>& adapter, + PollTarget* target, PollCallback cb); + void Cancel(Event event, std::shared_ptr<DummyPrSocket>& adapter); + void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb, + std::shared_ptr<Timer>* handle); + bool Poll(); + + private: + Poller() : waiters_(), timers_() {} + ~Poller() {} + + class Waiter { + public: + Waiter(std::shared_ptr<DummyPrSocket> io) : io_(io) { + memset(&targets_[0], 0, sizeof(targets_)); + memset(&callbacks_[0], 0, sizeof(callbacks_)); + } + + void WaitFor(Event event, PollCallback callback); + + std::shared_ptr<DummyPrSocket> io_; + PollTarget* targets_[TIMER_EVENT]; + PollCallback callbacks_[TIMER_EVENT]; + }; + + class TimerComparator { + public: + bool operator()(const std::shared_ptr<Timer> lhs, + const std::shared_ptr<Timer> rhs) { + return lhs->deadline_ > rhs->deadline_; + } + }; + + static Poller* instance; + std::map<std::shared_ptr<DummyPrSocket>, std::unique_ptr<Waiter>> waiters_; + std::priority_queue<std::shared_ptr<Timer>, + std::vector<std::shared_ptr<Timer>>, TimerComparator> + timers_; +}; + +} // namespace nss_test + +#endif diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc new file mode 100644 index 0000000000..8ec2f40f75 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_agent.cc @@ -0,0 +1,1432 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "tls_agent.h" +#include "databuffer.h" +#include "keyhi.h" +#include "pk11func.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslexp.h" +#include "sslproto.h" +#include "tls_filter.h" +#include "tls_parser.h" + +// This is an internal header, used to get DTLS_1_3_DRAFT_VERSION. +#include "ssl3prot.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" + +extern std::string g_working_dir_path; + +namespace nss_test { + +const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"}; + +const std::string TlsAgent::kClient = "client"; // both sign and encrypt +const std::string TlsAgent::kRsa2048 = "rsa2048"; // bigger +const std::string TlsAgent::kRsa8192 = "rsa8192"; // biggest allowed +const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt +const std::string TlsAgent::kServerRsaSign = "rsa_sign"; +const std::string TlsAgent::kServerRsaPss = "rsa_pss"; +const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt"; +const std::string TlsAgent::kServerEcdsa256 = "ecdsa256"; +const std::string TlsAgent::kServerEcdsa384 = "ecdsa384"; +const std::string TlsAgent::kServerEcdsa521 = "ecdsa521"; +const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa"; +const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa"; +const std::string TlsAgent::kServerDsa = "dsa"; +const std::string TlsAgent::kDelegatorEcdsa256 = "delegator_ecdsa256"; +const std::string TlsAgent::kDelegatorRsae2048 = "delegator_rsae2048"; +const std::string TlsAgent::kDelegatorRsaPss2048 = "delegator_rsa_pss2048"; + +static const uint8_t kCannedTls13ServerHello[] = { + 0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3, + 0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b, + 0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76, + 0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, + 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03, + 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9, + 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08, + 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04}; + +TlsAgent::TlsAgent(const std::string& nm, Role rl, SSLProtocolVariant var) + : name_(nm), + variant_(var), + role_(rl), + server_key_bits_(0), + adapter_(new DummyPrSocket(role_str(), var)), + ssl_fd_(nullptr), + state_(STATE_INIT), + timer_handle_(nullptr), + falsestart_enabled_(false), + expected_version_(0), + expected_cipher_suite_(0), + expect_client_auth_(false), + expect_ech_(false), + expect_psk_(ssl_psk_none), + can_falsestart_hook_called_(false), + sni_hook_called_(false), + auth_certificate_hook_called_(false), + expected_received_alert_(kTlsAlertCloseNotify), + expected_received_alert_level_(kTlsAlertWarning), + expected_sent_alert_(kTlsAlertCloseNotify), + expected_sent_alert_level_(kTlsAlertWarning), + handshake_callback_called_(false), + resumption_callback_called_(false), + error_code_(0), + send_ctr_(0), + recv_ctr_(0), + expect_readwrite_error_(false), + handshake_callback_(), + auth_certificate_callback_(), + sni_callback_(), + skip_version_checks_(false), + resumption_token_(), + policy_() { + memset(&info_, 0, sizeof(info_)); + memset(&csinfo_, 0, sizeof(csinfo_)); + SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_); + EXPECT_EQ(SECSuccess, rv); +} + +TlsAgent::~TlsAgent() { + if (timer_handle_) { + timer_handle_->Cancel(); + } + + if (adapter_) { + Poller::Instance()->Cancel(READABLE_EVENT, adapter_); + } + + // Add failures manually, if any, so we don't throw in a destructor. + if (expected_received_alert_ != kTlsAlertCloseNotify || + expected_received_alert_level_ != kTlsAlertWarning) { + ADD_FAILURE() << "Wrong expected_received_alert status: " << role_str(); + } + if (expected_sent_alert_ != kTlsAlertCloseNotify || + expected_sent_alert_level_ != kTlsAlertWarning) { + ADD_FAILURE() << "Wrong expected_sent_alert status: " << role_str(); + } +} + +void TlsAgent::SetState(State s) { + if (state_ == s) return; + + LOG("Changing state from " << state_ << " to " << s); + state_ = s; +} + +/*static*/ bool TlsAgent::LoadCertificate(const std::string& name, + ScopedCERTCertificate* cert, + ScopedSECKEYPrivateKey* priv) { + cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr)); + EXPECT_NE(nullptr, cert); + if (!cert) return false; + EXPECT_NE(nullptr, cert->get()); + if (!cert->get()) return false; + + priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr)); + EXPECT_NE(nullptr, priv); + if (!priv) return false; + EXPECT_NE(nullptr, priv->get()); + if (!priv->get()) return false; + + return true; +} + +// Loads a key pair from the certificate identified by |id|. +/*static*/ bool TlsAgent::LoadKeyPairFromCert(const std::string& name, + ScopedSECKEYPublicKey* pub, + ScopedSECKEYPrivateKey* priv) { + ScopedCERTCertificate cert; + if (!TlsAgent::LoadCertificate(name, &cert, priv)) { + return false; + } + + pub->reset(SECKEY_ExtractPublicKey(&cert->subjectPublicKeyInfo)); + if (!pub->get()) { + return false; + } + + return true; +} + +void TlsAgent::DelegateCredential(const std::string& name, + const ScopedSECKEYPublicKey& dc_pub, + SSLSignatureScheme dc_cert_verify_alg, + PRUint32 dc_valid_for, PRTime now, + SECItem* dc) { + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey cert_priv; + EXPECT_TRUE(TlsAgent::LoadCertificate(name, &cert, &cert_priv)) + << "Could not load delegate certificate: " << name + << "; test db corrupt?"; + + EXPECT_EQ(SECSuccess, + SSL_DelegateCredential(cert.get(), cert_priv.get(), dc_pub.get(), + dc_cert_verify_alg, dc_valid_for, now, dc)); +} + +void TlsAgent::EnableDelegatedCredentials() { + ASSERT_TRUE(EnsureTlsSetup()); + SetOption(SSL_ENABLE_DELEGATED_CREDENTIALS, PR_TRUE); +} + +void TlsAgent::AddDelegatedCredential(const std::string& dc_name, + SSLSignatureScheme dc_cert_verify_alg, + PRUint32 dc_valid_for, PRTime now) { + ASSERT_TRUE(EnsureTlsSetup()); + + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(dc_name, &pub, &priv)); + + StackSECItem dc; + TlsAgent::DelegateCredential(name_, pub, dc_cert_verify_alg, dc_valid_for, + now, &dc); + + SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr, + nullptr, &dc, priv.get()}; + EXPECT_TRUE(ConfigServerCert(name_, true, &extra_data)); +} + +bool TlsAgent::ConfigServerCert(const std::string& id, bool updateKeyBits, + const SSLExtraServerCertData* serverCertData) { + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + if (!TlsAgent::LoadCertificate(id, &cert, &priv)) { + return false; + } + + if (updateKeyBits) { + ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get())); + EXPECT_NE(nullptr, pub.get()); + if (!pub.get()) return false; + server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get()); + } + + SECStatus rv = + SSL_ConfigSecureServer(ssl_fd(), nullptr, nullptr, ssl_kea_null); + EXPECT_EQ(SECFailure, rv); + rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData, + serverCertData ? sizeof(*serverCertData) : 0); + return rv == SECSuccess; +} + +bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { + // Don't set up twice + if (ssl_fd_) return true; + NssManagePolicy policyManage(policy_, option_); + + ScopedPRFileDesc dummy_fd(adapter_->CreateFD()); + EXPECT_NE(nullptr, dummy_fd); + if (!dummy_fd) { + return false; + } + if (adapter_->variant() == ssl_variant_stream) { + ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get())); + } else { + ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get())); + } + + EXPECT_NE(nullptr, ssl_fd_); + if (!ssl_fd_) { + return false; + } + dummy_fd.release(); // Now subsumed by ssl_fd_. + + SECStatus rv; + if (!skip_version_checks_) { + rv = SSL_VersionRangeSet(ssl_fd(), &vrange_); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + } + + ScopedCERTCertList anchors(CERT_NewCertList()); + rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get()); + if (rv != SECSuccess) return false; + + if (role_ == SERVER) { + EXPECT_TRUE(ConfigServerCert(name_, true)); + + rv = SSL_SNISocketConfigHook(ssl_fd(), SniHook, this); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + rv = SSL_SetMaxEarlyDataSize(ssl_fd(), 1024); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + } else { + rv = SSL_SetURL(ssl_fd(), "server"); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + } + + rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + // All these tests depend on having this disabled to start with. + SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_FALSE); + + return true; +} + +bool TlsAgent::MaybeSetResumptionToken() { + if (!resumption_token_.empty()) { + LOG("setting external resumption token"); + SECStatus rv = SSL_SetResumptionToken(ssl_fd(), resumption_token_.data(), + resumption_token_.size()); + + // rv is SECFailure with error set to SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR + // if the resumption token was bad (expired/malformed/etc.). + if (expect_psk_ == ssl_psk_resume) { + // Only in case we expect resumption this has to be successful. We might + // not expect resumption due to some reason but the token is totally fine. + EXPECT_EQ(SECSuccess, rv); + } + if (rv != SECSuccess) { + EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError()); + resumption_token_.clear(); + EXPECT_FALSE(expect_psk_ == ssl_psk_resume); + if (expect_psk_ == ssl_psk_resume) return false; + } + } + + return true; +} + +void TlsAgent::SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx) { + EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(ssl_fd(), ctx.get())); +} + +// Defaults to a Sync callback returning success +void TlsAgent::SetupClientAuth(ClientAuthCallbackType callbackType, + bool callbackSuccess) { + EXPECT_TRUE(EnsureTlsSetup()); + ASSERT_EQ(CLIENT, role_); + + client_auth_callback_type_ = callbackType; + client_auth_callback_success_ = callbackSuccess; + + if (callbackType == ClientAuthCallbackType::kNone && !callbackSuccess) { + // Don't set a callback for this case. + return; + } + EXPECT_EQ(SECSuccess, + SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook, + reinterpret_cast<void*>(this))); +} + +void CheckCertReqAgainstDefaultCAs(const CERTDistNames* caNames) { + ScopedCERTDistNames expected(CERT_GetSSLCACerts(nullptr)); + + ASSERT_EQ(expected->nnames, caNames->nnames); + + for (size_t i = 0; i < static_cast<size_t>(expected->nnames); ++i) { + EXPECT_EQ(SECEqual, + SECITEM_CompareItem(&(expected->names[i]), &(caNames->names[i]))); + } +} + +// Complete processing of Client Certificate Selection +// A No-op if the agent is using synchronous client cert selection. +// Otherwise, calls SSL_ClientCertCallbackComplete. +// kAsyncDelay triggers a call to SSL_ForceHandshake prior to completion to +// ensure that the socket is correctly blocked. +void TlsAgent::ClientAuthCallbackComplete() { + ASSERT_EQ(CLIENT, role_); + + if (client_auth_callback_type_ != ClientAuthCallbackType::kAsyncDelay && + client_auth_callback_type_ != ClientAuthCallbackType::kAsyncImmediate) { + return; + } + client_auth_callback_fired_++; + EXPECT_TRUE(client_auth_callback_awaiting_); + + std::cerr << "client: calling SSL_ClientCertCallbackComplete with status " + << (client_auth_callback_success_ ? "success" : "failed") + << std::endl; + + client_auth_callback_awaiting_ = false; + + if (client_auth_callback_type_ == ClientAuthCallbackType::kAsyncDelay) { + std::cerr + << "Running Handshake prior to running SSL_ClientCertCallbackComplete" + << std::endl; + SECStatus rv = SSL_ForceHandshake(ssl_fd()); + EXPECT_EQ(rv, SECFailure); + EXPECT_EQ(PORT_GetError(), PR_WOULD_BLOCK_ERROR); + } + + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + if (client_auth_callback_success_) { + ASSERT_TRUE(TlsAgent::LoadCertificate(name(), &cert, &priv)); + EXPECT_EQ(SECSuccess, + SSL_ClientCertCallbackComplete(ssl_fd(), SECSuccess, + priv.release(), cert.release())); + } else { + EXPECT_EQ(SECSuccess, SSL_ClientCertCallbackComplete(ssl_fd(), SECFailure, + nullptr, nullptr)); + } +} + +SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd, + CERTDistNames* caNames, + CERTCertificate** clientCert, + SECKEYPrivateKey** clientKey) { + TlsAgent* agent = reinterpret_cast<TlsAgent*>(self); + EXPECT_EQ(CLIENT, agent->role_); + agent->client_auth_callback_fired_++; + + switch (agent->client_auth_callback_type_) { + case ClientAuthCallbackType::kAsyncDelay: + case ClientAuthCallbackType::kAsyncImmediate: + std::cerr << "Waiting for complete call" << std::endl; + agent->client_auth_callback_awaiting_ = true; + return SECWouldBlock; + case ClientAuthCallbackType::kSync: + case ClientAuthCallbackType::kNone: + // Handle the sync case. None && Success is treated as Sync and Success. + if (!agent->client_auth_callback_success_) { + return SECFailure; + } + ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd())); + EXPECT_TRUE(peerCert) << "Client should be able to see the server cert"; + + // See bug 1573945 + // CheckCertReqAgainstDefaultCAs(caNames); + + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) { + return SECFailure; + } + + *clientCert = cert.release(); + *clientKey = priv.release(); + return SECSuccess; + } + /* This is unreachable, but some old compilers can't tell that. */ + PORT_Assert(0); + PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); + return SECFailure; +} + +// Increments by 1 for each callback +bool TlsAgent::CheckClientAuthCallbacksCompleted(uint8_t expected) { + EXPECT_EQ(CLIENT, role_); + return expected == client_auth_callback_fired_; +} + +bool TlsAgent::GetPeerChainLength(size_t* count) { + CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd()); + if (!chain) return false; + *count = 0; + + for (PRCList* cursor = PR_NEXT_LINK(&chain->list); cursor != &chain->list; + cursor = PR_NEXT_LINK(cursor)) { + CERTCertListNode* node = (CERTCertListNode*)cursor; + std::cerr << node->cert->subjectName << std::endl; + ++(*count); + } + + CERT_DestroyCertList(chain); + + return true; +} + +void TlsAgent::CheckCipherSuite(uint16_t suite) { + EXPECT_EQ(csinfo_.cipherSuite, suite); +} + +void TlsAgent::RequestClientAuth(bool requireAuth) { + ASSERT_EQ(SERVER, role_); + + SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE); + SetOption(SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE); + + EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook( + ssl_fd(), &TlsAgent::ClientAuthenticated, this)); + expect_client_auth_ = true; +} + +void TlsAgent::StartConnect(PRFileDesc* model) { + EXPECT_TRUE(EnsureTlsSetup(model)); + + SECStatus rv; + rv = SSL_ResetHandshake(ssl_fd(), role_ == SERVER ? PR_TRUE : PR_FALSE); + EXPECT_EQ(SECSuccess, rv); + SetState(STATE_CONNECTING); +} + +void TlsAgent::DisableAllCiphers() { + for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { + SECStatus rv = + SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_FALSE); + EXPECT_EQ(SECSuccess, rv); + } +} + +// Not actually all groups, just the onece that we are actually willing +// to use. +const std::vector<SSLNamedGroup> kAllDHEGroups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1, ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, + ssl_grp_ffdhe_4096, ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192}; + +const std::vector<SSLNamedGroup> kECDHEGroups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1}; + +const std::vector<SSLNamedGroup> kFFDHEGroups = { + ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096, + ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192}; + +// Defined because the big DHE groups are ridiculously slow. +const std::vector<SSLNamedGroup> kFasterDHEGroups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, + ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072}; + +void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) { + EXPECT_TRUE(EnsureTlsSetup()); + + for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { + SSLCipherSuiteInfo csinfo; + + SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, + sizeof(csinfo)); + ASSERT_EQ(SECSuccess, rv); + EXPECT_EQ(sizeof(csinfo), csinfo.length); + + if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) { + rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE); + EXPECT_EQ(SECSuccess, rv); + } + } +} + +void TlsAgent::EnableGroupsByKeyExchange(SSLKEAType kea) { + switch (kea) { + case ssl_kea_dh: + ConfigNamedGroups(kFFDHEGroups); + break; + case ssl_kea_ecdh: + ConfigNamedGroups(kECDHEGroups); + break; + default: + break; + } +} + +void TlsAgent::EnableGroupsByAuthType(SSLAuthType authType) { + if (authType == ssl_auth_ecdh_rsa || authType == ssl_auth_ecdh_ecdsa || + authType == ssl_auth_ecdsa || authType == ssl_auth_tls13_any) { + ConfigNamedGroups(kECDHEGroups); + } +} + +void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) { + EXPECT_TRUE(EnsureTlsSetup()); + + for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { + SSLCipherSuiteInfo csinfo; + + SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, + sizeof(csinfo)); + ASSERT_EQ(SECSuccess, rv); + + if ((csinfo.authType == authType) || + (csinfo.keaType == ssl_kea_tls13_any)) { + rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE); + EXPECT_EQ(SECSuccess, rv); + } + } +} + +void TlsAgent::EnableSingleCipher(uint16_t cipher) { + DisableAllCiphers(); + SECStatus rv = SSL_CipherPrefSet(ssl_fd(), cipher, PR_TRUE); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) { + EXPECT_TRUE(EnsureTlsSetup()); + SECStatus rv = SSL_NamedGroupConfig(ssl_fd(), &groups[0], groups.size()); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsAgent::Set0RttEnabled(bool en) { + SetOption(SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE); +} + +void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { + vrange_.min = minver; + vrange_.max = maxver; + + if (ssl_fd()) { + SECStatus rv = SSL_VersionRangeSet(ssl_fd(), &vrange_); + EXPECT_EQ(SECSuccess, rv); + } +} + +SECStatus ResumptionTokenCallback(PRFileDesc* fd, + const PRUint8* resumptionToken, + unsigned int len, void* ctx) { + EXPECT_NE(nullptr, resumptionToken); + if (!resumptionToken) { + return SECFailure; + } + + std::vector<uint8_t> new_token(resumptionToken, resumptionToken + len); + reinterpret_cast<TlsAgent*>(ctx)->SetResumptionToken(new_token); + reinterpret_cast<TlsAgent*>(ctx)->SetResumptionCallbackCalled(); + return SECSuccess; +} + +void TlsAgent::SetResumptionTokenCallback() { + EXPECT_TRUE(EnsureTlsSetup()); + SECStatus rv = + SSL_SetResumptionTokenCallback(ssl_fd(), ResumptionTokenCallback, this); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) { + *minver = vrange_.min; + *maxver = vrange_.max; +} + +void TlsAgent::SetExpectedVersion(uint16_t ver) { expected_version_ = ver; } + +void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; } + +void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; } + +void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; } + +void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes, + size_t count) { + EXPECT_TRUE(EnsureTlsSetup()); + EXPECT_LE(count, SSL_SignatureMaxCount()); + EXPECT_EQ(SECSuccess, + SSL_SignatureSchemePrefSet(ssl_fd(), schemes, + static_cast<unsigned int>(count))); + EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 0)) + << "setting no schemes should fail and do nothing"; + + std::vector<SSLSignatureScheme> configuredSchemes(count); + unsigned int configuredCount; + EXPECT_EQ(SECFailure, + SSL_SignatureSchemePrefGet(ssl_fd(), nullptr, &configuredCount, 1)) + << "get schemes, schemes is nullptr"; + EXPECT_EQ(SECFailure, + SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], + &configuredCount, 0)) + << "get schemes, too little space"; + EXPECT_EQ(SECFailure, + SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], nullptr, + configuredSchemes.size())) + << "get schemes, countOut is nullptr"; + + EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet( + ssl_fd(), &configuredSchemes[0], &configuredCount, + configuredSchemes.size())); + // SignatureSchemePrefSet drops unsupported algorithms silently, so the + // number that are configured might be fewer. + EXPECT_LE(configuredCount, count); + unsigned int i = 0; + for (unsigned int j = 0; j < count && i < configuredCount; ++j) { + if (i < configuredCount && schemes[j] == configuredSchemes[i]) { + ++i; + } + } + EXPECT_EQ(i, configuredCount) << "schemes in use were all set"; +} + +void TlsAgent::CheckKEA(SSLKEAType kea, SSLNamedGroup kea_group, + size_t kea_size) const { + EXPECT_EQ(STATE_CONNECTED, state_); + EXPECT_EQ(kea, info_.keaType); + if (kea_size == 0) { + switch (kea_group) { + case ssl_grp_ec_curve25519: + kea_size = 255; + break; + case ssl_grp_ec_secp256r1: + kea_size = 256; + break; + case ssl_grp_ec_secp384r1: + kea_size = 384; + break; + case ssl_grp_ffdhe_2048: + kea_size = 2048; + break; + case ssl_grp_ffdhe_3072: + kea_size = 3072; + break; + case ssl_grp_ffdhe_custom: + break; + default: + if (kea == ssl_kea_rsa) { + kea_size = server_key_bits_; + } else { + EXPECT_TRUE(false) << "need to update group sizes"; + } + } + } + if (kea_group != ssl_grp_ffdhe_custom) { + EXPECT_EQ(kea_size, info_.keaKeyBits); + EXPECT_EQ(kea_group, info_.keaGroup); + } +} + +void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const { + if (kea_group != ssl_grp_ffdhe_custom) { + EXPECT_EQ(kea_group, info_.originalKeaGroup); + } +} + +void TlsAgent::CheckAuthType(SSLAuthType auth, + SSLSignatureScheme sig_scheme) const { + EXPECT_EQ(STATE_CONNECTED, state_); + EXPECT_EQ(auth, info_.authType); + if (auth != ssl_auth_psk) { + EXPECT_EQ(server_key_bits_, info_.authKeyBits); + } + if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) { + switch (auth) { + case ssl_auth_rsa_sign: + sig_scheme = ssl_sig_rsa_pkcs1_sha1md5; + break; + case ssl_auth_ecdsa: + sig_scheme = ssl_sig_ecdsa_sha1; + break; + default: + break; + } + } + EXPECT_EQ(sig_scheme, info_.signatureScheme); + + if (info_.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) { + return; + } + + // Check authAlgorithm, which is the old value for authType. This is a second + // switch statement because default label is different. + switch (auth) { + case ssl_auth_rsa_sign: + case ssl_auth_rsa_pss: + EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm) + << "authAlgorithm for RSA is always decrypt"; + break; + case ssl_auth_ecdh_rsa: + EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm) + << "authAlgorithm for ECDH_RSA is RSA decrypt (i.e., wrong)"; + break; + case ssl_auth_ecdh_ecdsa: + EXPECT_EQ(ssl_auth_ecdsa, csinfo_.authAlgorithm) + << "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)"; + break; + default: + EXPECT_EQ(auth, csinfo_.authAlgorithm) + << "authAlgorithm is (usually) the same as authType"; + break; + } +} + +void TlsAgent::EnableFalseStart() { + EXPECT_TRUE(EnsureTlsSetup()); + + falsestart_enabled_ = true; + EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback( + ssl_fd(), CanFalseStartCallback, this)); + SetOption(SSL_ENABLE_FALSE_START, PR_TRUE); +} + +void TlsAgent::ExpectEch(bool expected) { expect_ech_ = expected; } + +void TlsAgent::ExpectPsk(SSLPskType psk) { expect_psk_ = psk; } + +void TlsAgent::ExpectResumption() { expect_psk_ = ssl_psk_resume; } + +void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) { + EXPECT_TRUE(EnsureTlsSetup()); + EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len)); +} + +void TlsAgent::AddPsk(const ScopedPK11SymKey& psk, std::string label, + SSLHashType hash, uint16_t zeroRttSuite) { + EXPECT_TRUE(EnsureTlsSetup()); + EXPECT_EQ(SECSuccess, SSL_AddExternalPsk0Rtt( + ssl_fd(), psk.get(), + reinterpret_cast<const uint8_t*>(label.data()), + label.length(), hash, zeroRttSuite, 1000)); +} + +void TlsAgent::RemovePsk(std::string label) { + EXPECT_EQ(SECSuccess, + SSL_RemoveExternalPsk( + ssl_fd(), reinterpret_cast<const uint8_t*>(label.data()), + label.length())); +} + +void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, + const std::string& expected) const { + SSLNextProtoState alpn_state; + char chosen[10]; + unsigned int chosen_len; + SECStatus rv = SSL_GetNextProto(ssl_fd(), &alpn_state, + reinterpret_cast<unsigned char*>(chosen), + &chosen_len, sizeof(chosen)); + EXPECT_EQ(SECSuccess, rv); + EXPECT_EQ(expected_state, alpn_state); + if (alpn_state == SSL_NEXT_PROTO_NO_SUPPORT) { + EXPECT_EQ("", expected); + } else { + EXPECT_NE("", expected); + EXPECT_EQ(expected, std::string(chosen, chosen_len)); + } +} + +void TlsAgent::CheckEpochs(uint16_t expected_read, + uint16_t expected_write) const { + uint16_t read_epoch = 0; + uint16_t write_epoch = 0; + EXPECT_EQ(SECSuccess, + SSL_GetCurrentEpoch(ssl_fd(), &read_epoch, &write_epoch)); + EXPECT_EQ(expected_read, read_epoch) << role_str() << " read epoch"; + EXPECT_EQ(expected_write, write_epoch) << role_str() << " write epoch"; +} + +void TlsAgent::EnableSrtp() { + EXPECT_TRUE(EnsureTlsSetup()); + const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80, + SRTP_AES128_CM_HMAC_SHA1_32}; + EXPECT_EQ(SECSuccess, + SSL_SetSRTPCiphers(ssl_fd(), ciphers, PR_ARRAY_SIZE(ciphers))); +} + +void TlsAgent::CheckSrtp() const { + uint16_t actual; + EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd(), &actual)); + EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual); +} + +void TlsAgent::CheckErrorCode(int32_t expected) const { + EXPECT_EQ(STATE_ERROR, state_); + EXPECT_EQ(expected, error_code_) + << "Got error code " << PORT_ErrorToName(error_code_) << " expecting " + << PORT_ErrorToName(expected) << std::endl; +} + +static uint8_t GetExpectedAlertLevel(uint8_t alert) { + if (alert == kTlsAlertCloseNotify) { + return kTlsAlertWarning; + } + return kTlsAlertFatal; +} + +void TlsAgent::ExpectReceiveAlert(uint8_t alert, uint8_t level) { + expected_received_alert_ = alert; + if (level == 0) { + expected_received_alert_level_ = GetExpectedAlertLevel(alert); + } else { + expected_received_alert_level_ = level; + } +} + +void TlsAgent::ExpectSendAlert(uint8_t alert, uint8_t level) { + expected_sent_alert_ = alert; + if (level == 0) { + expected_sent_alert_level_ = GetExpectedAlertLevel(alert); + } else { + expected_sent_alert_level_ = level; + } +} + +void TlsAgent::CheckAlert(bool sent, const SSLAlert* alert) { + LOG(((alert->level == kTlsAlertWarning) ? "Warning" : "Fatal") + << " alert " << (sent ? "sent" : "received") << ": " + << static_cast<int>(alert->description)); + + auto& expected = sent ? expected_sent_alert_ : expected_received_alert_; + auto& expected_level = + sent ? expected_sent_alert_level_ : expected_received_alert_level_; + /* Silently pass close_notify in case the test has already ended. */ + if (expected == kTlsAlertCloseNotify && expected_level == kTlsAlertWarning && + alert->description == expected && alert->level == expected_level) { + return; + } + + EXPECT_EQ(expected, alert->description); + EXPECT_EQ(expected_level, alert->level); + expected = kTlsAlertCloseNotify; + expected_level = kTlsAlertWarning; +} + +void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const { + ASSERT_EQ(0, error_code_); + WAIT_(error_code_ != 0, delay); + EXPECT_EQ(expected, error_code_) + << "Got error code " << PORT_ErrorToName(error_code_) << " expecting " + << PORT_ErrorToName(expected) << std::endl; +} + +void TlsAgent::CheckPreliminaryInfo() { + SSLPreliminaryChannelInfo preinfo; + EXPECT_EQ(SECSuccess, + SSL_GetPreliminaryChannelInfo(ssl_fd(), &preinfo, sizeof(preinfo))); + EXPECT_EQ(sizeof(preinfo), preinfo.length); + EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version); + + // A version of 0 is invalid and indicates no expectation. This value is + // initialized to 0 so that tests that don't explicitly set an expected + // version can negotiate a version. + if (!expected_version_) { + expected_version_ = preinfo.protocolVersion; + } + EXPECT_EQ(expected_version_, preinfo.protocolVersion); + + // As with the version; 0 is the null cipher suite (and also invalid). + if (!expected_cipher_suite_) { + expected_cipher_suite_ = preinfo.cipherSuite; + } + EXPECT_EQ(expected_cipher_suite_, preinfo.cipherSuite); +} + +// Check that all the expected callbacks have been called. +void TlsAgent::CheckCallbacks() const { + // If false start happens, the handshake is reported as being complete at the + // point that false start happens. + if (expect_psk_ == ssl_psk_resume || !falsestart_enabled_) { + EXPECT_TRUE(handshake_callback_called_); + } + + // These callbacks shouldn't fire if we are resuming, except on TLS 1.3. + if (role_ == SERVER) { + PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn); + EXPECT_EQ(((expect_psk_ != ssl_psk_resume && have_sni) || + expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3), + sni_hook_called_); + } else { + EXPECT_EQ(expect_psk_ == ssl_psk_none, auth_certificate_hook_called_); + // Note that this isn't unconditionally called, even with false start on. + // But the callback is only skipped if a cipher that is ridiculously weak + // (80 bits) is chosen. Don't test that: plan to remove bad ciphers. + EXPECT_EQ(falsestart_enabled_ && expect_psk_ != ssl_psk_resume, + can_falsestart_hook_called_); + } +} + +void TlsAgent::ResetPreliminaryInfo() { + expected_version_ = 0; + expected_cipher_suite_ = 0; +} + +void TlsAgent::UpdatePreliminaryChannelInfo() { + SECStatus rv = + SSL_GetPreliminaryChannelInfo(ssl_fd(), &pre_info_, sizeof(pre_info_)); + EXPECT_EQ(SECSuccess, rv); + EXPECT_EQ(sizeof(pre_info_), pre_info_.length); +} + +void TlsAgent::ValidateCipherSpecs() { + PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd()); + // We use one ciphersuite in each direction. + PRInt32 expected = 2; + if (variant_ == ssl_variant_datagram) { + // For DTLS 1.3, the client retains the cipher spec for early data and the + // handshake so that it can retransmit EndOfEarlyData and its final flight. + // It also retains the handshake read cipher spec so that it can read ACKs + // from the server. The server retains the handshake read cipher spec so it + // can read the client's retransmitted Finished. + if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + if (role_ == CLIENT) { + expected = info_.earlyDataAccepted ? 5 : 4; + } else { + expected = 3; + } + } else { + // For DTLS 1.1 and 1.2, the last endpoint to send maintains a cipher spec + // until the holddown timer runs down. + if (expect_psk_ == ssl_psk_resume) { + if (role_ == CLIENT) { + expected = 3; + } + } else { + if (role_ == SERVER) { + expected = 3; + } + } + } + } + // This function will be run before the handshake completes if false start is + // enabled. In that case, the client will still be reading cleartext, but + // will have a spec prepared for reading ciphertext. With DTLS, the client + // will also have a spec retained for retransmission of handshake messages. + if (role_ == CLIENT && falsestart_enabled_ && !handshake_callback_called_) { + EXPECT_GT(SSL_LIBRARY_VERSION_TLS_1_3, expected_version_); + expected = (variant_ == ssl_variant_datagram) ? 4 : 3; + } + EXPECT_EQ(expected, cipherSpecs); + if (expected != cipherSpecs) { + SSLInt_PrintCipherSpecs(role_str().c_str(), ssl_fd()); + } +} + +void TlsAgent::Connected() { + if (state_ == STATE_CONNECTED) { + return; + } + + LOG("Handshake success"); + CheckPreliminaryInfo(); + CheckCallbacks(); + + SECStatus rv = SSL_GetChannelInfo(ssl_fd(), &info_, sizeof(info_)); + EXPECT_EQ(SECSuccess, rv); + EXPECT_EQ(sizeof(info_), info_.length); + + EXPECT_EQ(expect_psk_ == ssl_psk_resume, info_.resumed == PR_TRUE); + EXPECT_EQ(expect_psk_, info_.pskType); + EXPECT_EQ(expect_ech_, info_.echAccepted); + + // Preliminary values are exposed through callbacks during the handshake. + // If either expected values were set or the callbacks were called, check + // that the final values are correct. + UpdatePreliminaryChannelInfo(); + EXPECT_EQ(expected_version_, info_.protocolVersion); + EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite); + + rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_)); + EXPECT_EQ(SECSuccess, rv); + EXPECT_EQ(sizeof(csinfo_), csinfo_.length); + + ValidateCipherSpecs(); + + SetState(STATE_CONNECTED); +} + +void TlsAgent::CheckClientAuthCompleted(uint8_t handshakes) { + EXPECT_FALSE(client_auth_callback_awaiting_); + switch (client_auth_callback_type_) { + case ClientAuthCallbackType::kNone: + if (!client_auth_callback_success_) { + EXPECT_TRUE(CheckClientAuthCallbacksCompleted(0)); + break; + } + case ClientAuthCallbackType::kSync: + EXPECT_TRUE(CheckClientAuthCallbacksCompleted(handshakes)); + break; + case ClientAuthCallbackType::kAsyncDelay: + case ClientAuthCallbackType::kAsyncImmediate: + EXPECT_TRUE(CheckClientAuthCallbacksCompleted(2 * handshakes)); + break; + } +} + +void TlsAgent::EnableExtendedMasterSecret() { + SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); +} + +void TlsAgent::CheckExtendedMasterSecret(bool expected) { + if (version() >= SSL_LIBRARY_VERSION_TLS_1_3) { + expected = PR_TRUE; + } + ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE) + << "unexpected extended master secret state for " << name_; +} + +void TlsAgent::CheckEarlyDataAccepted(bool expected) { + if (version() < SSL_LIBRARY_VERSION_TLS_1_3) { + expected = false; + } + ASSERT_EQ(expected, info_.earlyDataAccepted != PR_FALSE) + << "unexpected early data state for " << name_; +} + +void TlsAgent::CheckSecretsDestroyed() { + ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd())); +} + +void TlsAgent::SetDowngradeCheckVersion(uint16_t ver) { + ASSERT_TRUE(EnsureTlsSetup()); + + SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), ver); + ASSERT_EQ(SECSuccess, rv); +} + +void TlsAgent::Handshake() { + LOGV("Handshake"); + SECStatus rv = SSL_ForceHandshake(ssl_fd()); + if (client_auth_callback_awaiting_) { + ClientAuthCallbackComplete(); + rv = SSL_ForceHandshake(ssl_fd()); + } + if (rv == SECSuccess) { + Connected(); + Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, + &TlsAgent::ReadableCallback); + return; + } + + int32_t err = PR_GetError(); + if (err == PR_WOULD_BLOCK_ERROR) { + LOGV("Would have blocked"); + if (variant_ == ssl_variant_datagram) { + if (timer_handle_) { + timer_handle_->Cancel(); + timer_handle_ = nullptr; + } + + PRIntervalTime timeout; + rv = DTLS_GetHandshakeTimeout(ssl_fd(), &timeout); + if (rv == SECSuccess) { + Poller::Instance()->SetTimer( + timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_); + } + } + Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, + &TlsAgent::ReadableCallback); + return; + } + + LOG("Handshake failed with error " << PORT_ErrorToName(err) << ": " + << PORT_ErrorToString(err)); + error_code_ = err; + SetState(STATE_ERROR); +} + +void TlsAgent::PrepareForRenegotiate() { + EXPECT_EQ(STATE_CONNECTED, state_); + + SetState(STATE_CONNECTING); +} + +void TlsAgent::StartRenegotiate() { + PrepareForRenegotiate(); + + SECStatus rv = SSL_ReHandshake(ssl_fd(), PR_TRUE); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsAgent::SendDirect(const DataBuffer& buf) { + LOG("Send Direct " << buf); + auto peer = adapter_->peer().lock(); + if (peer) { + peer->PacketReceived(buf); + } else { + LOG("Send Direct peer absent"); + } +} + +void TlsAgent::SendRecordDirect(const TlsRecord& record) { + DataBuffer buf; + + auto rv = record.header.Write(&buf, 0, record.buffer); + EXPECT_EQ(record.header.header_length() + record.buffer.len(), rv); + SendDirect(buf); +} + +static bool ErrorIsFatal(PRErrorCode code) { + return code != PR_WOULD_BLOCK_ERROR && code != SSL_ERROR_RX_SHORT_DTLS_READ; +} + +void TlsAgent::SendData(size_t bytes, size_t blocksize) { + uint8_t block[16385]; // One larger than the maximum record size. + + ASSERT_LE(blocksize, sizeof(block)); + + while (bytes) { + size_t tosend = std::min(blocksize, bytes); + + for (size_t i = 0; i < tosend; ++i) { + block[i] = 0xff & send_ctr_; + ++send_ctr_; + } + + SendBuffer(DataBuffer(block, tosend)); + bytes -= tosend; + } +} + +void TlsAgent::SendBuffer(const DataBuffer& buf) { + LOGV("Writing " << buf.len() << " bytes"); + int32_t rv = PR_Write(ssl_fd(), buf.data(), buf.len()); + if (expect_readwrite_error_) { + EXPECT_GT(0, rv); + EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_); + error_code_ = PR_GetError(); + expect_readwrite_error_ = false; + } else { + ASSERT_EQ(buf.len(), static_cast<size_t>(rv)); + } +} + +bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, + uint64_t seq, uint8_t ct, + const DataBuffer& buf) { + // Ensure that we are doing TLS 1.3. + EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3); + if (variant_ != ssl_variant_datagram) { + ADD_FAILURE(); + return false; + } + + LOGV("Encrypting " << buf.len() << " bytes"); + uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | + kCtDtlsCiphertextLengthPresent; + TlsRecordHeader header(variant_, expected_version_, dtls13_ct, seq); + TlsRecordHeader out_header(header); + DataBuffer padded = buf; + padded.Write(padded.len(), ct, 1); + DataBuffer ciphertext; + if (!spec->Protect(header, padded, &ciphertext, &out_header)) { + return false; + } + + DataBuffer record; + auto rv = out_header.Write(&record, 0, ciphertext); + EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv); + SendDirect(record); + return true; +} + +void TlsAgent::ReadBytes(size_t amount) { + uint8_t block[16384]; + + size_t remaining = amount; + while (remaining > 0) { + int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block))); + LOGV("ReadBytes " << rv); + + if (rv > 0) { + size_t count = static_cast<size_t>(rv); + for (size_t i = 0; i < count; ++i) { + ASSERT_EQ(recv_ctr_ & 0xff, block[i]); + recv_ctr_++; + } + remaining -= rv; + } else { + PRErrorCode err = 0; + if (rv < 0) { + err = PR_GetError(); + LOG("Read error " << PORT_ErrorToName(err) << ": " + << PORT_ErrorToString(err)); + if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) { + error_code_ = err; + expect_readwrite_error_ = false; + } + } + if (err != 0 && ErrorIsFatal(err)) { + // If we hit a fatal error, we're done. + remaining = 0; + } + break; + } + } + + // If closed, then don't bother waiting around. + if (remaining) { + LOGV("Re-arming"); + Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, + &TlsAgent::ReadableCallback); + } +} + +void TlsAgent::ResetSentBytes(size_t bytes) { send_ctr_ = bytes; } + +void TlsAgent::SetOption(int32_t option, int value) { + ASSERT_TRUE(EnsureTlsSetup()); + EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), option, value)); +} + +void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { + SetOption(SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE); + SetOption(SSL_ENABLE_SESSION_TICKETS, + mode & RESUME_TICKET ? PR_TRUE : PR_FALSE); +} + +void TlsAgent::EnableECDHEServerKeyReuse() { + ASSERT_EQ(TlsAgent::SERVER, role_); + SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_TRUE); +} + +static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"}; +::testing::internal::ParamGenerator<std::string> + TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr); + +void TlsAgentTestBase::SetUp() { + SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); +} + +void TlsAgentTestBase::TearDown() { + agent_ = nullptr; + SSL_ClearSessionCache(); + SSL_ShutdownServerSessionIDCache(); +} + +void TlsAgentTestBase::Reset(const std::string& server_name) { + agent_.reset( + new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name, + role_, variant_)); + if (version_) { + agent_->SetVersionRange(version_, version_); + } + agent_->adapter()->SetPeer(sink_adapter_); + agent_->StartConnect(); +} + +void TlsAgentTestBase::EnsureInit() { + if (!agent_) { + Reset(); + } + const std::vector<SSLNamedGroup> groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, + ssl_grp_ffdhe_2048}; + agent_->ConfigNamedGroups(groups); +} + +void TlsAgentTestBase::ExpectAlert(uint8_t alert) { + EnsureInit(); + agent_->ExpectSendAlert(alert); +} + +void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, + TlsAgent::State expected_state, + int32_t error_code) { + std::cerr << "Process message: " << buffer << std::endl; + EnsureInit(); + agent_->adapter()->PacketReceived(buffer); + agent_->Handshake(); + + ASSERT_EQ(expected_state, agent_->state()); + + if (expected_state == TlsAgent::STATE_ERROR) { + ASSERT_EQ(error_code, agent_->error_code()); + } +} + +void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type, + uint16_t version, const uint8_t* buf, + size_t len, DataBuffer* out, + uint64_t sequence_number) { + // Fixup the content type for DTLSCiphertext + if (variant == ssl_variant_datagram && + version >= SSL_LIBRARY_VERSION_TLS_1_3 && + type == ssl_ct_application_data) { + type = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | + kCtDtlsCiphertextLengthPresent; + } + + size_t index = 0; + if (variant == ssl_variant_stream) { + index = out->Write(index, type, 1); + index = out->Write(index, version, 2); + } else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 && + (type & kCtDtlsCiphertextMask) == kCtDtlsCiphertext) { + uint32_t epoch = (sequence_number >> 48) & 0x3; + index = out->Write(index, type | epoch, 1); + uint32_t seqno = sequence_number & ((1ULL << 16) - 1); + index = out->Write(index, seqno, 2); + } else { + index = out->Write(index, type, 1); + index = out->Write(index, TlsVersionToDtlsVersion(version), 2); + index = out->Write(index, sequence_number >> 32, 4); + index = out->Write(index, sequence_number & PR_UINT32_MAX, 4); + } + index = out->Write(index, len, 2); + out->Write(index, buf, len); +} + +void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version, + const uint8_t* buf, size_t len, + DataBuffer* out, uint64_t seq_num) const { + MakeRecord(variant_, type, version, buf, len, out, seq_num); +} + +void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type, + const uint8_t* data, size_t hs_len, + DataBuffer* out, + uint64_t seq_num) const { + return MakeHandshakeMessageFragment(hs_type, data, hs_len, out, seq_num, 0, + 0); +} + +void TlsAgentTestBase::MakeHandshakeMessageFragment( + uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out, + uint64_t seq_num, uint32_t fragment_offset, + uint32_t fragment_length) const { + size_t index = 0; + if (!fragment_length) fragment_length = hs_len; + index = out->Write(index, hs_type, 1); // Handshake record type. + index = out->Write(index, hs_len, 3); // Handshake length + if (variant_ == ssl_variant_datagram) { + index = out->Write(index, seq_num, 2); + index = out->Write(index, fragment_offset, 3); + index = out->Write(index, fragment_length, 3); + } + if (data) { + index = out->Write(index, data, fragment_length); + } else { + for (size_t i = 0; i < fragment_length; ++i) { + index = out->Write(index, 1, 1); + } + } +} + +void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type, + size_t hs_len, + DataBuffer* out) { + size_t index = 0; + index = out->Write(index, ssl_ct_handshake, 1); // Content Type + index = out->Write(index, 3, 1); // Version high + index = out->Write(index, 1, 1); // Version low + index = out->Write(index, 4 + hs_len, 2); // Length + + index = out->Write(index, hs_type, 1); // Handshake record type. + index = out->Write(index, hs_len, 3); // Handshake length + for (size_t i = 0; i < hs_len; ++i) { + index = out->Write(index, 1, 1); + } +} + +DataBuffer TlsAgentTestBase::MakeCannedTls13ServerHello() { + DataBuffer sh(kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello)); + if (variant_ == ssl_variant_datagram) { + sh.Write(0, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 2); + // The version should be at the end. + uint32_t v; + EXPECT_TRUE(sh.Read(sh.len() - 2, 2, &v)); + EXPECT_EQ(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_3), v); + sh.Write(sh.len() - 2, 0x7f00 | DTLS_1_3_DRAFT_VERSION, 2); + } + return sh; +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h new file mode 100644 index 0000000000..35375e0c11 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_agent.h @@ -0,0 +1,588 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef tls_agent_h_ +#define tls_agent_h_ + +#include "prio.h" +#include "ssl.h" +#include "sslproto.h" + +#include <functional> +#include <iostream> + +#include "nss_policy.h" +#include "test_io.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "nss_scoped_ptrs.h" +#include "scoped_ptrs_ssl.h" + +extern bool g_ssl_gtest_verbose; + +namespace nss_test { + +#define LOG(msg) std::cerr << role_str() << ": " << msg << std::endl +#define LOGV(msg) \ + do { \ + if (g_ssl_gtest_verbose) LOG(msg); \ + } while (false) + +enum SessionResumptionMode { + RESUME_NONE = 0, + RESUME_SESSIONID = 1, + RESUME_TICKET = 2, + RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET +}; + +enum class ClientAuthCallbackType { + kAsyncImmediate, + kAsyncDelay, + kSync, + kNone, +}; + +class PacketFilter; +class TlsAgent; +class TlsCipherSpec; +struct TlsRecord; + +const extern std::vector<SSLNamedGroup> kAllDHEGroups; +const extern std::vector<SSLNamedGroup> kECDHEGroups; +const extern std::vector<SSLNamedGroup> kFFDHEGroups; +const extern std::vector<SSLNamedGroup> kFasterDHEGroups; + +// These functions are called from callbacks. They use bare pointers because +// TlsAgent sets up the callback and it doesn't know who owns it. +typedef std::function<SECStatus(TlsAgent* agent, bool checksig, bool isServer)> + AuthCertificateCallbackFunction; + +typedef std::function<void(TlsAgent* agent)> HandshakeCallbackFunction; + +typedef std::function<int32_t(TlsAgent* agent, const SECItem* srvNameArr, + PRUint32 srvNameArrSize)> + SniCallbackFunction; + +class TlsAgent : public PollTarget { + public: + enum Role { CLIENT, SERVER }; + enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED, STATE_ERROR }; + + static const std::string kClient; // the client key is sign only + static const std::string kRsa2048; // bigger sign and encrypt for either + static const std::string kRsa8192; // biggest sign and encrypt for either + static const std::string kServerRsa; // both sign and encrypt + static const std::string kServerRsaSign; + static const std::string kServerRsaPss; + static const std::string kServerRsaDecrypt; + static const std::string kServerEcdsa256; + static const std::string kServerEcdsa384; + static const std::string kServerEcdsa521; + static const std::string kServerEcdhEcdsa; + static const std::string kServerEcdhRsa; + static const std::string kServerDsa; + static const std::string kDelegatorEcdsa256; // draft-ietf-tls-subcerts + static const std::string kDelegatorRsae2048; // draft-ietf-tls-subcerts + static const std::string kDelegatorRsaPss2048; // draft-ietf-tls-subcerts + + TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant); + virtual ~TlsAgent(); + + void SetPeer(std::shared_ptr<TlsAgent>& peer) { + adapter_->SetPeer(peer->adapter_); + } + + void SetFilter(std::shared_ptr<PacketFilter> filter) { + adapter_->SetPacketFilter(filter); + } + void ClearFilter() { adapter_->SetPacketFilter(nullptr); } + + void StartConnect(PRFileDesc* model = nullptr); + void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group, + size_t kea_size = 0) const; + void CheckOriginalKEA(SSLNamedGroup kea_group) const; + void CheckAuthType(SSLAuthType auth_type, + SSLSignatureScheme sig_scheme) const; + + void DisableAllCiphers(); + void EnableCiphersByAuthType(SSLAuthType authType); + void EnableCiphersByKeyExchange(SSLKEAType kea); + void EnableGroupsByKeyExchange(SSLKEAType kea); + void EnableGroupsByAuthType(SSLAuthType authType); + void EnableSingleCipher(uint16_t cipher); + + void Handshake(); + // Marks the internal state as CONNECTING in anticipation of renegotiation. + void PrepareForRenegotiate(); + // Prepares for renegotiation, then actually triggers it. + void StartRenegotiate(); + void SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx); + + static bool LoadCertificate(const std::string& name, + ScopedCERTCertificate* cert, + ScopedSECKEYPrivateKey* priv); + static bool LoadKeyPairFromCert(const std::string& name, + ScopedSECKEYPublicKey* pub, + ScopedSECKEYPrivateKey* priv); + + // Delegated credentials. + // + // Generate a delegated credential and sign it using the certificate + // associated with |name|. + static void DelegateCredential(const std::string& name, + const ScopedSECKEYPublicKey& dcPub, + SSLSignatureScheme dcCertVerifyAlg, + PRUint32 dcValidFor, PRTime now, SECItem* dc); + // Indicate support for the delegated credentials extension. + void EnableDelegatedCredentials(); + // Generate and configure a delegated credential to use in the handshake with + // clients that support this extension.. + void AddDelegatedCredential(const std::string& dc_name, + SSLSignatureScheme dcCertVerifyAlg, + PRUint32 dcValidFor, PRTime now); + void UpdatePreliminaryChannelInfo(); + + bool ConfigServerCert(const std::string& name, bool updateKeyBits = false, + const SSLExtraServerCertData* serverCertData = nullptr); + bool ConfigServerCertWithChain(const std::string& name); + bool EnsureTlsSetup(PRFileDesc* modelSocket = nullptr); + + void SetupClientAuth( + ClientAuthCallbackType callbackType = ClientAuthCallbackType::kSync, + bool callbackSuccess = true); + void RequestClientAuth(bool requireAuth); + void ClientAuthCallbackComplete(); + bool CheckClientAuthCallbacksCompleted(uint8_t expected); + void CheckClientAuthCompleted(uint8_t handshakes = 1); + void SetOption(int32_t option, int value); + void ConfigureSessionCache(SessionResumptionMode mode); + void Set0RttEnabled(bool en); + void SetFallbackSCSVEnabled(bool en); + void SetVersionRange(uint16_t minver, uint16_t maxver); + void GetVersionRange(uint16_t* minver, uint16_t* maxver); + void CheckPreliminaryInfo(); + void ResetPreliminaryInfo(); + void SetExpectedVersion(uint16_t version); + void SetServerKeyBits(uint16_t bits); + void ExpectReadWriteError(); + void EnableFalseStart(); + void ExpectEch(bool expected = true); + bool GetEchExpected() const { return expect_ech_; } + void ExpectPsk(SSLPskType psk = ssl_psk_external); + void ExpectResumption(); + void SkipVersionChecks(); + void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count); + void EnableAlpn(const uint8_t* val, size_t len); + void CheckAlpn(SSLNextProtoState expected_state, + const std::string& expected = "") const; + void EnableSrtp(); + void CheckSrtp() const; + void CheckEpochs(uint16_t expected_read, uint16_t expected_write) const; + void CheckErrorCode(int32_t expected) const; + void WaitForErrorCode(int32_t expected, uint32_t delay) const; + // Send data on the socket, encrypting it. + void SendData(size_t bytes, size_t blocksize = 1024); + void SendBuffer(const DataBuffer& buf); + bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, + uint64_t seq, uint8_t ct, const DataBuffer& buf); + // Send data directly to the underlying socket, skipping the TLS layer. + void SendDirect(const DataBuffer& buf); + void SendRecordDirect(const TlsRecord& record); + void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash, + uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL); + void RemovePsk(std::string label); + void ReadBytes(size_t max = 16384U); + void ResetSentBytes(size_t bytes = 0); // Hack to test drops. + void EnableExtendedMasterSecret(); + void CheckExtendedMasterSecret(bool expected); + void CheckEarlyDataAccepted(bool expected); + void CheckEchAccepted(bool expected); + void SetDowngradeCheckVersion(uint16_t version); + void CheckSecretsDestroyed(); + void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); + void EnableECDHEServerKeyReuse(); + bool GetPeerChainLength(size_t* count); + void CheckCipherSuite(uint16_t cipher_suite); + void SetResumptionTokenCallback(); + bool MaybeSetResumptionToken(); + void SetResumptionToken(const std::vector<uint8_t>& resumption_token) { + resumption_token_ = resumption_token; + } + const std::vector<uint8_t>& GetResumptionToken() const { + return resumption_token_; + } + void GetTokenInfo(ScopedSSLResumptionTokenInfo& token) { + SECStatus rv = SSL_GetResumptionTokenInfo( + resumption_token_.data(), resumption_token_.size(), token.get(), + sizeof(SSLResumptionTokenInfo)); + ASSERT_EQ(SECSuccess, rv); + } + void SetResumptionCallbackCalled() { resumption_callback_called_ = true; } + bool resumption_callback_called() const { + return resumption_callback_called_; + } + + const std::string& name() const { return name_; } + + Role role() const { return role_; } + std::string role_str() const { return role_ == SERVER ? "server" : "client"; } + + SSLProtocolVariant variant() const { return variant_; } + + State state() const { return state_; } + + const CERTCertificate* peer_cert() const { + return SSL_PeerCertificate(ssl_fd_.get()); + } + + const char* state_str() const { return state_str(state()); } + + static const char* state_str(State state) { return states[state]; } + + NssManagedFileDesc ssl_fd() const { + return NssManagedFileDesc(ssl_fd_.get(), policy_, option_); + } + std::shared_ptr<DummyPrSocket>& adapter() { return adapter_; } + + const SSLChannelInfo& info() const { + EXPECT_EQ(STATE_CONNECTED, state_); + return info_; + } + + const SSLPreliminaryChannelInfo& pre_info() const { return pre_info_; } + + bool is_compressed() const { + return info().compressionMethod != ssl_compression_null; + } + uint16_t server_key_bits() const { return server_key_bits_; } + uint16_t min_version() const { return vrange_.min; } + uint16_t max_version() const { return vrange_.max; } + uint16_t version() const { return info().protocolVersion; } + + bool cipher_suite(uint16_t* suite) const { + if (state_ != STATE_CONNECTED) return false; + + *suite = info_.cipherSuite; + return true; + } + + void expected_cipher_suite(uint16_t suite) { expected_cipher_suite_ = suite; } + + std::string cipher_suite_name() const { + if (state_ != STATE_CONNECTED) return "UNKNOWN"; + + return csinfo_.cipherSuiteName; + } + + std::vector<uint8_t> session_id() const { + return std::vector<uint8_t>(info_.sessionID, + info_.sessionID + info_.sessionIDLength); + } + + bool auth_type(SSLAuthType* a) const { + if (state_ != STATE_CONNECTED) return false; + + *a = info_.authType; + return true; + } + + bool kea_type(SSLKEAType* k) const { + if (state_ != STATE_CONNECTED) return false; + + *k = info_.keaType; + return true; + } + + size_t received_bytes() const { return recv_ctr_; } + PRErrorCode error_code() const { return error_code_; } + + bool can_falsestart_hook_called() const { + return can_falsestart_hook_called_; + } + + void SetHandshakeCallback(HandshakeCallbackFunction handshake_callback) { + handshake_callback_ = handshake_callback; + } + + void SetAuthCertificateCallback( + AuthCertificateCallbackFunction auth_certificate_callback) { + auth_certificate_callback_ = auth_certificate_callback; + } + + void SetSniCallback(SniCallbackFunction sni_callback) { + sni_callback_ = sni_callback; + } + + void ExpectReceiveAlert(uint8_t alert, uint8_t level = 0); + void ExpectSendAlert(uint8_t alert, uint8_t level = 0); + + std::string alpn_value_to_use_ = ""; + // set the given policy before this agent runs + void SetPolicy(SECOidTag oid, PRUint32 set, PRUint32 clear) { + policy_ = NssPolicy(oid, set, clear); + } + void SetNssOption(PRInt32 id, PRInt32 value) { + option_ = NssOption(id, value); + } + + private: + const static char* states[]; + + void SetState(State state); + void ValidateCipherSpecs(); + + // Dummy auth certificate hook. + static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd, + PRBool checksig, PRBool isServer) { + TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); + agent->CheckPreliminaryInfo(); + agent->auth_certificate_hook_called_ = true; + if (agent->auth_certificate_callback_) { + return agent->auth_certificate_callback_(agent, checksig ? true : false, + isServer ? true : false); + } + return SECSuccess; + } + + // Client auth certificate hook. + static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd, + PRBool checksig, PRBool isServer) { + TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); + EXPECT_TRUE(agent->expect_client_auth_); + EXPECT_EQ(PR_TRUE, isServer); + if (agent->auth_certificate_callback_) { + return agent->auth_certificate_callback_(agent, checksig ? true : false, + isServer ? true : false); + } + return SECSuccess; + } + + static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd, + CERTDistNames* caNames, + CERTCertificate** cert, + SECKEYPrivateKey** privKey); + + static void ReadableCallback(PollTarget* self, Event event) { + TlsAgent* agent = static_cast<TlsAgent*>(self); + if (event == TIMER_EVENT) { + agent->timer_handle_ = nullptr; + } + agent->ReadableCallback_int(); + } + + void ReadableCallback_int() { + LOGV("Readable"); + switch (state_) { + case STATE_CONNECTING: + Handshake(); + break; + case STATE_CONNECTED: + ReadBytes(); + break; + default: + break; + } + } + + static PRInt32 SniHook(PRFileDesc* fd, const SECItem* srvNameArr, + PRUint32 srvNameArrSize, void* arg) { + TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); + agent->CheckPreliminaryInfo(); + agent->sni_hook_called_ = true; + EXPECT_EQ(1UL, srvNameArrSize); + if (agent->sni_callback_) { + return agent->sni_callback_(agent, srvNameArr, srvNameArrSize); + } + return 0; // First configuration. + } + + static SECStatus CanFalseStartCallback(PRFileDesc* fd, void* arg, + PRBool* canFalseStart) { + TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); + agent->CheckPreliminaryInfo(); + EXPECT_TRUE(agent->falsestart_enabled_); + EXPECT_FALSE(agent->can_falsestart_hook_called_); + agent->can_falsestart_hook_called_ = true; + *canFalseStart = true; + return SECSuccess; + } + + void CheckAlert(bool sent, const SSLAlert* alert); + + static void AlertReceivedCallback(const PRFileDesc* fd, void* arg, + const SSLAlert* alert) { + reinterpret_cast<TlsAgent*>(arg)->CheckAlert(false, alert); + } + + static void AlertSentCallback(const PRFileDesc* fd, void* arg, + const SSLAlert* alert) { + reinterpret_cast<TlsAgent*>(arg)->CheckAlert(true, alert); + } + + static void HandshakeCallback(PRFileDesc* fd, void* arg) { + TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); + agent->handshake_callback_called_ = true; + agent->Connected(); + if (agent->handshake_callback_) { + agent->handshake_callback_(agent); + } + } + + void DisableLameGroups(); + void ConfigStrongECGroups(bool en); + void ConfigAllDHGroups(bool en); + void CheckCallbacks() const; + void Connected(); + + const std::string name_; + SSLProtocolVariant variant_; + Role role_; + uint16_t server_key_bits_; + std::shared_ptr<DummyPrSocket> adapter_; + ScopedPRFileDesc ssl_fd_; + State state_; + std::shared_ptr<Poller::Timer> timer_handle_; + bool falsestart_enabled_; + uint16_t expected_version_; + uint16_t expected_cipher_suite_; + bool expect_client_auth_; + bool expect_ech_; + SSLPskType expect_psk_; + bool can_falsestart_hook_called_; + bool sni_hook_called_; + bool auth_certificate_hook_called_; + uint8_t expected_received_alert_; + uint8_t expected_received_alert_level_; + uint8_t expected_sent_alert_; + uint8_t expected_sent_alert_level_; + bool handshake_callback_called_; + bool resumption_callback_called_; + SSLChannelInfo info_; + SSLPreliminaryChannelInfo pre_info_; + SSLCipherSuiteInfo csinfo_; + SSLVersionRange vrange_; + PRErrorCode error_code_; + size_t send_ctr_; + size_t recv_ctr_; + bool expect_readwrite_error_; + HandshakeCallbackFunction handshake_callback_; + AuthCertificateCallbackFunction auth_certificate_callback_; + SniCallbackFunction sni_callback_; + bool skip_version_checks_; + std::vector<uint8_t> resumption_token_; + NssPolicy policy_; + NssOption option_; + ClientAuthCallbackType client_auth_callback_type_ = + ClientAuthCallbackType::kNone; + bool client_auth_callback_success_ = false; + uint8_t client_auth_callback_fired_ = 0; + bool client_auth_callback_awaiting_ = false; +}; + +inline std::ostream& operator<<(std::ostream& stream, + const TlsAgent::State& state) { + return stream << TlsAgent::state_str(state); +} + +class TlsAgentTestBase : public ::testing::Test { + public: + static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll; + + TlsAgentTestBase(TlsAgent::Role role, SSLProtocolVariant variant, + uint16_t version = 0) + : agent_(nullptr), + role_(role), + variant_(variant), + version_(version), + sink_adapter_(new DummyPrSocket("sink", variant)) {} + virtual ~TlsAgentTestBase() {} + + void SetUp(); + void TearDown(); + + void ExpectAlert(uint8_t alert); + + static void MakeRecord(SSLProtocolVariant variant, uint8_t type, + uint16_t version, const uint8_t* buf, size_t len, + DataBuffer* out, uint64_t seq_num = 0); + void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf, + size_t len, DataBuffer* out, uint64_t seq_num = 0) const; + void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len, + DataBuffer* out, uint64_t seq_num = 0) const; + void MakeHandshakeMessageFragment(uint8_t hs_type, const uint8_t* data, + size_t hs_len, DataBuffer* out, + uint64_t seq_num, uint32_t fragment_offset, + uint32_t fragment_length) const; + DataBuffer MakeCannedTls13ServerHello(); + static void MakeTrivialHandshakeRecord(uint8_t hs_type, size_t hs_len, + DataBuffer* out); + static inline TlsAgent::Role ToRole(const std::string& str) { + return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER; + } + + void Init(const std::string& server_name = TlsAgent::kServerRsa); + void Reset(const std::string& server_name = TlsAgent::kServerRsa); + + protected: + void EnsureInit(); + void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, + int32_t error_code = 0); + + std::shared_ptr<TlsAgent> agent_; + TlsAgent::Role role_; + SSLProtocolVariant variant_; + uint16_t version_; + // This adapter is here just to accept packets from this agent. + std::shared_ptr<DummyPrSocket> sink_adapter_; +}; + +class TlsAgentTest + : public TlsAgentTestBase, + public ::testing::WithParamInterface< + std::tuple<std::string, SSLProtocolVariant, uint16_t>> { + public: + TlsAgentTest() + : TlsAgentTestBase(ToRole(std::get<0>(GetParam())), + std::get<1>(GetParam()), std::get<2>(GetParam())) {} +}; + +class TlsAgentTestClient : public TlsAgentTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsAgentTestClient() + : TlsAgentTestBase(TlsAgent::CLIENT, std::get<0>(GetParam()), + std::get<1>(GetParam())) {} +}; + +class TlsAgentTestClient13 : public TlsAgentTestClient {}; + +class TlsAgentStreamTestClient : public TlsAgentTestBase { + public: + TlsAgentStreamTestClient() + : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_stream) {} +}; + +class TlsAgentStreamTestServer : public TlsAgentTestBase { + public: + TlsAgentStreamTestServer() + : TlsAgentTestBase(TlsAgent::SERVER, ssl_variant_stream) {} +}; + +class TlsAgentDgramTestClient : public TlsAgentTestBase { + public: + TlsAgentDgramTestClient() + : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_datagram) {} +}; + +inline bool operator==(const SSLVersionRange& vr1, const SSLVersionRange& vr2) { + return vr1.min == vr2.min && vr1.max == vr2.max; +} + +} // namespace nss_test + +#endif diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc new file mode 100644 index 0000000000..fd10e34a79 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_connect.cc @@ -0,0 +1,1065 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "tls_connect.h" +#include "sslexp.h" +extern "C" { +#include "libssl_internals.h" +} + +#include <iostream> + +#include "databuffer.h" +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "sslproto.h" + +extern std::string g_working_dir_path; + +namespace nss_test { + +static const SSLProtocolVariant kTlsVariantsStreamArr[] = {ssl_variant_stream}; +::testing::internal::ParamGenerator<SSLProtocolVariant> + TlsConnectTestBase::kTlsVariantsStream = + ::testing::ValuesIn(kTlsVariantsStreamArr); +static const SSLProtocolVariant kTlsVariantsDatagramArr[] = { + ssl_variant_datagram}; +::testing::internal::ParamGenerator<SSLProtocolVariant> + TlsConnectTestBase::kTlsVariantsDatagram = + ::testing::ValuesIn(kTlsVariantsDatagramArr); +static const SSLProtocolVariant kTlsVariantsAllArr[] = {ssl_variant_stream, + ssl_variant_datagram}; +::testing::internal::ParamGenerator<SSLProtocolVariant> + TlsConnectTestBase::kTlsVariantsAll = + ::testing::ValuesIn(kTlsVariantsAllArr); + +static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10 = + ::testing::ValuesIn(kTlsV10Arr); +static const uint16_t kTlsV11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11 = + ::testing::ValuesIn(kTlsV11Arr); +static const uint16_t kTlsV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_2}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12 = + ::testing::ValuesIn(kTlsV12Arr); +static const uint16_t kTlsV10V11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_1}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10V11 = + ::testing::ValuesIn(kTlsV10V11Arr); +static const uint16_t kTlsV10ToV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10ToV12 = + ::testing::ValuesIn(kTlsV10ToV12Arr); +static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11V12 = + ::testing::ValuesIn(kTlsV11V12Arr); + +static const uint16_t kTlsV11PlusArr[] = { +#ifndef NSS_DISABLE_TLS_1_3 + SSL_LIBRARY_VERSION_TLS_1_3, +#endif + SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11Plus = + ::testing::ValuesIn(kTlsV11PlusArr); +static const uint16_t kTlsV12PlusArr[] = { +#ifndef NSS_DISABLE_TLS_1_3 + SSL_LIBRARY_VERSION_TLS_1_3, +#endif + SSL_LIBRARY_VERSION_TLS_1_2}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12Plus = + ::testing::ValuesIn(kTlsV12PlusArr); +static const uint16_t kTlsV13Arr[] = {SSL_LIBRARY_VERSION_TLS_1_3}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV13 = + ::testing::ValuesIn(kTlsV13Arr); +static const uint16_t kTlsVAllArr[] = { +#ifndef NSS_DISABLE_TLS_1_3 + SSL_LIBRARY_VERSION_TLS_1_3, +#endif + SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_0}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsVAll = + ::testing::ValuesIn(kTlsVAllArr); + +std::string VersionString(uint16_t version) { + switch (version) { + case 0: + return "(no version)"; + case SSL_LIBRARY_VERSION_3_0: + return "1.0"; + case SSL_LIBRARY_VERSION_TLS_1_0: + return "1.0"; + case SSL_LIBRARY_VERSION_TLS_1_1: + return "1.1"; + case SSL_LIBRARY_VERSION_TLS_1_2: + return "1.2"; + case SSL_LIBRARY_VERSION_TLS_1_3: + return "1.3"; + default: + std::cerr << "Invalid version: " << version << std::endl; + EXPECT_TRUE(false); + return ""; + } +} + +// The default anti-replay window for tests. Tests that rely on a different +// value call ResetAntiReplay directly. +static PRTime kAntiReplayWindow = 100 * PR_USEC_PER_SEC; + +TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant, + uint16_t version) + : variant_(variant), + client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)), + server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)), + client_model_(nullptr), + server_model_(nullptr), + version_(version), + expected_resumption_mode_(RESUME_NONE), + expected_resumptions_(0), + session_ids_(), + expect_extended_master_secret_(false), + expect_early_data_accepted_(false), + skip_version_checks_(false) { + std::string v; + if (variant_ == ssl_variant_datagram && + version_ == SSL_LIBRARY_VERSION_TLS_1_1) { + v = "1.0"; + } else { + v = VersionString(version_); + } + std::cerr << "Version: " << variant_ << " " << v << std::endl; +} + +TlsConnectTestBase::~TlsConnectTestBase() {} + +// Check the group of each of the supported groups +void TlsConnectTestBase::CheckGroups( + const DataBuffer& groups, std::function<void(SSLNamedGroup)> check_group) { + DuplicateGroupChecker group_set; + uint32_t tmp = 0; + EXPECT_TRUE(groups.Read(0, 2, &tmp)); + EXPECT_EQ(groups.len() - 2, static_cast<size_t>(tmp)); + for (size_t i = 2; i < groups.len(); i += 2) { + EXPECT_TRUE(groups.Read(i, 2, &tmp)); + SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp); + group_set.AddAndCheckGroup(group); + check_group(group); + } +} + +// Check the group of each of the shares +void TlsConnectTestBase::CheckShares( + const DataBuffer& shares, std::function<void(SSLNamedGroup)> check_group) { + DuplicateGroupChecker group_set; + uint32_t tmp = 0; + EXPECT_TRUE(shares.Read(0, 2, &tmp)); + EXPECT_EQ(shares.len() - 2, static_cast<size_t>(tmp)); + size_t i; + for (i = 2; i < shares.len(); i += 4 + tmp) { + ASSERT_TRUE(shares.Read(i, 2, &tmp)); + SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp); + group_set.AddAndCheckGroup(group); + check_group(group); + ASSERT_TRUE(shares.Read(i + 2, 2, &tmp)); + } + EXPECT_EQ(shares.len(), i); +} + +void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch, + uint16_t server_epoch) const { + client_->CheckEpochs(server_epoch, client_epoch); + server_->CheckEpochs(client_epoch, server_epoch); +} + +void TlsConnectTestBase::ClearStats() { + // Clear statistics. + SSL3Statistics* stats = SSL_GetStatistics(); + memset(stats, 0, sizeof(*stats)); +} + +void TlsConnectTestBase::ClearServerCache() { + SSL_ShutdownServerSessionIDCache(); + SSLInt_ClearSelfEncryptKey(); + SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); +} + +void TlsConnectTestBase::SaveAlgorithmPolicy() { + saved_policies_.clear(); + for (auto it = algorithms_.begin(); it != algorithms_.end(); ++it) { + uint32_t policy; + SECStatus rv = NSS_GetAlgorithmPolicy(*it, &policy); + ASSERT_EQ(SECSuccess, rv); + saved_policies_.push_back(std::make_tuple(*it, policy)); + } + saved_options_.clear(); + for (auto it : options_) { + int32_t option; + SECStatus rv = NSS_OptionGet(it, &option); + ASSERT_EQ(SECSuccess, rv); + saved_options_.push_back(std::make_tuple(it, option)); + } +} + +void TlsConnectTestBase::RestoreAlgorithmPolicy() { + for (auto it = saved_policies_.begin(); it != saved_policies_.end(); ++it) { + auto algorithm = std::get<0>(*it); + auto policy = std::get<1>(*it); + SECStatus rv = NSS_SetAlgorithmPolicy( + algorithm, policy, NSS_USE_POLICY_IN_SSL | NSS_USE_ALG_IN_SSL_KX); + ASSERT_EQ(SECSuccess, rv); + } + for (auto it = saved_options_.begin(); it != saved_options_.end(); ++it) { + auto option_id = std::get<0>(*it); + auto option = std::get<1>(*it); + SECStatus rv = NSS_OptionSet(option_id, option); + ASSERT_EQ(SECSuccess, rv); + } +} + +PRTime TlsConnectTestBase::TimeFunc(void* arg) { + return *reinterpret_cast<PRTime*>(arg); +} + +void TlsConnectTestBase::SetUp() { + SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); + SSLInt_ClearSelfEncryptKey(); + now_ = PR_Now(); + ResetAntiReplay(kAntiReplayWindow); + ClearStats(); + SaveAlgorithmPolicy(); + Init(); +} + +void TlsConnectTestBase::TearDown() { + client_ = nullptr; + server_ = nullptr; + + SSL_ClearSessionCache(); + SSLInt_ClearSelfEncryptKey(); + SSL_ShutdownServerSessionIDCache(); + RestoreAlgorithmPolicy(); +} + +void TlsConnectTestBase::Init() { + client_->SetPeer(server_); + server_->SetPeer(client_); + + if (version_) { + ConfigureVersion(version_); + } +} + +void TlsConnectTestBase::ResetAntiReplay(PRTime window) { + SSLAntiReplayContext* p_anti_replay = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateAntiReplayContext(now_, window, 1, 3, &p_anti_replay)); + EXPECT_NE(nullptr, p_anti_replay); + anti_replay_.reset(p_anti_replay); +} + +ScopedSECItem TlsConnectTestBase::MakeEcKeyParams(SSLNamedGroup group) { + auto groupDef = ssl_LookupNamedGroup(group); + EXPECT_NE(nullptr, groupDef); + + auto oidData = SECOID_FindOIDByTag(groupDef->oidTag); + EXPECT_NE(nullptr, oidData); + ScopedSECItem params( + SECITEM_AllocItem(nullptr, nullptr, (2 + oidData->oid.len))); + EXPECT_TRUE(!!params); + params->data[0] = SEC_ASN1_OBJECT_ID; + params->data[1] = oidData->oid.len; + memcpy(params->data + 2, oidData->oid.data, oidData->oid.len); + return params; +} + +void TlsConnectTestBase::GenerateEchConfig( + HpkeKemId kem_id, const std::vector<HpkeSymmetricSuite>& cipher_suites, + const std::string& public_name, uint16_t max_name_len, DataBuffer& record, + ScopedSECKEYPublicKey& pubKey, ScopedSECKEYPrivateKey& privKey) { + bool gen_keys = !pubKey && !privKey; + + SECKEYPublicKey* pub = nullptr; + SECKEYPrivateKey* priv = nullptr; + + if (gen_keys) { + ScopedSECItem ecParams = MakeEcKeyParams(ssl_grp_ec_curve25519); + priv = SECKEY_CreateECPrivateKey(ecParams.get(), &pub, nullptr); + } else { + priv = privKey.get(); + pub = pubKey.get(); + } + ASSERT_NE(nullptr, priv); + PRUint8 encoded[1024]; + unsigned int encoded_len = 0; + SECStatus rv = SSL_EncodeEchConfigId( + 77, public_name.c_str(), max_name_len, kem_id, pub, cipher_suites.data(), + cipher_suites.size(), encoded, &encoded_len, sizeof(encoded)); + EXPECT_EQ(SECSuccess, rv); + EXPECT_GT(encoded_len, 0U); + + if (gen_keys) { + pubKey.reset(pub); + privKey.reset(priv); + } + record.Truncate(0); + record.Write(0, encoded, encoded_len); +} + +void TlsConnectTestBase::SetupEch(std::shared_ptr<TlsAgent>& client, + std::shared_ptr<TlsAgent>& server, + HpkeKemId kem_id, bool expect_ech, + bool set_client_config, + bool set_server_config, int max_name_len) { + EXPECT_TRUE(set_server_config || set_client_config); + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer record; + static const std::vector<HpkeSymmetricSuite> kDefaultSuites = { + {HpkeKdfHkdfSha256, HpkeAeadChaCha20Poly1305}, + {HpkeKdfHkdfSha256, HpkeAeadAes128Gcm}}; + + GenerateEchConfig(kem_id, kDefaultSuites, "public.name", max_name_len, record, + pub, priv); + ASSERT_NE(0U, record.len()); + SECStatus rv; + if (set_server_config) { + rv = SSL_SetServerEchConfigs(server->ssl_fd(), pub.get(), priv.get(), + record.data(), record.len()); + ASSERT_EQ(SECSuccess, rv); + } + if (set_client_config) { + rv = SSL_SetClientEchConfigs(client->ssl_fd(), record.data(), record.len()); + ASSERT_EQ(SECSuccess, rv); + } + + /* Filter expect_ech, which typically defaults to true. Parameterized tests + * running DTLS or TLS < 1.3 should expect only a non-ECH result. */ + bool expect = expect_ech && variant_ != ssl_variant_datagram && + version_ >= SSL_LIBRARY_VERSION_TLS_1_3 && set_client_config && + set_server_config; + client->ExpectEch(expect); + server->ExpectEch(expect); +} + +void TlsConnectTestBase::Reset() { + // Take a copy of the names because they are about to disappear. + std::string server_name = server_->name(); + std::string client_name = client_->name(); + Reset(server_name, client_name); +} + +void TlsConnectTestBase::Reset(const std::string& server_name, + const std::string& client_name) { + auto token = client_->GetResumptionToken(); + client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_)); + client_->SetResumptionToken(token); + server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_)); + if (skip_version_checks_) { + client_->SkipVersionChecks(); + server_->SkipVersionChecks(); + } + + std::cerr << "Reset server:" << server_name << ", client:" << client_name + << std::endl; + Init(); +} + +void TlsConnectTestBase::MakeNewServer() { + auto replacement = std::make_shared<TlsAgent>( + server_->name(), TlsAgent::SERVER, server_->variant()); + server_ = replacement; + if (version_) { + server_->SetVersionRange(version_, version_); + } + client_->SetPeer(server_); + server_->SetPeer(client_); + server_->StartConnect(); +} + +void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected, + uint8_t num_resumptions) { + expected_resumption_mode_ = expected; + if (expected != RESUME_NONE) { + client_->ExpectResumption(); + server_->ExpectResumption(); + expected_resumptions_ = num_resumptions; + } + EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE); +} + +void TlsConnectTestBase::EnsureTlsSetup() { + EXPECT_TRUE(server_->EnsureTlsSetup( + server_model_ ? server_model_->ssl_fd().get() : nullptr)); + EXPECT_TRUE(client_->EnsureTlsSetup( + client_model_ ? client_model_->ssl_fd().get() : nullptr)); + server_->SetAntiReplayContext(anti_replay_); + EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(client_->ssl_fd(), + TlsConnectTestBase::TimeFunc, &now_)); + EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(server_->ssl_fd(), + TlsConnectTestBase::TimeFunc, &now_)); +} + +void TlsConnectTestBase::Handshake() { + client_->SetServerKeyBits(server_->server_key_bits()); + client_->Handshake(); + server_->Handshake(); + + ASSERT_TRUE_WAIT((client_->state() != TlsAgent::STATE_CONNECTING) && + (server_->state() != TlsAgent::STATE_CONNECTING), + 5000); +} + +void TlsConnectTestBase::EnableExtendedMasterSecret() { + client_->EnableExtendedMasterSecret(); + server_->EnableExtendedMasterSecret(); + ExpectExtendedMasterSecret(true); +} + +void TlsConnectTestBase::Connect() { + StartConnect(); + client_->MaybeSetResumptionToken(); + Handshake(); + CheckConnected(); +} + +void TlsConnectTestBase::StartConnect() { + EnsureTlsSetup(); + server_->StartConnect(); + client_->StartConnect(); +} + +void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { + EnsureTlsSetup(); + client_->EnableSingleCipher(cipher_suite); + + Connect(); + SendReceive(); + + // Check that we used the right cipher suite. + uint16_t actual; + EXPECT_TRUE(client_->cipher_suite(&actual)); + EXPECT_EQ(cipher_suite, actual); + EXPECT_TRUE(server_->cipher_suite(&actual)); + EXPECT_EQ(cipher_suite, actual); +} + +void TlsConnectTestBase::CheckConnected() { + // Have the client read handshake twice to make sure we get the + // NST and the ACK. + if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 && + variant_ == ssl_variant_datagram) { + client_->Handshake(); + client_->Handshake(); + auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd()); + // Verify that we dropped the client's retransmission cipher suites. + EXPECT_EQ(2, suites) << "Client has the wrong number of suites"; + if (suites != 2) { + SSLInt_PrintCipherSpecs("client", client_->ssl_fd()); + } + } + EXPECT_EQ(client_->version(), server_->version()); + if (!skip_version_checks_) { + // Check the version is as expected + EXPECT_EQ(std::min(client_->max_version(), server_->max_version()), + client_->version()); + } + + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + uint16_t cipher_suite1, cipher_suite2; + ASSERT_TRUE(client_->cipher_suite(&cipher_suite1)); + ASSERT_TRUE(server_->cipher_suite(&cipher_suite2)); + EXPECT_EQ(cipher_suite1, cipher_suite2); + + std::cerr << "Connected with version " << client_->version() + << " cipher suite " << client_->cipher_suite_name() << std::endl; + + if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) { + // Check and store session ids. + std::vector<uint8_t> sid_c1 = client_->session_id(); + EXPECT_EQ(32U, sid_c1.size()); + std::vector<uint8_t> sid_s1 = server_->session_id(); + EXPECT_EQ(32U, sid_s1.size()); + EXPECT_EQ(sid_c1, sid_s1); + session_ids_.push_back(sid_c1); + } + + CheckExtendedMasterSecret(); + CheckEarlyDataAccepted(); + CheckResumption(expected_resumption_mode_); + client_->CheckSecretsDestroyed(); + server_->CheckSecretsDestroyed(); +} + +void TlsConnectTestBase::CheckEarlyDataLimit( + const std::shared_ptr<TlsAgent>& agent, size_t expected_size) { + SSLPreliminaryChannelInfo preinfo; + SECStatus rv = + SSL_GetPreliminaryChannelInfo(agent->ssl_fd(), &preinfo, sizeof(preinfo)); + EXPECT_EQ(SECSuccess, rv); + EXPECT_EQ(expected_size, static_cast<size_t>(preinfo.maxEarlyDataSize)); +} + +void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, + SSLAuthType auth_type, + SSLSignatureScheme sig_scheme) const { + if (kea_group != ssl_grp_none) { + client_->CheckKEA(kea_type, kea_group); + server_->CheckKEA(kea_type, kea_group); + } + server_->CheckAuthType(auth_type, sig_scheme); + client_->CheckAuthType(auth_type, sig_scheme); +} + +void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, + SSLAuthType auth_type) const { + SSLNamedGroup group; + switch (kea_type) { + case ssl_kea_ecdh: + group = ssl_grp_ec_curve25519; + break; + case ssl_kea_dh: + group = ssl_grp_ffdhe_2048; + break; + case ssl_kea_rsa: + group = ssl_grp_none; + break; + default: + EXPECT_TRUE(false) << "unexpected KEA"; + group = ssl_grp_none; + break; + } + + SSLSignatureScheme scheme; + switch (auth_type) { + case ssl_auth_rsa_decrypt: + scheme = ssl_sig_none; + break; + case ssl_auth_rsa_sign: + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) { + scheme = ssl_sig_rsa_pss_rsae_sha256; + } else { + scheme = ssl_sig_rsa_pkcs1_sha256; + } + break; + case ssl_auth_rsa_pss: + scheme = ssl_sig_rsa_pss_rsae_sha256; + break; + case ssl_auth_ecdsa: + scheme = ssl_sig_ecdsa_secp256r1_sha256; + break; + case ssl_auth_dsa: + scheme = ssl_sig_dsa_sha1; + break; + default: + EXPECT_TRUE(false) << "unexpected auth type"; + scheme = static_cast<SSLSignatureScheme>(0x0100); + break; + } + CheckKeys(kea_type, group, auth_type, scheme); +} + +void TlsConnectTestBase::CheckKeys() const { + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); +} + +void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type, + SSLNamedGroup kea_group, + SSLNamedGroup original_kea_group, + SSLAuthType auth_type, + SSLSignatureScheme sig_scheme) { + CheckKeys(kea_type, kea_group, auth_type, sig_scheme); + EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE); + client_->CheckOriginalKEA(original_kea_group); + server_->CheckOriginalKEA(original_kea_group); +} + +void TlsConnectTestBase::ConnectExpectFail() { + StartConnect(); + Handshake(); + ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state()); + ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state()); +} + +void TlsConnectTestBase::ExpectAlert(std::shared_ptr<TlsAgent>& sender, + uint8_t alert) { + EnsureTlsSetup(); + auto receiver = (sender == client_) ? server_ : client_; + sender->ExpectSendAlert(alert); + receiver->ExpectReceiveAlert(alert); +} + +void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, + uint8_t alert) { + ExpectAlert(sender, alert); + ConnectExpectFail(); +} + +void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) { + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + client_->Handshake(); + server_->Handshake(); + + auto failing_agent = server_; + if (failing_side == TlsAgent::CLIENT) { + failing_agent = client_; + } + ASSERT_TRUE_WAIT(failing_agent->state() == TlsAgent::STATE_ERROR, 5000); +} + +void TlsConnectTestBase::ConfigureVersion(uint16_t version) { + version_ = version; + client_->SetVersionRange(version, version); + server_->SetVersionRange(version, version); +} + +void TlsConnectTestBase::SetExpectedVersion(uint16_t version) { + client_->SetExpectedVersion(version); + server_->SetExpectedVersion(version); +} + +void TlsConnectTestBase::AddPsk(const ScopedPK11SymKey& psk, std::string label, + SSLHashType hash, uint16_t zeroRttSuite) { + client_->AddPsk(psk, label, hash, zeroRttSuite); + server_->AddPsk(psk, label, hash, zeroRttSuite); + client_->ExpectPsk(); + server_->ExpectPsk(); +} + +void TlsConnectTestBase::DisableAllCiphers() { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + server_->DisableAllCiphers(); +} + +void TlsConnectTestBase::EnableOnlyStaticRsaCiphers() { + DisableAllCiphers(); + + client_->EnableCiphersByKeyExchange(ssl_kea_rsa); + server_->EnableCiphersByKeyExchange(ssl_kea_rsa); +} + +void TlsConnectTestBase::EnableOnlyDheCiphers() { + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_dh); + server_->EnableCiphersByKeyExchange(ssl_kea_dh); + } else { + client_->ConfigNamedGroups(kFFDHEGroups); + server_->ConfigNamedGroups(kFFDHEGroups); + } +} + +void TlsConnectTestBase::EnableSomeEcdhCiphers() { + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + client_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa); + client_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa); + server_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa); + server_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa); + } else { + client_->ConfigNamedGroups(kECDHEGroups); + server_->ConfigNamedGroups(kECDHEGroups); + } +} + +void TlsConnectTestBase::ConfigureSelfEncrypt() { + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey privKey; + ASSERT_TRUE( + TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey)); + + ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); + ASSERT_TRUE(pubKey); + + EXPECT_EQ(SECSuccess, + SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); +} + +void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, + SessionResumptionMode server) { + client_->ConfigureSessionCache(client); + server_->ConfigureSessionCache(server); + if ((server & RESUME_TICKET) != 0) { + ConfigureSelfEncrypt(); + } +} + +void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { + EXPECT_NE(RESUME_BOTH, expected); + + int resume_count = expected ? expected_resumptions_ : 0; + int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0; + + // Note: hch == server counter; hsh == client counter. + SSL3Statistics* stats = SSL_GetStatistics(); + EXPECT_EQ(resume_count, stats->hch_sid_cache_hits); + EXPECT_EQ(resume_count, stats->hsh_sid_cache_hits); + + EXPECT_EQ(stateless_count, stats->hch_sid_stateless_resumes); + EXPECT_EQ(stateless_count, stats->hsh_sid_stateless_resumes); + + if (expected != RESUME_NONE) { + if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 && + client_->GetResumptionToken().size() == 0) { + // Check that the last two session ids match. + ASSERT_EQ(1U + expected_resumptions_, session_ids_.size()); + EXPECT_EQ(session_ids_[session_ids_.size() - 1], + session_ids_[session_ids_.size() - 2]); + } else { + // We've either chosen TLS 1.3 or are using an external resumption token, + // both of which only use tickets. + EXPECT_TRUE(expected & RESUME_TICKET); + } + } +} + +static SECStatus NextProtoCallbackServer(void* arg, PRFileDesc* fd, + const unsigned char* protos, + unsigned int protos_len, + unsigned char* protoOut, + unsigned int* protoOutLen, + unsigned int protoMaxLen) { + EXPECT_EQ(protoMaxLen, 255U); + TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); + // Check that agent->alpn_value_to_use_ is in protos. + if (protos_len < 1) { + return SECFailure; + } + for (size_t i = 0; i < protos_len;) { + size_t l = protos[i]; + EXPECT_LT(i + l, protos_len); + if (i + l >= protos_len) { + return SECFailure; + } + std::string protos_s(reinterpret_cast<const char*>(protos + i + 1), l); + if (protos_s == agent->alpn_value_to_use_) { + size_t s_len = agent->alpn_value_to_use_.size(); + EXPECT_LE(s_len, 255U); + memcpy(protoOut, &agent->alpn_value_to_use_[0], s_len); + *protoOutLen = s_len; + return SECSuccess; + } + i += l + 1; + } + return SECFailure; +} + +void TlsConnectTestBase::EnableAlpn() { + client_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); + server_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); +} + +void TlsConnectTestBase::EnableAlpnWithCallback( + const std::vector<uint8_t>& client_vals, std::string server_choice) { + EnsureTlsSetup(); + server_->alpn_value_to_use_ = server_choice; + EXPECT_EQ(SECSuccess, + SSL_SetNextProtoNego(client_->ssl_fd(), client_vals.data(), + client_vals.size())); + SECStatus rv = SSL_SetNextProtoCallback( + server_->ssl_fd(), NextProtoCallbackServer, server_.get()); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsConnectTestBase::EnableAlpn(const std::vector<uint8_t>& vals) { + client_->EnableAlpn(vals.data(), vals.size()); + server_->EnableAlpn(vals.data(), vals.size()); +} + +void TlsConnectTestBase::EnsureModelSockets() { + // Make sure models agents are available. + if (!client_model_) { + ASSERT_EQ(server_model_, nullptr); + client_model_.reset( + new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)); + server_model_.reset( + new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)); + if (skip_version_checks_) { + client_model_->SkipVersionChecks(); + server_model_->SkipVersionChecks(); + } + } +} + +void TlsConnectTestBase::CheckAlpn(const std::string& val) { + client_->CheckAlpn(SSL_NEXT_PROTO_SELECTED, val); + server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, val); +} + +void TlsConnectTestBase::EnableSrtp() { + client_->EnableSrtp(); + server_->EnableSrtp(); +} + +void TlsConnectTestBase::CheckSrtp() const { + client_->CheckSrtp(); + server_->CheckSrtp(); +} + +void TlsConnectTestBase::SendReceive(size_t total) { + ASSERT_GT(total, client_->received_bytes()); + ASSERT_GT(total, server_->received_bytes()); + client_->SendData(total - server_->received_bytes()); + server_->SendData(total - client_->received_bytes()); + Receive(total); // Receive() is cumulative +} + +// Do a first connection so we can do 0-RTT on the second one. +void TlsConnectTestBase::SetupForZeroRtt() { + // Force rollover of the anti-replay window. + // If we don't do this, then all 0-RTT attempts will be rejected. + RolloverAntiReplay(); + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT. + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); + StartConnect(); +} + +// Do a first connection so we can do resumption +void TlsConnectTestBase::SetupForResume() { + EnsureTlsSetup(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); +} + +void TlsConnectTestBase::ZeroRttSendReceive( + bool expect_writable, bool expect_readable, + std::function<bool()> post_clienthello_check) { + const char* k0RttData = "ABCDEF"; + const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + + client_->Handshake(); // Send ClientHello. + if (post_clienthello_check) { + if (!post_clienthello_check()) return; + } + PRInt32 rv = + PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write. + if (expect_writable) { + EXPECT_EQ(k0RttDataLen, rv); + } else { + EXPECT_EQ(SECFailure, rv); + } + server_->Handshake(); // Consume ClientHello + + std::vector<uint8_t> buf(k0RttDataLen); + rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read + if (expect_readable) { + std::cerr << "0-RTT read " << rv << " bytes\n"; + EXPECT_EQ(k0RttDataLen, rv); + } else { + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + } + + // Do a second read. This should fail. + rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +void TlsConnectTestBase::Receive(size_t amount) { + WAIT_(client_->received_bytes() == amount && + server_->received_bytes() == amount, + 2000); + ASSERT_EQ(amount, client_->received_bytes()); + ASSERT_EQ(amount, server_->received_bytes()); +} + +void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) { + expect_extended_master_secret_ = expected; +} + +void TlsConnectTestBase::CheckExtendedMasterSecret() { + client_->CheckExtendedMasterSecret(expect_extended_master_secret_); + server_->CheckExtendedMasterSecret(expect_extended_master_secret_); +} + +void TlsConnectTestBase::ExpectEarlyDataAccepted(bool expected) { + expect_early_data_accepted_ = expected; +} + +void TlsConnectTestBase::CheckEarlyDataAccepted() { + client_->CheckEarlyDataAccepted(expect_early_data_accepted_); + server_->CheckEarlyDataAccepted(expect_early_data_accepted_); +} + +void TlsConnectTestBase::EnableECDHEServerKeyReuse() { + server_->EnableECDHEServerKeyReuse(); +} + +void TlsConnectTestBase::SkipVersionChecks() { + skip_version_checks_ = true; + client_->SkipVersionChecks(); + server_->SkipVersionChecks(); +} + +// Shift the DTLS timers, to the minimum time necessary to let the next timer +// run on either client or server. This allows tests to skip waiting without +// having timers run out of order. +void TlsConnectTestBase::ShiftDtlsTimers() { + PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT; + PRIntervalTime time; + SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); + if (rv == SECSuccess) { + time_shift = time; + } + rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); + if (rv == SECSuccess && + (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) { + time_shift = time; + } + + if (time_shift != PR_INTERVAL_NO_TIMEOUT) { + AdvanceTime(PR_IntervalToMicroseconds(time_shift)); + EXPECT_EQ(SECSuccess, + SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift)); + EXPECT_EQ(SECSuccess, + SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift)); + } +} + +void TlsConnectTestBase::AdvanceTime(PRTime time_shift) { now_ += time_shift; } + +// Advance time by a full anti-replay window. +void TlsConnectTestBase::RolloverAntiReplay() { + AdvanceTime(kAntiReplayWindow); +} + +TlsConnectGeneric::TlsConnectGeneric() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + +TlsConnectPre12::TlsConnectPre12() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + +TlsConnectTls12::TlsConnectTls12() + : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_2) {} + +TlsConnectTls12Plus::TlsConnectTls12Plus() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + +TlsConnectTls13::TlsConnectTls13() + : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} + +TlsConnectGenericResumption::TlsConnectGenericResumption() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())), + external_cache_(std::get<2>(GetParam())) {} + +TlsConnectTls13ResumptionToken::TlsConnectTls13ResumptionToken() + : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} + +TlsConnectGenericResumptionToken::TlsConnectGenericResumptionToken() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + +void TlsKeyExchangeTest::EnsureKeyShareSetup() { + EnsureTlsSetup(); + groups_capture_ = + std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn); + shares_capture_ = + std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn); + shares_capture2_ = std::make_shared<TlsExtensionCapture>( + client_, ssl_tls13_key_share_xtn, true); + std::vector<std::shared_ptr<PacketFilter>> captures = { + groups_capture_, shares_capture_, shares_capture2_}; + client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures)); + capture_hrr_ = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeHelloRetryRequest); +} + +void TlsKeyExchangeTest::ConfigNamedGroups( + const std::vector<SSLNamedGroup>& groups) { + client_->ConfigNamedGroups(groups); + server_->ConfigNamedGroups(groups); +} + +std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails( + const std::shared_ptr<TlsExtensionCapture>& capture) { + EXPECT_TRUE(capture->captured()); + const DataBuffer& ext = capture->extension(); + + uint32_t tmp = 0; + EXPECT_TRUE(ext.Read(0, 2, &tmp)); + EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); + EXPECT_TRUE(ext.len() % 2 == 0); + + std::vector<SSLNamedGroup> groups; + for (size_t i = 1; i < ext.len() / 2; i += 1) { + EXPECT_TRUE(ext.Read(2 * i, 2, &tmp)); + groups.push_back(static_cast<SSLNamedGroup>(tmp)); + } + return groups; +} + +std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails( + const std::shared_ptr<TlsExtensionCapture>& capture) { + EXPECT_TRUE(capture->captured()); + const DataBuffer& ext = capture->extension(); + + uint32_t tmp = 0; + EXPECT_TRUE(ext.Read(0, 2, &tmp)); + EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); + + std::vector<SSLNamedGroup> shares; + size_t i = 2; + while (i < ext.len()) { + EXPECT_TRUE(ext.Read(i, 2, &tmp)); + shares.push_back(static_cast<SSLNamedGroup>(tmp)); + EXPECT_TRUE(ext.Read(i + 2, 2, &tmp)); + i += 4 + tmp; + } + EXPECT_EQ(ext.len(), i); + return shares; +} + +void TlsKeyExchangeTest::CheckKEXDetails( + const std::vector<SSLNamedGroup>& expected_groups, + const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) { + std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_); + EXPECT_EQ(expected_groups, groups); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + ASSERT_LT(0U, expected_shares.size()); + std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_); + EXPECT_EQ(expected_shares, shares); + } else { + EXPECT_FALSE(shares_capture_->captured()); + } + + EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0); +} + +void TlsKeyExchangeTest::CheckKEXDetails( + const std::vector<SSLNamedGroup>& expected_groups, + const std::vector<SSLNamedGroup>& expected_shares) { + CheckKEXDetails(expected_groups, expected_shares, false); +} + +void TlsKeyExchangeTest::CheckKEXDetails( + const std::vector<SSLNamedGroup>& expected_groups, + const std::vector<SSLNamedGroup>& expected_shares, + SSLNamedGroup expected_share2) { + CheckKEXDetails(expected_groups, expected_shares, true); + + for (auto it : expected_shares) { + EXPECT_NE(expected_share2, it); + } + std::vector<SSLNamedGroup> expected_shares2 = {expected_share2}; + EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_)); +} +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_connect.h b/security/nss/gtests/ssl_gtest/tls_connect.h new file mode 100644 index 0000000000..6a4795f83e --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_connect.h @@ -0,0 +1,390 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef tls_connect_h_ +#define tls_connect_h_ + +#include <tuple> + +#include "sslproto.h" +#include "sslt.h" +#include "nss.h" + +#include "tls_agent.h" +#include "tls_filter.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" + +namespace nss_test { + +extern std::string VersionString(uint16_t version); + +// A generic TLS connection test base. +class TlsConnectTestBase : public ::testing::Test { + public: + static ::testing::internal::ParamGenerator<SSLProtocolVariant> + kTlsVariantsStream; + static ::testing::internal::ParamGenerator<SSLProtocolVariant> + kTlsVariantsDatagram; + static ::testing::internal::ParamGenerator<SSLProtocolVariant> + kTlsVariantsAll; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV10; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV11; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV12; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV10V11; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV11V12; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV10ToV12; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV13; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV11Plus; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus; + static ::testing::internal::ParamGenerator<uint16_t> kTlsVAll; + + TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version); + virtual ~TlsConnectTestBase(); + + virtual void SetUp(); + virtual void TearDown(); + + PRTime now() const { return now_; } + + // Initialize client and server. + void Init(); + // Clear the statistics. + void ClearStats(); + // Clear the server session cache. + void ClearServerCache(); + // Make sure TLS is configured for a connection. + virtual void EnsureTlsSetup(); + // Reset and keep the same certificate names + void Reset(); + // Reset, and update the certificate names on both peers + void Reset(const std::string& server_name, + const std::string& client_name = "client"); + // Replace the server. + void MakeNewServer(); + + // Set up + void StartConnect(); + // Run the handshake. + void Handshake(); + // Connect and check that it works. + void Connect(); + // Check that the connection was successfully established. + void CheckConnected(); + // Connect and expect it to fail. + void ConnectExpectFail(); + void ExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert); + void ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert); + void ConnectExpectFailOneSide(TlsAgent::Role failingSide); + void ConnectWithCipherSuite(uint16_t cipher_suite); + void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent, + size_t expected_size); + // Check that the keys used in the handshake match expectations. + void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, + SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const; + // This version guesses some of the values. + void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const; + // This version assumes defaults. + void CheckKeys() const; + // Check that keys on resumed sessions. + void CheckKeysResumption(SSLKEAType kea_type, SSLNamedGroup kea_group, + SSLNamedGroup original_kea_group, + SSLAuthType auth_type, + SSLSignatureScheme sig_scheme); + void CheckGroups(const DataBuffer& groups, + std::function<void(SSLNamedGroup)> check_group); + void CheckShares(const DataBuffer& shares, + std::function<void(SSLNamedGroup)> check_group); + void CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const; + + void ConfigureVersion(uint16_t version); + void SetExpectedVersion(uint16_t version); + // Expect resumption of a particular type. + void ExpectResumption(SessionResumptionMode expected, + uint8_t num_resumed = 1); + void DisableAllCiphers(); + void EnableOnlyStaticRsaCiphers(); + void EnableOnlyDheCiphers(); + void EnableSomeEcdhCiphers(); + void EnableExtendedMasterSecret(); + void ConfigureSelfEncrypt(); + void ConfigureSessionCache(SessionResumptionMode client, + SessionResumptionMode server); + void EnableAlpn(); + void EnableAlpnWithCallback(const std::vector<uint8_t>& client, + std::string server_choice); + void EnableAlpn(const std::vector<uint8_t>& vals); + void EnsureModelSockets(); + void CheckAlpn(const std::string& val); + void EnableSrtp(); + void CheckSrtp() const; + void SendReceive(size_t total = 50); + void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash, + uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL); + void RemovePsk(std::string label); + void SetupForZeroRtt(); + void SetupForResume(); + void ZeroRttSendReceive( + bool expect_writable, bool expect_readable, + std::function<bool()> post_clienthello_check = nullptr); + void Receive(size_t amount); + void ExpectExtendedMasterSecret(bool expected); + void ExpectEarlyDataAccepted(bool expected); + void EnableECDHEServerKeyReuse(); + void SkipVersionChecks(); + + // Move the DTLS timers for both endpoints to pop the next timer. + void ShiftDtlsTimers(); + void AdvanceTime(PRTime time_shift); + + void ResetAntiReplay(PRTime window); + void RolloverAntiReplay(); + + void SaveAlgorithmPolicy(); + void RestoreAlgorithmPolicy(); + + static ScopedSECItem MakeEcKeyParams(SSLNamedGroup group); + static void GenerateEchConfig( + HpkeKemId kem_id, const std::vector<HpkeSymmetricSuite>& cipher_suites, + const std::string& public_name, uint16_t max_name_len, DataBuffer& record, + ScopedSECKEYPublicKey& pubKey, ScopedSECKEYPrivateKey& privKey); + void SetupEch(std::shared_ptr<TlsAgent>& client, + std::shared_ptr<TlsAgent>& server, + HpkeKemId kem_id = HpkeDhKemX25519Sha256, + bool expect_ech = true, bool set_client_config = true, + bool set_server_config = true, int maxConfigSize = 100); + + protected: + SSLProtocolVariant variant_; + std::shared_ptr<TlsAgent> client_; + std::shared_ptr<TlsAgent> server_; + std::unique_ptr<TlsAgent> client_model_; + std::unique_ptr<TlsAgent> server_model_; + uint16_t version_; + SessionResumptionMode expected_resumption_mode_; + uint8_t expected_resumptions_; + std::vector<std::vector<uint8_t>> session_ids_; + ScopedSSLAntiReplayContext anti_replay_; + + // A simple value of "a", "b". Note that the preferred value of "a" is placed + // at the end, because the NSS API follows the now defunct NPN specification, + // which places the preferred (and default) entry at the end of the list. + // NSS will move this final entry to the front when used with ALPN. + const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61}; + + // A list of algorithm IDs whose policies need to be preserved + // around test cases. In particular, DSA is checked in + // ssl_extension_unittest.cc. + const std::vector<SECOidTag> algorithms_ = {SEC_OID_APPLY_SSL_POLICY, + SEC_OID_ANSIX9_DSA_SIGNATURE, + SEC_OID_CURVE25519, SEC_OID_SHA1}; + std::vector<std::tuple<SECOidTag, uint32_t>> saved_policies_; + const std::vector<PRInt32> options_ = { + NSS_RSA_MIN_KEY_SIZE, NSS_DH_MIN_KEY_SIZE, NSS_DSA_MIN_KEY_SIZE, + NSS_TLS_VERSION_MIN_POLICY, NSS_TLS_VERSION_MAX_POLICY}; + std::vector<std::tuple<PRInt32, uint32_t>> saved_options_; + + private: + void CheckResumption(SessionResumptionMode expected); + void CheckExtendedMasterSecret(); + void CheckEarlyDataAccepted(); + static PRTime TimeFunc(void* arg); + + bool expect_extended_master_secret_; + bool expect_early_data_accepted_; + bool skip_version_checks_; + PRTime now_; + + // Track groups and make sure that there are no duplicates. + class DuplicateGroupChecker { + public: + void AddAndCheckGroup(SSLNamedGroup group) { + EXPECT_EQ(groups_.end(), groups_.find(group)) + << "Group " << group << " should not be duplicated"; + groups_.insert(group); + } + + private: + std::set<SSLNamedGroup> groups_; + }; +}; + +// A non-parametrized TLS test base. +class TlsConnectTest : public TlsConnectTestBase { + public: + TlsConnectTest() : TlsConnectTestBase(ssl_variant_stream, 0) {} +}; + +// A non-parametrized DTLS-only test base. +class DtlsConnectTest : public TlsConnectTestBase { + public: + DtlsConnectTest() : TlsConnectTestBase(ssl_variant_datagram, 0) {} +}; + +// A TLS-only test base. +class TlsConnectStream : public TlsConnectTestBase, + public ::testing::WithParamInterface<uint16_t> { + public: + TlsConnectStream() : TlsConnectTestBase(ssl_variant_stream, GetParam()) {} +}; + +// A TLS-only test base for tests before 1.3 +class TlsConnectStreamPre13 : public TlsConnectStream {}; + +// A DTLS-only test base. +class TlsConnectDatagram : public TlsConnectTestBase, + public ::testing::WithParamInterface<uint16_t> { + public: + TlsConnectDatagram() : TlsConnectTestBase(ssl_variant_datagram, GetParam()) {} +}; + +// A generic test class that can be either stream or datagram and a single +// version of TLS. This is configured in ssl_loopback_unittest.cc. +class TlsConnectGeneric : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsConnectGeneric(); +}; + +class TlsConnectGenericResumption + : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t, bool>> { + private: + bool external_cache_; + + public: + TlsConnectGenericResumption(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + // Enable external resumption token cache. + if (external_cache_) { + client_->SetResumptionTokenCallback(); + } + } + + bool use_external_cache() const { return external_cache_; } +}; + +class TlsConnectTls13ResumptionToken + : public TlsConnectTestBase, + public ::testing::WithParamInterface<SSLProtocolVariant> { + public: + TlsConnectTls13ResumptionToken(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + client_->SetResumptionTokenCallback(); + } +}; + +class TlsConnectGenericResumptionToken + : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsConnectGenericResumptionToken(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + client_->SetResumptionTokenCallback(); + } +}; + +// A Pre TLS 1.2 generic test. +class TlsConnectPre12 : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsConnectPre12(); +}; + +// A TLS 1.2 only generic test. +class TlsConnectTls12 + : public TlsConnectTestBase, + public ::testing::WithParamInterface<SSLProtocolVariant> { + public: + TlsConnectTls12(); +}; + +// A TLS 1.2 only stream test. +class TlsConnectStreamTls12 : public TlsConnectTestBase { + public: + TlsConnectStreamTls12() + : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_2) {} +}; + +// A TLS 1.2+ generic test. +class TlsConnectTls12Plus : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsConnectTls12Plus(); +}; + +// A TLS 1.3 only generic test. +class TlsConnectTls13 + : public TlsConnectTestBase, + public ::testing::WithParamInterface<SSLProtocolVariant> { + public: + TlsConnectTls13(); +}; + +// A TLS 1.3 only stream test. +class TlsConnectStreamTls13 : public TlsConnectTestBase { + public: + TlsConnectStreamTls13() + : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {} +}; + +class TlsConnectDatagram13 : public TlsConnectTestBase { + public: + TlsConnectDatagram13() + : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {} +}; + +class TlsConnectDatagramPre13 : public TlsConnectDatagram { + public: + TlsConnectDatagramPre13() {} +}; + +// A variant that is used only with Pre13. +class TlsConnectGenericPre13 : public TlsConnectGeneric {}; + +class TlsKeyExchangeTest : public TlsConnectGeneric { + protected: + std::shared_ptr<TlsExtensionCapture> groups_capture_; + std::shared_ptr<TlsExtensionCapture> shares_capture_; + std::shared_ptr<TlsExtensionCapture> shares_capture2_; + std::shared_ptr<TlsHandshakeRecorder> capture_hrr_; + + void EnsureKeyShareSetup(); + void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); + std::vector<SSLNamedGroup> GetGroupDetails( + const std::shared_ptr<TlsExtensionCapture>& capture); + std::vector<SSLNamedGroup> GetShareDetails( + const std::shared_ptr<TlsExtensionCapture>& capture); + void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, + const std::vector<SSLNamedGroup>& expectedShares); + void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, + const std::vector<SSLNamedGroup>& expectedShares, + SSLNamedGroup expectedShare2); + + private: + void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, + const std::vector<SSLNamedGroup>& expectedShares, + bool expect_hrr); +}; + +class TlsKeyExchangeTest13 : public TlsKeyExchangeTest {}; +class TlsKeyExchangeTestPre13 : public TlsKeyExchangeTest {}; + +} // namespace nss_test + +#endif diff --git a/security/nss/gtests/ssl_gtest/tls_ech_unittest.cc b/security/nss/gtests/ssl_gtest/tls_ech_unittest.cc new file mode 100644 index 0000000000..1e70a6ee59 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_ech_unittest.cc @@ -0,0 +1,2913 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" + +#include "gtest_utils.h" +#include "pk11pub.h" +#include "tls_agent.h" +#include "tls_connect.h" +#include "util.h" +#include "tls13ech.h" + +namespace nss_test { + +class TlsAgentEchTest : public TlsAgentTestClient13 { + protected: + void InstallEchConfig(const DataBuffer& echconfig, PRErrorCode err = 0) { + SECStatus rv = SSL_SetClientEchConfigs(agent_->ssl_fd(), echconfig.data(), + echconfig.len()); + if (err == 0) { + ASSERT_EQ(SECSuccess, rv); + } else { + ASSERT_EQ(SECFailure, rv); + ASSERT_EQ(err, PORT_GetError()); + } + } +}; + +#include "cpputil.h" // Unused function error if included without HPKE. + +static std::string kPublicName("public.name"); + +static const std::vector<HpkeSymmetricSuite> kDefaultSuites = { + {HpkeKdfHkdfSha256, HpkeAeadChaCha20Poly1305}, + {HpkeKdfHkdfSha256, HpkeAeadAes128Gcm}}; +static const std::vector<HpkeSymmetricSuite> kSuiteChaCha = { + {HpkeKdfHkdfSha256, HpkeAeadChaCha20Poly1305}}; +static const std::vector<HpkeSymmetricSuite> kSuiteAes = { + {HpkeKdfHkdfSha256, HpkeAeadAes128Gcm}}; +std::vector<HpkeSymmetricSuite> kBogusSuite = { + {static_cast<HpkeKdfId>(0xfefe), static_cast<HpkeAeadId>(0xfefe)}}; +static const std::vector<HpkeSymmetricSuite> kUnknownFirstSuite = { + {static_cast<HpkeKdfId>(0xfefe), static_cast<HpkeAeadId>(0xfefe)}, + {HpkeKdfHkdfSha256, HpkeAeadAes128Gcm}}; + +class TlsConnectStreamTls13Ech : public TlsConnectTestBase { + public: + TlsConnectStreamTls13Ech() + : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {} + + void ReplayChWithMalformedInner(const std::string& ch, uint8_t server_alert, + uint32_t server_code, uint32_t client_code) { + std::vector<uint8_t> ch_vec = hex_string_to_bytes(ch); + DataBuffer ch_buf; + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + EnsureTlsSetup(); + ImportFixedEchKeypair(pub, priv); + SetMutualEchConfigs(pub, priv); + + TlsAgentTestBase::MakeRecord(variant_, ssl_ct_handshake, + SSL_LIBRARY_VERSION_TLS_1_3, ch_vec.data(), + ch_vec.size(), &ch_buf, 0); + StartConnect(); + client_->SendDirect(ch_buf); + ExpectAlert(server_, server_alert); + server_->Handshake(); + server_->CheckErrorCode(server_code); + client_->ExpectReceiveAlert(server_alert, kTlsAlertFatal); + client_->Handshake(); + client_->CheckErrorCode(client_code); + } + + // Setup Client/Server with mismatched AEADs + void SetupForEchRetry() { + ScopedSECKEYPublicKey server_pub; + ScopedSECKEYPrivateKey server_priv; + ScopedSECKEYPublicKey client_pub; + ScopedSECKEYPrivateKey client_priv; + DataBuffer server_rec; + DataBuffer client_rec; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteChaCha, + kPublicName, 100, server_rec, + server_pub, server_priv); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(), + server_priv.get(), server_rec.data(), + server_rec.len())); + + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes, + kPublicName, 100, client_rec, + client_pub, client_priv); + ASSERT_EQ(SECSuccess, + SSL_SetClientEchConfigs(client_->ssl_fd(), client_rec.data(), + client_rec.len())); + } + + // Parse a captured SNI extension and validate the contained name. + void CheckSniExtension(const DataBuffer& data, + const std::string expected_name) { + TlsParser parser(data.data(), data.len()); + uint32_t tmp; + ASSERT_TRUE(parser.Read(&tmp, 2)); + ASSERT_EQ(parser.remaining(), tmp); + ASSERT_TRUE(parser.Read(&tmp, 1)); + ASSERT_EQ(0U, tmp); /* sni_nametype_hostname */ + DataBuffer name; + ASSERT_TRUE(parser.ReadVariable(&name, 2)); + ASSERT_EQ(0U, parser.remaining()); + // Manual comparison to silence coverity false-positives. + ASSERT_EQ(name.len(), kPublicName.length()); + ASSERT_EQ(0, + memcmp(kPublicName.c_str(), name.data(), kPublicName.length())); + } + + void DoEchRetry(const ScopedSECKEYPublicKey& server_pub, + const ScopedSECKEYPrivateKey& server_priv, + const DataBuffer& server_rec) { + StackSECItem retry_configs; + ASSERT_EQ(SECSuccess, + SSL_GetEchRetryConfigs(client_->ssl_fd(), &retry_configs)); + ASSERT_NE(0U, retry_configs.len); + + // Reset expectations for the TlsAgent dtor. + server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning); + Reset(); + EnsureTlsSetup(); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(), + server_priv.get(), server_rec.data(), + server_rec.len())); + ASSERT_EQ(SECSuccess, + SSL_SetClientEchConfigs(client_->ssl_fd(), retry_configs.data, + retry_configs.len)); + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); + } + + void ImportFixedEchKeypair(ScopedSECKEYPublicKey& pub, + ScopedSECKEYPrivateKey& priv) { + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + if (!slot) { + ADD_FAILURE() << "No slot"; + return; + } + std::vector<uint8_t> pkcs8_r = hex_string_to_bytes(kFixedServerKey); + SECItem pkcs8_r_item = {siBuffer, toUcharPtr(pkcs8_r.data()), + static_cast<unsigned int>(pkcs8_r.size())}; + + SECKEYPrivateKey* tmp_priv = nullptr; + ASSERT_EQ(SECSuccess, PK11_ImportDERPrivateKeyInfoAndReturnKey( + slot.get(), &pkcs8_r_item, nullptr, nullptr, + false, false, KU_ALL, &tmp_priv, nullptr)); + priv.reset(tmp_priv); + SECKEYPublicKey* tmp_pub = SECKEY_ConvertToPublicKey(tmp_priv); + pub.reset(tmp_pub); + ASSERT_NE(nullptr, tmp_pub); + } + + void SetMutualEchConfigs(ScopedSECKEYPublicKey& pub, + ScopedSECKEYPrivateKey& priv) { + DataBuffer echconfig; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, + priv); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + ASSERT_EQ(SECSuccess, + SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(), + echconfig.len())); + } + + void ValidatePublicNames(const std::vector<std::string>& names, + SECStatus expected) { + static const std::vector<HpkeSymmetricSuite> kSuites = { + {HpkeKdfHkdfSha256, HpkeAeadAes128Gcm}}; + + ScopedSECItem ecParams = MakeEcKeyParams(ssl_grp_ec_curve25519); + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + SECKEYPublicKey* pub_p = nullptr; + SECKEYPrivateKey* priv_p = + SECKEY_CreateECPrivateKey(ecParams.get(), &pub_p, nullptr); + pub.reset(pub_p); + priv.reset(priv_p); + ASSERT_TRUE(!!pub); + ASSERT_TRUE(!!priv); + + EnsureTlsSetup(); + + DataBuffer cfg; + SECStatus rv; + for (auto name : names) { + if (g_ssl_gtest_verbose) { + std::cout << ((expected == SECFailure) ? "in" : "") + << "valid public_name: " << name << std::endl; + } + GenerateEchConfig(HpkeDhKemX25519Sha256, kSuites, name, 100, cfg, pub, + priv); + + rv = SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + cfg.data(), cfg.len()); + EXPECT_EQ(expected, rv); + + rv = SSL_SetClientEchConfigs(client_->ssl_fd(), cfg.data(), cfg.len()); + EXPECT_EQ(expected, rv); + } + } + + private: + // Testing certan invalid CHInner configurations is tricky, particularly + // since the CHOuter forms AAD and isn't available in filters. Instead of + // generating these inputs on the fly, use a fixed server keypair so that + // the input can be generated once (e.g. via a debugger) and replayed in + // each invocation of the test. + std::string kFixedServerKey = + "3067020100301406072a8648ce3d020106092b06010401da470f01044c304a" + "02010104205a8aa0d2476b28521588e0c704b14db82cdd4970d340d293a957" + "6deaee9ec1c7a1230321008756e2580c07c1d2ffcb662f5fadc6d6ff13da85" + "abd7adfecf984aaa102c1269"; +}; + +static void CheckCertVerifyPublicName(TlsAgent* agent) { + agent->UpdatePreliminaryChannelInfo(); + EXPECT_NE(0U, (agent->pre_info().valuesSet & ssl_preinfo_ech)); + EXPECT_EQ(agent->GetEchExpected(), agent->pre_info().echAccepted); + + // Check that echPublicName is only exposed in the rejection + // case, so that the application can use it for CertVerfiy. + if (agent->GetEchExpected()) { + EXPECT_EQ(nullptr, agent->pre_info().echPublicName); + } else { + EXPECT_NE(nullptr, agent->pre_info().echPublicName); + if (agent->pre_info().echPublicName) { + EXPECT_EQ(0, + strcmp(kPublicName.c_str(), agent->pre_info().echPublicName)); + } + } +} + +static SECStatus AuthCompleteFail(TlsAgent* agent, PRBool, PRBool) { + CheckCertVerifyPublicName(agent); + return SECFailure; +} + +// Given two EchConfigList structures, e.g. from GenerateEchConfig, construct +// a single list containing all entries. +static DataBuffer MakeEchConfigList(DataBuffer config1, DataBuffer config2) { + DataBuffer sizedConfigListBuffer; + + sizedConfigListBuffer.Write(2, config1.data() + 2, config1.len() - 2); + sizedConfigListBuffer.Write(sizedConfigListBuffer.len(), config2.data() + 2, + config2.len() - 2); + sizedConfigListBuffer.Write(0, sizedConfigListBuffer.len() - 2, 2); + + PR_ASSERT(sizedConfigListBuffer.len() == config1.len() + config2.len() - 2); + return sizedConfigListBuffer; +} + +TEST_P(TlsAgentEchTest, EchConfigsSupportedYesNo) { + if (variant_ == ssl_variant_datagram) { + GTEST_SKIP(); + } + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + // ECHConfig 2 cipher_suites are unsupported. + DataBuffer config1; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes, + kPublicName, 100, config1, pub, priv); + DataBuffer config2; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kBogusSuite, + kPublicName, 100, config2, pub, priv); + EnsureInit(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(), + PR_FALSE)); // Don't GREASE + + DataBuffer sizedConfigListBuffer = MakeEchConfigList(config1, config2); + InstallEchConfig(sizedConfigListBuffer, 0); + auto filter = MakeTlsFilter<TlsExtensionCapture>( + agent_, ssl_tls13_encrypted_client_hello_xtn); + agent_->Handshake(); + ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state()); + ASSERT_TRUE(filter->captured()); +} + +TEST_P(TlsAgentEchTest, EchConfigsSupportedNoYes) { + if (variant_ == ssl_variant_datagram) { + GTEST_SKIP(); + } + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer config2; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes, + kPublicName, 100, config2, pub, priv); + DataBuffer config1; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kBogusSuite, + kPublicName, 100, config1, pub, priv); + // ECHConfig 1 cipher_suites are unsupported. + DataBuffer sizedConfigListBuffer = MakeEchConfigList(config1, config2); + EnsureInit(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(), + PR_FALSE)); // Don't GREASE + InstallEchConfig(sizedConfigListBuffer, 0); + auto filter = MakeTlsFilter<TlsExtensionCapture>( + agent_, ssl_tls13_encrypted_client_hello_xtn); + agent_->Handshake(); + ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state()); + ASSERT_TRUE(filter->captured()); +} + +TEST_P(TlsAgentEchTest, EchConfigsSupportedNoNo) { + if (variant_ == ssl_variant_datagram) { + GTEST_SKIP(); + } + + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer config2; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kBogusSuite, + kPublicName, 100, config2, pub, priv); + DataBuffer config1; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kBogusSuite, + kPublicName, 100, config1, pub, priv); + // ECHConfig 1 and 2 cipher_suites are unsupported. + DataBuffer sizedConfigListBuffer = MakeEchConfigList(config1, config2); + EnsureInit(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(), + PR_FALSE)); // Don't GREASE + InstallEchConfig(sizedConfigListBuffer, SEC_ERROR_INVALID_ARGS); + auto filter = MakeTlsFilter<TlsExtensionCapture>( + agent_, ssl_tls13_encrypted_client_hello_xtn); + agent_->Handshake(); + ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state()); + ASSERT_FALSE(filter->captured()); +} + +TEST_P(TlsAgentEchTest, ShortEchConfig) { + EnsureInit(); + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, priv); + echconfig.Truncate(echconfig.len() - 1); + InstallEchConfig(echconfig, SEC_ERROR_BAD_DATA); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(), + PR_FALSE)); // Don't GREASE + auto filter = MakeTlsFilter<TlsExtensionCapture>( + agent_, ssl_tls13_encrypted_client_hello_xtn); + agent_->Handshake(); + ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state()); + ASSERT_FALSE(filter->captured()); +} + +TEST_P(TlsAgentEchTest, LongEchConfig) { + EnsureInit(); + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, priv); + echconfig.Write(echconfig.len(), 1, 1); // Append one byte + InstallEchConfig(echconfig, SEC_ERROR_BAD_DATA); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(), + PR_FALSE)); // Don't GREASE + auto filter = MakeTlsFilter<TlsExtensionCapture>( + agent_, ssl_tls13_encrypted_client_hello_xtn); + agent_->Handshake(); + ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state()); + ASSERT_FALSE(filter->captured()); +} + +TEST_P(TlsAgentEchTest, UnsupportedEchConfigVersion) { + EnsureInit(); + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + static const uint8_t bad_version[] = {0xff, 0xff}; + DataBuffer bad_ver_buf(bad_version, sizeof(bad_version)); + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, priv); + echconfig.Splice(bad_ver_buf, 2, 2); + InstallEchConfig(echconfig, SEC_ERROR_INVALID_ARGS); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(), + PR_FALSE)); // Don't GREASE + auto filter = MakeTlsFilter<TlsExtensionCapture>( + agent_, ssl_tls13_encrypted_client_hello_xtn); + agent_->Handshake(); + ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state()); + ASSERT_FALSE(filter->captured()); +} + +TEST_P(TlsAgentEchTest, UnsupportedHpkeKem) { + EnsureInit(); + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + // SSL_EncodeEchConfigId encodes without validation. + TlsConnectTestBase::GenerateEchConfig(static_cast<HpkeKemId>(0xff), + kDefaultSuites, kPublicName, 100, + echconfig, pub, priv); + InstallEchConfig(echconfig, SEC_ERROR_INVALID_ARGS); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(), + PR_FALSE)); // Don't GREASE + auto filter = MakeTlsFilter<TlsExtensionCapture>( + agent_, ssl_tls13_encrypted_client_hello_xtn); + agent_->Handshake(); + ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state()); + ASSERT_FALSE(filter->captured()); +} + +TEST_P(TlsAgentEchTest, EchRejectIgnoreAllUnknownSuites) { + EnsureInit(); + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kBogusSuite, + kPublicName, 100, echconfig, pub, priv); + InstallEchConfig(echconfig, SEC_ERROR_INVALID_ARGS); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(), + PR_FALSE)); // Don't GREASE + auto filter = MakeTlsFilter<TlsExtensionCapture>( + agent_, ssl_tls13_encrypted_client_hello_xtn); + agent_->Handshake(); + ASSERT_FALSE(filter->captured()); +} + +TEST_P(TlsAgentEchTest, EchConfigRejectEmptyPublicName) { + EnsureInit(); + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kBogusSuite, "", + 100, echconfig, pub, priv); + InstallEchConfig(echconfig, SSL_ERROR_RX_MALFORMED_ECH_CONFIG); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(), + PR_FALSE)); // Don't GREASE + auto filter = MakeTlsFilter<TlsExtensionCapture>( + agent_, ssl_tls13_encrypted_client_hello_xtn); + agent_->Handshake(); + ASSERT_FALSE(filter->captured()); +} + +TEST_F(TlsConnectStreamTls13, EchAcceptIgnoreSingleUnknownSuite) { + EnsureTlsSetup(); + DataBuffer echconfig; + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, + kUnknownFirstSuite, kPublicName, 100, + echconfig, pub, priv); + ASSERT_EQ(SECSuccess, + SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(), + echconfig.len())); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); +} + +TEST_P(TlsAgentEchTest, ApiInvalidArgs) { + EnsureInit(); + // SetClient + EXPECT_EQ(SECFailure, SSL_SetClientEchConfigs(agent_->ssl_fd(), nullptr, 1)); + + EXPECT_EQ(SECFailure, + SSL_SetClientEchConfigs(agent_->ssl_fd(), + reinterpret_cast<const uint8_t*>(1), 0)); + + // SetServer + EXPECT_EQ(SECFailure, + SSL_SetServerEchConfigs(agent_->ssl_fd(), nullptr, + reinterpret_cast<SECKEYPrivateKey*>(1), + reinterpret_cast<const uint8_t*>(1), 1)); + EXPECT_EQ(SECFailure, + SSL_SetServerEchConfigs( + agent_->ssl_fd(), reinterpret_cast<SECKEYPublicKey*>(1), + nullptr, reinterpret_cast<const uint8_t*>(1), 1)); + EXPECT_EQ(SECFailure, + SSL_SetServerEchConfigs( + agent_->ssl_fd(), reinterpret_cast<SECKEYPublicKey*>(1), + reinterpret_cast<SECKEYPrivateKey*>(1), nullptr, 1)); + EXPECT_EQ(SECFailure, + SSL_SetServerEchConfigs(agent_->ssl_fd(), + reinterpret_cast<SECKEYPublicKey*>(1), + reinterpret_cast<SECKEYPrivateKey*>(1), + reinterpret_cast<const uint8_t*>(1), 0)); + + // GetRetries + EXPECT_EQ(SECFailure, SSL_GetEchRetryConfigs(agent_->ssl_fd(), nullptr)); + + // EncodeEchConfigId + EXPECT_EQ(SECFailure, + SSL_EncodeEchConfigId(0, nullptr, 1, static_cast<HpkeKemId>(1), + reinterpret_cast<SECKEYPublicKey*>(1), + reinterpret_cast<HpkeSymmetricSuite*>(1), 1, + reinterpret_cast<uint8_t*>(1), + reinterpret_cast<unsigned int*>(1), 1)); + + EXPECT_EQ(SECFailure, + SSL_EncodeEchConfigId(0, "name", 1, static_cast<HpkeKemId>(1), + reinterpret_cast<SECKEYPublicKey*>(1), + nullptr, 1, reinterpret_cast<uint8_t*>(1), + reinterpret_cast<unsigned int*>(1), 1)); + EXPECT_EQ(SECFailure, + SSL_EncodeEchConfigId(0, "name", 1, static_cast<HpkeKemId>(1), + reinterpret_cast<SECKEYPublicKey*>(1), + reinterpret_cast<HpkeSymmetricSuite*>(1), 0, + reinterpret_cast<uint8_t*>(1), + reinterpret_cast<unsigned int*>(1), 1)); + + EXPECT_EQ(SECFailure, SSL_EncodeEchConfigId( + 0, "name", 1, static_cast<HpkeKemId>(1), nullptr, + reinterpret_cast<HpkeSymmetricSuite*>(1), 1, + reinterpret_cast<uint8_t*>(1), + reinterpret_cast<unsigned int*>(1), 1)); + + EXPECT_EQ(SECFailure, + SSL_EncodeEchConfigId(0, nullptr, 0, static_cast<HpkeKemId>(1), + reinterpret_cast<SECKEYPublicKey*>(1), + reinterpret_cast<HpkeSymmetricSuite*>(1), 1, + reinterpret_cast<uint8_t*>(1), + reinterpret_cast<unsigned int*>(1), 1)); + + EXPECT_EQ(SECFailure, SSL_EncodeEchConfigId( + 0, "name", 1, static_cast<HpkeKemId>(1), + reinterpret_cast<SECKEYPublicKey*>(1), + reinterpret_cast<HpkeSymmetricSuite*>(1), 1, + nullptr, reinterpret_cast<unsigned int*>(1), 1)); + + EXPECT_EQ(SECFailure, + SSL_EncodeEchConfigId(0, "name", 1, static_cast<HpkeKemId>(1), + reinterpret_cast<SECKEYPublicKey*>(1), + reinterpret_cast<HpkeSymmetricSuite*>(1), 1, + reinterpret_cast<uint8_t*>(1), nullptr, 1)); +} + +TEST_P(TlsAgentEchTest, NoEarlyRetryConfigs) { + EnsureInit(); + StackSECItem retry_configs; + EXPECT_EQ(SECFailure, + SSL_GetEchRetryConfigs(agent_->ssl_fd(), &retry_configs)); + EXPECT_EQ(SSL_ERROR_HANDSHAKE_NOT_COMPLETED, PORT_GetError()); + + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, priv); + InstallEchConfig(echconfig, 0); + + EXPECT_EQ(SECFailure, + SSL_GetEchRetryConfigs(agent_->ssl_fd(), &retry_configs)); + EXPECT_EQ(SSL_ERROR_HANDSHAKE_NOT_COMPLETED, PORT_GetError()); +} + +TEST_P(TlsAgentEchTest, NoSniSoNoEch) { + EnsureInit(); + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, priv); + SSL_SetURL(agent_->ssl_fd(), ""); + InstallEchConfig(echconfig, 0); + SSL_SetURL(agent_->ssl_fd(), ""); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(), + PR_FALSE)); // Don't GREASE + auto filter = MakeTlsFilter<TlsExtensionCapture>( + agent_, ssl_tls13_encrypted_client_hello_xtn); + agent_->Handshake(); + ASSERT_FALSE(filter->captured()); +} + +TEST_P(TlsAgentEchTest, NoEchConfigSoNoEch) { + EnsureInit(); + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(), + PR_FALSE)); // Don't GREASE + auto filter = MakeTlsFilter<TlsExtensionCapture>( + agent_, ssl_tls13_encrypted_client_hello_xtn); + agent_->Handshake(); + ASSERT_FALSE(filter->captured()); +} + +TEST_P(TlsAgentEchTest, EchConfigDuplicateExtensions) { + EnsureInit(); + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, priv); + + static const uint8_t duped_xtn[] = {0x00, 0x08, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x01, 0x00, 0x00}; + DataBuffer buf(duped_xtn, sizeof(duped_xtn)); + echconfig.Truncate(echconfig.len() - 2); + echconfig.Append(buf); + uint32_t len; + ASSERT_TRUE(echconfig.Read(0, 2, &len)); + len += buf.len() - 2; + DataBuffer new_len; + ASSERT_TRUE(new_len.Write(0, len, 2)); + echconfig.Splice(new_len, 0, 2); + new_len.Truncate(0); + + ASSERT_TRUE(echconfig.Read(4, 2, &len)); + len += buf.len() - 2; + ASSERT_TRUE(new_len.Write(0, len, 2)); + echconfig.Splice(new_len, 4, 2); + + InstallEchConfig(echconfig, SEC_ERROR_EXTENSION_VALUE_INVALID); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(), + PR_FALSE)); // Don't GREASE + auto filter = MakeTlsFilter<TlsExtensionCapture>( + agent_, ssl_tls13_encrypted_client_hello_xtn); + agent_->Handshake(); + ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state()); + ASSERT_FALSE(filter->captured()); +} + +TEST_F(TlsConnectStreamTls13Ech, EchFixedConfig) { + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + EnsureTlsSetup(); + ImportFixedEchKeypair(pub, priv); + SetMutualEchConfigs(pub, priv); + + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); +} + +// The next set of tests all use a fixed server key and a pre-built ClientHello. +// This ClientHelo can be constructed using the above EchFixedConfig test, +// modifying tls13_ConstructInnerExtensionsFromOuter as indicated. For this +// small number of tests, these fixed values are easier to construct than +// constructing ClientHello in the test that can be successfully decrypted. + +// Test an encoded ClientHelloInner containing an extra extensionType +// in outer_extensions, for which there is no corresponding (uncompressed) +// extension in ClientHelloOuter. +TEST_F(TlsConnectStreamTls13Ech, EchOuterExtensionsReferencesMissing) { + // Construct this by prepending 0xabcd to ssl_tls13_outer_extensions_xtn. + std::string ch = + "010001fc030390901d039ca83262d9115a5f98f43ddb2553241a8de5c46d9f118c4c29c2" + "64bc000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff" + "01000100000a00140012001d00170018001901000101010201030104003300260024001d" + "00206df5f908d1c02320e246694c765d5ec1c0f7d7aef2b1b00b17c36331623d332d002b" + "0003020304000d0018001604030503060302030804080508060401050106010201002d00" + "020101001c00024001fe0d00f900000100034d00209a4f67b0744d1fba23aa4bacfadb2a" + "c706562dae04d80a83ae668a6f2dd6ef2700cfab1671182341df246d66c3aca873e8c714" + "bc2b1c3b576653609533c486df0bdcf63ab4e4e7d0b67fadf4e3504eec96f72e6778b15d" + "69c9a9594a041348a7130f67a1a7cac796a0e6d6fca505438355278a9a8fd55e44218441" + "9927a1e084ac7d7adeb2f0c19faafba430876bf0cdf4d195b2d06428b3de13120f65748a" + "468f8997a2c3bf1dd7f3996a0f2c70dea6c88149df182b3c3b78a8da8bb709a9ed9d77c6" + "5dc09accdfeb66c90db26b99a35052a8cbaf7bb9307a1e17d90a7aa9f768f5f446559d08" + "69bccc83eda9d2b347a00015004200000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000000000000000000" + "0000000000000000"; + ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter, + SSL_ERROR_RX_MALFORMED_ECH_EXTENSION, + SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_F(TlsConnectStreamTls13Ech, EchOuterExtensionsInsideInner) { + // Construct this by appending ssl_tls13_outer_extensions_xtn to the + // references in + // ssl_tls13_outer_extensions_xtn. + std::string ch = + "010001fc03035e2268bc7133079cd33eb088253393e561d80c5ee6f9a238aff022e1e10d" + "4c82000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff" + "01000100000a00140012001d00170018001901000101010201030104003300260024001d" + "00200e071fd982854d50236ed0e4e7981460840f03d03fd84b44c409fe486203b252002b" + "0003020304000d0018001604030503060302030804080508060401050106010201002d00" + "020101001c00024001fe0d00f900000100034d002099a032502ea4fd3c85b858ae1c59df" + "6a374f3698ed6bca188cf75c432c78cf5a00cf28dde32de7ade40abb16d550c1eec3dad4" + "a03c85efb95ec605837deae92a419285116e5cb8223ea53cff2b605e66f28e96d37e9b4e" + "3035fb1cfa125fa053d6770091b5731c9fb03e872a82991dfdd24ad8399fcc76db7fadba" + "029e064beb02c1282684a93e777bcefbca3dd143dfc225d2e65c80dbf3819ebda288e32c" + "3a1f8a27bb3aa9480dee2a4307073da3e15ee03dba386223d9399ad796af80c646f85406" + "282c34fd9406d25752087f08140e1be834e8a149f0bebfc2b3db16ccba83c37051e2e75d" + "e8a4e999ef385c74c96d0015004200000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000000000000000000" + "0000000000000000"; + ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter, + SSL_ERROR_RX_MALFORMED_ECH_EXTENSION, + SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_F(TlsConnectStreamTls13Ech, EchOuterExtensionsDuplicateReference) { + // Construct this by altering tls13_ConstructInnerExtensionsFromOuter to have + // each extension inserted in the reference list twice and running the + // EchFixedConfig test. + std::string ch = + "010001fc0303d8717df80286fcd8b4242ed846995c6473e290678231046bb1bfc7848460" + "b122000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff" + "01000100000a00140012001d00170018001901000101010201030104003300260024001d" + "00206f21d5fdf7bf81943939a03656c1195ad347cec453bf7a16d0773ffef481d22f002b" + "0003020304000d0018001604030503060302030804080508060401050106010201002d00" + "020101001c00024001fe0d011900000100034d002027eb9b641ba8ffc3a4028d00d1f5bd" + "e190736b1ea5a79513dee0a551cc6fe55200efc2ed6bf501f100896eb91221ce512c20c3" + "c5c110e7be6a5d340854ff5ac0175312631b021fd5a5c9841549989f415c4041a4b384b1" + "dba1d6b4182cc48904f993a15eab6bf7787b267ca65acef51c019508e0c9b382086a71d8" + "517cf19644d66d396efc066a4d37916d67b0e5fe08d52dd94d068dd85b9a245aaffac4ff" + "66d9a5221fd5805473bb7584eb7f218357c00aff890d2f2edf1c092c648c888b5cba1ca6" + "26817fda7765fcedfbc418b90b1841d878ed443593cafb61fa8fb708c53977615b45f545" + "2a8236cab3ec121cdc91a2de6a79437cae9d09e781339fddcac005ce62fd65d50e33faa2" + "2366955a0374001500220000000000000000000000000000000000000000000000000000" + "0000000000000000"; + ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter, + SSL_ERROR_RX_MALFORMED_ECH_EXTENSION, + SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_F(TlsConnectStreamTls13Ech, EchOuterExtensionsOutOfOrder) { + // Construct this by altering tls13_ConstructInnerExtensionsFromOuter to leave + // a gap at the start and insert a 'late' extension there. + std::string ch = + "010001fc0303fabff6caf4d8b1eb1db5945c96badefec4b33188b97121e6a206e82b74bd" + "a855000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff" + "01000100000a00140012001d00170018001901000101010201030104003300260024001d" + "00208fe970fc0c908f0c51734f18467e640df1d45a6ace2948b5c4bf73ee52ab3160002b" + "0003020304000d0018001604030503060302030804080508060401050106010201002d00" + "020101001c00024001fe0d00f900000100034d00203339239f8925c3f9b89f4ced17c3b3" + "1c649299d7e10b3cdbc115de2a57d90d2200cf006e62866516380e8a16763bee5b2a75a8" + "74e8698c459f474d0e952c2fd3300bef1decd6f259b8ac2912684ef69b7a7be2520fbf15" + "5e0c3f88998789976ca1fbcaa40616fc513e3353540db091da76ca98007532974550d3da" + "aaddb799baf60adbc5800df30e187251427fe9de707d18a270352ee44f6eb37f0d8c72a1" + "5f9ffb5dd4bbb6045473c8d99b7a5c2c8cc59027f346cbe6ef240d5cf1919f58a998d427" + "0f8c882d03d22ec4df4079e15a639452ea4c24023f6bcad89566ce6a32b1dad6ddf6b436" + "3e6759bd48bed1b30a840015004200000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000000000000000000" + "0000000000000000"; + ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter, + SSL_ERROR_RX_MALFORMED_ECH_EXTENSION, + SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Drop supported_versions from CHInner, make sure we don't negotiate 1.2+ECH. +TEST_F(TlsConnectStreamTls13Ech, EchVersion12Inner) { + // Construct this by removing ssl_tls13_supported_versions_xtn entirely. + std::string ch = + "010001fc030338e9ebcde2b87ef779c4d9a9b9870aef3978130b254fbf168a92644c97c1" + "c5cb000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff" + "01000100000a00140012001d00170018001901000101010201030104003300260024001d" + "002081b3ea444fd631f9264e01276bcc1a6771aed3b5a8a396446467d1c820e52b25002b" + "0003020304000d0018001604030503060302030804080508060401050106010201002d00" + "020101001c00024001fe0d00f900000100034d00205864042b43f4d4d544558fbcba410f" + "ebfb78ddfc5528672a7f7d9e70abc3eb6300cf6ff3271da628139bddc4a58ee92db26170" + "7310dee54d88c8a96a8d998b8608d5f10260b7e201e5dc8cafa13917a3fdfdf399082959" + "8adf3c291decf640f696e64c4e22bafb81565587c50dd829ccad68bd00babeaba7d8a7a5" + "400ad3200dbae674c549953ca6d3298ed751a9bc215a33be444fe908bf1c6f374cc139f9" + "98339f58b8fd3510a670e4102e3f7de21586ebd70c3fb1df8bb6b9e5dbc0db147dbac6d0" + "72dfc6cdf17ecee5c019c311b37ef9f5ceabb7edbdf87a4a04041c4d8b512a16517c5380" + "e8d4f6e3b2412b4a6c030015004200000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000000000000000000" + "0000000000000000"; + ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter, + SSL_ERROR_UNSUPPORTED_VERSION, + SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Use CHInner supported_versions to negotiate 1.2. +TEST_F(TlsConnectStreamTls13Ech, EchVersion12InnerSupportedVersions) { + // Construct this by changing ssl_tls13_supported_versions_xtn to write + // TLS 1.2 instead of TLS 1.3. + std::string ch = + "010001fc0303f7146bdc88c399feb49c62b796db2f8b1330e25292a889edf7c65231d0be" + "b95f000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff" + "01000100000a00140012001d00170018001901000101010201030104003300260024001d" + "0020d31f8eb204efba49dbdbf40bb046b1e0b90fa3f034260d60f351d4b15e614e7f002b" + "0003020304000d0018001604030503060302030804080508060401050106010201002d00" + "020101001c00024001fe0d00f900000100034d0020eaa25e92721e65fd405577bf2fd322" + "857e60f8766a595929fc404c9a01ef441200cf04992c693fbc8eac87726b336a11abc411" + "541ceff50d533d4cf4d6e1078479acb5446675b652f22d6db04daf0c3640ec2429ba4f51" + "99c00daa43e9a7d85bd6733041feeca0b38ee6ca07042c7e67d40cd3e236499f3f9d92ab" + "e4642e483c75d77c247b0228bc773c09551d15845c35663afd1805c5b3adb136ffa6d94f" + "b7cbfe93d5d33c894b2a6437ad9a2278d5863ed20db652a6084c9e95a8dfaf821d0b474a" + "7efc2839f110edb4a73376ecab629b26b1eea63304899c49a07157fbbee67c786686cb04" + "a53666a74e1e003aefc70015004200000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000000000000000000" + "0000000000000000"; + ReplayChWithMalformedInner(ch, kTlsAlertProtocolVersion, + SSL_ERROR_UNSUPPORTED_VERSION, + SSL_ERROR_PROTOCOL_VERSION_ALERT); +} + +// Replay a CH for which CHInner lacks the required ech xtn of inner type +TEST_F(TlsConnectStreamTls13Ech, EchInnerMissing) { + // Construct by omitting the ech inner extension + std::string ch = + "010001fc0303fa9cd9cf5b77bb4083f69a1d169d44b356faea0d6a0aee6d50412de6fef7" + "8d22000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff" + "01000100000a00140012001d00170018001901000101010201030104003300260024001d" + "0020c329f1dde4d51b50f68c21053b545290b250af527b2832d3acf2c6af9b8b8d5c002b" + "0003020304000d0018001604030503060302030804080508060401050106010201002d00" + "020101001c00024001fe0d00f900000100034d00207e2a0397b7d2776ae468057d630243" + "b01388cf80680b074323adf4091aba7b4c00cff4b649fb5b3a0719c1e085c7006a95eaad" + "32375b717a42d009c075e6246342fdc1e847c528495f90378ff5b4912da5190f7e8bfa1c" + "c9744b50e9e469cd7cd12bcb5f6534b7d617459d2efa4d796ad244567c49f1d22feb08a5" + "8e8ebdce059c28883dd69ca401e189f3ef438c3f0bf3d377e6727a1f6abf3a8a8cc149ee" + "60a1aa5ba4a50e99d2519216762558e9613a238bd630b5822f549575d9402f8da066aaef" + "2e0e6a7a04583b041925e0ef4575107c4436f9af26e561c0ab733cd88bee6a20e6414128" + "ea0ba1c73612bb62c1e90015004200000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000000000000000000" + "0000000000000000"; + ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter, + SSL_ERROR_MISSING_ECH_EXTENSION, + SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_F(TlsConnectStreamTls13Ech, EchInnerWrongSize) { + // Construct by including ech inner with wrong size + std::string ch = + "010001fc03035f8410dab9e49b0833d13390f3fe0b3c6321d842961c9cc46b59a0b5b8e1" + "4e0b000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff" + "01000100000a00140012001d00170018001901000101010201030104003300260024001d" + "0020526a56087d685e574accb0e87d6781bc553612479e56460fe6a497fa1cd74e2e002b" + "0003020304000d0018001604030503060302030804080508060401050106010201002d00" + "020101001c00024001fe0d00f900000100034d00200d096bf6ac0c3bcb79d70677da0e0d" + "249b40bc5ba6b8727654619fe6567d0b0700cfd13e136d2d041e3cd993b252386d97e98d" + "c972d29d28e0281a210fa56156b95e4371a6610a0b3e65f1b842875fb456de9b9c0e03f8" + "aa4d1055057ac3e20e5fa45b837ccbb06ef3856c71f1f63e91b60bfb5f3415f26e9a0d3c" + "4d404d5d5aaa6dca8d57cf2e6b4aaf399fa7271b0c1eedbfdd85fbc9711b0446eb9c9535" + "a74f3e5a71e2e22dc8d89980f96233ec9b80fbe4f295ff7903bade407fc544c8d76df4fb" + "ce4b8d79cea0ff7e0b0736ecbeaf5a146a4f81a930e788ae144cf2219e90dc3594165a7e" + "2a0b64f6189a87a348840015004200000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000000000000000000" + "0000000000000000"; + ReplayChWithMalformedInner(ch, kTlsAlertDecodeError, + SSL_ERROR_RX_MALFORMED_ESNI_EXTENSION, + SSL_ERROR_DECODE_ERROR_ALERT); +} + +TEST_F(TlsConnectStreamTls13Ech, InnerWithEchAndEchIsInner) { + // Construct by appending an empty ssl_tls13_encrypted_client_hello_xtn of + // type outer to + // CHInner. + std::string ch = + "010001fc0303527df5a8dbcf390c184c5274295283fdba78d05784170d8f3cb8c7d84747" + "afb5000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff" + "01000100000a00140012001d00170018001901000101010201030104003300260024001d" + "002099461dcfcdc7804a0f34bf3ca49ac39776a7ef4d8edd30fab3599ff59b09f826002b" + "0003020304000d0018001604030503060302030804080508060401050106010201002d00" + "020101001c00024001fe0d00f900000100034d00201da1341e8ba21ff90e025d2438d4e5" + "b4e8b376befc57cf8c9afb484e6f051b2f00cff747491b810705e5cc8d8a1302468000d9" + "8660d659d8382a6fc23ca1a582def728eabb363771328035565048213b1d725b20f757be" + "63d6956cd861aa9d33adcc913de2443695f70e130af96fd2b078dd662478a29bd17a4479" + "715c949b5fc118456d0243c9d1819cecd0f5fbd1c78dadd6fcd09abe41ca97a00c97efb3" + "894c9d4bab60dcd150b55608f6260723a08e112e39e6a43f645f85a08085054f27f269bc" + "1acb9ff5007b04eaef3414767666472e4e24c2a2953f5dc68aeb5207d556f1b872a810b6" + "686cf83a09db8b474df70015004200000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000000000000000000" + "0000000000000000"; + ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter, + SSL_ERROR_RX_UNEXPECTED_EXTENSION, + SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_F(TlsConnectStreamTls13, EchWithInnerExtNotSplit) { + static uint8_t type_val[1] = {1}; + DataBuffer type_buffer(type_val, sizeof(type_val)); + + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_FALSE)); + MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello, + ssl_tls13_encrypted_client_hello_xtn, + type_buffer); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_EXTENSION); +} + +/* Parameters + * Length of SNI for first connection + * Length of SNI for second connection + * Use GREASE for first connection? + * Use GREASE for second connection? + * For both connections, SNI length to pad to. + */ +class EchCHPaddingTest : public TlsConnectStreamTls13, + public testing::WithParamInterface< + std::tuple<int, int, bool, bool, int>> {}; + +TEST_P(EchCHPaddingTest, EchChPaddingEqual) { + auto parameters = GetParam(); + std::string name_str1 = std::string(std::get<0>(parameters), 'a'); + std::string name_str2 = std::string(std::get<1>(parameters), 'a'); + const char* name1 = name_str1.c_str(); + const char* name2 = name_str2.c_str(); + bool grease_mode1 = std::get<2>(parameters); + bool grease_mode2 = std::get<3>(parameters); + uint8_t max_name_len = std::get<4>(parameters); + + // Connection 1 + EnsureTlsSetup(); + SSL_SetURL(client_->ssl_fd(), name1); + if (grease_mode1) { + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, + SSL_SetTls13GreaseEchSize(client_->ssl_fd(), max_name_len)); + client_->ExpectEch(false); + server_->ExpectEch(false); + } else { + SetupEch(client_, server_, HpkeDhKemX25519Sha256, true, true, true, + max_name_len); + } + auto filter1 = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + Connect(); + size_t echXtnLen1 = filter1->extension().len(); + + Reset(); + + // Connection 2 + EnsureTlsSetup(); + SSL_SetURL(client_->ssl_fd(), name2); + if (grease_mode2) { + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, + SSL_SetTls13GreaseEchSize(client_->ssl_fd(), max_name_len)); + client_->ExpectEch(false); + server_->ExpectEch(false); + } else { + SetupEch(client_, server_, HpkeDhKemX25519Sha256, true, true, true, + max_name_len); + } + auto filter2 = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + Connect(); + size_t echXtnLen2 = filter2->extension().len(); + + // We always expect an ECH extension. + ASSERT_TRUE(echXtnLen2 > 0 && echXtnLen1 > 0); + // We expect the ECH extension to round to the same multiple of 32. + // Note: It will not be 0 % 32 because we pad the Payload, but have a number + // of extra bytes from the rest of the ECH extension (e.g. ciphersuite) + ASSERT_EQ(echXtnLen1 % 32, echXtnLen2 % 32); + // Both connections should have the same size after padding. + if (name_str1.size() <= max_name_len && name_str2.size() <= max_name_len) { + ASSERT_EQ(echXtnLen1, echXtnLen2); + } +} + +#define ECH_PADDING_TEST_INSTANTIATE(name, values) \ + INSTANTIATE_TEST_SUITE_P(name, EchCHPaddingTest, \ + testing::Combine(values, values, testing::Bool(), \ + testing::Bool(), values)) + +const int kExtremalSNILengths[] = {1, 128, 255}; +const int kNormalSNILengths[] = {17, 24, 100}; +const int kLongSNILengths[] = {90, 167, 214}; + +/* Each invocation with N lengths, results in 4N^3 test cases, so we test + * 3 lots of (4*3^3) rather than all permutations. */ +ECH_PADDING_TEST_INSTANTIATE(extremal, testing::ValuesIn(kExtremalSNILengths)); +ECH_PADDING_TEST_INSTANTIATE(normal, testing::ValuesIn(kNormalSNILengths)); +ECH_PADDING_TEST_INSTANTIATE(lengthy, testing::ValuesIn(kLongSNILengths)); + +// Check the server rejects ClientHellos with bad padding +TEST_F(TlsConnectStreamTls13Ech, EchChPaddingChecked) { + // Generate this string by changing the padding in + // tls13_GenPaddingClientHelloInner + std::string ch = + "010001fc03037473367a6eb6773391081b403908fc0c0026aac706889c59ca694d0c1188" + "c4b3000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff" + "01000100000a00140012001d00170018001901000101010201030104003300260024001d" + "0020f7d8ad5fea0165e115e984e11c43f1d8f255bd8f772b893432d8d7721e91785a002b" + "0003020304000d0018001604030503060302030804080508060401050106010201002d00" + "020101001c00024001fe0d00f900000100034d00207e0ad8e83f8a9c89e1ae4fd65b8091" + "01e496bbb5f29ce20b299ce58937e2563300cff471a787585e15ae5aff5e4fee7ec988ba" + "72f8a95db41e793568b0301d553251f0826dc0c3ff658e4e029ef840ae86fa80af4b11b5" + "3a33fab99887bf8df18bc87abbb1f578f7964848d91a2023cbe7609fcc31bd721865009c" + "ad68c09e438d677f7c56af76e62c168bdb373bb88962471dacc4ddf654e435cd903f6555" + "4c9a93ffd2541cd7bce520e7215d15495184b781ca8c138cedd573fbdef1d40e5de82c33" + "5c9c43370102ecb0b66dd27efc719a9a54589b6e6b599b1b0146e121eae0ab5b2070c12f" + "4f4f2b099808294a459f0015004200000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000000000000000000" + "0000000000000000"; + ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter, + SSL_ERROR_RX_MALFORMED_ECH_EXTENSION, + SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_F(TlsConnectStreamTls13Ech, EchConfigList) { + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + EnsureTlsSetup(); + + DataBuffer config1; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes, + kPublicName, 100, config1, pub, priv); + DataBuffer config2; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes, + kPublicName, 100, config2, pub, priv); + DataBuffer configList = MakeEchConfigList(config1, config2); + SECStatus rv = + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + configList.data(), configList.len()); + printf("%u", rv); + ASSERT_EQ(rv, SECSuccess); +} + +TEST_F(TlsConnectStreamTls13Ech, EchConfigsTrialDecrypt) { + // Apply two ECHConfigs on the server. They are identical with the exception + // of the public key: the first ECHConfig contains a public key for which we + // lack the private value. Use an SSLInt function to zero all the config_ids + // (client and server), then confirm that trial decryption works. + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + EnsureTlsSetup(); + ImportFixedEchKeypair(pub, priv); + ScopedSECKEYPublicKey pub2; + ScopedSECKEYPrivateKey priv2; + DataBuffer config2; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes, + kPublicName, 100, config2, pub, priv); + DataBuffer config1; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes, + kPublicName, 100, config1, pub2, priv2); + // Zero the config id for both, only public key differs. + config2.Write(7, (uint32_t)0, 1); + config1.Write(7, (uint32_t)0, 1); + // Server only knows private key for conf2 + DataBuffer configList = MakeEchConfigList(config1, config2); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + configList.data(), configList.len())); + ASSERT_EQ(SECSuccess, SSL_SetClientEchConfigs(client_->ssl_fd(), + config2.data(), config2.len())); + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); +} + +TEST_F(TlsConnectStreamTls13Ech, EchAcceptBasic) { + EnsureTlsSetup(); + SetupEch(client_, server_); + auto c_filter_sni = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_server_name_xtn); + Connect(); + ASSERT_TRUE(c_filter_sni->captured()); + CheckSniExtension(c_filter_sni->extension(), kPublicName); +} + +TEST_F(TlsConnectStreamTls13, EchAcceptWithResume) { + EnsureTlsSetup(); + SetupEch(client_, server_); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); + EnsureTlsSetup(); + SetupEch(client_, server_); + ExpectResumption(RESUME_TICKET); + auto filter = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn); + StartConnect(); + Handshake(); + CheckConnected(); + // Make sure that the PSK extension is only in CHInner. + ASSERT_TRUE(filter->captured()); +} + +TEST_F(TlsConnectStreamTls13, EchAcceptWithExternalPsk) { + static const std::string kPskId = "testing123"; + EnsureTlsSetup(); + SetupEch(client_, server_); + + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + ASSERT_TRUE(!!slot); + ScopedPK11SymKey key( + PK11_KeyGen(slot.get(), CKM_HKDF_KEY_GEN, nullptr, 16, nullptr)); + ASSERT_TRUE(!!key); + AddPsk(key, kPskId, ssl_hash_sha256); + + // Not permitted in outer. + auto filter = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn); + StartConnect(); + Handshake(); + CheckConnected(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none); + // The PSK extension is present in CHOuter. + ASSERT_TRUE(filter->captured()); + + // But the PSK in CHOuter is completely different. + // (Failure/collision chance means kPskId needs to be longish.) + uint32_t v = 0; + ASSERT_TRUE(filter->extension().Read(0, 2, &v)); + ASSERT_EQ(v, kPskId.size() + 2 + 4) << "check size of identities"; + ASSERT_TRUE(filter->extension().Read(2, 2, &v)); + ASSERT_EQ(v, kPskId.size()) << "check size of identity"; + bool different = false; + for (size_t i = 0; i < kPskId.size(); ++i) { + ASSERT_TRUE(filter->extension().Read(i + 4, 1, &v)); + different |= v != static_cast<uint8_t>(kPskId[i]); + } + ASSERT_TRUE(different); +} + +// If an earlier version is negotiated, False Start must be disabled. +TEST_F(TlsConnectStreamTls13, EchDowngradeNoFalseStart) { + EnsureTlsSetup(); + SetupEch(client_, server_, HpkeDhKemX25519Sha256, false, true, false); + MakeTlsFilter<TlsExtensionDropper>(client_, + ssl_tls13_encrypted_client_hello_xtn); + client_->EnableFalseStart(); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + + StartConnect(); + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + EXPECT_FALSE(client_->can_falsestart_hook_called()); + + // Make sure the write is blocked. + client_->ExpectReadWriteError(); + client_->SendData(10); +} + +SSLHelloRetryRequestAction RetryEchHello(PRBool firstHello, + const PRUint8* clientToken, + unsigned int clientTokenLen, + PRUint8* appToken, + unsigned int* appTokenLen, + unsigned int appTokenMax, void* arg) { + auto* called = reinterpret_cast<size_t*>(arg); + ++*called; + + EXPECT_EQ(0U, clientTokenLen); + return firstHello ? ssl_hello_retry_request : ssl_hello_retry_accept; +} + +// Generate HRR on CH1 Inner +TEST_F(TlsConnectStreamTls13, EchAcceptWithHrr) { + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, priv); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + ASSERT_EQ(SECSuccess, + SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(), + echconfig.len())); + client_->ExpectEch(); + server_->ExpectEch(); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + + auto server_hrr_ech_xtn = MakeTlsFilter<TlsExtensionCapture>( + server_, ssl_tls13_encrypted_client_hello_xtn); + // Start the handshake. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + MakeNewServer(); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + client_->ExpectEch(); + server_->ExpectEch(); + Handshake(); + ASSERT_TRUE(server_hrr_ech_xtn->captured()); + EXPECT_EQ(1U, cb_called); + CheckConnected(); + SendReceive(); +} + +TEST_F(TlsConnectStreamTls13Ech, EchGreaseSize) { + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE)); + + auto greased_ext = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + Connect(); + ASSERT_TRUE(greased_ext->captured()); + + Reset(); + EnsureTlsSetup(); + + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + ImportFixedEchKeypair(pub, priv); + SetMutualEchConfigs(pub, priv); + + auto real_ext = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); + + ASSERT_TRUE(real_ext->captured()); + ASSERT_EQ(real_ext->extension().len(), greased_ext->extension().len()); +} + +TEST_F(TlsConnectStreamTls13Ech, EchGreaseClientDisable) { + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + EnsureTlsSetup(); + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, priv); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_FALSE)); + + auto c_filter_esni = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + + Connect(); + ASSERT_TRUE(!c_filter_esni->captured()); +} + +TEST_F(TlsConnectStreamTls13Ech, EchHrrGreaseServerDisable) { + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(server_->ssl_fd(), PR_FALSE)); + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + + auto server_hrr_ech_xtn = MakeTlsFilter<TlsExtensionCapture>( + server_, ssl_tls13_encrypted_client_hello_xtn); + // Start the handshake. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + MakeNewServer(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(server_->ssl_fd(), PR_FALSE)); + Handshake(); + ASSERT_TRUE(!server_hrr_ech_xtn->captured()); + EXPECT_EQ(1U, cb_called); + CheckConnected(); + SendReceive(); +} + +TEST_F(TlsConnectStreamTls13Ech, EchGreaseSizePsk) { + // Original connection without ECH + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + SendReceive(); + + // Resumption with only GREASE + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE)); + + auto greased_ext = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + Connect(); + SendReceive(); + ASSERT_TRUE(greased_ext->captured()); + + // Finally, resume with ECH enabled + // ECH state does not determine whether resumption succeeds + // or is attempted, so this should work fine. + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET, 2); + + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + ImportFixedEchKeypair(pub, priv); + SetMutualEchConfigs(pub, priv); + + auto real_ext = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); + ASSERT_TRUE(real_ext->captured()); + + ASSERT_EQ(real_ext->extension().len(), greased_ext->extension().len()); +} + +// Send GREASE ECH in CH1. CH2 must send exactly the same GREASE ECH contents. +TEST_F(TlsConnectStreamTls13, GreaseEchHrrMatches) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), + PR_TRUE)); // GREASE + auto capture = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + + // Start the handshake. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); // Send CH1 + EXPECT_TRUE(capture->captured()); + DataBuffer ch1_grease = capture->extension(); + + server_->Handshake(); + MakeNewServer(); + capture = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + + EXPECT_FALSE(capture->captured()); + client_->Handshake(); // Send CH2 + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(ch1_grease, capture->extension()); + + EXPECT_EQ(1U, cb_called); + server_->StartConnect(); + Handshake(); + CheckConnected(); +} + +TEST_F(TlsConnectStreamTls13Ech, EchRejectMisizedEchXtn) { + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(server_->ssl_fd(), PR_TRUE)); + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + auto server_hrr_ext_xtn_fake = MakeTlsFilter<TlsExtensionResizer>( + server_, ssl_tls13_encrypted_client_hello_xtn, 34); + // Start the handshake. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + // Process the hello retry. + server_->ExpectReceiveAlert(kTlsAlertDecodeError, kTlsAlertFatal); + client_->ExpectSendAlert(kTlsAlertDecodeError); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_ECH_EXTENSION); + server_->CheckErrorCode(SSL_ERROR_DECODE_ERROR_ALERT); + EXPECT_EQ(1U, cb_called); +} + +TEST_F(TlsConnectStreamTls13Ech, EchRejectDroppedEchXtn) { + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, priv); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + ASSERT_EQ(SECSuccess, + SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(), + echconfig.len())); + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + auto server_hrr_ext_xtn_fake = MakeTlsFilter<TlsExtensionDropper>( + server_, ssl_tls13_encrypted_client_hello_xtn); + // Start the handshake. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + MakeNewServer(); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + // Process the hello retry. + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + EXPECT_EQ(1U, cb_called); +} + +// Generate an HRR on CHInner. Mangle the Hrr Xtn causing client to reject ECH +// which then causes a MAC mismatch. +TEST_F(TlsConnectStreamTls13Ech, EchRejectMangledHrrXtn) { + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, priv); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + ASSERT_EQ(SECSuccess, + SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(), + echconfig.len())); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + auto server_hrr_ech_xtn = MakeTlsFilter<TlsExtensionDamager>( + server_, ssl_tls13_encrypted_client_hello_xtn, 4); + // Start the handshake. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + MakeNewServer(); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + client_->ExpectEch(false); + server_->ExpectEch(false); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + EXPECT_EQ(1U, cb_called); +} + +// First capture an ECH CH Xtn. +// Start new connection, inject ECH CH Xtn. +// Server will respond with ECH HRR Xtn. +// Check Client correctly panics. +TEST_F(TlsConnectStreamTls13Ech, EchClientRejectSpuriousHrrXtn) { + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, priv); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE)); + client_->ExpectEch(false); + server_->ExpectEch(false); + auto client_ech_xtn_capture = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + Connect(); + ASSERT_TRUE(client_ech_xtn_capture->captured()); + + // Now configure client without ECH. Server with ECH. + Reset(); + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_FALSE)); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + client_->ExpectEch(false); + server_->ExpectEch(false); + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + + // Inject CH ECH Xtn into CH. + DataBuffer buff = DataBuffer(client_ech_xtn_capture->extension()); + auto client_ech_xtn = MakeTlsFilter<TlsExtensionAppender>( + client_, kTlsHandshakeClientHello, ssl_tls13_encrypted_client_hello_xtn, + buff); + + // Connect and check we see the HRR extension and alert. + auto server_hrr_ech_xtn = MakeTlsFilter<TlsExtensionCapture>( + server_, ssl_tls13_encrypted_client_hello_xtn); + server_hrr_ech_xtn->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + + ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension); + + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_EXTENSION); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_EXTENSION_ALERT); + ASSERT_TRUE(server_hrr_ech_xtn->captured()); +} + +// Fail to decrypt CH2. Unlike CH1, this generates an alert. +TEST_F(TlsConnectStreamTls13, EchFailDecryptCH2) { + EnsureTlsSetup(); + SetupEch(client_, server_); + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + EXPECT_EQ(1U, cb_called); + // Stop the callback from being called in future handshakes. + EXPECT_EQ(SECSuccess, + SSL_HelloRetryRequestCallback(server_->ssl_fd(), nullptr, nullptr)); + + MakeTlsFilter<TlsExtensionDamager>(client_, + ssl_tls13_encrypted_client_hello_xtn, 80); + ExpectAlert(server_, kTlsAlertDecryptError); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_ECH_EXTENSION); +} + +// Change the ECH advertisement between CH1 and CH2. Use GREASE for simplicity. +TEST_F(TlsConnectStreamTls13, EchHrrChangeCh2OfferingYN) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + + // Start the handshake, send GREASE ECH. + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), + PR_TRUE)); // GREASE + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + MakeNewServer(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), + PR_FALSE)); // Don't GREASE + ExpectAlert(server_, kTlsAlertMissingExtension); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + EXPECT_EQ(1U, cb_called); +} + +TEST_F(TlsConnectStreamTls13, EchHrrChangeCh2OfferingNY) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + SetupEch(client_, server_); + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + + MakeTlsFilter<TlsExtensionDropper>(client_, + ssl_tls13_encrypted_client_hello_xtn); + // Start the handshake. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + MakeNewServer(); + client_->ClearFilter(); // Let the second ECH offering through. + ExpectAlert(server_, kTlsAlertIllegalParameter); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + EXPECT_EQ(1U, cb_called); +} + +// Change the ECHCipherSuite between CH1 and CH2. Expect alert. +TEST_F(TlsConnectStreamTls13, EchHrrChangeCipherSuite) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + SetupEch(client_, server_); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + // Start the handshake and trigger HRR. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + MakeNewServer(); + + // Damage the first byte of the ciphersuite (offset 1) + MakeTlsFilter<TlsExtensionDamager>(client_, + ssl_tls13_encrypted_client_hello_xtn, 1); + + ExpectAlert(server_, kTlsAlertIllegalParameter); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + EXPECT_EQ(1U, cb_called); +} + +// Configure an external PSK. Generate an HRR off CH1Inner (which contains +// the PSK extension). Use the same PSK in CH2 and connect. +TEST_F(TlsConnectStreamTls13, EchAcceptWithHrrAndPsk) { + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, priv); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + ASSERT_EQ(SECSuccess, + SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(), + echconfig.len())); + client_->ExpectEch(); + server_->ExpectEch(); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + + static const uint8_t key_buf[16] = {0}; + SECItem key_item = {siBuffer, const_cast<uint8_t*>(&key_buf[0]), + sizeof(key_buf)}; + const char* label = "foo"; + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + ASSERT_TRUE(!!slot); + ScopedPK11SymKey key(PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN, + PK11_OriginUnwrap, CKA_DERIVE, + &key_item, nullptr)); + ASSERT_TRUE(!!key); + AddPsk(key, std::string(label), ssl_hash_sha256); + + // Start the handshake. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + MakeNewServer(); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + client_->ExpectEch(); + server_->ExpectEch(); + EXPECT_EQ(SECSuccess, + SSL_AddExternalPsk0Rtt(server_->ssl_fd(), key.get(), + reinterpret_cast<const uint8_t*>(label), + strlen(label), ssl_hash_sha256, 0, 1000)); + server_->ExpectPsk(); + Handshake(); + EXPECT_EQ(1U, cb_called); + CheckConnected(); + SendReceive(); +} + +// Generate an HRR on CHOuter. Reject ECH on the second CH. +TEST_F(TlsConnectStreamTls13Ech, EchRejectWithHrr) { + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + SetupForEchRetry(); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(server_->ssl_fd(), PR_TRUE)); + auto server_hrr_ech_xtn = MakeTlsFilter<TlsExtensionCapture>( + server_, ssl_tls13_encrypted_client_hello_xtn); + // Start the handshake. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + MakeNewServer(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(server_->ssl_fd(), PR_TRUE)); + client_->ExpectEch(false); + server_->ExpectEch(false); + ExpectAlert(client_, kTlsAlertEchRequired); + Handshake(); + ASSERT_TRUE(server_hrr_ech_xtn->captured()); + client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITHOUT_ECH); + server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal); + server_->Handshake(); + EXPECT_EQ(1U, cb_called); +} + +// Server can't change its mind on ECH after HRR. We change the confirmation +// value and the server panics accordingly. +TEST_F(TlsConnectStreamTls13Ech, EchHrrServerYN) { + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, priv); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + ASSERT_EQ(SECSuccess, + SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(), + echconfig.len())); + client_->ExpectEch(); + server_->ExpectEch(); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + + auto server_hrr_ech_xtn = MakeTlsFilter<TlsExtensionCapture>( + server_, ssl_tls13_encrypted_client_hello_xtn); + // Start the handshake. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + MakeNewServer(); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(), + echconfig.data(), echconfig.len())); + client_->ExpectEch(); + server_->ExpectEch(); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + auto server_random_damager = MakeTlsFilter<ServerHelloRandomChanger>(server_); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + ASSERT_TRUE(server_hrr_ech_xtn->captured()); + EXPECT_EQ(1U, cb_called); +} + +// Client sends GREASE'd ECH Xtn, server reponds with HRR in GREASE mode +// Check HRR responses are present and differ. +TEST_F(TlsConnectStreamTls13Ech, EchHrrServerGreaseChanges) { + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(server_->ssl_fd(), PR_TRUE)); + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + + auto server_hrr_ech_xtn_1 = MakeTlsFilter<TlsExtensionCapture>( + server_, ssl_tls13_encrypted_client_hello_xtn); + // Start the handshake. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + ASSERT_TRUE(server_hrr_ech_xtn_1->captured()); + EXPECT_EQ(1U, cb_called); + + /* Run the connection again */ + Reset(); + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(server_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE)); + cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + + auto server_hrr_ech_xtn_2 = MakeTlsFilter<TlsExtensionCapture>( + server_, ssl_tls13_encrypted_client_hello_xtn); + // Start the handshake. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + ASSERT_TRUE(server_hrr_ech_xtn_2->captured()); + EXPECT_EQ(1U, cb_called); + + ASSERT_TRUE(server_hrr_ech_xtn_1->extension().len() == + server_hrr_ech_xtn_2->extension().len()); + ASSERT_TRUE(memcmp(server_hrr_ech_xtn_1->extension().data(), + server_hrr_ech_xtn_2->extension().data(), + server_hrr_ech_xtn_1->extension().len())); +} + +// Reject ECH on CH1 and CH2. PSKs are no longer allowed +// in CHOuter, but we can still make sure the handshake succeeds. +// This prompts an ech_required alert when the handshake completes. +TEST_F(TlsConnectStreamTls13, EchRejectWithHrrAndPsk) { + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, echconfig, pub, priv); + ASSERT_EQ(SECSuccess, + SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(), + echconfig.len())); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryEchHello, &cb_called)); + + // Add a PSK to both endpoints. + static const uint8_t key_buf[16] = {0}; + SECItem key_item = {siBuffer, const_cast<uint8_t*>(&key_buf[0]), + sizeof(key_buf)}; + const char* label = "foo"; + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + ASSERT_TRUE(!!slot); + ScopedPK11SymKey key(PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN, + PK11_OriginUnwrap, CKA_DERIVE, + &key_item, nullptr)); + ASSERT_TRUE(!!key); + AddPsk(key, std::string(label), ssl_hash_sha256); + client_->ExpectPsk(ssl_psk_none); + + // Start the handshake. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + MakeNewServer(); + client_->ExpectEch(false); + server_->ExpectEch(false); + EXPECT_EQ(SECSuccess, + SSL_AddExternalPsk0Rtt(server_->ssl_fd(), key.get(), + reinterpret_cast<const uint8_t*>(label), + strlen(label), ssl_hash_sha256, 0, 1000)); + // Don't call ExpectPsk + ExpectAlert(client_, kTlsAlertEchRequired); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITHOUT_ECH); + server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal); + server_->Handshake(); + EXPECT_EQ(1U, cb_called); +} + +// ECH (both connections), resumption rejected. +TEST_F(TlsConnectStreamTls13, EchRejectResume) { + EnsureTlsSetup(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + SetupEch(client_, server_); + Connect(); + SendReceive(); + + Reset(); + ClearServerCache(); // Invalidate the ticket + ConfigureSessionCache(RESUME_BOTH, RESUME_NONE); + ExpectResumption(RESUME_NONE); + SetupEch(client_, server_); + Connect(); + SendReceive(); +} + +// ECH (both connections) + 0-RTT +TEST_F(TlsConnectStreamTls13, EchZeroRttBoth) { + EnsureTlsSetup(); + SetupEch(client_, server_); + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + SetupEch(client_, server_); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); +} + +// ECH (first connection only) + 0-RTT +TEST_F(TlsConnectStreamTls13, EchZeroRttFirst) { + EnsureTlsSetup(); + SetupEch(client_, server_); + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); +} + +// ECH (second connection only) + 0-RTT +TEST_F(TlsConnectStreamTls13, EchZeroRttSecond) { + EnsureTlsSetup(); + SetupForZeroRtt(); // Get a ticket + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + SetupEch(client_, server_); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); +} + +// ECH (first connection only, reject on second) + 0-RTT +TEST_F(TlsConnectStreamTls13, EchZeroRttRejectSecond) { + EnsureTlsSetup(); + SetupEch(client_, server_); + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + + // Setup ECH only on the client. + SetupEch(client_, server_, HpkeDhKemX25519Sha256, false, true, false); + + ExpectResumption(RESUME_NONE); + ExpectAlert(client_, kTlsAlertEchRequired); + ZeroRttSendReceive(true, false); + server_->Handshake(); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITHOUT_ECH); + + ExpectEarlyDataAccepted(false); + server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal); + server_->Handshake(); + // Reset expectations for the TlsAgent dtor. + server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning); +} + +// Test a critical extension in ECHConfig +TEST_F(TlsConnectStreamTls13, EchRejectUnknownCriticalExtension) { + EnsureTlsSetup(); + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + DataBuffer echconfig; + DataBuffer crit_rec; + DataBuffer len_buf; + uint64_t tmp; + + static const uint8_t crit_extensions[] = {0x00, 0x04, 0xff, 0xff, 0x00, 0x00}; + static const uint8_t extensions[] = {0x00, 0x04, 0x7f, 0xff, 0x00, 0x00}; + DataBuffer crit_exts(crit_extensions, sizeof(crit_extensions)); + DataBuffer non_crit_exts(extensions, sizeof(extensions)); + + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteChaCha, + kPublicName, 100, echconfig, pub, priv); + echconfig.Truncate(echconfig.len() - 2); // Eat the empty extensions. + crit_rec.Assign(echconfig); + ASSERT_TRUE(crit_rec.Read(0, 2, &tmp)); + len_buf.Write(0, tmp + crit_exts.len() - 2, 2); // two bytes of length + crit_rec.Splice(len_buf, 0, 2); + len_buf.Truncate(0); + + ASSERT_TRUE(crit_rec.Read(4, 2, &tmp)); + len_buf.Write(0, tmp + crit_exts.len() - 2, 2); // two bytes of length + crit_rec.Append(crit_exts); + crit_rec.Splice(len_buf, 4, 2); + len_buf.Truncate(0); + + ASSERT_TRUE(echconfig.Read(0, 2, &tmp)); + len_buf.Write(0, tmp + non_crit_exts.len() - 2, 2); + echconfig.Append(non_crit_exts); + echconfig.Splice(len_buf, 0, 2); + ASSERT_TRUE(echconfig.Read(4, 2, &tmp)); + len_buf.Write(0, tmp + non_crit_exts.len() - 2, 2); + echconfig.Splice(len_buf, 4, 2); + + /* Expect that retry configs containing unsupported mandatory extensions can + * not be set and lead to SEC_ERROR_INVALID_ARGS. */ + EXPECT_EQ(SECFailure, + SSL_SetClientEchConfigs(client_->ssl_fd(), crit_rec.data(), + crit_rec.len())); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), + PR_FALSE)); // Don't GREASE + auto filter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + StartConnect(); + client_->Handshake(); + ASSERT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + ASSERT_FALSE(filter->captured()); + + // Now try a variant with non-critical extensions, it should work. + Reset(); + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, + SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(), + echconfig.len())); + filter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + StartConnect(); + client_->Handshake(); + ASSERT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + ASSERT_TRUE(filter->captured()); +} + +// Secure disable without ECH +TEST_F(TlsConnectStreamTls13, EchRejectAuthCertSuccessNoRetries) { + EnsureTlsSetup(); + SetupEch(client_, server_, HpkeDhKemX25519Sha256, false, true, false); + ExpectAlert(client_, kTlsAlertEchRequired); + ConnectExpectFailOneSide(TlsAgent::CLIENT); + client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITHOUT_ECH); + server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal); + server_->Handshake(); + // Reset expectations for the TlsAgent dtor. + server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning); +} + +// When authenticating to the public name, the client MUST NOT +// send a certificate in response to a certificate request. +TEST_F(TlsConnectStreamTls13, EchRejectSuppressClientCert) { + EnsureTlsSetup(); + SetupEch(client_, server_, HpkeDhKemX25519Sha256, false, true, false); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + auto cert_capture = + MakeTlsFilter<TlsHandshakeRecorder>(client_, kTlsHandshakeCertificate); + cert_capture->EnableDecryption(); + + StartConnect(); + client_->ExpectSendAlert(kTlsAlertEchRequired); + server_->ExpectSendAlert(kTlsAlertCertificateRequired); + ConnectExpectFail(); + + static const uint8_t empty_cert[4] = {0}; + EXPECT_EQ(DataBuffer(empty_cert, sizeof(empty_cert)), cert_capture->buffer()); +} + +// Secure disable with incompatible ECHConfig +TEST_F(TlsConnectStreamTls13, EchRejectAuthCertSuccessIncompatibleRetries) { + EnsureTlsSetup(); + ScopedSECKEYPublicKey server_pub; + ScopedSECKEYPrivateKey server_priv; + ScopedSECKEYPublicKey client_pub; + ScopedSECKEYPrivateKey client_priv; + DataBuffer server_rec; + DataBuffer client_rec; + + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteChaCha, + kPublicName, 100, server_rec, + server_pub, server_priv); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(), + server_priv.get(), server_rec.data(), + server_rec.len())); + + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes, + kPublicName, 100, client_rec, + client_pub, client_priv); + ASSERT_EQ(SECSuccess, + SSL_SetClientEchConfigs(client_->ssl_fd(), client_rec.data(), + client_rec.len())); + + // Change the first ECHConfig version to one we don't understand. + server_rec.Write(2, 0xfefe, 2); + // Skip the ECHConfigs length, the server sender will re-encode. + ASSERT_EQ(SECSuccess, SSLInt_SetRawEchConfigForRetry(server_->ssl_fd(), + &server_rec.data()[2], + server_rec.len() - 2)); + + ExpectAlert(client_, kTlsAlertEchRequired); + ConnectExpectFailOneSide(TlsAgent::CLIENT); + client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITHOUT_ECH); + server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal); + server_->Handshake(); + // Reset expectations for the TlsAgent dtor. + server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning); +} + +// Check that an otherwise-accepted ECH fails expectedly +// with a bad certificate. +TEST_F(TlsConnectStreamTls13, EchRejectAuthCertFail) { + EnsureTlsSetup(); + SetupEch(client_, server_); + client_->SetAuthCertificateCallback(AuthCompleteFail); + ConnectExpectAlert(client_, kTlsAlertBadCertificate); + client_->CheckErrorCode(SSL_ERROR_BAD_CERTIFICATE); + server_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); +} + +TEST_F(TlsConnectStreamTls13Ech, EchShortClientEncryptedCH) { + EnsureTlsSetup(); + SetupForEchRetry(); + auto filter = MakeTlsFilter<TlsExtensionResizer>( + client_, ssl_tls13_encrypted_client_hello_xtn, 1); + ConnectExpectAlert(server_, kTlsAlertDecodeError); + client_->CheckErrorCode(SSL_ERROR_DECODE_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_ECH_EXTENSION); +} + +TEST_F(TlsConnectStreamTls13Ech, EchLongClientEncryptedCH) { + EnsureTlsSetup(); + SetupForEchRetry(); + auto filter = MakeTlsFilter<TlsExtensionResizer>( + client_, ssl_tls13_encrypted_client_hello_xtn, 1000); + ConnectExpectAlert(server_, kTlsAlertDecodeError); + client_->CheckErrorCode(SSL_ERROR_DECODE_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_ECH_EXTENSION); +} + +TEST_F(TlsConnectStreamTls13Ech, EchShortServerEncryptedCH) { + EnsureTlsSetup(); + SetupForEchRetry(); + auto filter = MakeTlsFilter<TlsExtensionResizer>( + server_, ssl_tls13_encrypted_client_hello_xtn, 1); + filter->EnableDecryption(); + ConnectExpectAlert(client_, kTlsAlertDecodeError); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_ECH_CONFIG); + server_->CheckErrorCode(SSL_ERROR_DECODE_ERROR_ALERT); +} + +TEST_F(TlsConnectStreamTls13Ech, EchLongServerEncryptedCH) { + EnsureTlsSetup(); + SetupForEchRetry(); + auto filter = MakeTlsFilter<TlsExtensionResizer>( + server_, ssl_tls13_encrypted_client_hello_xtn, 1000); + filter->EnableDecryption(); + ConnectExpectAlert(client_, kTlsAlertDecodeError); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_ECH_CONFIG); + server_->CheckErrorCode(SSL_ERROR_DECODE_ERROR_ALERT); +} + +// Check that if authCertificate fails, retry_configs +// are not available to the application. +TEST_F(TlsConnectStreamTls13Ech, EchInsecureFallbackNoRetries) { + EnsureTlsSetup(); + StackSECItem retry_configs; + SetupForEchRetry(); + + // Use the filter to make sure retry_configs are sent. + auto filter = MakeTlsFilter<TlsExtensionCapture>( + server_, ssl_tls13_encrypted_client_hello_xtn); + filter->EnableDecryption(); + + client_->SetAuthCertificateCallback(AuthCompleteFail); + ConnectExpectAlert(client_, kTlsAlertBadCertificate); + client_->CheckErrorCode(SSL_ERROR_BAD_CERTIFICATE); + server_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + EXPECT_EQ(SECFailure, + SSL_GetEchRetryConfigs(client_->ssl_fd(), &retry_configs)); + EXPECT_EQ(SSL_ERROR_HANDSHAKE_NOT_COMPLETED, PORT_GetError()); + ASSERT_EQ(0U, retry_configs.len); + EXPECT_TRUE(filter->captured()); +} + +// Test that mismatched ECHConfigContents triggers a retry. +TEST_F(TlsConnectStreamTls13Ech, EchMismatchHpkeCiphersRetry) { + EnsureTlsSetup(); + ScopedSECKEYPublicKey server_pub; + ScopedSECKEYPrivateKey server_priv; + ScopedSECKEYPublicKey client_pub; + ScopedSECKEYPrivateKey client_priv; + DataBuffer server_rec; + DataBuffer client_rec; + + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteChaCha, + kPublicName, 100, server_rec, + server_pub, server_priv); + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes, + kPublicName, 100, client_rec, + client_pub, client_priv); + + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(), + server_priv.get(), server_rec.data(), + server_rec.len())); + ASSERT_EQ(SECSuccess, + SSL_SetClientEchConfigs(client_->ssl_fd(), client_rec.data(), + client_rec.len())); + + ExpectAlert(client_, kTlsAlertEchRequired); + ConnectExpectFailOneSide(TlsAgent::CLIENT); + client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITH_ECH); + server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal); + server_->Handshake(); + DoEchRetry(server_pub, server_priv, server_rec); +} + +// Test that mismatched ECH server keypair triggers a retry. +TEST_F(TlsConnectStreamTls13Ech, EchMismatchKeysRetry) { + EnsureTlsSetup(); + ScopedSECKEYPublicKey server_pub; + ScopedSECKEYPrivateKey server_priv; + ScopedSECKEYPublicKey client_pub; + ScopedSECKEYPrivateKey client_priv; + DataBuffer server_rec; + DataBuffer client_rec; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, server_rec, + server_pub, server_priv); + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, client_rec, + client_pub, client_priv); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(), + server_priv.get(), server_rec.data(), + server_rec.len())); + ASSERT_EQ(SECSuccess, + SSL_SetClientEchConfigs(client_->ssl_fd(), client_rec.data(), + client_rec.len())); + + client_->ExpectSendAlert(kTlsAlertEchRequired); + ConnectExpectFailOneSide(TlsAgent::CLIENT); + client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITH_ECH); + server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal); + server_->Handshake(); + DoEchRetry(server_pub, server_priv, server_rec); +} + +// Check that the client validates any server response to GREASE ECH +TEST_F(TlsConnectStreamTls13, EchValidateGreaseResponse) { + EnsureTlsSetup(); + ScopedSECKEYPublicKey server_pub; + ScopedSECKEYPrivateKey server_priv; + DataBuffer server_rec; + TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites, + kPublicName, 100, server_rec, + server_pub, server_priv); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(), + server_priv.get(), server_rec.data(), + server_rec.len())); + + // Damage the length and expect an alert. + auto filter = MakeTlsFilter<TlsExtensionDamager>( + server_, ssl_tls13_encrypted_client_hello_xtn, 0); + filter->EnableDecryption(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), + PR_TRUE)); // GREASE + ConnectExpectAlert(client_, kTlsAlertDecodeError); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_ECH_CONFIG); + server_->CheckErrorCode(SSL_ERROR_DECODE_ERROR_ALERT); + + // If the retry_config contains an unknown version, it should be ignored. + Reset(); + EnsureTlsSetup(); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(), + server_priv.get(), server_rec.data(), + server_rec.len())); + server_rec.Write(2, 0xfefe, 2); + // Skip the ECHConfigs length, the server sender will re-encode. + ASSERT_EQ(SECSuccess, SSLInt_SetRawEchConfigForRetry(server_->ssl_fd(), + &server_rec.data()[2], + server_rec.len() - 2)); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), + PR_TRUE)); // GREASE + Connect(); + + // Lastly, if we DO support the retry_config, GREASE ECH should ignore it. + Reset(); + EnsureTlsSetup(); + server_rec.Write(2, ssl_tls13_encrypted_client_hello_xtn, 2); + ASSERT_EQ(SECSuccess, + SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(), + server_priv.get(), server_rec.data(), + server_rec.len())); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), + PR_TRUE)); // GREASE + Connect(); +} + +// Test a tampered CHInner (decrypt failure). +// Expect negotiation on outer, which fails due to the tampered transcript. +TEST_F(TlsConnectStreamTls13, EchBadCiphertext) { + EnsureTlsSetup(); + SetupEch(client_, server_); + /* Target the payload: + struct { + ECHCipherSuite suite; // 4B + opaque config_id<0..255>; // 32B + opaque enc<1..2^16-1>; // 32B for X25519 + opaque payload<1..2^16-1>; + } ClientEncryptedCH; + */ + MakeTlsFilter<TlsExtensionDamager>(client_, + ssl_tls13_encrypted_client_hello_xtn, 80); + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); +} + +// Test a tampered CHOuter (decrypt failure on AAD). +// Expect negotiation on outer, which fails due to the tampered transcript. +TEST_F(TlsConnectStreamTls13, EchOuterBinding) { + EnsureTlsSetup(); + SetupEch(client_, server_); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + + static const uint8_t supported_vers_13[] = {0x02, 0x03, 0x04}; + DataBuffer buf(supported_vers_13, sizeof(supported_vers_13)); + MakeTlsFilter<TlsExtensionReplacer>(client_, ssl_tls13_supported_versions_xtn, + buf); + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); +} + +// Altering the CH after the Ech Xtn should also cause a failure. +TEST_F(TlsConnectStreamTls13, EchOuterBindingAfterXtn) { + EnsureTlsSetup(); + SetupEch(client_, server_); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + + static const uint8_t supported_vers_13[] = {0x02, 0x03, 0x04}; + DataBuffer buf(supported_vers_13, sizeof(supported_vers_13)); + MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello, 5044, + buf); + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); +} + +// Test a bad (unknown) ECHCipherSuite. +// Expect negotiation on outer, which fails due to the tampered transcript. +TEST_F(TlsConnectStreamTls13, EchBadCiphersuite) { + EnsureTlsSetup(); + SetupEch(client_, server_); + /* Make KDF unknown */ + MakeTlsFilter<TlsExtensionDamager>(client_, + ssl_tls13_encrypted_client_hello_xtn, 1); + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); + + Reset(); + EnsureTlsSetup(); + SetupEch(client_, server_); + /* Make AEAD unknown */ + MakeTlsFilter<TlsExtensionDamager>(client_, + ssl_tls13_encrypted_client_hello_xtn, 4); + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); +} + +/* ECH (configured) client connects to a 1.2 server, this MUST lead to an + * 'ech_required' alert being sent by the client when handling the handshake + * finished messages [draft-ietf-tls-esni-14, Section 6.1.6]. */ +TEST_F(TlsConnectStreamTls13, EchToTls12Server) { + EnsureTlsSetup(); + SetupEch(client_, server_); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + + client_->ExpectEch(false); + server_->ExpectEch(false); + + client_->ExpectSendAlert(kTlsAlertEchRequired, kTlsAlertFatal); + server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal); + ConnectExpectFailOneSide(TlsAgent::CLIENT); + client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITHOUT_ECH); + + /* Reset expectations for the TlsAgent deconstructor. */ + server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning); +} + +TEST_F(TlsConnectStreamTls13, NoEchFromTls12Client) { + EnsureTlsSetup(); + SetupEch(client_, server_); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + auto filter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + client_->ExpectEch(false); + server_->ExpectEch(false); + SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_2); + Connect(); + ASSERT_FALSE(filter->captured()); +} + +TEST_F(TlsConnectStreamTls13, EchOuterWith12Max) { + EnsureTlsSetup(); + SetupEch(client_, server_); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + + static const uint8_t supported_vers_12[] = {0x02, 0x03, 0x03}; + DataBuffer buf(supported_vers_12, sizeof(supported_vers_12)); + + // The server will set the downgrade sentinel. The client needs + // to ignore it for this test. + client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_FALSE); + + StartConnect(); + MakeTlsFilter<TlsExtensionReplacer>(client_, ssl_tls13_supported_versions_xtn, + buf); + + // Server should ignore the extension if 1.2 is negotiated. + // Here the CHInner is not modified, so if Accepted we'd connect. + auto filter = MakeTlsFilter<TlsExtensionCapture>( + server_, ssl_tls13_encrypted_client_hello_xtn); + client_->ExpectEch(false); + server_->ExpectEch(false); + ConnectExpectAlert(server_, kTlsAlertDecryptError); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); + ASSERT_FALSE(filter->captured()); +} + +TEST_F(TlsConnectStreamTls13, EchOuterExtensionsInCHOuter) { + EnsureTlsSetup(); + uint8_t outer[2] = {0}; + DataBuffer outer_buf(outer, sizeof(outer)); + MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello, + ssl_tls13_outer_extensions_xtn, + outer_buf); + + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +static SECStatus NoopExtensionHandler(PRFileDesc* fd, SSLHandshakeType message, + const PRUint8* data, unsigned int len, + SSLAlertDescription* alert, void* arg) { + return SECSuccess; +} + +static PRBool EmptyExtensionWriter(PRFileDesc* fd, SSLHandshakeType message, + PRUint8* data, unsigned int* len, + unsigned int maxLen, void* arg) { + return true; +} + +static PRBool LargeExtensionWriter(PRFileDesc* fd, SSLHandshakeType message, + PRUint8* data, unsigned int* len, + unsigned int maxLen, void* arg) { + unsigned int length = 1024; + PR_ASSERT(length <= maxLen); + memset(data, 0, length); + *len = length; + return true; +} + +static PRBool OuterOnlyExtensionWriter(PRFileDesc* fd, SSLHandshakeType message, + PRUint8* data, unsigned int* len, + unsigned int maxLen, void* arg) { + if (message == ssl_hs_ech_outer_client_hello) { + return LargeExtensionWriter(fd, message, data, len, maxLen, arg); + } + return false; +} + +static PRBool InnerOnlyExtensionWriter(PRFileDesc* fd, SSLHandshakeType message, + PRUint8* data, unsigned int* len, + unsigned int maxLen, void* arg) { + if (message == ssl_hs_client_hello) { + return LargeExtensionWriter(fd, message, data, len, maxLen, arg); + } + return false; +} + +static PRBool InnerOuterDiffExtensionWriter(PRFileDesc* fd, + SSLHandshakeType message, + PRUint8* data, unsigned int* len, + unsigned int maxLen, void* arg) { + unsigned int length = 1024; + PR_ASSERT(length <= maxLen); + memset(data, (message == ssl_hs_client_hello) ? 1 : 0, length); + *len = length; + return true; +} + +TEST_F(TlsConnectStreamTls13Ech, EchCustomExtensionWriter) { + EnsureTlsSetup(); + SetupEch(client_, server_); + + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62028, EmptyExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); +} + +TEST_F(TlsConnectStreamTls13Ech, EchCustomExtensionWriterOuterOnly) { + EnsureTlsSetup(); + SetupEch(client_, server_); + + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62028, OuterOnlyExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + EXPECT_EQ(SECSuccess, + SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true)); + + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); +} + +TEST_F(TlsConnectStreamTls13Ech, EchCustomExtensionWriterInnerOnly) { + EnsureTlsSetup(); + SetupEch(client_, server_); + + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62028, InnerOnlyExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + EXPECT_EQ(SECSuccess, + SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true)); + + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); +} + +// Write different values to inner and outer CH. +TEST_F(TlsConnectStreamTls13Ech, EchCustomExtensionWriterDifferent) { + EnsureTlsSetup(); + SetupEch(client_, server_); + + EXPECT_EQ(SECSuccess, + SSL_InstallExtensionHooks(client_->ssl_fd(), 62028, + InnerOuterDiffExtensionWriter, nullptr, + NoopExtensionHandler, nullptr)); + EXPECT_EQ(SECSuccess, + SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true)); + auto filter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); + ASSERT_TRUE(filter->extension().len() > 1024); +} + +// Test that basic compression works +TEST_F(TlsConnectStreamTls13Ech, EchCustomExtensionWriterCompressionBasic) { + EnsureTlsSetup(); + SetupEch(client_, server_); + + // This will be compressed. + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62028, LargeExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + EXPECT_EQ(SECSuccess, + SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true)); + auto filter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); + size_t echXtnLen = filter->extension().len(); + ASSERT_TRUE(echXtnLen > 0 && echXtnLen < 1024); +} + +// Test that compression works when things change. +TEST_F(TlsConnectStreamTls13Ech, + EchCustomExtensionWriterCompressSomeDifferent) { + EnsureTlsSetup(); + SetupEch(client_, server_); + + // This will be compressed. + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62028, LargeExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + // This can't be. + EXPECT_EQ(SECSuccess, + SSL_InstallExtensionHooks(client_->ssl_fd(), 62029, + InnerOuterDiffExtensionWriter, nullptr, + NoopExtensionHandler, nullptr)); + // This will be compressed. + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62030, LargeExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + EXPECT_EQ(SECSuccess, + SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true)); + auto filter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); + auto echXtnLen = filter->extension().len(); + /* Exactly one custom xtn plus change */ + ASSERT_TRUE(echXtnLen > 1024 && echXtnLen < 2048); +} + +// An outer-only extension stops compression. +TEST_F(TlsConnectStreamTls13Ech, + EchCustomExtensionWriterCompressSomeOuterOnly) { + EnsureTlsSetup(); + SetupEch(client_, server_); + + // This will be compressed. + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62028, LargeExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + // This can't be as it appears in the outer only. + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62029, OuterOnlyExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + // This will be compressed + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62030, LargeExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + EXPECT_EQ(SECSuccess, + SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true)); + auto filter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); + size_t echXtnLen = filter->extension().len(); + ASSERT_TRUE(echXtnLen > 0 && echXtnLen < 1024); +} + +// An inner only extension does not stop compression. +TEST_F(TlsConnectStreamTls13Ech, EchCustomExtensionWriterCompressAllInnerOnly) { + EnsureTlsSetup(); + SetupEch(client_, server_); + + // This will be compressed. + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62028, LargeExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + // This can't be as it appears in the inner only. + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62029, InnerOnlyExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + // This will be compressed. + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62030, LargeExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + EXPECT_EQ(SECSuccess, + SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true)); + auto filter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_encrypted_client_hello_xtn); + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); + size_t echXtnLen = filter->extension().len(); + ASSERT_TRUE(echXtnLen > 1024 && echXtnLen < 2048); +} + +TEST_F(TlsConnectStreamTls13Ech, EchAcceptCustomXtn) { + EnsureTlsSetup(); + SetupEch(client_, server_); + + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62028, LargeExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + + EXPECT_EQ(SECSuccess, + SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true)); + + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + server_->ssl_fd(), 62028, LargeExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + auto filter = MakeTlsFilter<TlsExtensionCapture>(server_, 62028); + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); +} + +// Test that we reject Outer Xtn in SH if accepting ECH Inner +TEST_F(TlsConnectStreamTls13Ech, EchRejectOuterXtnOnInner) { + EnsureTlsSetup(); + SetupEch(client_, server_); + + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62028, OuterOnlyExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + + EXPECT_EQ(SECSuccess, + SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true)); + + // Put the same extension on the Server Hello + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + server_->ssl_fd(), 62028, LargeExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + auto filter = MakeTlsFilter<TlsExtensionCapture>(server_, 62028); + client_->ExpectEch(false); + server_->ExpectEch(false); + client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + // The server will be expecting an alert encrypted under a different key. + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + ASSERT_TRUE(filter->captured()); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_EXTENSION); +} + +// Test that we reject Inner Xtn in SH if accepting ECH Outer +TEST_F(TlsConnectStreamTls13Ech, EchRejectInnerXtnOnOuter) { + EnsureTlsSetup(); + + // Setup ECH only on the client + SetupEch(client_, server_, HpkeDhKemX25519Sha256, false, true, false); + + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62028, InnerOnlyExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + + EXPECT_EQ(SECSuccess, + SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true)); + + // Put the same extension on the Server Hello + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + server_->ssl_fd(), 62028, LargeExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + auto filter = MakeTlsFilter<TlsExtensionCapture>(server_, 62028); + client_->ExpectEch(false); + server_->ExpectEch(false); + client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + // The server will be expecting an alert encrypted under a different key. + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + ASSERT_TRUE(filter->captured()); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_EXTENSION); +} + +// Test that we reject an Inner Xtn in SH, if accepting Ech Inner and +// we didn't advertise it on SH Outer. +TEST_F(TlsConnectStreamTls13Ech, EchRejectInnerXtnNotOnOuter) { + EnsureTlsSetup(); + + // Setup ECH only on the client + SetupEch(client_, server_); + + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + client_->ssl_fd(), 62028, InnerOnlyExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + + EXPECT_EQ(SECSuccess, + SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true)); + + // Put the same extension on the Server Hello + EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks( + server_->ssl_fd(), 62028, LargeExtensionWriter, + nullptr, NoopExtensionHandler, nullptr)); + auto filter = MakeTlsFilter<TlsExtensionCapture>(server_, 62028); + client_->ExpectEch(false); + server_->ExpectEch(false); + client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + // The server will be expecting an alert encrypted under a different key. + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + ASSERT_TRUE(filter->captured()); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_EXTENSION); +} + +// At draft-09: If a CH containing the ech_is_inner extension is received, the +// server acts as backend server in split-mode by responding with the ECH +// acceptance signal. The signal value itself depends on the handshake secret, +// which we've broken by appending ech_is_inner. For now, just check that the +// server negotiates ech_is_inner (which is what triggers sending the signal). +TEST_F(TlsConnectStreamTls13, EchBackendAcceptance) { + DataBuffer ch_buf; + static uint8_t inner_value[1] = {1}; + DataBuffer inner_buffer(inner_value, sizeof(inner_value)); + + EnsureTlsSetup(); + StartConnect(); + EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_FALSE)); + MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello, + ssl_tls13_encrypted_client_hello_xtn, + inner_buffer); + + EXPECT_EQ(SECSuccess, SSL_EnableTls13BackendEch(server_->ssl_fd(), PR_TRUE)); + client_->Handshake(); + server_->Handshake(); + + ExpectAlert(client_, kTlsAlertBadRecordMac); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + EXPECT_EQ(PR_TRUE, + SSLInt_ExtensionNegotiated(server_->ssl_fd(), + ssl_tls13_encrypted_client_hello_xtn)); + server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning); +} + +// A public_name that includes an IP address has to be rejected. +TEST_F(TlsConnectStreamTls13Ech, EchPublicNameIp) { + static const std::vector<std::string> kIps = { + "0.0.0.0", + "1.1.1.1", + "255.255.255.255", + "255.255.65535", + "255.16777215", + "4294967295", + "0377.0377.0377.0377", + "0377.0377.0177777", + "0377.077777777", + "037777777777", + "00377.00377.00377.00377", + "00377.00377.00177777", + "00377.0077777777", + "0037777777777", + "0xff.0xff.0xff.0xff", + "0xff.0xff.0xffff", + "0xff.0xffffff", + "0xffffffff", + "0XFF.0XFF.0XFF.0XFF", + "0XFF.0XFF.0XFFFF", + "0XFF.0XFFFFFF", + "0XFFFFFFFF", + "0x0ff.0x0ff.0x0ff.0x0ff", + "0x0ff.0x0ff.0x0ffff", + "0x0ff.0x0ffffff", + "0x0ffffffff", + "00000000000000000000000000000000000000000", + "00000000000000000000000000000000000000001", + "127.0.0.1", + "127.0.1", + "127.1", + "2130706433", + "017700000001", + }; + ValidatePublicNames(kIps, SECFailure); +} + +// These are nearly IP addresses. +TEST_F(TlsConnectStreamTls13Ech, EchPublicNameNotIp) { + static const std::vector<std::string> kNotIps = { + "0.0.0.0.0", + "1.2.3.4.5", + "999999999999999999999999999999999", + "07777777777777777777777777777777777777777", + "111111111100000000001111111111000000000011111111110000000000123", + "256.255.255.255", + "255.256.255.255", + "255.255.256.255", + "255.255.255.256", + "255.255.65536", + "255.16777216", + "4294967296", + "0400.0377.0377.0377", + "0377.0400.0377.0377", + "0377.0377.0400.0377", + "0377.0377.0377.0400", + "0377.0377.0200000", + "0377.0100000000", + "040000000000", + "0x100.0xff.0xff.0xff", + "0xff.0x100.0xff.0xff", + "0xff.0xff.0x100.0xff", + "0xff.0xff.0xff.0x100", + "0xff.0xff.0x10000", + "0xff.0x1000000", + "0x100000000", + "08", + "09", + "a", + "0xg", + "0XG", + "0x", + "0x.1.2.3", + "test-name", + "test-name.test", + "TEST-NAME", + "under_score", + "_under_score", + "under_score_", + }; + ValidatePublicNames(kNotIps, SECSuccess); +} + +TEST_F(TlsConnectStreamTls13Ech, EchPublicNameNotLdh) { + static const std::vector<std::string> kNotLdh = { + ".", + "name.", + ".name", + "test..name", + "1111111111000000000011111111110000000000111111111100000000001234", + "-name", + "name-", + "test-.name", + "!", + u8"\u2077", + }; + ValidatePublicNames(kNotLdh, SECFailure); +} + +TEST_F(TlsConnectStreamTls13, EchClientHelloExtensionPermutation) { + EnsureTlsSetup(); + PR_ASSERT(SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_CH_EXTENSION_PERMUTATION, + PR_TRUE) == SECSuccess); + SetupEch(client_, server_); + + client_->ExpectEch(); + server_->ExpectEch(); + Connect(); +} + +TEST_F(TlsConnectStreamTls13, EchGreaseClientHelloExtensionPermutation) { + EnsureTlsSetup(); + PR_ASSERT(SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_CH_EXTENSION_PERMUTATION, + PR_TRUE) == SECSuccess); + PR_ASSERT(SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_FALSE) == + SECSuccess); + Connect(); +} + +INSTANTIATE_TEST_SUITE_P(EchAgentTest, TlsAgentEchTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13)); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc new file mode 100644 index 0000000000..ab52a07e84 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_filter.cc @@ -0,0 +1,1293 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "tls_filter.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include <cassert> +#include <iostream> +#include "gtest_utils.h" +#include "tls_agent.h" +#include "tls_filter.h" +#include "tls_parser.h" +#include "tls_protect.h" + +namespace nss_test { + +void TlsVersioned::WriteStream(std::ostream& stream) const { + stream << (is_dtls() ? "DTLS " : "TLS "); + switch (version()) { + case 0: + stream << "(no version)"; + break; + case SSL_LIBRARY_VERSION_TLS_1_0: + stream << "1.0"; + break; + case SSL_LIBRARY_VERSION_TLS_1_1: + stream << (is_dtls() ? "1.0" : "1.1"); + break; + case SSL_LIBRARY_VERSION_TLS_1_2: + stream << "1.2"; + break; + case SSL_LIBRARY_VERSION_TLS_1_3: + stream << "1.3"; + break; + default: + stream << "Invalid version: " << version(); + break; + } +} + +TlsRecordFilter::TlsRecordFilter(const std::shared_ptr<TlsAgent>& a) + : agent_(a) { + cipher_specs_.emplace_back(a->variant() == ssl_variant_datagram, 0); +} + +void TlsRecordFilter::EnableDecryption() { + EXPECT_EQ(SECSuccess, + SSL_SecretCallback(agent()->ssl_fd(), SecretCallback, this)); + decrypting_ = true; +} + +void TlsRecordFilter::SecretCallback(PRFileDesc* fd, PRUint16 epoch, + SSLSecretDirection dir, PK11SymKey* secret, + void* arg) { + TlsRecordFilter* self = static_cast<TlsRecordFilter*>(arg); + if (g_ssl_gtest_verbose) { + std::cerr << self->agent()->role_str() << ": " << dir + << " secret changed for epoch " << epoch << std::endl; + } + + if (dir == ssl_secret_read) { + return; + } + + for (auto& spec : self->cipher_specs_) { + ASSERT_NE(spec.epoch(), epoch) << "duplicate spec for epoch " << epoch; + } + + SSLPreliminaryChannelInfo preinfo; + EXPECT_EQ(SECSuccess, + SSL_GetPreliminaryChannelInfo(self->agent()->ssl_fd(), &preinfo, + sizeof(preinfo))); + EXPECT_EQ(sizeof(preinfo), preinfo.length); + + // Check the version. + if (preinfo.valuesSet & ssl_preinfo_version) { + EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion); + } else { + EXPECT_EQ(1U, epoch); + } + + uint16_t suite; + if (epoch == 1) { + // 0-RTT + EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_0rtt_cipher_suite); + suite = preinfo.zeroRttCipherSuite; + } else { + EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite); + suite = preinfo.cipherSuite; + } + + SSLCipherSuiteInfo cipherinfo; + EXPECT_EQ(SECSuccess, + SSL_GetCipherSuiteInfo(suite, &cipherinfo, sizeof(cipherinfo))); + EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length); + + self->cipher_specs_.emplace_back(self->is_dtls_agent(), epoch); + EXPECT_TRUE(self->cipher_specs_.back().SetKeys(&cipherinfo, secret)); +} + +bool TlsRecordFilter::is_dtls_agent() const { + return agent()->variant() == ssl_variant_datagram; +} + +bool TlsRecordFilter::is_dtls13() const { + if (!is_dtls_agent()) { + return false; + } + if (agent()->state() == TlsAgent::STATE_CONNECTED) { + return agent()->version() >= SSL_LIBRARY_VERSION_TLS_1_3; + } + SSLPreliminaryChannelInfo info; + EXPECT_EQ(SECSuccess, SSL_GetPreliminaryChannelInfo(agent()->ssl_fd(), &info, + sizeof(info))); + return (info.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) || + info.canSendEarlyData; +} + +bool TlsRecordFilter::is_dtls13_ciphertext(uint8_t ct) const { + return is_dtls13() && (ct & kCtDtlsCiphertextMask) == kCtDtlsCiphertext; +} + +// Gets the cipher spec that matches the specified epoch. +TlsCipherSpec& TlsRecordFilter::spec(uint16_t write_epoch) { + for (auto& sp : cipher_specs_) { + if (sp.epoch() == write_epoch) { + return sp; + } + } + + // If we aren't decrypting, provide a cipher spec that does nothing other than + // count sequence numbers. + EXPECT_FALSE(decrypting_) << "No spec available for epoch " << write_epoch; + ; + cipher_specs_.emplace_back(is_dtls_agent(), write_epoch); + return cipher_specs_.back(); +} + +PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, + DataBuffer* output) { + // Disable during shutdown. + if (!agent()) { + return KEEP; + } + + bool changed = false; + size_t offset = 0U; + + output->Allocate(input.len()); + TlsParser parser(input); + + // This uses the current write spec for the purposes of parsing the epoch and + // sequence number from the header. This might be wrong because we can + // receive records from older specs, but guessing is good enough: + // - In DTLS, parsing the sequence number corrects any errors. + // - In TLS, we don't use the sequence number unless decrypting, where we use + // trial decryption to get the right epoch. + uint16_t write_epoch = 0; + SECStatus rv = SSL_GetCurrentEpoch(agent()->ssl_fd(), nullptr, &write_epoch); + if (rv != SECSuccess) { + ADD_FAILURE() << "unable to read epoch"; + return KEEP; + } + uint64_t guess_seqno = static_cast<uint64_t>(write_epoch) << 48; + + while (parser.remaining()) { + TlsRecordHeader header; + DataBuffer record; + if (!header.Parse(is_dtls13(), guess_seqno, &parser, &record)) { + ADD_FAILURE() << "not a valid record"; + return KEEP; + } + + if (FilterRecord(header, record, &offset, output) != KEEP) { + changed = true; + } else { + offset = header.Write(output, offset, record); + } + } + output->Truncate(offset); + + // Record how many packets we actually touched. + if (changed) { + ++count_; + return (offset == 0) ? DROP : CHANGE; + } + + return KEEP; +} + +PacketFilter::Action TlsRecordFilter::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& record, size_t* offset, + DataBuffer* output) { + DataBuffer filtered; + uint8_t inner_content_type; + DataBuffer plaintext; + uint16_t protection_epoch = 0; + TlsRecordHeader out_header(header); + + if (!Unprotect(header, record, &protection_epoch, &inner_content_type, + &plaintext, &out_header)) { + std::cerr << agent()->role_str() << ": unprotect failed: " << header << ":" + << record << std::endl; + return KEEP; + } + + auto& protection_spec = spec(protection_epoch); + TlsRecordHeader real_header(out_header.variant(), out_header.version(), + inner_content_type, out_header.sequence_number()); + + PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered); + // In stream mode, even if something doesn't change we need to re-encrypt if + // previous packets were dropped. + if (action == KEEP) { + if (out_header.is_dtls() || !protection_spec.record_dropped()) { + // Count every outgoing packet. + protection_spec.RecordProtected(); + return KEEP; + } + filtered = plaintext; + } + + if (action == DROP) { + std::cerr << "record drop: " << out_header << ":" << record << std::endl; + protection_spec.RecordDropped(); + return DROP; + } + + EXPECT_GT(0x10000U, filtered.len()); + if (action != KEEP) { + std::cerr << "record old: " << plaintext << std::endl; + std::cerr << "record new: " << filtered << std::endl; + } + + uint64_t seq_num = protection_spec.next_out_seqno(); + if (!decrypting_ && out_header.is_dtls()) { + // Copy over the epoch, which isn't tracked when not decrypting. + seq_num |= out_header.sequence_number() & (0xffffULL << 48); + } + out_header.sequence_number(seq_num); + + DataBuffer ciphertext; + bool rv = Protect(protection_spec, out_header, inner_content_type, filtered, + &ciphertext, &out_header); + if (!rv) { + return KEEP; + } + *offset = out_header.Write(output, *offset, ciphertext); + return CHANGE; +} + +size_t TlsRecordHeader::header_length() const { + // If we have a header, return it's length. + if (header_.len()) { + return header_.len(); + } + + // Otherwise make a dummy header and return the length. + DataBuffer buf; + return WriteHeader(&buf, 0, 0); +} + +bool TlsRecordHeader::MaskSequenceNumber() { + return MaskSequenceNumber(sn_mask()); +} + +bool TlsRecordHeader::MaskSequenceNumber(const DataBuffer& mask_buf) { + if (mask_buf.empty()) { + return false; + } + + DataBuffer mask; + if (is_dtls13_ciphertext()) { + uint64_t seqno = sequence_number(); + uint8_t len = content_type() & kCtDtlsCiphertext16bSeqno ? 2 : 1; + uint16_t seqno_bitmask = (1 << len * 8) - 1; + DataBuffer val; + if (val.Write(0, seqno & seqno_bitmask, len) != len) { + return false; + } + +#ifdef UNSAFE_FUZZER_MODE + // Use a null mask. + mask.Allocate(mask_buf.len()); +#endif + mask.Append(mask_buf); + val.data()[0] ^= mask.data()[0]; + if (len == 2 && mask.len() > 1) { + val.data()[1] ^= mask.data()[1]; + } + uint32_t tmp; + if (!val.Read(0, len, &tmp)) { + return false; + } + + seqno = (seqno & ~seqno_bitmask) | tmp; + seqno_is_masked_ = !seqno_is_masked_; + if (!seqno_is_masked_) { + seqno = ParseSequenceNumber(guess_seqno_, seqno, len * 8, 2); + } + sequence_number_ = seqno; + + // Now update the header bytes + if (header_.len() > 1) { + header_.data()[1] ^= mask.data()[0]; + if ((content_type() & kCtDtlsCiphertext16bSeqno) && header().len() > 2) { + header_.data()[2] ^= mask.data()[1]; + } + } + } + + sn_mask_ = mask; + return true; +} + +uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t guess_seqno, + uint32_t partial, + size_t partial_bits) { + EXPECT_GE(32U, partial_bits); + uint64_t mask = (1ULL << partial_bits) - 1; + // First we determine the highest possible value. This is half the + // expressible range above the expected value (|guess_seqno|), less 1. + // + // We subtract the extra 1 from the cap so that when given a choice between + // the equidistant expected+N and expected-N we want to chose the lower. With + // 0-RTT, we sometimes have to recover an epoch of 1 when we expect an epoch + // of 3 and with 2 partial bits, the alternative result of 5 is wrong. + uint64_t cap = guess_seqno + (1ULL << (partial_bits - 1)) - 1; + // Add the partial piece in. e.g., xxxx789a and 1234 becomes xxxx1234. + uint64_t seq_no = (cap & ~mask) | partial; + // If the partial value is higher than the same partial piece from the cap, + // then the real value has to be lower. e.g., xxxx1234 can't become xxxx5678. + if (partial > (cap & mask) && (seq_no >= (1ULL << partial_bits))) { + seq_no -= 1ULL << partial_bits; + } + return seq_no; +} + +// Determine the full epoch and sequence number from an expected and raw value. +// The expected, raw, and output values are packed as they are in DTLS 1.2 and +// earlier: with 16 bits of epoch and 48 bits of sequence number. The raw value +// is packed this way (even before recovery) so that we don't need to track a +// moving value between two calls (one to recover the epoch, and one after +// unmasking to recover the sequence number). +uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint64_t raw, + size_t seq_no_bits, + size_t epoch_bits) { + uint64_t epoch_mask = (1ULL << epoch_bits) - 1; + uint64_t ep = RecoverSequenceNumber(expected >> 48, (raw >> 48) & epoch_mask, + epoch_bits); + if (ep > (expected >> 48)) { + // If the epoch has changed, reset the expected sequence number. + expected = 0; + } else { + // Otherwise, retain just the sequence number part. + expected &= (1ULL << 48) - 1; + } + uint64_t seq_no_mask = (1ULL << seq_no_bits) - 1; + uint64_t seq_no = (raw & seq_no_mask); + if (!seqno_is_masked_) { + seq_no = RecoverSequenceNumber(expected, seq_no, seq_no_bits); + } + + return (ep << 48) | seq_no; +} + +bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser, + DataBuffer* body) { + auto mark = parser->consumed(); + + if (!parser->Read(&content_type_)) { + return false; + } + + if (is_dtls13) { + variant_ = ssl_variant_datagram; + version_ = SSL_LIBRARY_VERSION_TLS_1_3; + +#ifndef UNSAFE_FUZZER_MODE + // Deal with the DTLSCipherText header. + if (is_dtls13_ciphertext()) { + uint8_t seq_no_bytes = + (content_type_ & kCtDtlsCiphertext16bSeqno) ? 2 : 1; + uint32_t tmp; + + if (!parser->Read(&tmp, seq_no_bytes)) { + return false; + } + + // Store the guess if masked. If and when seqno_bytesenceNumber is called, + // the value will be unmasked and recovered. This assumes we only call + // Parse() on headers containing masked values. + seqno_is_masked_ = true; + guess_seqno_ = seqno; + uint64_t ep = content_type_ & 0x03; + sequence_number_ = (ep << 48) | tmp; + + // Recover the full epoch. Note the sequence number portion holds the + // masked value until a call to Mask() reveals it (as indicated by + // |seqno_is_masked_|). + sequence_number_ = + ParseSequenceNumber(seqno, sequence_number_, seq_no_bytes * 8, 2); + + uint32_t len_bytes = + (content_type_ & kCtDtlsCiphertextLengthPresent) ? 2 : 0; + if (len_bytes) { + if (!parser->Read(&tmp, 2)) { + return false; + } + } + + if (!parser->ReadFromMark(&header_, parser->consumed() - mark, mark)) { + return false; + } + + return len_bytes ? parser->Read(body, tmp) + : parser->Read(body, parser->remaining()); + } + + // The full DTLSPlainText header can only be used for a few types. + EXPECT_TRUE(content_type_ == ssl_ct_alert || + content_type_ == ssl_ct_handshake || + content_type_ == ssl_ct_ack); +#endif + } + + uint32_t ver; + if (!parser->Read(&ver, 2)) { + return false; + } + if (!is_dtls13) { + variant_ = IsDtls(ver) ? ssl_variant_datagram : ssl_variant_stream; + } + version_ = NormalizeTlsVersion(ver); + + if (is_dtls()) { + // If this is DTLS, read the sequence number. + uint32_t tmp; + if (!parser->Read(&tmp, 4)) { + return false; + } + sequence_number_ = static_cast<uint64_t>(tmp) << 32; + if (!parser->Read(&tmp, 4)) { + return false; + } + sequence_number_ |= static_cast<uint64_t>(tmp); + } else { + sequence_number_ = seqno; + } + if (!parser->ReadFromMark(&header_, parser->consumed() + 2 - mark, mark)) { + return false; + } + return parser->ReadVariable(body, 2); +} + +size_t TlsRecordHeader::WriteHeader(DataBuffer* buffer, size_t offset, + size_t body_len) const { + if (is_dtls13_ciphertext()) { + uint8_t seq_no_bytes = (content_type_ & kCtDtlsCiphertext16bSeqno) ? 2 : 1; + // application_data records in TLS 1.3 have a different header format. + uint32_t e = (sequence_number_ >> 48) & 0x3; + uint32_t seqno = sequence_number_ & ((1ULL << seq_no_bytes * 8) - 1); + uint8_t new_content_type_ = content_type_ | e; + offset = buffer->Write(offset, new_content_type_, 1); + offset = buffer->Write(offset, seqno, seq_no_bytes); + + if (content_type_ & kCtDtlsCiphertextLengthPresent) { + offset = buffer->Write(offset, body_len, 2); + } + } else { + offset = buffer->Write(offset, content_type_, 1); + uint16_t v = is_dtls() ? TlsVersionToDtlsVersion(version_) : version_; + offset = buffer->Write(offset, v, 2); + if (is_dtls()) { + // write epoch (2 octet), and seqnum (6 octet) + offset = buffer->Write(offset, sequence_number_ >> 32, 4); + offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4); + } + offset = buffer->Write(offset, body_len, 2); + } + + return offset; +} + +size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset, + const DataBuffer& body) const { + offset = WriteHeader(buffer, offset, body.len()); + offset = buffer->Write(offset, body); + return offset; +} + +bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, + const DataBuffer& ciphertext, + uint16_t* protection_epoch, + uint8_t* inner_content_type, + DataBuffer* plaintext, + TlsRecordHeader* out_header) { + if (!decrypting_ || !header.is_protected()) { + // Maintain the epoch and sequence number for plaintext records. + uint16_t ep = 0; + if (is_dtls_agent()) { + ep = static_cast<uint16_t>(header.sequence_number() >> 48); + } + spec(ep).RecordUnprotected(header.sequence_number()); + *protection_epoch = ep; + *inner_content_type = header.content_type(); + *plaintext = ciphertext; + return true; + } + + uint16_t ep = 0; + if (is_dtls_agent()) { + ep = static_cast<uint16_t>(header.sequence_number() >> 48); + if (!spec(ep).Unprotect(header, ciphertext, plaintext, out_header)) { + return false; + } + } else { + // In TLS, records aren't clearly labelled with their epoch, and we + // can't just use the newest keys because the same flight of messages can + // contain multiple epochs. So... trial decrypt! + for (size_t i = cipher_specs_.size() - 1; i > 0; --i) { + if (cipher_specs_[i].Unprotect(header, ciphertext, plaintext, + out_header)) { + ep = cipher_specs_[i].epoch(); + break; + } + } + if (!ep) { + return false; + } + } + + size_t len = plaintext->len(); + while (len > 0 && !plaintext->data()[len - 1]) { + --len; + } + if (!len) { + // Bogus padding. + return false; + } + + *protection_epoch = ep; + *inner_content_type = plaintext->data()[len - 1]; + plaintext->Truncate(len - 1); + if (g_ssl_gtest_verbose) { + std::cerr << agent()->role_str() << ": unprotect: epoch=" << ep + << " seq=" << std::hex << header.sequence_number() << std::dec + << " " << *plaintext << std::endl; + } + + return true; +} + +bool TlsRecordFilter::Protect(TlsCipherSpec& protection_spec, + const TlsRecordHeader& header, + uint8_t inner_content_type, + const DataBuffer& plaintext, + DataBuffer* ciphertext, + TlsRecordHeader* out_header, size_t padding) { + if (!protection_spec.is_protected()) { + // Not protected, just keep the sequence numbers updated. + protection_spec.RecordProtected(); + *ciphertext = plaintext; + return true; + } + + DataBuffer padded; + padded.Allocate(plaintext.len() + 1 + padding); + size_t offset = padded.Write(0, plaintext.data(), plaintext.len()); + padded.Write(offset, inner_content_type, 1); + + bool ok = protection_spec.Protect(header, padded, ciphertext, out_header); + if (!ok) { + ADD_FAILURE() << "protect fail"; + } else if (g_ssl_gtest_verbose) { + std::cerr << agent()->role_str() + << ": protect: epoch=" << protection_spec.epoch() + << " seq=" << std::hex << header.sequence_number() << std::dec + << " " << *ciphertext << std::endl; + } + return ok; +} + +bool IsHelloRetry(const DataBuffer& body) { + static const uint8_t ssl_hello_retry_random[] = { + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, + 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, + 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C}; + return memcmp(body.data() + 2, ssl_hello_retry_random, + sizeof(ssl_hello_retry_random)) == 0; +} + +bool TlsHandshakeFilter::IsFilteredType(const HandshakeHeader& header, + const DataBuffer& body) { + if (handshake_types_.empty()) { + return true; + } + + uint8_t type = header.handshake_type(); + if (type == kTlsHandshakeServerHello) { + if (IsHelloRetry(body)) { + type = kTlsHandshakeHelloRetryRequest; + } + } + return handshake_types_.count(type) > 0U; +} + +PacketFilter::Action TlsHandshakeFilter::FilterRecord( + const TlsRecordHeader& record_header, const DataBuffer& input, + DataBuffer* output) { + // Check that the first byte is as requested. + if (record_header.content_type() != ssl_ct_handshake) { + return KEEP; + } + + bool changed = false; + size_t offset = 0U; + output->Allocate(input.len()); // Preallocate a little. + + TlsParser parser(input); + while (parser.remaining()) { + HandshakeHeader header; + DataBuffer handshake; + bool complete = false; + if (!header.Parse(&parser, record_header, preceding_fragment_, &handshake, + &complete)) { + return KEEP; + } + + if (!complete) { + EXPECT_TRUE(record_header.is_dtls()); + // Save the fragment and drop it from this record. Fragments are + // coalesced with the last fragment of the handshake message. + changed = true; + preceding_fragment_.Assign(handshake); + continue; + } + preceding_fragment_.Truncate(0); + + DataBuffer filtered; + PacketFilter::Action action; + if (!IsFilteredType(header, handshake)) { + action = KEEP; + } else { + action = FilterHandshake(header, handshake, &filtered); + } + if (action == DROP) { + changed = true; + std::cerr << "handshake drop: " << handshake << std::endl; + continue; + } + + const DataBuffer* source = &handshake; + if (action == CHANGE) { + EXPECT_GT(0x1000000U, filtered.len()); + changed = true; + std::cerr << "handshake old: " << handshake << std::endl; + std::cerr << "handshake new: " << filtered << std::endl; + source = &filtered; + } else if (preceding_fragment_.len()) { + changed = true; + } + + offset = header.Write(output, offset, *source); + } + output->Truncate(offset); + return changed ? (offset ? CHANGE : DROP) : KEEP; +} + +bool TlsHandshakeFilter::HandshakeHeader::ReadLength( + TlsParser* parser, const TlsRecordHeader& header, uint32_t expected_offset, + uint32_t* length, bool* last_fragment) { + uint32_t message_length; + if (!parser->Read(&message_length, 3)) { + return false; // malformed + } + + if (!header.is_dtls()) { + *last_fragment = true; + *length = message_length; + return true; // nothing left to do + } + + // Read and check DTLS parameters + uint32_t message_seq_tmp; + if (!parser->Read(&message_seq_tmp, 2)) { // sequence number + return false; + } + message_seq_ = message_seq_tmp; + + uint32_t offset = 0; + if (!parser->Read(&offset, 3)) { + return false; + } + // We only parse if the fragments are all complete and in order. + if (offset != expected_offset) { + EXPECT_NE(0U, header.epoch()) + << "Received out of order handshake fragment for epoch 0"; + return false; + } + + // For DTLS, we return the length of just this fragment. + if (!parser->Read(length, 3)) { + return false; + } + + // It's a fragment if the entire message is longer than what we have. + *last_fragment = message_length == (*length + offset); + return true; +} + +bool TlsHandshakeFilter::HandshakeHeader::Parse( + TlsParser* parser, const TlsRecordHeader& record_header, + const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) { + *complete = false; + + variant_ = record_header.variant(); + version_ = record_header.version(); + if (!parser->Read(&handshake_type_)) { + return false; // malformed + } + + uint32_t length; + if (!ReadLength(parser, record_header, preceding_fragment.len(), &length, + complete)) { + return false; + } + + if (!parser->Read(body, length)) { + return false; + } + if (preceding_fragment.len()) { + body->Splice(preceding_fragment, 0); + } + return true; +} + +size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment( + DataBuffer* buffer, size_t offset, const DataBuffer& body, + size_t fragment_offset, size_t fragment_length) const { + EXPECT_TRUE(is_dtls()); + EXPECT_GE(body.len(), fragment_offset + fragment_length); + offset = buffer->Write(offset, handshake_type(), 1); + offset = buffer->Write(offset, body.len(), 3); + offset = buffer->Write(offset, message_seq_, 2); + offset = buffer->Write(offset, fragment_offset, 3); + offset = buffer->Write(offset, fragment_length, 3); + offset = + buffer->Write(offset, body.data() + fragment_offset, fragment_length); + return offset; +} + +size_t TlsHandshakeFilter::HandshakeHeader::Write( + DataBuffer* buffer, size_t offset, const DataBuffer& body) const { + if (is_dtls()) { + return WriteFragment(buffer, offset, body, 0U, body.len()); + } + offset = buffer->Write(offset, handshake_type(), 1); + offset = buffer->Write(offset, body.len(), 3); + offset = buffer->Write(offset, body); + return offset; +} + +PacketFilter::Action TlsHandshakeRecorder::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + // Only do this once. + if (buffer_.len()) { + return KEEP; + } + + buffer_ = input; + return KEEP; +} + +PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + *output = buffer_; + return CHANGE; +} + +PacketFilter::Action TlsRecordRecorder::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& input, + DataBuffer* output) { + if (!filter_ || (header.content_type() == ct_)) { + records_.push_back({header, input}); + } + return KEEP; +} + +PacketFilter::Action TlsConversationRecorder::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& input, + DataBuffer* output) { + buffer_.Append(input); + return KEEP; +} + +PacketFilter::Action TlsHeaderRecorder::FilterRecord(const TlsRecordHeader& hdr, + const DataBuffer& input, + DataBuffer* output) { + headers_.push_back(hdr); + return KEEP; +} + +const TlsRecordHeader* TlsHeaderRecorder::header(size_t index) { + if (index > headers_.size() + 1) { + return nullptr; + } + return &headers_[index]; +} + +PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input, + DataBuffer* output) { + DataBuffer in(input); + bool changed = false; + for (auto it = filters_.begin(); it != filters_.end(); ++it) { + PacketFilter::Action action = (*it)->Process(in, output); + if (action == DROP) { + return DROP; + } + + if (action == CHANGE) { + in = *output; + changed = true; + } + } + return changed ? CHANGE : KEEP; +} + +bool FindClientHelloExtensions(TlsParser* parser, const TlsVersioned& header) { + if (!parser->Skip(2 + 32)) { // version + random + return false; + } + if (!parser->SkipVariable(1)) { // session ID + return false; + } + if (header.is_dtls() && !parser->SkipVariable(1)) { // DTLS cookie + return false; + } + if (!parser->SkipVariable(2)) { // cipher suites + return false; + } + if (!parser->SkipVariable(1)) { // compression methods + return false; + } + return true; +} + +bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) { + uint32_t vtmp; + if (!parser->Read(&vtmp, 2)) { + return false; + } + uint16_t version = static_cast<uint16_t>(vtmp); + if (!parser->Skip(32)) { // random + return false; + } + if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) { + if (!parser->SkipVariable(1)) { // session ID + return false; + } + } + if (!parser->Skip(2)) { // cipher suite + return false; + } + if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) { + if (!parser->Skip(1)) { // compression method + return false; + } + } + return true; +} + +bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) { + return true; +} + +static bool FindCertReqExtensions(TlsParser* parser, + const TlsVersioned& header) { + if (!parser->SkipVariable(1)) { // request context + return false; + } + return true; +} + +// Only look at the EE cert for this one. +static bool FindCertificateExtensions(TlsParser* parser, + const TlsVersioned& header) { + if (!parser->SkipVariable(1)) { // request context + return false; + } + if (!parser->Skip(3)) { // length of certificate list + return false; + } + if (!parser->SkipVariable(3)) { // ASN1Cert + return false; + } + return true; +} + +static bool FindNewSessionTicketExtensions(TlsParser* parser, + const TlsVersioned& header) { + if (!parser->Skip(8)) { // lifetime, age add + return false; + } + if (!parser->SkipVariable(1)) { // ticket_nonce + return false; + } + if (!parser->SkipVariable(2)) { // ticket + return false; + } + return true; +} + +static const std::map<uint16_t, TlsExtensionFinder> kExtensionFinders = { + {kTlsHandshakeClientHello, FindClientHelloExtensions}, + {kTlsHandshakeServerHello, FindServerHelloExtensions}, + {kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions}, + {kTlsHandshakeCertificateRequest, FindCertReqExtensions}, + {kTlsHandshakeCertificate, FindCertificateExtensions}, + {kTlsHandshakeNewSessionTicket, FindNewSessionTicketExtensions}}; + +bool TlsExtensionFilter::FindExtensions(TlsParser* parser, + const HandshakeHeader& header) { + auto it = kExtensionFinders.find(header.handshake_type()); + if (it == kExtensionFinders.end()) { + return false; + } + return (it->second)(parser, header); +} + +PacketFilter::Action TlsExtensionFilter::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + TlsParser parser(input); + if (!FindExtensions(&parser, header)) { + return KEEP; + } + return FilterExtensions(&parser, input, output); +} + +PacketFilter::Action TlsExtensionFilter::FilterExtensions( + TlsParser* parser, const DataBuffer& input, DataBuffer* output) { + size_t length_offset = parser->consumed(); + uint32_t all_extensions; + if (!parser->Read(&all_extensions, 2)) { + return KEEP; // no extensions, odd but OK + } + if (all_extensions != parser->remaining()) { + return KEEP; // malformed + } + + bool changed = false; + + // Write out the start of the message. + output->Allocate(input.len()); + size_t offset = output->Write(0, input.data(), parser->consumed()); + + while (parser->remaining()) { + uint32_t extension_type; + if (!parser->Read(&extension_type, 2)) { + return KEEP; // malformed + } + + DataBuffer extension; + if (!parser->ReadVariable(&extension, 2)) { + return KEEP; // malformed + } + + DataBuffer filtered; + PacketFilter::Action action = + FilterExtension(extension_type, extension, &filtered); + if (action == DROP) { + changed = true; + std::cerr << "extension drop: " << extension << std::endl; + continue; + } + + const DataBuffer* source = &extension; + if (action == CHANGE) { + EXPECT_GT(0x10000U, filtered.len()); + changed = true; + std::cerr << "extension old: " << extension << std::endl; + std::cerr << "extension new: " << filtered << std::endl; + source = &filtered; + } + + // Write out extension. + offset = output->Write(offset, extension_type, 2); + offset = output->Write(offset, source->len(), 2); + if (source->len() > 0) { + offset = output->Write(offset, *source); + } + } + output->Truncate(offset); + + if (changed) { + size_t newlen = output->len() - length_offset - 2; + EXPECT_GT(0x10000U, newlen); + if (newlen >= 0x10000) { + return KEEP; // bad: size increased too much + } + output->Write(length_offset, newlen, 2); + return CHANGE; + } + return KEEP; +} + +PacketFilter::Action TlsExtensionOrderCapture::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + order.push_back(extension_type); + return KEEP; +} + +PacketFilter::Action TlsExtensionCapture::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + if (extension_type == extension_ && (last_ || !captured_)) { + data_.Assign(input); + captured_ = true; + } + return KEEP; +} + +PacketFilter::Action TlsExtensionReplacer::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + if (extension_type != extension_) { + return KEEP; + } + + *output = data_; + return CHANGE; +} + +PacketFilter::Action TlsExtensionResizer::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + if (extension_type != extension_) { + return KEEP; + } + + if (input.len() <= length_) { + DataBuffer buf(length_ - input.len()); + output->Append(buf); + return CHANGE; + } + + output->Assign(input.data(), length_); + return CHANGE; +} + +PacketFilter::Action TlsExtensionAppender::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + TlsParser parser(input); + if (!TlsExtensionFilter::FindExtensions(&parser, header)) { + return KEEP; + } + *output = input; + + // Increase the length of the extensions block. + if (!UpdateLength(output, parser.consumed(), 2)) { + return KEEP; + } + + // Extensions in Certificate are nested twice. Increase the size of the + // certificate list. + if (header.handshake_type() == kTlsHandshakeCertificate) { + TlsParser p2(input); + if (!p2.SkipVariable(1)) { + ADD_FAILURE(); + return KEEP; + } + if (!UpdateLength(output, p2.consumed(), 3)) { + return KEEP; + } + } + + size_t offset = output->len(); + offset = output->Write(offset, extension_, 2); + WriteVariable(output, offset, data_, 2); + + return CHANGE; +} + +bool TlsExtensionAppender::UpdateLength(DataBuffer* output, size_t offset, + size_t size) { + uint32_t len; + if (!output->Read(offset, size, &len)) { + ADD_FAILURE(); + return false; + } + + len += 4 + data_.len(); + output->Write(offset, len, size); + return true; +} + +PacketFilter::Action TlsExtensionDropper::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + if (extension_type == extension_) { + return DROP; + } + return KEEP; +} + +PacketFilter::Action TlsExtensionDamager::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + if (extension_type != extension_) { + return KEEP; + } + + *output = input; + output->data()[index_] += 73; // Increment selected for maximum damage + return CHANGE; +} + +PacketFilter::Action TlsExtensionInjector::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + TlsParser parser(input); + if (!TlsExtensionFilter::FindExtensions(&parser, header)) { + return KEEP; + } + size_t offset = parser.consumed(); + + *output = input; + + // Increase the size of the extensions. + uint16_t ext_len; + memcpy(&ext_len, output->data() + offset, sizeof(ext_len)); + ext_len = htons(ntohs(ext_len) + data_.len() + 4); + memcpy(output->data() + offset, &ext_len, sizeof(ext_len)); + + // Insert the extension type and length. + DataBuffer type_length; + type_length.Allocate(4); + type_length.Write(0, extension_, 2); + type_length.Write(2, data_.len(), 2); + output->Splice(type_length, offset + 2); + + // Insert the payload. + if (data_.len() > 0) { + output->Splice(data_, offset + 6); + } + + return CHANGE; +} + +PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header, + const DataBuffer& body, + DataBuffer* out) { + if (counter_++ == record_) { + DataBuffer buf; + header.Write(&buf, 0, body); + agent()->SendDirect(buf); + dest_.lock()->Handshake(); + func_(); + return DROP; + } + + return KEEP; +} + +PacketFilter::Action TlsClientHelloVersionChanger::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + EXPECT_EQ(SECSuccess, + SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd())); + return KEEP; +} + +PacketFilter::Action SelectiveDropFilter::Filter(const DataBuffer& input, + DataBuffer* output) { + if (counter_ >= 32) { + return KEEP; + } + return ((1 << counter_++) & pattern_) ? DROP : KEEP; +} + +PacketFilter::Action SelectiveRecordDropFilter::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& data, + DataBuffer* changed) { + if (counter_ >= 32) { + return KEEP; + } + return ((1 << counter_++) & pattern_) ? DROP : KEEP; +} + +/* static */ uint32_t SelectiveRecordDropFilter::ToPattern( + std::initializer_list<size_t> records) { + uint32_t pattern = 0; + for (auto it = records.begin(); it != records.end(); ++it) { + EXPECT_GT(32U, *it); + assert(*it < 32U); + pattern |= 1 << *it; + } + return pattern; +} + +PacketFilter::Action TlsMessageVersionSetter::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + *output = input; + output->Write(0, version_, 2); + return CHANGE; +} + +PacketFilter::Action SelectedCipherSuiteReplacer::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + *output = input; + uint32_t temp = 0; + EXPECT_TRUE(input.Read(0, 2, &temp)); + EXPECT_EQ(header.version(), NormalizeTlsVersion(temp)); + // Cipher suite is after version(2), random(32) + // and [legacy_]session_id(<0..32>). + size_t pos = 34; + EXPECT_TRUE(input.Read(pos, 1, &temp)); + pos += 1 + temp; + + output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2); + return CHANGE; +} + +PacketFilter::Action ServerHelloRandomChanger::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + *output = input; + uint32_t temp = 0; + size_t pos = 30; + EXPECT_TRUE(input.Read(pos, 2, &temp)); + output->Write(pos, (temp ^ 0xffff), 2); + return CHANGE; +} + +PacketFilter::Action ClientHelloPreambleCapture::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + EXPECT_TRUE(header.handshake_type() == kTlsHandshakeClientHello); + + if (captured_) { + return KEEP; + } + captured_ = true; + + DataBuffer temp; + TlsParser parser(input); + EXPECT_TRUE(parser.Read(&temp, 2 + 32)); // Version + Random + EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Session ID + if (is_dtls_agent()) { + EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Cookie + } + EXPECT_TRUE(parser.ReadVariable(&temp, 2)); // Ciphersuites + EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Compression + + // Copy the preamble into a new buffer + data_ = input; + data_.Truncate(parser.consumed()); + + return KEEP; +} + +PacketFilter::Action ClientHelloCiphersuiteCapture::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + EXPECT_TRUE(header.handshake_type() == kTlsHandshakeClientHello); + + if (captured_) { + return KEEP; + } + captured_ = true; + + TlsParser parser(input); + EXPECT_TRUE(parser.Skip(2 + 32)); // Version + Random + EXPECT_TRUE(parser.SkipVariable(1)); // Session ID + if (is_dtls_agent()) { + EXPECT_TRUE(parser.SkipVariable(1)); // Cookie + } + + EXPECT_TRUE(parser.ReadVariable(&data_, 2)); // Ciphersuites + + return KEEP; +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h new file mode 100644 index 0000000000..7c45aab12f --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_filter.h @@ -0,0 +1,1013 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef tls_filter_h_ +#define tls_filter_h_ + +#include <functional> +#include <memory> +#include <set> +#include <vector> +#include "pk11pub.h" +#include "sslt.h" +#include "sslproto.h" +#include "test_io.h" +#include "tls_agent.h" +#include "tls_parser.h" +#include "tls_protect.h" + +extern "C" { +#include "libssl_internals.h" +} + +namespace nss_test { + +class TlsCipherSpec; + +class TlsSendCipherSpecCapturer { + public: + TlsSendCipherSpecCapturer(const std::shared_ptr<TlsAgent>& agent) + : agent_(agent), send_cipher_specs_() { + EXPECT_EQ(SECSuccess, + SSL_SecretCallback(agent_->ssl_fd(), SecretCallback, this)); + } + + std::shared_ptr<TlsCipherSpec> spec(size_t i) { + if (i >= send_cipher_specs_.size()) { + return nullptr; + } + return send_cipher_specs_[i]; + } + + private: + static void SecretCallback(PRFileDesc* fd, PRUint16 epoch, + SSLSecretDirection dir, PK11SymKey* secret, + void* arg) { + auto self = static_cast<TlsSendCipherSpecCapturer*>(arg); + std::cerr << self->agent_->role_str() << ": capture " << dir + << " secret for epoch " << epoch << std::endl; + + if (dir == ssl_secret_read) { + return; + } + + SSLPreliminaryChannelInfo preinfo; + EXPECT_EQ(SECSuccess, + SSL_GetPreliminaryChannelInfo(self->agent_->ssl_fd(), &preinfo, + sizeof(preinfo))); + EXPECT_EQ(sizeof(preinfo), preinfo.length); + EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite); + + // Check the version: + EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version); + ASSERT_GE(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion); + + SSLCipherSuiteInfo cipherinfo; + EXPECT_EQ(SECSuccess, + SSL_GetCipherSuiteInfo(preinfo.cipherSuite, &cipherinfo, + sizeof(cipherinfo))); + EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length); + + auto spec = std::make_shared<TlsCipherSpec>(true, epoch); + EXPECT_TRUE(spec->SetKeys(&cipherinfo, secret)); + self->send_cipher_specs_.push_back(spec); + } + + std::shared_ptr<TlsAgent> agent_; + std::vector<std::shared_ptr<TlsCipherSpec>> send_cipher_specs_; +}; + +class TlsVersioned { + public: + TlsVersioned() : variant_(ssl_variant_stream), version_(0) {} + TlsVersioned(SSLProtocolVariant var, uint16_t ver) + : variant_(var), version_(ver) {} + + bool is_dtls() const { return variant_ == ssl_variant_datagram; } + SSLProtocolVariant variant() const { return variant_; } + uint16_t version() const { return version_; } + + void WriteStream(std::ostream& stream) const; + + protected: + SSLProtocolVariant variant_; + uint16_t version_; +}; + +class TlsRecordHeader : public TlsVersioned { + public: + TlsRecordHeader() + : TlsVersioned(), + content_type_(0), + guess_seqno_(0), + seqno_is_masked_(false), + sequence_number_(0), + header_() {} + TlsRecordHeader(SSLProtocolVariant var, uint16_t ver, uint8_t ct, + uint64_t seqno) + : TlsVersioned(var, ver), + content_type_(ct), + guess_seqno_(0), + seqno_is_masked_(false), + sequence_number_(seqno), + header_(), + sn_mask_() {} + + bool is_protected() const { + // *TLS < 1.3 + if (version() < SSL_LIBRARY_VERSION_TLS_1_3 && + content_type() == ssl_ct_application_data) { + return true; + } + + // TLS 1.3 + if (!is_dtls() && version() >= SSL_LIBRARY_VERSION_TLS_1_3 && + content_type() == ssl_ct_application_data) { + return true; + } + + // DTLS 1.3 + return is_dtls13_ciphertext(); + } + + uint8_t content_type() const { return content_type_; } + uint16_t epoch() const { + return static_cast<uint16_t>(sequence_number_ >> 48); + } + uint64_t sequence_number() const { return sequence_number_; } + void sequence_number(uint64_t seqno) { sequence_number_ = seqno; } + const DataBuffer& sn_mask() const { return sn_mask_; } + bool is_dtls13_ciphertext() const { + return is_dtls() && (version() >= SSL_LIBRARY_VERSION_TLS_1_3) && + (content_type() & kCtDtlsCiphertextMask) == kCtDtlsCiphertext; + } + + size_t header_length() const; + const DataBuffer& header() const { return header_; } + + bool MaskSequenceNumber(); + bool MaskSequenceNumber(const DataBuffer& mask_buf); + + // Parse the header; return true if successful; body in an outparam if OK. + bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser, + DataBuffer* body); + // Write the header and body to a buffer at the given offset. + // Return the offset of the end of the write. + size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; + size_t WriteHeader(DataBuffer* buffer, size_t offset, size_t body_len) const; + + private: + static uint64_t RecoverSequenceNumber(uint64_t guess_seqno, uint32_t partial, + size_t partial_bits); + uint64_t ParseSequenceNumber(uint64_t expected, uint64_t raw, + size_t seq_no_bits, size_t epoch_bits); + + uint8_t content_type_; + uint64_t guess_seqno_; + bool seqno_is_masked_; + uint64_t sequence_number_; + DataBuffer header_; + DataBuffer sn_mask_; +}; + +struct TlsRecord { + const TlsRecordHeader header; + const DataBuffer buffer; +}; + +// Make a filter and install it on a TlsAgent. +template <class T, typename... Args> +inline std::shared_ptr<T> MakeTlsFilter(const std::shared_ptr<TlsAgent>& agent, + Args&&... args) { + auto filter = std::make_shared<T>(agent, std::forward<Args>(args)...); + agent->SetFilter(filter); + return filter; +} + +// Abstract filter that operates on entire (D)TLS records. +class TlsRecordFilter : public PacketFilter { + public: + TlsRecordFilter(const std::shared_ptr<TlsAgent>& a); + + std::shared_ptr<TlsAgent> agent() const { return agent_.lock(); } + + // External interface. Overrides PacketFilter. + PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output); + + // Report how many packets were altered by the filter. + size_t filtered_packets() const { return count_; } + + // Enable decryption. This only works properly for TLS 1.3 and above. + // Enabling it for lower version tests will cause undefined + // behavior. + void EnableDecryption(); + bool decrypting() const { return decrypting_; }; + bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText, + uint16_t* protection_epoch, uint8_t* inner_content_type, + DataBuffer* plaintext, TlsRecordHeader* out_header); + bool Protect(TlsCipherSpec& protection_spec, const TlsRecordHeader& header, + uint8_t inner_content_type, const DataBuffer& plaintext, + DataBuffer* ciphertext, TlsRecordHeader* out_header, + size_t padding = 0); + + protected: + // There are two filter functions which can be overriden. Both are + // called with the header and the record but the outer one is called + // with a raw pointer to let you write into the buffer and lets you + // do anything with this section of the stream. The inner one + // just lets you change the record contents. By default, the + // outer one calls the inner one, so if you override the outer + // one, the inner one is never called unless you call it yourself. + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& record, + size_t* offset, DataBuffer* output); + + // The record filter receives the record contentType, version and DTLS + // sequence number (which is zero for TLS), plus the existing record payload. + // It returns an action (KEEP, CHANGE, DROP). It writes to the `changed` + // outparam with the new record contents if it chooses to CHANGE the record. + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) { + return KEEP; + } + + bool is_dtls_agent() const; + bool is_dtls13() const; + bool is_dtls13_ciphertext(uint8_t ct) const; + TlsCipherSpec& spec(uint16_t epoch); + + private: + static void SecretCallback(PRFileDesc* fd, PRUint16 epoch, + SSLSecretDirection dir, PK11SymKey* secret, + void* arg); + + std::weak_ptr<TlsAgent> agent_; + size_t count_ = 0; + std::vector<TlsCipherSpec> cipher_specs_; + bool decrypting_ = false; +}; + +inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) { + v.WriteStream(stream); + return stream; +} + +inline std::ostream& operator<<(std::ostream& stream, + const TlsRecordHeader& hdr) { + hdr.WriteStream(stream); + stream << ' '; + switch (hdr.content_type()) { + case ssl_ct_change_cipher_spec: + stream << "CCS"; + break; + case ssl_ct_alert: + stream << "Alert"; + break; + case ssl_ct_handshake: + stream << "Handshake"; + break; + case ssl_ct_application_data: + stream << "Data"; + break; + case ssl_ct_ack: + stream << "ACK"; + break; + default: + stream << '<' << static_cast<int>(hdr.content_type()) << '>'; + break; + } + return stream << ' ' << std::hex << hdr.sequence_number() << std::dec; +} + +// Abstract filter that operates on handshake messages rather than records. +// This assumes that the handshake messages are written in a block as entire +// records and that they don't span records or anything crazy like that. +class TlsHandshakeFilter : public TlsRecordFilter { + public: + TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a), handshake_types_(), preceding_fragment_() {} + TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a, + const std::set<uint8_t>& types) + : TlsRecordFilter(a), handshake_types_(types), preceding_fragment_() {} + + // This filter can be set to be selective based on handshake message type. If + // this function isn't used (or the set is empty), then all handshake messages + // will be filtered. + void SetHandshakeTypes(const std::set<uint8_t>& types) { + handshake_types_ = types; + } + + class HandshakeHeader : public TlsVersioned { + public: + HandshakeHeader() : TlsVersioned(), handshake_type_(0), message_seq_(0) {} + + uint8_t handshake_type() const { return handshake_type_; } + bool Parse(TlsParser* parser, const TlsRecordHeader& record_header, + const DataBuffer& preceding_fragment, DataBuffer* body, + bool* complete); + size_t Write(DataBuffer* buffer, size_t offset, + const DataBuffer& body) const; + size_t WriteFragment(DataBuffer* buffer, size_t offset, + const DataBuffer& body, size_t fragment_offset, + size_t fragment_length) const; + + private: + // Reads the length from the record header. + // This also reads the DTLS fragment information and checks it. + bool ReadLength(TlsParser* parser, const TlsRecordHeader& header, + uint32_t expected_offset, uint32_t* length, + bool* last_fragment); + + uint8_t handshake_type_; + uint16_t message_seq_; + // fragment_offset is always zero in these tests. + }; + + protected: + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& input, + DataBuffer* output); + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) = 0; + + private: + bool IsFilteredType(const HandshakeHeader& header, + const DataBuffer& handshake); + + std::set<uint8_t> handshake_types_; + DataBuffer preceding_fragment_; +}; + +// Make a copy of the first instance of a handshake message. +class TlsHandshakeRecorder : public TlsHandshakeFilter { + public: + TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a, + uint8_t handshake_type) + : TlsHandshakeFilter(a, {handshake_type}), buffer_() {} + TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a, + const std::set<uint8_t>& handshake_types) + : TlsHandshakeFilter(a, handshake_types), buffer_() {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output); + + void Reset() { buffer_.Truncate(0); } + + const DataBuffer& buffer() const { return buffer_; } + + private: + DataBuffer buffer_; +}; + +// Replace all instances of a handshake message. +class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { + public: + TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& a, + uint8_t handshake_type, + const DataBuffer& replacement) + : TlsHandshakeFilter(a, {handshake_type}), buffer_(replacement) {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output); + + private: + DataBuffer buffer_; +}; + +// Make a copy of each record of a given type. +class TlsRecordRecorder : public TlsRecordFilter { + public: + TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a, uint8_t ct) + : TlsRecordFilter(a), filter_(true), ct_(ct), records_() {} + TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a), + filter_(false), + ct_(ssl_ct_handshake), // dummy (<optional> is C++14) + records_() {} + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& input, + DataBuffer* output); + + size_t count() const { return records_.size(); } + void Clear() { records_.clear(); } + + const TlsRecord& record(size_t i) const { return records_[i]; } + + private: + bool filter_; + uint8_t ct_; + std::vector<TlsRecord> records_; +}; + +// Make a copy of the complete conversation. +class TlsConversationRecorder : public TlsRecordFilter { + public: + TlsConversationRecorder(const std::shared_ptr<TlsAgent>& a, + DataBuffer& buffer) + : TlsRecordFilter(a), buffer_(buffer) {} + + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& input, + DataBuffer* output); + + private: + DataBuffer buffer_; +}; + +// Make a copy of the records +class TlsHeaderRecorder : public TlsRecordFilter { + public: + TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& a) : TlsRecordFilter(a) {} + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& input, + DataBuffer* output); + const TlsRecordHeader* header(size_t index); + + private: + std::vector<TlsRecordHeader> headers_; +}; + +typedef std::initializer_list<std::shared_ptr<PacketFilter>> + ChainedPacketFilterInit; + +// Runs multiple packet filters in series. +class ChainedPacketFilter : public PacketFilter { + public: + ChainedPacketFilter() {} + ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters) + : filters_(filters.begin(), filters.end()) {} + ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {} + virtual ~ChainedPacketFilter() {} + + virtual PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output); + + // Takes ownership of the filter. + void Add(std::shared_ptr<PacketFilter> filter) { filters_.push_back(filter); } + + private: + std::vector<std::shared_ptr<PacketFilter>> filters_; +}; + +typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)> + TlsExtensionFinder; + +class TlsExtensionFilter : public TlsHandshakeFilter { + public: + TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, + {kTlsHandshakeClientHello, kTlsHandshakeServerHello, + kTlsHandshakeHelloRetryRequest, + kTlsHandshakeEncryptedExtensions}) {} + + TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a, + const std::set<uint8_t>& types) + : TlsHandshakeFilter(a, types) {} + + static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header); + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override; + + virtual PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) = 0; + + private: + PacketFilter::Action FilterExtensions(TlsParser* parser, + const DataBuffer& input, + DataBuffer* output); +}; + +class TlsExtensionOrderCapture : public TlsExtensionFilter { + public: + TlsExtensionOrderCapture(const std::shared_ptr<TlsAgent>& a, uint8_t message) + : TlsExtensionFilter(a, {message}){}; + + std::vector<uint16_t> order; + + protected: + PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) override; +}; + +class TlsExtensionCapture : public TlsExtensionFilter { + public: + TlsExtensionCapture(const std::shared_ptr<TlsAgent>& a, uint16_t ext, + bool last = false) + : TlsExtensionFilter(a), + extension_(ext), + captured_(false), + last_(last), + data_() {} + + const DataBuffer& extension() const { return data_; } + bool captured() const { return captured_; } + + protected: + PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) override; + + private: + const uint16_t extension_; + bool captured_; + bool last_; + DataBuffer data_; +}; + +class TlsExtensionReplacer : public TlsExtensionFilter { + public: + TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& a, uint16_t extension, + const DataBuffer& data) + : TlsExtensionFilter(a), extension_(extension), data_(data) {} + PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) override; + + private: + const uint16_t extension_; + const DataBuffer data_; +}; + +class TlsExtensionResizer : public TlsExtensionFilter { + public: + TlsExtensionResizer(const std::shared_ptr<TlsAgent>& a, uint16_t extension, + size_t length) + : TlsExtensionFilter(a), extension_(extension), length_(length) {} + PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) override; + + private: + uint16_t extension_; + size_t length_; +}; + +class TlsExtensionAppender : public TlsHandshakeFilter { + public: + TlsExtensionAppender(const std::shared_ptr<TlsAgent>& a, + uint8_t handshake_type, uint16_t ext, DataBuffer& data) + : TlsHandshakeFilter(a, {handshake_type}), extension_(ext), data_(data) {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output); + + private: + bool UpdateLength(DataBuffer* output, size_t offset, size_t size); + + const uint16_t extension_; + const DataBuffer data_; +}; + +class TlsExtensionDropper : public TlsExtensionFilter { + public: + TlsExtensionDropper(const std::shared_ptr<TlsAgent>& a, uint16_t extension) + : TlsExtensionFilter(a), extension_(extension) {} + PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer&, DataBuffer*) override; + + private: + uint16_t extension_; +}; + +class TlsHandshakeDropper : public TlsHandshakeFilter { + public: + TlsHandshakeDropper(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + return DROP; + } +}; + +class TlsEncryptedHandshakeMessageReplacer : public TlsRecordFilter { + public: + TlsEncryptedHandshakeMessageReplacer(const std::shared_ptr<TlsAgent>& a, + uint8_t old_ct, uint8_t new_ct) + : TlsRecordFilter(a), old_ct_(old_ct), new_ct_(new_ct) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& record, size_t* offset, + DataBuffer* output) override { + if (header.content_type() != ssl_ct_application_data) { + return KEEP; + } + + uint16_t protection_epoch = 0; + uint8_t inner_content_type; + DataBuffer plaintext; + TlsRecordHeader out_header; + if (!Unprotect(header, record, &protection_epoch, &inner_content_type, + &plaintext, &out_header) || + !plaintext.len()) { + return KEEP; + } + + if (inner_content_type != ssl_ct_handshake) { + return KEEP; + } + + size_t off = 0; + uint32_t msg_len = 0; + uint32_t msg_type = 255; // Not a real message + do { + if (!plaintext.Read(off, 1, &msg_type) || msg_type == old_ct_) { + break; + } + + // Increment and check next messages + if (!plaintext.Read(++off, 3, &msg_len)) { + break; + } + off += 3 + msg_len; + } while (msg_type != old_ct_); + + if (msg_type == old_ct_) { + plaintext.Write(off, new_ct_, 1); + } + + DataBuffer ciphertext; + bool ok = Protect(spec(protection_epoch), out_header, inner_content_type, + plaintext, &ciphertext, &out_header); + if (!ok) { + return KEEP; + } + *offset = out_header.Write(output, *offset, ciphertext); + return CHANGE; + } + + private: + uint8_t old_ct_; + uint8_t new_ct_; +}; + +class TlsExtensionInjector : public TlsHandshakeFilter { + public: + TlsExtensionInjector(const std::shared_ptr<TlsAgent>& a, uint16_t ext, + const DataBuffer& data) + : TlsHandshakeFilter(a), extension_(ext), data_(data) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override; + + private: + const uint16_t extension_; + const DataBuffer data_; +}; + +class TlsExtensionDamager : public TlsExtensionFilter { + public: + TlsExtensionDamager(const std::shared_ptr<TlsAgent>& a, uint16_t extension, + size_t index) + : TlsExtensionFilter(a), extension_(extension), index_(index) {} + virtual PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output); + + private: + uint16_t extension_; + size_t index_; +}; + +typedef std::function<void(void)> VoidFunction; + +class AfterRecordN : public TlsRecordFilter { + public: + AfterRecordN(const std::shared_ptr<TlsAgent>& src, + const std::shared_ptr<TlsAgent>& dest, unsigned int record, + VoidFunction func) + : TlsRecordFilter(src), + dest_(dest), + record_(record), + func_(func), + counter_(0) {} + + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& body, + DataBuffer* out) override; + + private: + std::weak_ptr<TlsAgent> dest_; + unsigned int record_; + VoidFunction func_; + unsigned int counter_; +}; + +// When we see the ClientKeyExchange from |client|, increment the +// ClientHelloVersion on |server|. +class TlsClientHelloVersionChanger : public TlsHandshakeFilter { + public: + TlsClientHelloVersionChanger(const std::shared_ptr<TlsAgent>& client, + const std::shared_ptr<TlsAgent>& server) + : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}), + server_(server) {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output); + + private: + std::weak_ptr<TlsAgent> server_; +}; + +// Damage a record. +class TlsRecordLastByteDamager : public TlsRecordFilter { + public: + TlsRecordLastByteDamager(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + *changed = data; + changed->data()[changed->len() - 1]++; + return CHANGE; + } +}; + +// This class selectively drops complete writes. This relies on the fact that +// writes in libssl are on record boundaries. +class SelectiveDropFilter : public PacketFilter { + public: + SelectiveDropFilter(uint32_t pattern) : pattern_(pattern), counter_(0) {} + + protected: + virtual PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) override; + + private: + const uint32_t pattern_; + uint8_t counter_; +}; + +// This class selectively drops complete records. The difference from +// SelectiveDropFilter is that if multiple DTLS records are in the same +// datagram, we just drop one. +class SelectiveRecordDropFilter : public TlsRecordFilter { + public: + SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a, + uint32_t pattern, bool on = true) + : TlsRecordFilter(a), pattern_(pattern), counter_(0) { + if (!on) { + Disable(); + } + } + SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a, + std::initializer_list<size_t> records) + : SelectiveRecordDropFilter(a, ToPattern(records), true) {} + + void Reset(uint32_t pattern) { + counter_ = 0; + PacketFilter::Enable(); + pattern_ = pattern; + } + + void Reset(std::initializer_list<size_t> records) { + Reset(ToPattern(records)); + } + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override; + + private: + static uint32_t ToPattern(std::initializer_list<size_t> records); + + uint32_t pattern_; + uint8_t counter_; +}; + +// Set the version value in the ClientHello, ServerHello or HelloRetryRequest +class TlsMessageVersionSetter : public TlsHandshakeFilter { + public: + TlsMessageVersionSetter(const std::shared_ptr<TlsAgent>& a, uint8_t message, + uint16_t version) + : TlsHandshakeFilter(a, {message}), version_(version) { + PR_ASSERT(message == kTlsHandshakeClientHello || + message == kTlsHandshakeServerHello || + message == kTlsHandshakeHelloRetryRequest); + } + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output); + + private: + uint16_t version_; +}; + +// Damages the last byte of a handshake message. +class TlsLastByteDamager : public TlsHandshakeFilter { + public: + TlsLastByteDamager(const std::shared_ptr<TlsAgent>& a, uint8_t type) + : TlsHandshakeFilter(a), type_(type) {} + PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) override { + if (header.handshake_type() != type_) { + return KEEP; + } + + *output = input; + + output->data()[output->len() - 1]++; + return CHANGE; + } + + private: + uint8_t type_; +}; + +class SelectedCipherSuiteReplacer : public TlsHandshakeFilter { + public: + SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& a, + uint16_t suite) + : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}), + cipher_suite_(suite) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override; + + private: + uint16_t cipher_suite_; +}; + +class ClientHelloPreambleCapture : public TlsHandshakeFilter { + public: + ClientHelloPreambleCapture(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}), + captured_(false), + data_() {} + + const DataBuffer& contents() const { return data_; } + bool captured() const { return captured_; } + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override; + + private: + bool captured_; + DataBuffer data_; +}; + +class ClientHelloCiphersuiteCapture : public TlsHandshakeFilter { + public: + ClientHelloCiphersuiteCapture(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}), + captured_(false), + data_() {} + + const DataBuffer& contents() const { return data_; } + bool captured() const { return captured_; } + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override; + + private: + bool captured_; + DataBuffer data_; +}; + +class ServerHelloRandomChanger : public TlsHandshakeFilter { + public: + ServerHelloRandomChanger(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override; +}; + +// Replace SignatureAndHashAlgorithm of a SKE. +class DHEServerKEXSigAlgReplacer : public TlsHandshakeFilter { + public: + DHEServerKEXSigAlgReplacer(const std::shared_ptr<TlsAgent>& server, + uint16_t sig_scheme) + : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}), + sig_scheme_(sig_scheme) {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + *output = input; + + uint32_t len; + uint32_t idx = 0; + EXPECT_TRUE(output->Read(idx, 2, &len)); + idx += 2 + len; + EXPECT_TRUE(output->Read(idx, 2, &len)); + idx += 2 + len; + EXPECT_TRUE(output->Read(idx, 2, &len)); + idx += 2 + len; + output->Write(idx, sig_scheme_, 2); + + return CHANGE; + } + + private: + uint16_t sig_scheme_; +}; + +// Replace SignatureAndHashAlgorithm of a SKE. +class ECCServerKEXSigAlgReplacer : public TlsHandshakeFilter { + public: + ECCServerKEXSigAlgReplacer(const std::shared_ptr<TlsAgent>& server, + uint16_t sig_scheme) + : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}), + sig_scheme_(sig_scheme) {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + *output = input; + + uint32_t point_len; + EXPECT_TRUE(output->Read(3, 1, &point_len)); + output->Write(4 + point_len, sig_scheme_, 2); + + return CHANGE; + } + + private: + uint16_t sig_scheme_; +}; + +// Replace NamedCurve of a ECDHE SKE. +class ECCServerKEXNamedCurveReplacer : public TlsHandshakeFilter { + public: + ECCServerKEXNamedCurveReplacer(const std::shared_ptr<TlsAgent>& server, + uint16_t curve_name) + : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}), + curve_name_(curve_name) {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + *output = input; + + uint32_t curve_type; + EXPECT_TRUE(output->Read(0, 1, &curve_type)); + EXPECT_EQ(curve_type, ec_type_named); + output->Write(1, curve_name_, 2); + + return CHANGE; + } + + private: + uint16_t curve_name_; +}; + +// Replaces the signature scheme in a CertificateVerify message. +class TlsReplaceSignatureSchemeFilter : public TlsHandshakeFilter { + public: + TlsReplaceSignatureSchemeFilter(const std::shared_ptr<TlsAgent>& a, + uint16_t scheme) + : TlsHandshakeFilter(a, {kTlsHandshakeCertificateVerify}), + scheme_(scheme) {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + *output = input; + output->Write(0, scheme_, 2); + return CHANGE; + } + + private: + uint16_t scheme_; +}; + +} // namespace nss_test + +#endif diff --git a/security/nss/gtests/ssl_gtest/tls_grease_unittest.cc b/security/nss/gtests/ssl_gtest/tls_grease_unittest.cc new file mode 100644 index 0000000000..c89c41be04 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_grease_unittest.cc @@ -0,0 +1,878 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" + +#include "gtest_utils.h" +#include "tls_connect.h" +#include "util.h" + +namespace nss_test { + +const uint8_t kTlsGreaseExtensionMessages[] = {kTlsHandshakeEncryptedExtensions, + kTlsHandshakeCertificate}; + +const uint16_t kTlsGreaseValues[] = { + 0x0a0a, 0x1a1a, 0x2a2a, 0x3a3a, 0x4a4a, 0x5a5a, 0x6a6a, 0x7a7a, + 0x8a8a, 0x9a9a, 0xaaaa, 0xbaba, 0xcaca, 0xdada, 0xeaea, 0xfafa}; + +const uint8_t kTlsGreasePskValues[] = {0x0B, 0x2A, 0x49, 0x68, + 0x87, 0xA6, 0xC5, 0xE4}; + +size_t countGreaseInBuffer(const DataBuffer& list) { + if (!list.len()) { + return 0; + } + size_t occurrence = 0; + for (uint16_t greaseVal : kTlsGreaseValues) { + for (size_t i = 0; i < (list.len() - 1); i += 2) { + uint16_t sample = list.data()[i + 1] + (list.data()[i] << 8); + if (greaseVal == sample) { + occurrence++; + } + } + } + return occurrence; +} + +class GreasePresenceAbsenceTestBase : public TlsConnectTestBase { + public: + GreasePresenceAbsenceTestBase(SSLProtocolVariant variant, uint16_t version, + bool shouldGrease) + : TlsConnectTestBase(variant, version), set_grease_(shouldGrease){}; + + void SetupGrease() { + EnsureTlsSetup(); + ASSERT_EQ(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, set_grease_), + SECSuccess); + ASSERT_EQ(SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_GREASE, set_grease_), + SECSuccess); + } + + bool expectGrease() { + return set_grease_ && version_ >= SSL_LIBRARY_VERSION_TLS_1_3; + } + + void checkGreasePresence(const int ifEnabled, const int ifDisabled, + const DataBuffer& buffer) { + size_t expected = expectGrease() ? size_t(ifEnabled) : size_t(ifDisabled); + EXPECT_EQ(expected, countGreaseInBuffer(buffer)); + } + + private: + bool set_grease_; +}; + +class GreasePresenceAbsenceTestAllVersions + : public GreasePresenceAbsenceTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t, bool>> { + public: + GreasePresenceAbsenceTestAllVersions() + : GreasePresenceAbsenceTestBase(std::get<0>(GetParam()), + std::get<1>(GetParam()), + std::get<2>(GetParam())){}; +}; + +// Varies stream/datagram, TLS Version and whether GREASE is enabled +INSTANTIATE_TEST_SUITE_P(GreaseTests, GreasePresenceAbsenceTestAllVersions, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV11Plus, + ::testing::Values(true, false))); + +// Varies whether GREASE is enabled for TLS13 only +class GreasePresenceAbsenceTestTlsStream13 + : public GreasePresenceAbsenceTestBase, + public ::testing::WithParamInterface<bool> { + public: + GreasePresenceAbsenceTestTlsStream13() + : GreasePresenceAbsenceTestBase( + ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3, GetParam()){}; +}; + +INSTANTIATE_TEST_SUITE_P(GreaseTests, GreasePresenceAbsenceTestTlsStream13, + ::testing::Values(true, false)); + +// These tests check for the presence / absence of GREASE values in the various +// positions that we are permitted to add them. For positions which existed in +// prior versions of TLS, we check that enabling GREASE is only effective when +// negotiating TLS1.3 or higher and that disabling GREASE results in the absence +// of any GREASE values. +// For positions that specific to TLS1.3, we only check that enabling/disabling +// GREASE results in the correct presence/absence of the GREASE value. + +TEST_P(GreasePresenceAbsenceTestAllVersions, ClientGreaseCiphersuites) { + SetupGrease(); + + auto ch1 = MakeTlsFilter<ClientHelloCiphersuiteCapture>(client_); + Connect(); + EXPECT_TRUE(ch1->captured()); + + checkGreasePresence(1, 0, ch1->contents()); +} + +TEST_P(GreasePresenceAbsenceTestAllVersions, ClientGreaseNamedGroups) { + SetupGrease(); + + auto ch1 = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_supported_groups_xtn); + Connect(); + EXPECT_TRUE(ch1->captured()); + + checkGreasePresence(1, 0, ch1->extension()); +} + +TEST_P(GreasePresenceAbsenceTestAllVersions, ClientGreaseKeyShare) { + SetupGrease(); + + auto ch1 = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn); + Connect(); + EXPECT_TRUE((version_ >= SSL_LIBRARY_VERSION_TLS_1_3) == ch1->captured()); + + checkGreasePresence(1, 0, ch1->extension()); +} + +TEST_P(GreasePresenceAbsenceTestAllVersions, ClientGreaseSigAlg) { + SetupGrease(); + + auto ch1 = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn); + Connect(); + EXPECT_TRUE((version_ >= SSL_LIBRARY_VERSION_TLS_1_2) == ch1->captured()); + + checkGreasePresence(1, 0, ch1->extension()); +} + +TEST_P(GreasePresenceAbsenceTestAllVersions, ClientGreaseSupportedVersions) { + SetupGrease(); + + auto ch1 = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_supported_versions_xtn); + Connect(); + EXPECT_TRUE((version_ >= SSL_LIBRARY_VERSION_TLS_1_3) == ch1->captured()); + + // Supported Versions have a 1 byte length field. + TlsParser extParser(ch1->extension()); + DataBuffer versions; + extParser.ReadVariable(&versions, 1); + + checkGreasePresence(1, 0, versions); +} + +TEST_P(GreasePresenceAbsenceTestTlsStream13, ClientGreasePskExchange) { + SetupGrease(); + + auto ch1 = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_psk_key_exchange_modes_xtn); + Connect(); + EXPECT_TRUE(ch1->captured()); + + // PSK Exchange Modes have a 1 byte length field + TlsParser extParser(ch1->extension()); + DataBuffer modes; + extParser.ReadVariable(&modes, 1); + + // Scan for single byte GREASE PSK Values + size_t numGrease = 0; + for (uint8_t greaseVal : kTlsGreasePskValues) { + for (unsigned long i = 0; i < modes.len(); i++) { + if (greaseVal == modes.data()[i]) { + numGrease++; + } + } + } + + EXPECT_EQ(expectGrease() ? size_t(1) : size_t(0), numGrease); +} + +TEST_P(GreasePresenceAbsenceTestAllVersions, ClientGreaseAlpn) { + SetupGrease(); + EnableAlpn(); + + auto ch1 = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_app_layer_protocol_xtn); + Connect(); + EXPECT_TRUE((version_ >= SSL_LIBRARY_VERSION_TLS_1_1) == ch1->captured()); + + // ALPN Xtns have a redundant two-byte length + TlsParser alpnParser(ch1->extension()); + alpnParser.Skip(2); // Skip the length + DataBuffer alpnEntry; + + // Each ALPN entry has a single byte length prefixed. + size_t greaseAlpnEntrys = 0; + while (alpnParser.remaining()) { + alpnParser.ReadVariable(&alpnEntry, 1); + if (alpnEntry.len() == 2) { + greaseAlpnEntrys += countGreaseInBuffer(alpnEntry); + } + } + + EXPECT_EQ(expectGrease() ? size_t(1) : size_t(0), greaseAlpnEntrys); +} + +TEST_P(GreasePresenceAbsenceTestAllVersions, GreaseClientHelloExtension) { + SetupGrease(); + + auto ch1 = + MakeTlsFilter<TlsHandshakeRecorder>(client_, kTlsHandshakeClientHello); + Connect(); + EXPECT_TRUE(ch1->buffer().len() > 0); + + TlsParser extParser(ch1->buffer()); + EXPECT_TRUE(extParser.Skip(2 + 32)); // Version + Random + EXPECT_TRUE(extParser.SkipVariable(1)); // Session ID + if (variant_ == ssl_variant_datagram) { + EXPECT_TRUE(extParser.SkipVariable(1)); // Cookie + } + EXPECT_TRUE(extParser.SkipVariable(2)); // Ciphersuites + EXPECT_TRUE(extParser.SkipVariable(1)); // Compression Methods + EXPECT_TRUE(extParser.Skip(2)); // Extension Lengths + + // Scan for a 1-byte and a 0-byte extension. + uint32_t extType; + DataBuffer extBuf; + bool foundSmall = false; + bool foundLarge = false; + size_t numFound = 0; + while (extParser.remaining()) { + extParser.Read(&extType, 2); + extParser.ReadVariable(&extBuf, 2); + for (uint16_t greaseVal : kTlsGreaseValues) { + if (greaseVal == extType) { + numFound++; + foundSmall |= extBuf.len() == 0; + foundLarge |= extBuf.len() > 0; + } + } + } + + EXPECT_EQ(foundSmall, expectGrease()); + EXPECT_EQ(foundLarge, expectGrease()); + EXPECT_EQ(numFound, expectGrease() ? size_t(2) : size_t(0)); +} + +TEST_P(GreasePresenceAbsenceTestTlsStream13, GreaseCertificateRequestSigAlg) { + SetupGrease(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + + auto cr = + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_signature_algorithms_xtn); + cr->SetHandshakeTypes({kTlsHandshakeCertificateRequest}); + cr->EnableDecryption(); + Connect(); + EXPECT_TRUE(cr->captured()); + + checkGreasePresence(1, 0, cr->extension()); +} + +TEST_P(GreasePresenceAbsenceTestTlsStream13, + GreaseCertificateRequestExtension) { + SetupGrease(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + + auto cr = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeCertificateRequest); + cr->EnableDecryption(); + Connect(); + EXPECT_TRUE(cr->buffer().len() > 0); + + TlsParser extParser(cr->buffer()); + EXPECT_TRUE(extParser.SkipVariable(1)); // Context + EXPECT_TRUE(extParser.Skip(2)); // Extension Lengths + + uint32_t extType; + DataBuffer extBuf; + bool found = false; + // Scan for a single, empty extension + while (extParser.remaining()) { + extParser.Read(&extType, 2); + extParser.ReadVariable(&extBuf, 2); + for (uint16_t greaseVal : kTlsGreaseValues) { + if (greaseVal == extType) { + EXPECT_TRUE(!found); + EXPECT_EQ(extBuf.len(), size_t(0)); + found = true; + } + } + } + + EXPECT_EQ(expectGrease(), found); +} + +TEST_P(GreasePresenceAbsenceTestTlsStream13, GreaseNewSessionTicketExtension) { + SetupGrease(); + + auto nst = MakeTlsFilter<TlsHandshakeRecorder>(server_, + kTlsHandshakeNewSessionTicket); + nst->EnableDecryption(); + Connect(); + EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), nullptr, 0)); + EXPECT_TRUE(nst->buffer().len() > 0); + + TlsParser extParser(nst->buffer()); + EXPECT_TRUE(extParser.Skip(4)); // lifetime + EXPECT_TRUE(extParser.Skip(4)); // age + EXPECT_TRUE(extParser.SkipVariable(1)); // Nonce + EXPECT_TRUE(extParser.SkipVariable(2)); // Ticket + EXPECT_TRUE(extParser.Skip(2)); // Extension Length + + uint32_t extType; + DataBuffer extBuf; + bool found = false; + // Scan for a single, empty extension + while (extParser.remaining()) { + extParser.Read(&extType, 2); + extParser.ReadVariable(&extBuf, 2); + for (uint16_t greaseVal : kTlsGreaseValues) { + if (greaseVal == extType) { + EXPECT_TRUE(!found); + EXPECT_EQ(extBuf.len(), size_t(0)); + found = true; + } + } + } + + EXPECT_EQ(expectGrease(), found); +} + +// Generic Client GREASE test +TEST_P(TlsConnectGeneric, ClientGrease) { + EnsureTlsSetup(); + ASSERT_EQ(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE), + SECSuccess); + Connect(); +} + +// Generic Server GREASE test +TEST_P(TlsConnectGeneric, ServerGrease) { + EnsureTlsSetup(); + ASSERT_EQ(SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE), + SECSuccess); + Connect(); +} + +// Generic GREASE test +TEST_P(TlsConnectGeneric, Grease) { + EnsureTlsSetup(); + ASSERT_EQ(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE), + SECSuccess); + ASSERT_EQ(SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE), + SECSuccess); + Connect(); +} + +// Check that GREASE values can be correctly reconstructed after HRR. +TEST_P(TlsConnectGeneric, GreaseHRR) { + EnsureTlsSetup(); + const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1}; + const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + ASSERT_EQ(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE), + SECSuccess); + ASSERT_EQ(SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE), + SECSuccess); + Connect(); +} + +// Check that GREASE additions interact correctly with psk-only handshake. +TEST_F(TlsConnectStreamTls13, GreasePsk) { + EnsureTlsSetup(); + ASSERT_EQ(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE), + SECSuccess); + ASSERT_EQ(SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE), + SECSuccess); + + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + const uint8_t kPskDummyVal_[16] = {0x01, 0x02, 0x03, 0x04, 0x05, + 0x06, 0x07, 0x08, 0x09, 0x0a, + 0x0b, 0x0c, 0x0d, 0x0e, 0x0f}; + SECItem psk_item; + psk_item.type = siBuffer; + psk_item.len = sizeof(kPskDummyVal_); + psk_item.data = const_cast<uint8_t*>(kPskDummyVal_); + PK11SymKey* key = + PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN, PK11_OriginUnwrap, + CKA_DERIVE, &psk_item, NULL); + + ScopedPK11SymKey scoped_psk_(key); + const std::string kPskDummyLabel_ = "NSS PSK GTEST label"; + const SSLHashType kPskHash_ = ssl_hash_sha384; + AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none); +} + +// Test that ECH and GREASE work together successfully +TEST_F(TlsConnectStreamTls13, GreaseAndECH) { + EnsureTlsSetup(); + SetupEch(client_, server_); + ASSERT_EQ(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE), + SECSuccess); + ASSERT_EQ(SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE), + SECSuccess); + Connect(); +} + +// Test that TLS12 Server handles Client GREASE correctly +TEST_F(TlsConnectTest, GreaseTLS12Server) { + EnsureTlsSetup(); + ASSERT_EQ(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE), + SECSuccess); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); +} + +// Test that TLS12 Client handles Server GREASE correctly +TEST_F(TlsConnectTest, GreaseTLS12Client) { + EnsureTlsSetup(); + ASSERT_EQ(SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE), + SECSuccess); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + Connect(); +} + +class GreaseOnlyTestStreamTls13 : public TlsConnectStreamTls13 { + public: + GreaseOnlyTestStreamTls13() : TlsConnectStreamTls13() {} + + void ConnectWithCustomChExpectFail(const std::string& ch, + uint8_t server_alert, uint32_t server_code, + uint32_t client_code) { + std::vector<uint8_t> ch_vec = hex_string_to_bytes(ch); + DataBuffer ch_buf; + EnsureTlsSetup(); + + TlsAgentTestBase::MakeRecord(variant_, ssl_ct_handshake, + SSL_LIBRARY_VERSION_TLS_1_3, ch_vec.data(), + ch_vec.size(), &ch_buf, 0); + StartConnect(); + client_->SendDirect(ch_buf); + ExpectAlert(server_, server_alert); + server_->Handshake(); + server_->CheckErrorCode(server_code); + client_->ExpectReceiveAlert(server_alert, kTlsAlertFatal); + client_->Handshake(); + client_->CheckErrorCode(client_code); + } +}; + +// Client: Offer only GREASE CipherSuite value +TEST_F(GreaseOnlyTestStreamTls13, GreaseOnlyClientCipherSuite) { + // 0xdada + std::string ch = + "010000b003038afacda2963358e98f464f3ff0680ed3a9d382a8c3eac5e5604f5721add9" + "855c000002dada010000850000000b0009000006736572766572ff01000100000a001400" + "12001d00170018001901000101010201030104003300260024001d0020683668992de470" + "38660ee37bafc7392b05b8a94402ea1f3463ad3cfd7a694a46002b0003020304000d0018" + "001604030503060302030804080508060401050106010201002d00020101001c0002400" + "1"; + + ConnectWithCustomChExpectFail(ch, kTlsAlertHandshakeFailure, + SSL_ERROR_NO_CYPHER_OVERLAP, + SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// Client: Offer only GREASE SupportedGroups value +TEST_F(GreaseOnlyTestStreamTls13, GreaseOnlyClientSupportedGroup) { + // 0x3a3a + std::string ch = + "010000a40303484a4e14f547404da6115d7f73bbb0f1c9d65e66ac073dee6c4a62f72de9" + "a36f000006130113031302010000750000000b0009000006736572766572ff0100010000" + "0a000400023a3a003300260024001d0020e75cb8e217c95176954e8b5fb95843882462ce" + "2cd3fcfe67cf31463a05ea3d57002b0003020304000d0018001604030503060302030804" + "080508060401050106010201002d00020101001c00024001"; + + ConnectWithCustomChExpectFail(ch, kTlsAlertHandshakeFailure, + SSL_ERROR_NO_CYPHER_OVERLAP, + SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// Client: Offer only GREASE SigAlgs value +TEST_F(GreaseOnlyTestStreamTls13, GreaseOnlyClientSignatureAlgorithm) { + // 0x8a8a + std::string ch = + "010000a00303dfd8e2438a8d1b9f48d921dfc08959108807bd1105238bb3da2a2a8e3db0" + "6990000006130113031302010000710000000b0009000006736572766572ff0100010000" + "0a00140012001d00170018001901000101010201030104003300260024001d002074bb2c" + "94996d3ffc7ae5792f0c3c58676358a85ea304cd029fa3d6551013b333002b0003020304" + "000d000400028a8a002d00020101001c00024001"; + + ConnectWithCustomChExpectFail(ch, kTlsAlertHandshakeFailure, + SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM, + SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// Client: Offer only GREASE SupportedVersions value +TEST_F(GreaseOnlyTestStreamTls13, GreaseOnlyClientSupportedVersion) { + // 0xeaea + std::string ch = + "010000b203037e3618abae0dd0b3f06a504c47354551d1d5be36e9c3e1eac9c139c246b1" + "66da000006130113031302010000830000000b0009000006736572766572ff0100010000" + "0a00140012001d00170018001901000101010201030104003300260024001d00206b1816" + "577ff2e69d4d2661419150eaefa0328ffd396425cf1733ec06536b4e55002b000100000d" + "0018001604030503060302030804080508060401050106010201002d00020101001c0002" + "4001"; + + ConnectWithCustomChExpectFail(ch, kTlsAlertIllegalParameter, + SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, + SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +class GreaseTestStreamTls12 + : public TlsConnectStreamTls12, + public ::testing::WithParamInterface<uint16_t /* GREASE */> { + public: + GreaseTestStreamTls12() : TlsConnectStreamTls12(), grease_(GetParam()){}; + + void ConnectExpectSigAlgFail() { + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectReceiveAlert(kTlsAlertIllegalParameter); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + } + + protected: + uint16_t grease_; +}; + +class TlsCertificateRequestSigAlgSetterFilter : public TlsHandshakeFilter { + public: + TlsCertificateRequestSigAlgSetterFilter(const std::shared_ptr<TlsAgent>& a, + uint16_t sigAlg) + : TlsHandshakeFilter(a, {kTlsHandshakeCertificateRequest}), + sigAlg_(sigAlg) {} + virtual PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) { + TlsParser parser(input); + DataBuffer cert_types; + if (!parser.ReadVariable(&cert_types, 1)) { + ADD_FAILURE(); + return KEEP; + } + + if (!parser.SkipVariable(2)) { + ADD_FAILURE(); + return KEEP; + } + + DataBuffer cas; + if (!parser.ReadVariable(&cas, 2)) { + ADD_FAILURE(); + return KEEP; + } + + size_t idx = 0; + + // Write certificate types. + idx = output->Write(idx, cert_types.len(), 1); + idx = output->Write(idx, cert_types); + + // Write signature algorithm. + idx = output->Write(idx, sizeof(sigAlg_), 2); + idx = output->Write(idx, sigAlg_, 2); + + // Write certificate authorities. + idx = output->Write(idx, cas.len(), 2); + idx = output->Write(idx, cas); + + return CHANGE; + } + + private: + uint16_t sigAlg_; +}; + +// Server: Offer only GREASE CertificateRequest SigAlg value +TEST_P(GreaseTestStreamTls12, GreaseOnlyServerTLS12CertificateRequestSigAlg) { + EnsureTlsSetup(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + MakeTlsFilter<TlsCertificateRequestSigAlgSetterFilter>(server_, grease_); + + client_->ExpectSendAlert(kTlsAlertHandshakeFailure); + server_->ExpectReceiveAlert(kTlsAlertHandshakeFailure); + ConnectExpectFail(); + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT); + client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); +} + +// Illegally GREASE ServerKeyExchange ECC SignatureAlgorithm +TEST_P(GreaseTestStreamTls12, GreasedTLS12ServerKexEccSigAlg) { + MakeTlsFilter<ECCServerKEXSigAlgReplacer>(server_, grease_); + EnableSomeEcdhCiphers(); + + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectReceiveAlert(kTlsAlertIllegalParameter); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Illegally GREASE ServerKeyExchange DHE SignatureAlgorithm +TEST_P(GreaseTestStreamTls12, GreasedTLS12ServerKexDheSigAlg) { + MakeTlsFilter<DHEServerKEXSigAlgReplacer>(server_, grease_); + EnableOnlyDheCiphers(); + + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectReceiveAlert(kTlsAlertIllegalParameter); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Illegally GREASE ServerKeyExchange ECDHE NamedCurve +TEST_P(GreaseTestStreamTls12, GreasedTLS12ServerKexEcdheNamedCurve) { + MakeTlsFilter<ECCServerKEXNamedCurveReplacer>(server_, grease_); + EnableSomeEcdhCiphers(); + + client_->ExpectSendAlert(kTlsAlertHandshakeFailure); + server_->ExpectReceiveAlert(kTlsAlertHandshakeFailure); + ConnectExpectFail(); + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT); + client_->CheckErrorCode(SEC_ERROR_UNSUPPORTED_ELLIPTIC_CURVE); +} + +// Illegally GREASE TLS12 Client CertificateVerify SignatureAlgorithm +TEST_P(GreaseTestStreamTls12, GreasedTLS12ClientCertificateVerifySigAlg) { + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + MakeTlsFilter<TlsReplaceSignatureSchemeFilter>(client_, grease_); + + server_->ExpectSendAlert(kTlsAlertIllegalParameter); + client_->ExpectReceiveAlert(kTlsAlertIllegalParameter); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); +} + +class GreaseTestStreamTls13 + : public TlsConnectStreamTls13, + public ::testing::WithParamInterface<uint16_t /* GREASE */> { + public: + GreaseTestStreamTls13() : grease_(GetParam()){}; + + protected: + uint16_t grease_; +}; + +// Illegally GREASE TLS13 Client CertificateVerify SignatureAlgorithm +TEST_P(GreaseTestStreamTls13, GreasedTLS13ClientCertificateVerifySigAlg) { + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + auto filter = + MakeTlsFilter<TlsReplaceSignatureSchemeFilter>(client_, grease_); + filter->EnableDecryption(); + + server_->ExpectSendAlert(kTlsAlertIllegalParameter); + client_->ExpectReceiveAlert(kTlsAlertIllegalParameter); + + // Manually trigger handshake to avoid race conditions + StartConnect(); + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CERT_VERIFY); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Illegally GREASE TLS13 Server CertificateVerify SignatureAlgorithm +TEST_P(GreaseTestStreamTls13, GreasedTLS13ServerCertificateVerifySigAlg) { + EnsureTlsSetup(); + auto filter = + MakeTlsFilter<TlsReplaceSignatureSchemeFilter>(server_, grease_); + filter->EnableDecryption(); + + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectReceiveAlert(kTlsAlertIllegalParameter); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CERT_VERIFY); +} + +// Illegally GREASE HelloRetryRequest version value +TEST_P(GreaseTestStreamTls13, GreasedHelloRetryRequestVersion) { + EnsureTlsSetup(); + // Trigger HelloRetryRequest + MakeTlsFilter<TlsExtensionDropper>(client_, ssl_tls13_key_share_xtn); + auto filter = MakeTlsFilter<TlsMessageVersionSetter>( + server_, kTlsHandshakeHelloRetryRequest, grease_); + filter->EnableDecryption(); + + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectReceiveAlert(kTlsAlertIllegalParameter); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +class GreaseTestStreamTls123 + : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<uint16_t /* version */, uint16_t /* GREASE */>> { + public: + GreaseTestStreamTls123() + : TlsConnectTestBase(ssl_variant_stream, std::get<0>(GetParam())), + grease_(std::get<1>(GetParam())){}; + + void ConnectExpectIllegalGreaseFail() { + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + // Server expects handshake but receives encrypted alert. + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + } else { + server_->ExpectReceiveAlert(kTlsAlertIllegalParameter); + } + ConnectExpectFail(); + } + + protected: + uint16_t grease_; +}; + +// Illegally GREASE TLS12 and TLS13 ServerHello version value +TEST_P(GreaseTestStreamTls123, GreasedServerHelloVersion) { + EnsureTlsSetup(); + auto filter = MakeTlsFilter<TlsMessageVersionSetter>( + server_, kTlsHandshakeServerHello, grease_); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + filter->EnableDecryption(); + } + ConnectExpectIllegalGreaseFail(); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); +} + +// Illegally GREASE TLS12 and TLS13 selected CipherSuite value +TEST_P(GreaseTestStreamTls123, GreasedServerHelloCipherSuite) { + EnsureTlsSetup(); + auto filter = MakeTlsFilter<SelectedCipherSuiteReplacer>(server_, grease_); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + filter->EnableDecryption(); + } + ConnectExpectIllegalGreaseFail(); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +class GreaseExtensionTestStreamTls13 + : public TlsConnectStreamTls13, + public ::testing::WithParamInterface< + std::tuple<uint8_t /* message */, uint16_t /* GREASE */>> { + public: + GreaseExtensionTestStreamTls13() + : TlsConnectStreamTls13(), + message_(std::get<0>(GetParam())), + grease_(std::get<1>(GetParam())){}; + + protected: + uint8_t message_; + uint16_t grease_; +}; + +// Illegally GREASE TLS13 Server EncryptedExtensions and Certificate Extensions +// NSS currently allows offering unkown extensions in HelloRetryRequests! +TEST_P(GreaseExtensionTestStreamTls13, GreasedServerExtensions) { + EnsureTlsSetup(); + DataBuffer empty = DataBuffer(1); + auto filter = + MakeTlsFilter<TlsExtensionAppender>(server_, message_, grease_, empty); + filter->EnableDecryption(); + + server_->ExpectReceiveAlert(kTlsAlertUnsupportedExtension); + client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_EXTENSION); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_EXTENSION_ALERT); +} + +// Illegally GREASE TLS12 and TLS13 ServerHello Extensions +TEST_P(GreaseTestStreamTls123, GreasedServerHelloExtensions) { + EnsureTlsSetup(); + DataBuffer empty = DataBuffer(1); + auto filter = MakeTlsFilter<TlsExtensionAppender>( + server_, kTlsHandshakeServerHello, grease_, empty); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + filter->EnableDecryption(); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + } else { + server_->ExpectReceiveAlert(kTlsAlertUnsupportedExtension); + } + client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_EXTENSION); +} + +// Illegally GREASE TLS13 Client Certificate Extensions +// Server ignores injected client extensions and fails on CertificateVerify +TEST_P(GreaseTestStreamTls13, GreasedClientCertificateExtensions) { + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + DataBuffer empty = DataBuffer(1); + auto filter = MakeTlsFilter<TlsExtensionAppender>( + client_, kTlsHandshakeCertificate, grease_, empty); + filter->EnableDecryption(); + + server_->ExpectSendAlert(kTlsAlertDecryptError); + client_->ExpectReceiveAlert(kTlsAlertDecryptError); + + // Manually trigger handshake to avoid race conditions + StartConnect(); + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + + server_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); +} + +TEST_F(TlsConnectStreamTls13, GreaseClientHelloExtensionPermutation) { + EnsureTlsSetup(); + PR_ASSERT(SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_CH_EXTENSION_PERMUTATION, + PR_TRUE) == SECSuccess); + PR_ASSERT(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE) == + SECSuccess); + Connect(); +} + +INSTANTIATE_TEST_SUITE_P(GreaseTestTls12, GreaseTestStreamTls12, + ::testing::ValuesIn(kTlsGreaseValues)); + +INSTANTIATE_TEST_SUITE_P(GreaseTestTls13, GreaseTestStreamTls13, + ::testing::ValuesIn(kTlsGreaseValues)); + +INSTANTIATE_TEST_SUITE_P( + GreaseTestTls123, GreaseTestStreamTls123, + ::testing::Combine(TlsConnectTestBase::kTlsV12Plus, + ::testing::ValuesIn(kTlsGreaseValues))); + +INSTANTIATE_TEST_SUITE_P( + GreaseExtensionTest, GreaseExtensionTestStreamTls13, + testing::Combine(testing::ValuesIn(kTlsGreaseExtensionMessages), + testing::ValuesIn(kTlsGreaseValues))); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc new file mode 100644 index 0000000000..3e1e30bb86 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc @@ -0,0 +1,433 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <memory> +#include "nss.h" +#include "pk11pub.h" +#include "secerr.h" +#include "sslproto.h" +#include "sslexp.h" +#include "tls13hkdf.h" + +#include "databuffer.h" +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" + +namespace nss_test { + +const uint8_t kKey1Data[] = { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, + 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f}; +const DataBuffer kKey1(kKey1Data, sizeof(kKey1Data)); + +// The same as key1 but with the first byte +// 0x01. +const uint8_t kKey2Data[] = { + 0x01, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, + 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f}; +const DataBuffer kKey2(kKey2Data, sizeof(kKey2Data)); + +const char kLabelMasterSecret[] = "master secret"; + +const uint8_t kSessionHash[] = { + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, + 0xfc, 0xfd, 0xfe, 0xff, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, + 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, 0xd0, 0xd1, 0xd2, 0xd3, + 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf, + 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, + 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, + 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0xe0, 0xe1, 0xe2, 0xe3, + 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, + 0xfc, 0xfd, 0xfe, 0xff, +}; + +const size_t kHashLength[] = { + 0, /* ssl_hash_none */ + 16, /* ssl_hash_md5 */ + 20, /* ssl_hash_sha1 */ + 28, /* ssl_hash_sha224 */ + 32, /* ssl_hash_sha256 */ + 48, /* ssl_hash_sha384 */ + 64, /* ssl_hash_sha512 */ +}; + +size_t GetHashLength(SSLHashType hash) { + size_t i = static_cast<size_t>(hash); + if (i < PR_ARRAY_SIZE(kHashLength)) { + return kHashLength[i]; + } + ADD_FAILURE() << "Unknown hash: " << hash; + return 0; +} + +PRUint16 GetSomeCipherSuiteForHash(SSLHashType hash) { + switch (hash) { + case ssl_hash_sha256: + return TLS_AES_128_GCM_SHA256; + case ssl_hash_sha384: + return TLS_AES_256_GCM_SHA384; + default: + ADD_FAILURE() << "Unknown hash: " << hash; + } + return 0; +} + +const std::string kHashName[] = {"None", "MD5", "SHA-1", "SHA-224", + "SHA-256", "SHA-384", "SHA-512"}; + +static void ImportKey(ScopedPK11SymKey* to, const DataBuffer& key, + SSLHashType hash_type, PK11SlotInfo* slot) { + ASSERT_LT(hash_type, sizeof(kHashLength)); + ASSERT_LE(kHashLength[hash_type], key.len()); + SECItem key_item = {siBuffer, const_cast<uint8_t*>(key.data()), + static_cast<unsigned int>(GetHashLength(hash_type))}; + + PK11SymKey* inner = + PK11_ImportSymKey(slot, CKM_SSL3_MASTER_KEY_DERIVE, PK11_OriginUnwrap, + CKA_DERIVE, &key_item, NULL); + ASSERT_NE(nullptr, inner); + to->reset(inner); +} + +static void DumpData(const std::string& label, const uint8_t* buf, size_t len) { + DataBuffer d(buf, len); + + std::cerr << label << ": " << d << std::endl; +} + +void DumpKey(const std::string& label, ScopedPK11SymKey& key) { + SECStatus rv = PK11_ExtractKeyValue(key.get()); + ASSERT_EQ(SECSuccess, rv); + + SECItem* key_data = PK11_GetKeyData(key.get()); + ASSERT_NE(nullptr, key_data); + + DumpData(label, key_data->data, key_data->len); +} + +extern "C" { +extern char ssl_trace; +extern FILE* ssl_trace_iob; +} + +class TlsHkdfTest : public ::testing::Test, + public ::testing::WithParamInterface<SSLHashType> { + public: + TlsHkdfTest() + : k1_(), k2_(), hash_type_(GetParam()), slot_(PK11_GetInternalSlot()) { + EXPECT_NE(nullptr, slot_); + char* ev = getenv("SSLTRACE"); + if (ev && ev[0]) { + ssl_trace = atoi(ev); + ssl_trace_iob = stderr; + } + } + + void SetUp() { + ImportKey(&k1_, kKey1, hash_type_, slot_.get()); + ImportKey(&k2_, kKey2, hash_type_, slot_.get()); + } + + void VerifyKey(const ScopedPK11SymKey& key, CK_MECHANISM_TYPE expected_mech, + const DataBuffer& expected_value) { + EXPECT_EQ(expected_mech, PK11_GetMechanism(key.get())); + + SECStatus rv = PK11_ExtractKeyValue(key.get()); + ASSERT_EQ(SECSuccess, rv); + + SECItem* key_data = PK11_GetKeyData(key.get()); + ASSERT_NE(nullptr, key_data); + + EXPECT_EQ(expected_value.len(), key_data->len); + EXPECT_EQ( + 0, memcmp(expected_value.data(), key_data->data, expected_value.len())); + } + + void HkdfExtract(const ScopedPK11SymKey& ikmk1, const ScopedPK11SymKey& ikmk2, + SSLHashType base_hash, const DataBuffer& expected) { + std::cerr << "Hash = " << kHashName[base_hash] << std::endl; + + PK11SymKey* prk = nullptr; + SECStatus rv = tls13_HkdfExtract(ikmk1.get(), ikmk2.get(), base_hash, &prk); + ASSERT_EQ(SECSuccess, rv); + ScopedPK11SymKey prkk(prk); + + DumpKey("Output", prkk); + VerifyKey(prkk, CKM_HKDF_DERIVE, expected); + + // Now test the public wrapper. + PRUint16 cs = GetSomeCipherSuiteForHash(base_hash); + rv = SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_3, cs, ikmk1.get(), + ikmk2.get(), &prk); + ASSERT_EQ(SECSuccess, rv); + ASSERT_NE(nullptr, prk); + VerifyKey(ScopedPK11SymKey(prk), CKM_HKDF_DERIVE, expected); + } + + void HkdfExpandLabel(ScopedPK11SymKey* prk, SSLHashType base_hash, + const uint8_t* session_hash, size_t session_hash_len, + const char* label, size_t label_len, + const DataBuffer& expected) { + ASSERT_NE(nullptr, prk); + std::cerr << "Hash = " << kHashName[base_hash] << std::endl; + + std::vector<uint8_t> output(expected.len()); + + SECStatus rv = tls13_HkdfExpandLabelRaw( + prk->get(), base_hash, session_hash, session_hash_len, label, label_len, + ssl_variant_stream, &output[0], output.size()); + ASSERT_EQ(SECSuccess, rv); + DumpData("Output", &output[0], output.size()); + EXPECT_EQ(0, memcmp(expected.data(), &output[0], expected.len())); + + // Verify that the public API produces the same result. + PRUint16 cs = GetSomeCipherSuiteForHash(base_hash); + PK11SymKey* secret; + rv = SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3, cs, prk->get(), + session_hash, session_hash_len, label, label_len, + &secret); + EXPECT_EQ(SECSuccess, rv); + ASSERT_NE(nullptr, secret); + VerifyKey(ScopedPK11SymKey(secret), CKM_HKDF_DERIVE, expected); + + // Verify that a key can be created with a different key type and size. + rv = SSL_HkdfExpandLabelWithMech( + SSL_LIBRARY_VERSION_TLS_1_3, cs, prk->get(), session_hash, + session_hash_len, label, label_len, CKM_DES3_CBC_PAD, 24, &secret); + EXPECT_EQ(SECSuccess, rv); + ASSERT_NE(nullptr, secret); + ScopedPK11SymKey with_mech(secret); + EXPECT_EQ(static_cast<CK_MECHANISM_TYPE>(CKM_DES3_CBC_PAD), + PK11_GetMechanism(with_mech.get())); + // Just verify that the key is the right size. + rv = PK11_ExtractKeyValue(with_mech.get()); + ASSERT_EQ(SECSuccess, rv); + SECItem* key_data = PK11_GetKeyData(with_mech.get()); + ASSERT_NE(nullptr, key_data); + EXPECT_EQ(24U, key_data->len); + } + + protected: + ScopedPK11SymKey k1_; + ScopedPK11SymKey k2_; + SSLHashType hash_type_; + + private: + ScopedPK11SlotInfo slot_; +}; + +TEST_P(TlsHkdfTest, HkdfNullNull) { + const uint8_t tv[][48] = { + {/* ssl_hash_none */}, + {/* ssl_hash_md5 */}, + {/* ssl_hash_sha1 */}, + {/* ssl_hash_sha224 */}, + {0x33, 0xad, 0x0a, 0x1c, 0x60, 0x7e, 0xc0, 0x3b, 0x09, 0xe6, 0xcd, + 0x98, 0x93, 0x68, 0x0c, 0xe2, 0x10, 0xad, 0xf3, 0x00, 0xaa, 0x1f, + 0x26, 0x60, 0xe1, 0xb2, 0x2e, 0x10, 0xf1, 0x70, 0xf9, 0x2a}, + {0x7e, 0xe8, 0x20, 0x6f, 0x55, 0x70, 0x02, 0x3e, 0x6d, 0xc7, 0x51, 0x9e, + 0xb1, 0x07, 0x3b, 0xc4, 0xe7, 0x91, 0xad, 0x37, 0xb5, 0xc3, 0x82, 0xaa, + 0x10, 0xba, 0x18, 0xe2, 0x35, 0x7e, 0x71, 0x69, 0x71, 0xf9, 0x36, 0x2f, + 0x2c, 0x2f, 0xe2, 0xa7, 0x6b, 0xfd, 0x78, 0xdf, 0xec, 0x4e, 0xa9, 0xb5}}; + + const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_)); + HkdfExtract(nullptr, nullptr, hash_type_, expected_data); +} + +TEST_P(TlsHkdfTest, HkdfKey1Only) { + const uint8_t tv[][48] = { + {/* ssl_hash_none */}, + {/* ssl_hash_md5 */}, + {/* ssl_hash_sha1 */}, + {/* ssl_hash_sha224 */}, + {0x41, 0x6c, 0x53, 0x92, 0xb9, 0xf3, 0x6d, 0xf1, 0x88, 0xe9, 0x0e, + 0xb1, 0x4d, 0x17, 0xbf, 0x0d, 0xa1, 0x90, 0xbf, 0xdb, 0x7f, 0x1f, + 0x49, 0x56, 0xe6, 0xe5, 0x66, 0xa5, 0x69, 0xc8, 0xb1, 0x5c}, + {0x51, 0xb1, 0xd5, 0xb4, 0x59, 0x79, 0x79, 0x08, 0x4a, 0x15, 0xb2, 0xdb, + 0x84, 0xd3, 0xd6, 0xbc, 0xfc, 0x93, 0x45, 0xd9, 0xdc, 0x74, 0xda, 0x1a, + 0x57, 0xc2, 0x76, 0x9f, 0x3f, 0x83, 0x45, 0x2f, 0xf6, 0xf3, 0x56, 0x1f, + 0x58, 0x63, 0xdb, 0x88, 0xda, 0x40, 0xce, 0x63, 0x7d, 0x24, 0x37, 0xf3}}; + + const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_)); + HkdfExtract(k1_, nullptr, hash_type_, expected_data); +} + +TEST_P(TlsHkdfTest, HkdfKey2Only) { + const uint8_t tv[][48] = { + {/* ssl_hash_none */}, + {/* ssl_hash_md5 */}, + {/* ssl_hash_sha1 */}, + {/* ssl_hash_sha224 */}, + {0x16, 0xaf, 0x00, 0x54, 0x3a, 0x56, 0xc8, 0x26, 0xa2, 0xa7, 0xfc, + 0xb6, 0x34, 0x66, 0x8a, 0xfd, 0x36, 0xdc, 0x8e, 0xce, 0xc4, 0xd2, + 0x6c, 0x7a, 0xdc, 0xe3, 0x70, 0x36, 0x3d, 0x60, 0xfa, 0x0b}, + {0x7b, 0x40, 0xf9, 0xef, 0x91, 0xff, 0xc9, 0xd1, 0x29, 0x24, 0x5c, 0xbf, + 0xf8, 0x82, 0x76, 0x68, 0xae, 0x4b, 0x63, 0xe8, 0x03, 0xdd, 0x39, 0xa8, + 0xd4, 0x6a, 0xf6, 0xe5, 0xec, 0xea, 0xf8, 0x7d, 0x91, 0x71, 0x81, 0xf1, + 0xdb, 0x3b, 0xaf, 0xbf, 0xde, 0x71, 0x61, 0x15, 0xeb, 0xb5, 0x5f, 0x68}}; + + const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_)); + HkdfExtract(nullptr, k2_, hash_type_, expected_data); +} + +TEST_P(TlsHkdfTest, HkdfKey1Key2) { + const uint8_t tv[][48] = { + {/* ssl_hash_none */}, + {/* ssl_hash_md5 */}, + {/* ssl_hash_sha1 */}, + {/* ssl_hash_sha224 */}, + {0xa5, 0x68, 0x02, 0x5a, 0x95, 0xc9, 0x7f, 0x55, 0x38, 0xbc, 0xf7, + 0x97, 0xcc, 0x0f, 0xd5, 0xf6, 0xa8, 0x8d, 0x15, 0xbc, 0x0e, 0x85, + 0x74, 0x70, 0x3c, 0xa3, 0x65, 0xbd, 0x76, 0xcf, 0x9f, 0xd3}, + {0x01, 0x93, 0xc0, 0x07, 0x3f, 0x6a, 0x83, 0x0e, 0x2e, 0x4f, 0xb2, 0x58, + 0xe4, 0x00, 0x08, 0x5c, 0x68, 0x9c, 0x37, 0x32, 0x00, 0x37, 0xff, 0xc3, + 0x1c, 0x5b, 0x98, 0x0b, 0x02, 0x92, 0x3f, 0xfd, 0x73, 0x5a, 0x6f, 0x2a, + 0x95, 0xa3, 0xee, 0xf6, 0xd6, 0x8e, 0x6f, 0x86, 0xea, 0x63, 0xf8, 0x33}}; + + const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_)); + HkdfExtract(k1_, k2_, hash_type_, expected_data); +} + +TEST_P(TlsHkdfTest, HkdfExpandLabel) { + const uint8_t tv[][48] = { + {/* ssl_hash_none */}, + {/* ssl_hash_md5 */}, + {/* ssl_hash_sha1 */}, + {/* ssl_hash_sha224 */}, + {0x3e, 0x4e, 0x6e, 0xd0, 0xbc, 0xc4, 0xf4, 0xff, 0xf0, 0xf5, 0x69, + 0xd0, 0x6c, 0x1e, 0x0e, 0x10, 0x32, 0xaa, 0xd7, 0xa3, 0xef, 0xf6, + 0xa8, 0x65, 0x8e, 0xbe, 0xee, 0xc7, 0x1f, 0x01, 0x6d, 0x3c}, + {0x41, 0xea, 0x77, 0x09, 0x8c, 0x90, 0x04, 0x10, 0xec, 0xbc, 0x37, 0xd8, + 0x5b, 0x54, 0xcd, 0x7b, 0x08, 0x15, 0x13, 0x20, 0xed, 0x1e, 0x3f, 0x54, + 0x74, 0xf7, 0x8b, 0x06, 0x38, 0x28, 0x06, 0x37, 0x75, 0x23, 0xa2, 0xb7, + 0x34, 0xb1, 0x72, 0x2e, 0x59, 0x6d, 0x5a, 0x31, 0xf5, 0x53, 0xab, 0x99}}; + + const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_)); + HkdfExpandLabel(&k1_, hash_type_, kSessionHash, GetHashLength(hash_type_), + kLabelMasterSecret, strlen(kLabelMasterSecret), + expected_data); +} + +TEST_P(TlsHkdfTest, HkdfExpandLabelNoHash) { + const uint8_t tv[][48] = { + {/* ssl_hash_none */}, + {/* ssl_hash_md5 */}, + {/* ssl_hash_sha1 */}, + {/* ssl_hash_sha224 */}, + {0xb7, 0x08, 0x00, 0xe3, 0x8e, 0x48, 0x68, 0x91, 0xb1, 0x0f, 0x5e, + 0x6f, 0x22, 0x53, 0x6b, 0x84, 0x69, 0x75, 0xaa, 0xa3, 0x2a, 0xe7, + 0xde, 0xaa, 0xc3, 0xd1, 0xb4, 0x05, 0x22, 0x5c, 0x68, 0xf5}, + {0x13, 0xd3, 0x36, 0x9f, 0x3c, 0x78, 0xa0, 0x32, 0x40, 0xee, 0x16, 0xe9, + 0x11, 0x12, 0x66, 0xc7, 0x51, 0xad, 0xd8, 0x3c, 0xa1, 0xa3, 0x97, 0x74, + 0xd7, 0x45, 0xff, 0xa7, 0x88, 0x9e, 0x52, 0x17, 0x2e, 0xaa, 0x3a, 0xd2, + 0x35, 0xd8, 0xd5, 0x35, 0xfd, 0x65, 0x70, 0x9f, 0xa9, 0xf9, 0xfa, 0x23}}; + + const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_)); + HkdfExpandLabel(&k1_, hash_type_, nullptr, 0, kLabelMasterSecret, + strlen(kLabelMasterSecret), expected_data); +} + +TEST_P(TlsHkdfTest, BadExtractWrapperInput) { + PK11SymKey* key = nullptr; + + // Bad version. + EXPECT_EQ(SECFailure, + SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256, + k1_.get(), k2_.get(), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Bad ciphersuite. + EXPECT_EQ(SECFailure, + SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_NULL_SHA, + k1_.get(), k2_.get(), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Old ciphersuite. + EXPECT_EQ(SECFailure, SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_RSA_WITH_AES_128_CBC_SHA, k1_.get(), + k2_.get(), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // NULL outparam.. + EXPECT_EQ(SECFailure, SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_RSA_WITH_AES_128_CBC_SHA, k1_.get(), + k2_.get(), nullptr)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + EXPECT_EQ(nullptr, key); +} + +TEST_P(TlsHkdfTest, BadExpandLabelWrapperInput) { + PK11SymKey* key = nullptr; + static const char* kLabel = "label"; + + // Bad version. + EXPECT_EQ( + SECFailure, + SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256, + k1_.get(), nullptr, 0, kLabel, strlen(kLabel), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Bad ciphersuite. + EXPECT_EQ( + SECFailure, + SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_NULL_MD5, + k1_.get(), nullptr, 0, kLabel, strlen(kLabel), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Old ciphersuite. + EXPECT_EQ(SECFailure, + SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_RSA_WITH_AES_128_CBC_SHA, k1_.get(), + nullptr, 0, kLabel, strlen(kLabel), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Null PRK. + EXPECT_EQ(SECFailure, SSL_HkdfExpandLabel( + SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256, + nullptr, nullptr, 0, kLabel, strlen(kLabel), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Null, non-zero-length handshake hash. + EXPECT_EQ( + SECFailure, + SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256, + k1_.get(), nullptr, 2, kLabel, strlen(kLabel), &key)); + + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + // Null, non-zero-length label. + EXPECT_EQ(SECFailure, + SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_AES_128_GCM_SHA256, k1_.get(), nullptr, 0, + nullptr, strlen(kLabel), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Null, empty label. + EXPECT_EQ(SECFailure, SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_AES_128_GCM_SHA256, k1_.get(), + nullptr, 0, nullptr, 0, &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Null key pointer.. + EXPECT_EQ(SECFailure, + SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_AES_128_GCM_SHA256, k1_.get(), nullptr, 0, + kLabel, strlen(kLabel), nullptr)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + EXPECT_EQ(nullptr, key); +} + +static const SSLHashType kHashTypes[] = {ssl_hash_sha256, ssl_hash_sha384}; +INSTANTIATE_TEST_SUITE_P(AllHashFuncs, TlsHkdfTest, + ::testing::ValuesIn(kHashTypes)); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_protect.cc b/security/nss/gtests/ssl_gtest/tls_protect.cc new file mode 100644 index 0000000000..6187660a5c --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_protect.cc @@ -0,0 +1,148 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "tls_protect.h" +#include "sslproto.h" +#include "tls_filter.h" + +namespace nss_test { + +static uint64_t FirstSeqno(bool dtls, uint16_t epoc) { + if (dtls) { + return static_cast<uint64_t>(epoc) << 48; + } + return 0; +} + +TlsCipherSpec::TlsCipherSpec(bool dtls, uint16_t epoc) + : dtls_(dtls), + epoch_(epoc), + in_seqno_(FirstSeqno(dtls, epoc)), + out_seqno_(FirstSeqno(dtls, epoc)) {} + +bool TlsCipherSpec::SetKeys(SSLCipherSuiteInfo* cipherinfo, + PK11SymKey* secret) { + SSLAeadContext* aead_ctx; + SSLProtocolVariant variant = + dtls_ ? ssl_variant_datagram : ssl_variant_stream; + SECStatus rv = + SSL_MakeVariantAead(SSL_LIBRARY_VERSION_TLS_1_3, cipherinfo->cipherSuite, + variant, secret, "", 0, // Use the default labels. + &aead_ctx); + if (rv != SECSuccess) { + return false; + } + aead_.reset(aead_ctx); + + SSLMaskingContext* mask_ctx; + const char kHkdfPurposeSn[] = "sn"; + rv = SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, cipherinfo->cipherSuite, variant, secret, + kHkdfPurposeSn, strlen(kHkdfPurposeSn), &mask_ctx); + if (rv != SECSuccess) { + return false; + } + mask_.reset(mask_ctx); + return true; +} + +bool TlsCipherSpec::Unprotect(const TlsRecordHeader& header, + const DataBuffer& ciphertext, + DataBuffer* plaintext, + TlsRecordHeader* out_header) { + if (!aead_ || !out_header) { + return false; + } + *out_header = header; + + // Make space. + plaintext->Allocate(ciphertext.len()); + + unsigned int len; + uint64_t seqno = dtls_ ? header.sequence_number() : in_seqno_; + SECStatus rv; + + if (header.is_dtls13_ciphertext()) { + if (!mask_ || !out_header) { + return false; + } + PORT_Assert(ciphertext.len() >= 16); + DataBuffer mask(2); + rv = SSL_CreateMask(mask_.get(), ciphertext.data(), ciphertext.len(), + mask.data(), mask.len()); + if (rv != SECSuccess) { + return false; + } + + if (!out_header->MaskSequenceNumber(mask)) { + return false; + } + seqno = out_header->sequence_number(); + } + + auto header_bytes = out_header->header(); + rv = SSL_AeadDecrypt(aead_.get(), seqno, header_bytes.data(), + header_bytes.len(), ciphertext.data(), ciphertext.len(), + plaintext->data(), &len, plaintext->len()); + if (rv != SECSuccess) { + return false; + } + + RecordUnprotected(seqno); + plaintext->Truncate(static_cast<size_t>(len)); + + return true; +} + +bool TlsCipherSpec::Protect(const TlsRecordHeader& header, + const DataBuffer& plaintext, DataBuffer* ciphertext, + TlsRecordHeader* out_header) { + if (!aead_ || !out_header) { + return false; + } + + *out_header = header; + + // Make a padded buffer. + ciphertext->Allocate(plaintext.len() + + 32); // Room for any plausible auth tag + unsigned int len; + + DataBuffer header_bytes; + (void)header.WriteHeader(&header_bytes, 0, plaintext.len() + 16); + uint64_t seqno = dtls_ ? header.sequence_number() : out_seqno_; + + SECStatus rv = + SSL_AeadEncrypt(aead_.get(), seqno, header_bytes.data(), + header_bytes.len(), plaintext.data(), plaintext.len(), + ciphertext->data(), &len, ciphertext->len()); + if (rv != SECSuccess) { + return false; + } + + if (header.is_dtls13_ciphertext()) { + if (!mask_ || !out_header) { + return false; + } + PORT_Assert(ciphertext->len() >= 16); + DataBuffer mask(2); + rv = SSL_CreateMask(mask_.get(), ciphertext->data(), ciphertext->len(), + mask.data(), mask.len()); + if (rv != SECSuccess) { + return false; + } + if (!out_header->MaskSequenceNumber(mask)) { + return false; + } + } + + RecordProtected(); + ciphertext->Truncate(len); + + return true; +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_protect.h b/security/nss/gtests/ssl_gtest/tls_protect.h new file mode 100644 index 0000000000..d7ea2aa128 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_protect.h @@ -0,0 +1,60 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef tls_protection_h_ +#define tls_protection_h_ + +#include <cstdint> +#include <memory> + +#include "pk11pub.h" +#include "sslt.h" +#include "sslexp.h" + +#include "databuffer.h" +#include "scoped_ptrs_ssl.h" + +namespace nss_test { +class TlsRecordHeader; + +// Our analog of ssl3CipherSpec +class TlsCipherSpec { + public: + TlsCipherSpec(bool dtls, uint16_t epoc); + bool SetKeys(SSLCipherSuiteInfo* cipherinfo, PK11SymKey* secret); + + bool Protect(const TlsRecordHeader& header, const DataBuffer& plaintext, + DataBuffer* ciphertext, TlsRecordHeader* out_header); + bool Unprotect(const TlsRecordHeader& header, const DataBuffer& ciphertext, + DataBuffer* plaintext, TlsRecordHeader* out_header); + + uint16_t epoch() const { return epoch_; } + uint64_t next_in_seqno() const { return in_seqno_; } + void RecordUnprotected(uint64_t seqno) { + // Reordering happens, so don't let this go backwards. + in_seqno_ = (std::max)(in_seqno_, seqno + 1); + } + uint64_t next_out_seqno() { return out_seqno_; } + void RecordProtected() { out_seqno_++; } + + void RecordDropped() { record_dropped_ = true; } + bool record_dropped() const { return record_dropped_; } + + bool is_protected() const { return aead_ != nullptr; } + + private: + bool dtls_; + uint16_t epoch_; + uint64_t in_seqno_; + uint64_t out_seqno_; + bool record_dropped_ = false; + ScopedSSLAeadContext aead_; + ScopedSSLMaskingContext mask_; +}; + +} // namespace nss_test + +#endif diff --git a/security/nss/gtests/ssl_gtest/tls_psk_unittest.cc b/security/nss/gtests/ssl_gtest/tls_psk_unittest.cc new file mode 100644 index 0000000000..678a9ff585 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_psk_unittest.cc @@ -0,0 +1,515 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "tls_connect.h" + +namespace nss_test { + +class Tls13PskTest : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + Tls13PskTest() + : TlsConnectTestBase(std::get<0>(GetParam()), + SSL_LIBRARY_VERSION_TLS_1_3), + suite_(std::get<1>(GetParam())) {} + + void SetUp() override { + TlsConnectTestBase::SetUp(); + scoped_psk_.reset(GetPsk()); + ASSERT_TRUE(!!scoped_psk_); + } + + private: + PK11SymKey* GetPsk() { + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + if (!slot) { + ADD_FAILURE(); + return nullptr; + } + + SECItem psk_item; + psk_item.type = siBuffer; + psk_item.len = sizeof(kPskDummyVal_); + psk_item.data = const_cast<uint8_t*>(kPskDummyVal_); + + PK11SymKey* key = + PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN, PK11_OriginUnwrap, + CKA_DERIVE, &psk_item, NULL); + if (!key) { + ADD_FAILURE(); + } + return key; + } + + protected: + ScopedPK11SymKey scoped_psk_; + const uint16_t suite_; + const uint8_t kPskDummyVal_[16] = {0x01, 0x02, 0x03, 0x04, 0x05, + 0x06, 0x07, 0x08, 0x09, 0x0a, + 0x0b, 0x0c, 0x0d, 0x0e, 0x0f}; + const std::string kPskDummyLabel_ = "NSS PSK GTEST label"; + const SSLHashType kPskHash_ = ssl_hash_sha384; +}; + +// TLS 1.3 PSK connection test. +TEST_P(Tls13PskTest, NormalExternal) { + AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none); + client_->RemovePsk(kPskDummyLabel_); + server_->RemovePsk(kPskDummyLabel_); + + // Removing it again should fail. + EXPECT_EQ(SECFailure, SSL_RemoveExternalPsk(client_->ssl_fd(), + reinterpret_cast<const uint8_t*>( + kPskDummyLabel_.data()), + kPskDummyLabel_.length())); + EXPECT_EQ(SECFailure, SSL_RemoveExternalPsk(server_->ssl_fd(), + reinterpret_cast<const uint8_t*>( + kPskDummyLabel_.data()), + kPskDummyLabel_.length())); +} + +TEST_P(Tls13PskTest, KeyTooLarge) { + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + ASSERT_TRUE(!!slot); + ScopedPK11SymKey scoped_psk(PK11_KeyGen( + slot.get(), CKM_GENERIC_SECRET_KEY_GEN, nullptr, 128, nullptr)); + AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none); +} + +// Attempt to use a PSK with the wrong PRF hash. +// "Clients MUST verify that...the server selected a cipher suite +// indicating a Hash associated with the PSK" +TEST_P(Tls13PskTest, ClientVerifyHashType) { + AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + MakeTlsFilter<SelectedCipherSuiteReplacer>(server_, + TLS_CHACHA20_POLY1305_SHA256); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code()); + } else { + ConnectExpectFailOneSide(TlsAgent::CLIENT); + } + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code()); +} + +// Different EPSKs (by label) on each endpoint. Expect cert auth. +TEST_P(Tls13PskTest, LabelMismatch) { + client_->AddPsk(scoped_psk_, std::string("foo"), kPskHash_); + server_->AddPsk(scoped_psk_, std::string("bar"), kPskHash_); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); +} + +SSLHelloRetryRequestAction RetryFirstHello( + PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen, + PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast<size_t*>(arg); + ++*called; + EXPECT_EQ(0U, clientTokenLen); + EXPECT_EQ(*called, firstHello ? 1U : 2U); + return firstHello ? ssl_hello_retry_request : ssl_hello_retry_accept; +} + +// Test resumption PSK with HRR. +TEST_P(Tls13PskTest, ResPskRetryStateless) { + ConfigureSelfEncrypt(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); + StartConnect(); + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryFirstHello, &cb_called)); + ExpectResumption(RESUME_TICKET); + Handshake(); + CheckConnected(); + EXPECT_EQ(2U, cb_called); + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); + SendReceive(); +} + +// Test external PSK with HRR. +TEST_P(Tls13PskTest, ExtPskRetryStateless) { + AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), RetryFirstHello, &cb_called)); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + EXPECT_EQ(1U, cb_called); + auto replacement = std::make_shared<TlsAgent>( + server_->name(), TlsAgent::SERVER, server_->variant()); + server_ = replacement; + server_->SetVersionRange(version_, version_); + client_->SetPeer(server_); + server_->SetPeer(client_); + server_->AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + server_->ExpectPsk(); + server_->StartConnect(); + Handshake(); + CheckConnected(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none); +} + +// Server not configured with PSK and sends a certificate instead of +// a selected_identity. Client should attempt certificate authentication. +TEST_P(Tls13PskTest, ClientOnly) { + client_->AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); +} + +// Set a PSK, remove psk_key_exchange_modes. +TEST_P(Tls13PskTest, DropKexModes) { + AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + StartConnect(); + MakeTlsFilter<TlsExtensionDropper>(client_, + ssl_tls13_psk_key_exchange_modes_xtn); + ConnectExpectAlert(server_, kTlsAlertMissingExtension); + client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES); +} + +// "Clients MUST verify that...a server "key_share" extension is present +// if required by the ClientHello "psk_key_exchange_modes" extension." +// As we don't support PSK without DH, it is always required. +TEST_P(Tls13PskTest, DropRequiredKeyShare) { + AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + StartConnect(); + MakeTlsFilter<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn); + client_->ExpectSendAlert(kTlsAlertMissingExtension); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + } else { + ConnectExpectFailOneSide(TlsAgent::CLIENT); + } + client_->CheckErrorCode(SSL_ERROR_MISSING_KEY_SHARE); +} + +// "Clients MUST verify that...the server's selected_identity is +// within the range supplied by the client". We send one OfferedPsk. +TEST_P(Tls13PskTest, InvalidSelectedIdentity) { + static const uint8_t selected_identity[] = {0x00, 0x01}; + DataBuffer buf(selected_identity, sizeof(selected_identity)); + AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + StartConnect(); + MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_pre_shared_key_xtn, + buf); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + } else { + ConnectExpectFailOneSide(TlsAgent::CLIENT); + } + client_->CheckErrorCode(SSL_ERROR_MALFORMED_PRE_SHARED_KEY); +} + +// Resume-eligible reconnect with an EPSK configured. +// Expect the EPSK to be used. +TEST_P(Tls13PskTest, PreferEpsk) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); + AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + ExpectResumption(RESUME_NONE); + StartConnect(); + Handshake(); + CheckConnected(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none); +} + +// Enable resumption, but connect (initially) with an EPSK. +// Expect no session ticket. +TEST_P(Tls13PskTest, SuppressNewSessionTicket) { + AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + auto nst_capture = + MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_new_session_ticket); + nst_capture->EnableDecryption(); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none); + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), nullptr, 0)); + EXPECT_EQ(0U, nst_capture->buffer().len()); + if (variant_ == ssl_variant_stream) { + EXPECT_EQ(SSL_ERROR_FEATURE_DISABLED, PORT_GetError()); + } else { + EXPECT_EQ(SSL_ERROR_FEATURE_NOT_SUPPORTED_FOR_VERSION, PORT_GetError()); + } + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none); +} + +TEST_P(Tls13PskTest, BadConfigValues) { + EXPECT_TRUE(client_->EnsureTlsSetup()); + std::vector<uint8_t> label{'L', 'A', 'B', 'E', 'L'}; + EXPECT_EQ(SECFailure, + SSL_AddExternalPsk(client_->ssl_fd(), nullptr, label.data(), + label.size(), kPskHash_)); + EXPECT_EQ(SECFailure, SSL_AddExternalPsk(client_->ssl_fd(), scoped_psk_.get(), + nullptr, label.size(), kPskHash_)); + + EXPECT_EQ(SECFailure, SSL_AddExternalPsk(client_->ssl_fd(), scoped_psk_.get(), + label.data(), 0, kPskHash_)); + EXPECT_EQ(SECSuccess, + SSL_AddExternalPsk(client_->ssl_fd(), scoped_psk_.get(), + label.data(), label.size(), ssl_hash_sha256)); + + EXPECT_EQ(SECFailure, + SSL_RemoveExternalPsk(client_->ssl_fd(), nullptr, label.size())); + + EXPECT_EQ(SECFailure, + SSL_RemoveExternalPsk(client_->ssl_fd(), label.data(), 0)); + + EXPECT_EQ(SECSuccess, SSL_RemoveExternalPsk(client_->ssl_fd(), label.data(), + label.size())); +} + +// If the server has an EPSK configured with a ciphersuite not supported +// by the client, it should use certificate authentication. +TEST_P(Tls13PskTest, FallbackUnsupportedCiphersuite) { + client_->AddPsk(scoped_psk_, kPskDummyLabel_, ssl_hash_sha256, + TLS_AES_128_GCM_SHA256); + server_->AddPsk(scoped_psk_, kPskDummyLabel_, ssl_hash_sha256, + TLS_CHACHA20_POLY1305_SHA256); + + client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); +} + +// That fallback should not occur if there is no cipher overlap. +TEST_P(Tls13PskTest, ExplicitSuiteNoOverlap) { + client_->AddPsk(scoped_psk_, kPskDummyLabel_, ssl_hash_sha256, + TLS_AES_128_GCM_SHA256); + server_->AddPsk(scoped_psk_, kPskDummyLabel_, ssl_hash_sha256, + TLS_CHACHA20_POLY1305_SHA256); + + client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + server_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +TEST_P(Tls13PskTest, SuppressHandshakeCertReq) { + AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + server_->SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE); + server_->SetOption(SSL_REQUIRE_CERTIFICATE, PR_TRUE); + const std::set<uint8_t> hs_types = {ssl_hs_certificate, + ssl_hs_certificate_request}; + auto cr_cert_capture = MakeTlsFilter<TlsHandshakeRecorder>(server_, hs_types); + cr_cert_capture->EnableDecryption(); + + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none); + EXPECT_EQ(0U, cr_cert_capture->buffer().len()); +} + +TEST_P(Tls13PskTest, DisallowClientConfigWithoutServerCert) { + AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_); + server_->SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE); + server_->SetOption(SSL_REQUIRE_CERTIFICATE, PR_TRUE); + const std::set<uint8_t> hs_types = {ssl_hs_certificate, + ssl_hs_certificate_request}; + auto cr_cert_capture = MakeTlsFilter<TlsHandshakeRecorder>(server_, hs_types); + cr_cert_capture->EnableDecryption(); + + EXPECT_EQ(SECSuccess, SSLInt_RemoveServerCertificates(server_->ssl_fd())); + + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + EXPECT_EQ(0U, cr_cert_capture->buffer().len()); +} + +TEST_F(TlsConnectStreamTls13, ClientRejectHandshakeCertReq) { + // Stream only, as the filter doesn't support DTLS 1.3 yet. + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + ASSERT_TRUE(!!slot); + ScopedPK11SymKey scoped_psk(PK11_KeyGen( + slot.get(), CKM_GENERIC_SECRET_KEY_GEN, nullptr, 32, nullptr)); + AddPsk(scoped_psk, std::string("foo"), ssl_hash_sha256); + // Inject a CR after EE. This would be legal if not for ssl_auth_psk. + auto filter = MakeTlsFilter<TlsEncryptedHandshakeMessageReplacer>( + server_, kTlsHandshakeFinished, kTlsHandshakeCertificateRequest); + filter->EnableDecryption(); + + ExpectAlert(client_, kTlsAlertUnexpectedMessage); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CERT_REQUEST); + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +TEST_F(TlsConnectStreamTls13, RejectPha) { + // Stream only, as the filter doesn't support DTLS 1.3 yet. + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + ASSERT_TRUE(!!slot); + ScopedPK11SymKey scoped_psk(PK11_KeyGen( + slot.get(), CKM_GENERIC_SECRET_KEY_GEN, nullptr, 32, nullptr)); + AddPsk(scoped_psk, std::string("foo"), ssl_hash_sha256); + server_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE); + auto kuToCr = MakeTlsFilter<TlsEncryptedHandshakeMessageReplacer>( + server_, kTlsHandshakeKeyUpdate, kTlsHandshakeCertificateRequest); + kuToCr->EnableDecryption(); + Connect(); + + // Make sure the direct path is blocked. + EXPECT_EQ(SECFailure, SSL_SendCertificateRequest(server_->ssl_fd())); + EXPECT_EQ(SSL_ERROR_FEATURE_DISABLED, PORT_GetError()); + + // Inject a PHA CR. Since this is not allowed, send KeyUpdate + // and change the message type. + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + ExpectAlert(client_, kTlsAlertUnexpectedMessage); + client_->Handshake(); // Eat the CR. + server_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CERT_REQUEST); + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +class Tls13PskTestWithCiphers : public Tls13PskTest {}; + +TEST_P(Tls13PskTestWithCiphers, 0RttCiphers) { + RolloverAntiReplay(); + AddPsk(scoped_psk_, kPskDummyLabel_, tls13_GetHashForCipherSuite(suite_), + suite_); + StartConnect(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none); +} + +TEST_P(Tls13PskTestWithCiphers, 0RttMaxEarlyData) { + EnsureTlsSetup(); + RolloverAntiReplay(); + const char* big_message = "0123456789abcdef"; + const size_t short_size = strlen(big_message) - 1; + const PRInt32 short_length = static_cast<PRInt32>(short_size); + + // Set up the PSK + EXPECT_EQ(SECSuccess, + SSL_AddExternalPsk0Rtt( + client_->ssl_fd(), scoped_psk_.get(), + reinterpret_cast<const uint8_t*>(kPskDummyLabel_.data()), + kPskDummyLabel_.length(), tls13_GetHashForCipherSuite(suite_), + suite_, short_length)); + EXPECT_EQ(SECSuccess, + SSL_AddExternalPsk0Rtt( + server_->ssl_fd(), scoped_psk_.get(), + reinterpret_cast<const uint8_t*>(kPskDummyLabel_.data()), + kPskDummyLabel_.length(), tls13_GetHashForCipherSuite(suite_), + suite_, short_length)); + client_->ExpectPsk(); + server_->ExpectPsk(); + client_->expected_cipher_suite(suite_); + server_->expected_cipher_suite(suite_); + StartConnect(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + client_->Handshake(); + CheckEarlyDataLimit(client_, short_size); + + PRInt32 sent; + // Writing more than the limit will succeed in TLS, but fail in DTLS. + if (variant_ == ssl_variant_stream) { + sent = PR_Write(client_->ssl_fd(), big_message, + static_cast<PRInt32>(strlen(big_message))); + } else { + sent = PR_Write(client_->ssl_fd(), big_message, + static_cast<PRInt32>(strlen(big_message))); + EXPECT_GE(0, sent); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Try an exact-sized write now. + sent = PR_Write(client_->ssl_fd(), big_message, short_length); + } + EXPECT_EQ(short_length, sent); + + // Even a single octet write should now fail. + sent = PR_Write(client_->ssl_fd(), big_message, 1); + EXPECT_GE(0, sent); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Process the ClientHello and read 0-RTT. + server_->Handshake(); + CheckEarlyDataLimit(server_, short_size); + + std::vector<uint8_t> buf(short_size + 1); + PRInt32 read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity()); + EXPECT_EQ(short_length, read); + EXPECT_EQ(0, memcmp(big_message, buf.data(), short_size)); + + // Second read fails. + read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity()); + EXPECT_EQ(SECFailure, read); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); +} + +static const uint16_t k0RttCipherDefs[] = {TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384}; + +static const uint16_t kDefaultSuite[] = {TLS_CHACHA20_POLY1305_SHA256}; + +INSTANTIATE_TEST_SUITE_P( + Tls13PskTest, Tls13PskTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + ::testing::ValuesIn(kDefaultSuite))); + +INSTANTIATE_TEST_SUITE_P( + Tls13PskTestWithCiphers, Tls13PskTestWithCiphers, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + ::testing::ValuesIn(k0RttCipherDefs))); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_subcerts_unittest.cc b/security/nss/gtests/ssl_gtest/tls_subcerts_unittest.cc new file mode 100644 index 0000000000..5e01dee518 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_subcerts_unittest.cc @@ -0,0 +1,723 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <ctime> + +#include "prtime.h" +#include "secerr.h" +#include "ssl.h" +#include "nss.h" +#include "blapit.h" + +#include "gtest_utils.h" +#include "tls_agent.h" +#include "tls_connect.h" + +namespace nss_test { + +const std::string kEcdsaDelegatorId = TlsAgent::kDelegatorEcdsa256; +const std::string kRsaeDelegatorId = TlsAgent::kDelegatorRsae2048; +const std::string kPssDelegatorId = TlsAgent::kDelegatorRsaPss2048; +const std::string kDCId = TlsAgent::kServerEcdsa256; +const SSLSignatureScheme kDCScheme = ssl_sig_ecdsa_secp256r1_sha256; +const PRUint32 kDCValidFor = 60 * 60 * 24 * 7 /* 1 week (seconds) */; + +static void CheckPreliminaryPeerDelegCred( + const std::shared_ptr<TlsAgent>& client, bool expected, + PRUint32 key_bits = 0, SSLSignatureScheme sig_scheme = ssl_sig_none) { + EXPECT_NE(0U, (client->pre_info().valuesSet & ssl_preinfo_peer_auth)); + EXPECT_EQ(expected, client->pre_info().peerDelegCred); + if (expected) { + EXPECT_EQ(key_bits, client->pre_info().authKeyBits); + EXPECT_EQ(sig_scheme, client->pre_info().signatureScheme); + } +} + +static void CheckPeerDelegCred(const std::shared_ptr<TlsAgent>& client, + bool expected, PRUint32 key_bits = 0) { + EXPECT_EQ(expected, client->info().peerDelegCred); + EXPECT_EQ(expected, client->pre_info().peerDelegCred); + if (expected) { + EXPECT_EQ(key_bits, client->info().authKeyBits); + EXPECT_EQ(key_bits, client->pre_info().authKeyBits); + EXPECT_EQ(client->info().signatureScheme, + client->pre_info().signatureScheme); + } +} + +// AuthCertificate callbacks to simulate DC validation +static SECStatus CheckPreliminaryDC(TlsAgent* agent, bool checksig, + bool isServer) { + agent->UpdatePreliminaryChannelInfo(); + EXPECT_EQ(PR_TRUE, agent->pre_info().peerDelegCred); + EXPECT_EQ(256U, agent->pre_info().authKeyBits); + EXPECT_EQ(ssl_sig_ecdsa_secp256r1_sha256, agent->pre_info().signatureScheme); + return SECSuccess; +} + +static SECStatus CheckPreliminaryNoDC(TlsAgent* agent, bool checksig, + bool isServer) { + agent->UpdatePreliminaryChannelInfo(); + EXPECT_EQ(PR_FALSE, agent->pre_info().peerDelegCred); + return SECSuccess; +} + +// AuthCertificate callbacks for modifying DC attributes. +// This allows testing tls13_CertificateVerify for rejection +// of DC attributes that have changed since AuthCertificateHook +// may have handled them. +static SECStatus ModifyDCAuthKeyBits(TlsAgent* agent, bool checksig, + bool isServer) { + return SSLInt_TweakChannelInfoForDC(agent->ssl_fd(), + PR_TRUE, // Change authKeyBits + PR_FALSE); // Change scheme +} + +static SECStatus ModifyDCScheme(TlsAgent* agent, bool checksig, bool isServer) { + return SSLInt_TweakChannelInfoForDC(agent->ssl_fd(), + PR_FALSE, // Change authKeyBits + PR_TRUE); // Change scheme +} + +// Attempt to configure a DC when either the DC or DC private key is missing. +TEST_P(TlsConnectTls13, DCNotConfigured) { + // Load and delegate the credential. + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(kDCId, &pub, &priv)); + + StackSECItem dc; + TlsAgent::DelegateCredential(kEcdsaDelegatorId, pub, kDCScheme, kDCValidFor, + now(), &dc); + + // Attempt to install the certificate and DC with a missing DC private key. + EnsureTlsSetup(); + SSLExtraServerCertData extra_data_missing_dc_priv_key = { + ssl_auth_null, nullptr, nullptr, nullptr, &dc, nullptr}; + EXPECT_FALSE(server_->ConfigServerCert(kEcdsaDelegatorId, true, + &extra_data_missing_dc_priv_key)); + + // Attempt to install the certificate and with only the DC private key. + EnsureTlsSetup(); + SSLExtraServerCertData extra_data_missing_dc = { + ssl_auth_null, nullptr, nullptr, nullptr, nullptr, priv.get()}; + EXPECT_FALSE(server_->ConfigServerCert(kEcdsaDelegatorId, true, + &extra_data_missing_dc)); +} + +// Connected with ECDSA-P256. +TEST_P(TlsConnectTls13, DCConnectEcdsaP256) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa256, + ssl_sig_ecdsa_secp256r1_sha256, kDCValidFor, + now()); + + auto cfilter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, true, 256); + EXPECT_EQ(ssl_sig_ecdsa_secp256r1_sha256, client_->info().signatureScheme); +} + +// Connected with ECDSA-P384. +TEST_P(TlsConnectTls13, DCConnectEcdsaP483) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa384, + ssl_sig_ecdsa_secp384r1_sha384, kDCValidFor, + now()); + + auto cfilter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, true, 384); + EXPECT_EQ(ssl_sig_ecdsa_secp384r1_sha384, client_->info().signatureScheme); +} + +// Connected with ECDSA-P521. +TEST_P(TlsConnectTls13, DCConnectEcdsaP521) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa521, + ssl_sig_ecdsa_secp521r1_sha512, kDCValidFor, + now()); + client_->EnableDelegatedCredentials(); + + auto cfilter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, true, 521); + EXPECT_EQ(ssl_sig_ecdsa_secp521r1_sha512, client_->info().signatureScheme); +} + +// Connected with RSA-PSS, using a PSS SPKI and ECDSA delegation cert. +TEST_P(TlsConnectTls13, DCConnectRsaPssEcdsa) { + Reset(kEcdsaDelegatorId); + + // Need to enable PSS-PSS, which is not on by default. + static const SSLSignatureScheme kSchemes[] = {ssl_sig_ecdsa_secp256r1_sha256, + ssl_sig_rsa_pss_pss_sha256}; + client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential( + TlsAgent::kServerRsaPss, ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now()); + + auto cfilter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, true, 1024); + EXPECT_EQ(ssl_sig_rsa_pss_pss_sha256, client_->info().signatureScheme); +} + +// Connected with RSA-PSS, using a PSS SPKI and PSS delegation cert. +TEST_P(TlsConnectTls13, DCConnectRsaPssRsaPss) { + Reset(kPssDelegatorId); + + // Need to enable PSS-PSS, which is not on by default. + static const SSLSignatureScheme kSchemes[] = {ssl_sig_ecdsa_secp256r1_sha256, + ssl_sig_rsa_pss_pss_sha256}; + client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential( + TlsAgent::kServerRsaPss, ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now()); + + auto cfilter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, true, 1024); + EXPECT_EQ(ssl_sig_rsa_pss_pss_sha256, client_->info().signatureScheme); +} + +// Connected with ECDSA-P256 using a PSS delegation cert. +TEST_P(TlsConnectTls13, DCConnectEcdsaP256RsaPss) { + Reset(kPssDelegatorId); + + // Need to enable PSS-PSS, which is not on by default. + static const SSLSignatureScheme kSchemes[] = {ssl_sig_ecdsa_secp256r1_sha256, + ssl_sig_rsa_pss_pss_sha256}; + client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa256, + ssl_sig_ecdsa_secp256r1_sha256, kDCValidFor, + now()); + + auto cfilter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, true, 256); + EXPECT_EQ(ssl_sig_ecdsa_secp256r1_sha256, client_->info().signatureScheme); +} + +// Simulate the client receiving a DC containing algorithms not advertised. +// Do this by tweaking the client's supported sigSchemes after the CH. +TEST_P(TlsConnectTls13, DCReceiveUnadvertisedScheme) { + Reset(kEcdsaDelegatorId); + static const SSLSignatureScheme kClientSchemes[] = { + ssl_sig_ecdsa_secp256r1_sha256, ssl_sig_ecdsa_secp384r1_sha384}; + static const SSLSignatureScheme kServerSchemes[] = { + ssl_sig_ecdsa_secp384r1_sha384, ssl_sig_ecdsa_secp256r1_sha256}; + static const SSLSignatureScheme kEcdsaP256Only[] = { + ssl_sig_ecdsa_secp256r1_sha256}; + client_->SetSignatureSchemes(kClientSchemes, PR_ARRAY_SIZE(kClientSchemes)); + server_->SetSignatureSchemes(kServerSchemes, PR_ARRAY_SIZE(kServerSchemes)); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa384, + ssl_sig_ecdsa_secp384r1_sha384, kDCValidFor, + now()); + StartConnect(); + client_->Handshake(); // CH with P256/P384. + server_->Handshake(); // Respond with P384 DC. + // Tell the client it only advertised P256. + SECStatus rv = SSLInt_SetDCAdvertisedSigSchemes( + client_->ssl_fd(), kEcdsaP256Only, PR_ARRAY_SIZE(kEcdsaP256Only)); + EXPECT_EQ(SECSuccess, rv); + ExpectAlert(client_, kTlsAlertIllegalParameter); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Server schemes includes only RSAE schemes. Connection should succeed +// without delegation. +TEST_P(TlsConnectTls13, DCConnectServerRsaeOnly) { + Reset(kRsaeDelegatorId); + static const SSLSignatureScheme kClientSchemes[] = { + ssl_sig_rsa_pss_rsae_sha256, ssl_sig_rsa_pss_pss_sha256}; + static const SSLSignatureScheme kServerSchemes[] = { + ssl_sig_rsa_pss_rsae_sha256}; + client_->SetSignatureSchemes(kClientSchemes, PR_ARRAY_SIZE(kClientSchemes)); + server_->SetSignatureSchemes(kServerSchemes, PR_ARRAY_SIZE(kServerSchemes)); + client_->EnableDelegatedCredentials(); + Connect(); + + CheckPeerDelegCred(client_, false); +} + +// Connect with an RSA-PSS DC SPKI, and an RSAE Delegator SPKI. +TEST_P(TlsConnectTls13, DCConnectRsaeDelegator) { + Reset(kRsaeDelegatorId); + + static const SSLSignatureScheme kSchemes[] = {ssl_sig_rsa_pss_rsae_sha256, + ssl_sig_rsa_pss_pss_sha256}; + client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential( + TlsAgent::kServerRsaPss, ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now()); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); +} + +// Client schemes includes only RSAE schemes. Connection should succeed +// without delegation, and no DC extension should be present in the CH. +TEST_P(TlsConnectTls13, DCConnectClientRsaeOnly) { + Reset(kRsaeDelegatorId); + static const SSLSignatureScheme kClientSchemes[] = { + ssl_sig_rsa_pss_rsae_sha256}; + static const SSLSignatureScheme kServerSchemes[] = { + ssl_sig_rsa_pss_rsae_sha256, ssl_sig_rsa_pss_pss_sha256}; + client_->SetSignatureSchemes(kClientSchemes, PR_ARRAY_SIZE(kClientSchemes)); + server_->SetSignatureSchemes(kServerSchemes, PR_ARRAY_SIZE(kServerSchemes)); + client_->EnableDelegatedCredentials(); + auto cfilter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_delegated_credentials_xtn); + Connect(); + EXPECT_FALSE(cfilter->captured()); + CheckPeerDelegCred(client_, false); +} + +// Test fallback. DC extension will not advertise RSAE schemes. +// The server will attempt to set one, but decline to after seeing +// the client-advertised schemes does not include it. Expect non- +// delegated success. +TEST_P(TlsConnectTls13, DCConnectRsaeDcSpki) { + Reset(kRsaeDelegatorId); + + static const SSLSignatureScheme kSchemes[] = {ssl_sig_rsa_pss_rsae_sha256, + ssl_sig_rsa_pss_pss_sha256}; + client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + client_->EnableDelegatedCredentials(); + + EnsureTlsSetup(); + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + EXPECT_TRUE( + TlsAgent::LoadKeyPairFromCert(TlsAgent::kDelegatorRsae2048, &pub, &priv)); + + StackSECItem dc; + server_->DelegateCredential(server_->name(), pub, ssl_sig_rsa_pss_rsae_sha256, + kDCValidFor, now(), &dc); + + SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr, + nullptr, &dc, priv.get()}; + EXPECT_TRUE(server_->ConfigServerCert(server_->name(), true, &extra_data)); + auto sfilter = MakeTlsFilter<TlsExtensionCapture>( + server_, ssl_delegated_credentials_xtn); + Connect(); + EXPECT_FALSE(sfilter->captured()); + CheckPeerDelegCred(client_, false); +} + +// Generate a weak key. We can't do this in the fixture because certutil +// won't sign with such a tiny key. That's OK, because this is fast(ish). +static void GenerateWeakRsaKey(ScopedSECKEYPrivateKey& priv, + ScopedSECKEYPublicKey& pub) { + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + ASSERT_TRUE(slot); + PK11RSAGenParams rsaparams; +// The absolute minimum size of RSA key that we can use with SHA-256 is +// 256bit (hash) + 256bit (salt) + 8 (start byte) + 8 (end byte) = 528. +#define RSA_WEAK_KEY 528 +#if RSA_MIN_MODULUS_BITS < RSA_WEAK_KEY + rsaparams.keySizeInBits = 528; +#else + rsaparams.keySizeInBits = RSA_MIN_MODULUS_BITS + 1; +#endif + rsaparams.pe = 65537; + + SECKEYPublicKey* p_pub = nullptr; + priv.reset(PK11_GenerateKeyPair(slot.get(), CKM_RSA_PKCS_KEY_PAIR_GEN, + &rsaparams, &p_pub, false, false, nullptr)); + pub.reset(p_pub); + PR_ASSERT(priv); + return; +} + +// Fail to connect with a weak RSA key. +TEST_P(TlsConnectTls13, DCWeakKey) { + Reset(kPssDelegatorId); + EnsureTlsSetup(); + static const SSLSignatureScheme kSchemes[] = {ssl_sig_rsa_pss_rsae_sha256, + ssl_sig_rsa_pss_pss_sha256}; + client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); +#if RSA_MIN_MODULUS_BITS > RSA_WEAK_KEY + // save the MIN POLICY length. + PRInt32 minRsa; + + ASSERT_EQ(SECSuccess, NSS_OptionGet(NSS_RSA_MIN_KEY_SIZE, &minRsa)); +#if RSA_MIN_MODULUS_BITS >= 2048 + ASSERT_EQ(SECSuccess, + NSS_OptionSet(NSS_RSA_MIN_KEY_SIZE, RSA_MIN_MODULUS_BITS + 1024)); +#else + ASSERT_EQ(SECSuccess, NSS_OptionSet(NSS_RSA_MIN_KEY_SIZE, 2048)); +#endif +#endif + + ScopedSECKEYPrivateKey dc_priv; + ScopedSECKEYPublicKey dc_pub; + GenerateWeakRsaKey(dc_priv, dc_pub); + ASSERT_TRUE(dc_priv); + + // Construct a DC. + StackSECItem dc; + TlsAgent::DelegateCredential(kPssDelegatorId, dc_pub, + ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now(), + &dc); + + // Configure the DC on the server. + SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr, + nullptr, &dc, dc_priv.get()}; + EXPECT_TRUE(server_->ConfigServerCert(kPssDelegatorId, true, &extra_data)); + + client_->EnableDelegatedCredentials(); + + auto cfilter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_delegated_credentials_xtn); + ConnectExpectAlert(client_, kTlsAlertInsufficientSecurity); +#if RSA_MIN_MODULUS_BITS > RSA_WEAK_KEY + ASSERT_EQ(SECSuccess, NSS_OptionSet(NSS_RSA_MIN_KEY_SIZE, minRsa)); +#endif +} + +class ReplaceDCSigScheme : public TlsHandshakeFilter { + public: + ReplaceDCSigScheme(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, {ssl_hs_certificate_verify}) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + *output = input; + output->Write(0, ssl_sig_ecdsa_secp384r1_sha384, 2); + return CHANGE; + } +}; + +// Aborted because of incorrect DC signature algorithm indication. +TEST_P(TlsConnectTls13, DCAbortBadExpectedCertVerifyAlg) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa256, + ssl_sig_ecdsa_secp256r1_sha256, kDCValidFor, + now()); + auto filter = MakeTlsFilter<ReplaceDCSigScheme>(server_); + filter->EnableDecryption(); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_DC_CERT_VERIFY_ALG_MISMATCH); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Aborted because of invalid DC signature. +TEST_P(TlsConnectTls13, DCAbortBadSignature) { + Reset(kEcdsaDelegatorId); + EnsureTlsSetup(); + client_->EnableDelegatedCredentials(); + + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(kDCId, &pub, &priv)); + + StackSECItem dc; + TlsAgent::DelegateCredential(kEcdsaDelegatorId, pub, kDCScheme, kDCValidFor, + now(), &dc); + ASSERT_TRUE(dc.data != nullptr); + + // Flip the last bit of the DC so that the signature is invalid. + dc.data[dc.len - 1] ^= 0x01; + + SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr, + nullptr, &dc, priv.get()}; + EXPECT_TRUE(server_->ConfigServerCert(kEcdsaDelegatorId, true, &extra_data)); + + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_DC_BAD_SIGNATURE); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Aborted because of expired DC. +TEST_P(TlsConnectTls13, DCAbortExpired) { + Reset(kEcdsaDelegatorId); + server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now()); + client_->EnableDelegatedCredentials(); + // When the client checks the time, it will be at least one second after the + // DC expired. + AdvanceTime((static_cast<PRTime>(kDCValidFor) + 1) * PR_USEC_PER_SEC); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_DC_EXPIRED); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Aborted due to remaining TTL > max validity period. +TEST_P(TlsConnectTls13, DCAbortExcessiveTTL) { + Reset(kEcdsaDelegatorId); + server_->AddDelegatedCredential(kDCId, kDCScheme, + kDCValidFor + 1 /* seconds */, now()); + client_->EnableDelegatedCredentials(); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_DC_INAPPROPRIATE_VALIDITY_PERIOD); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Aborted because of invalid key usage. +TEST_P(TlsConnectTls13, DCAbortBadKeyUsage) { + // The sever does not have the delegationUsage extension. + Reset(TlsAgent::kServerEcdsa256); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now()); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); +} + +// Connected without DC because of no client indication. +TEST_P(TlsConnectTls13, DCConnectNoClientSupport) { + Reset(kEcdsaDelegatorId); + server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now()); + + auto cfilter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_FALSE(cfilter->captured()); + CheckPeerDelegCred(client_, false); +} + +// Connected without DC because of no server DC. +TEST_P(TlsConnectTls13, DCConnectNoServerSupport) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + + auto cfilter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, false); +} + +// Connected without DC because client doesn't support TLS 1.3. +TEST_P(TlsConnectTls13, DCConnectClientNoTls13) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now()); + + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + + auto cfilter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_delegated_credentials_xtn); + Connect(); + + // Should fallback to TLS 1.2 and not negotiate a DC. + EXPECT_FALSE(cfilter->captured()); + CheckPeerDelegCred(client_, false); +} + +// Connected without DC because server doesn't support TLS 1.3. +TEST_P(TlsConnectTls13, DCConnectServerNoTls13) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now()); + + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + + auto cfilter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_delegated_credentials_xtn); + Connect(); + + // Should fallback to TLS 1.2 and not negotiate a DC. The client will still + // send the indication because it supports 1.3. + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, false); +} + +// Connected without DC because client doesn't support the signature scheme. +TEST_P(TlsConnectTls13, DCConnectExpectedCertVerifyAlgNotSupported) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + static const SSLSignatureScheme kClientSchemes[] = { + ssl_sig_ecdsa_secp256r1_sha256, + }; + client_->SetSignatureSchemes(kClientSchemes, PR_ARRAY_SIZE(kClientSchemes)); + + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa521, + ssl_sig_ecdsa_secp521r1_sha512, kDCValidFor, + now()); + + auto cfilter = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_delegated_credentials_xtn); + Connect(); + + // Client sends indication, but the server doesn't send a DC. + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, false); +} + +// Check that preliminary channel info properly reflects the DC. +TEST_P(TlsConnectTls13, DCCheckPreliminaryInfo) { + Reset(kEcdsaDelegatorId); + EnsureTlsSetup(); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa256, + ssl_sig_ecdsa_secp256r1_sha256, kDCValidFor, + now()); + + auto filter = MakeTlsFilter<TlsHandshakeDropper>(server_); + filter->SetHandshakeTypes( + {kTlsHandshakeCertificateVerify, kTlsHandshakeFinished}); + filter->EnableDecryption(); + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send ServerHello + + client_->SetAuthCertificateCallback(CheckPreliminaryDC); + client_->Handshake(); // Process response + + client_->UpdatePreliminaryChannelInfo(); + CheckPreliminaryPeerDelegCred(client_, true, 256, + ssl_sig_ecdsa_secp256r1_sha256); +} + +// Check that preliminary channel info properly reflects a lack of DC. +TEST_P(TlsConnectTls13, DCCheckPreliminaryInfoNoDC) { + Reset(kEcdsaDelegatorId); + EnsureTlsSetup(); + client_->EnableDelegatedCredentials(); + auto filter = MakeTlsFilter<TlsHandshakeDropper>(server_); + filter->SetHandshakeTypes( + {kTlsHandshakeCertificateVerify, kTlsHandshakeFinished}); + filter->EnableDecryption(); + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send ServerHello + + client_->SetAuthCertificateCallback(CheckPreliminaryNoDC); + client_->Handshake(); // Process response + + client_->UpdatePreliminaryChannelInfo(); + CheckPreliminaryPeerDelegCred(client_, false); +} + +// Tweak the scheme in between |Cert| and |CertVerify|. +TEST_P(TlsConnectTls13, DCRejectModifiedDCScheme) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + client_->SetAuthCertificateCallback(ModifyDCScheme); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa521, + ssl_sig_ecdsa_secp521r1_sha512, kDCValidFor, + now()); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + client_->CheckErrorCode(SSL_ERROR_DC_CERT_VERIFY_ALG_MISMATCH); +} + +// Tweak the authKeyBits in between |Cert| and |CertVerify|. +TEST_P(TlsConnectTls13, DCRejectModifiedDCAuthKeyBits) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + client_->SetAuthCertificateCallback(ModifyDCAuthKeyBits); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa521, + ssl_sig_ecdsa_secp521r1_sha512, kDCValidFor, + now()); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + client_->CheckErrorCode(SSL_ERROR_DC_CERT_VERIFY_ALG_MISMATCH); +} + +class DCDelegation : public ::testing::Test {}; + +TEST_F(DCDelegation, DCDelegations) { + PRTime now = PR_Now(); + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + ASSERT_TRUE(TlsAgent::LoadCertificate(kEcdsaDelegatorId, &cert, &priv)); + + ScopedSECKEYPublicKey pub_rsa; + ScopedSECKEYPrivateKey priv_rsa; + ASSERT_TRUE( + TlsAgent::LoadKeyPairFromCert(TlsAgent::kServerRsa, &pub_rsa, &priv_rsa)); + + StackSECItem dc; + EXPECT_EQ(SECFailure, + SSL_DelegateCredential(cert.get(), priv.get(), pub_rsa.get(), + ssl_sig_ecdsa_secp256r1_sha256, kDCValidFor, + now, &dc)); + EXPECT_EQ(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM, PORT_GetError()); + + // Using different PSS hashes should be OK. + EXPECT_EQ(SECSuccess, SSL_DelegateCredential( + cert.get(), priv.get(), pub_rsa.get(), + ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now, &dc)); + // Make sure to reset |dc| after each success. + dc.Reset(); + EXPECT_EQ(SECSuccess, SSL_DelegateCredential( + cert.get(), priv.get(), pub_rsa.get(), + ssl_sig_rsa_pss_pss_sha384, kDCValidFor, now, &dc)); + dc.Reset(); + EXPECT_EQ(SECSuccess, SSL_DelegateCredential( + cert.get(), priv.get(), pub_rsa.get(), + ssl_sig_rsa_pss_pss_sha512, kDCValidFor, now, &dc)); + dc.Reset(); + + ScopedSECKEYPublicKey pub_ecdsa; + ScopedSECKEYPrivateKey priv_ecdsa; + ASSERT_TRUE(TlsAgent::LoadKeyPairFromCert(TlsAgent::kServerEcdsa256, + &pub_ecdsa, &priv_ecdsa)); + + EXPECT_EQ(SECFailure, + SSL_DelegateCredential(cert.get(), priv.get(), pub_ecdsa.get(), + ssl_sig_rsa_pss_rsae_sha256, kDCValidFor, + now, &dc)); + EXPECT_EQ(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM, PORT_GetError()); + EXPECT_EQ(SECFailure, SSL_DelegateCredential( + cert.get(), priv.get(), pub_ecdsa.get(), + ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now, &dc)); + EXPECT_EQ(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM, PORT_GetError()); + EXPECT_EQ(SECFailure, + SSL_DelegateCredential(cert.get(), priv.get(), pub_ecdsa.get(), + ssl_sig_ecdsa_secp384r1_sha384, kDCValidFor, + now, &dc)); + EXPECT_EQ(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM, PORT_GetError()); +} + +} // namespace nss_test |