summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /security/nss/gtests/ssl_gtest
parentInitial commit. (diff)
downloadfirefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz
firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'security/nss/gtests/ssl_gtest')
-rw-r--r--security/nss/gtests/ssl_gtest/Makefile58
-rw-r--r--security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc108
-rw-r--r--security/nss/gtests/ssl_gtest/gtest_utils.h57
-rw-r--r--security/nss/gtests/ssl_gtest/libssl_internals.c501
-rw-r--r--security/nss/gtests/ssl_gtest/libssl_internals.h56
-rw-r--r--security/nss/gtests/ssl_gtest/manifest.mn77
-rw-r--r--security/nss/gtests/ssl_gtest/nss_policy.h107
-rw-r--r--security/nss/gtests/ssl_gtest/rsa8193.h209
-rw-r--r--security/nss/gtests/ssl_gtest/selfencrypt_unittest.cc281
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc1183
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_aead_unittest.cc218
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc235
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc2261
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc246
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_cipherorder_unittest.cc241
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc531
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc499
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc104
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_debug_env_unittest.cc51
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc802
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc914
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc728
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_ems_unittest.cc96
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc188
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc1513
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc169
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc252
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_gather_unittest.cc156
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_gtest.cc52
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_gtest.gyp135
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc1364
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc164
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc209
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc801
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_masking_unittest.cc350
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc20
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_record_unittest.cc826
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_recordsep_unittest.cc679
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_recordsize_unittest.cc726
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc235
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc1522
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc246
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc139
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc573
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc414
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_version_unittest.cc456
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc385
-rw-r--r--security/nss/gtests/ssl_gtest/test_io.cc278
-rw-r--r--security/nss/gtests/ssl_gtest/test_io.h187
-rw-r--r--security/nss/gtests/ssl_gtest/tls_agent.cc1432
-rw-r--r--security/nss/gtests/ssl_gtest/tls_agent.h588
-rw-r--r--security/nss/gtests/ssl_gtest/tls_connect.cc1065
-rw-r--r--security/nss/gtests/ssl_gtest/tls_connect.h390
-rw-r--r--security/nss/gtests/ssl_gtest/tls_ech_unittest.cc2913
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.cc1293
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.h1013
-rw-r--r--security/nss/gtests/ssl_gtest/tls_grease_unittest.cc878
-rw-r--r--security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc433
-rw-r--r--security/nss/gtests/ssl_gtest/tls_protect.cc148
-rw-r--r--security/nss/gtests/ssl_gtest/tls_protect.h60
-rw-r--r--security/nss/gtests/ssl_gtest/tls_psk_unittest.cc515
-rw-r--r--security/nss/gtests/ssl_gtest/tls_subcerts_unittest.cc723
62 files changed, 33053 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/Makefile b/security/nss/gtests/ssl_gtest/Makefile
new file mode 100644
index 0000000000..46f0303576
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/Makefile
@@ -0,0 +1,58 @@
+#! gmake
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#######################################################################
+# (1) Include initial platform-independent assignments (MANDATORY). #
+#######################################################################
+
+include manifest.mn
+
+#######################################################################
+# (2) Include "global" configuration information. (OPTIONAL) #
+#######################################################################
+
+include $(CORE_DEPTH)/coreconf/config.mk
+
+#######################################################################
+# (3) Include "component" configuration information. (OPTIONAL) #
+#######################################################################
+
+
+#######################################################################
+# (4) Include "local" platform-dependent assignments (OPTIONAL). #
+#######################################################################
+
+include ../common/gtest.mk
+
+CFLAGS += -I$(CORE_DEPTH)/lib/ssl
+
+ifdef NSS_DISABLE_TLS_1_3
+NSS_DISABLE_TLS_1_3=1
+# Run parameterized tests only, for which we can easily exclude TLS 1.3
+CPPSRCS := $(filter-out $(shell grep -l '^TEST_F' $(CPPSRCS)), $(CPPSRCS))
+CFLAGS += -DNSS_DISABLE_TLS_1_3
+endif
+
+ifdef NSS_ALLOW_SSLKEYLOGFILE
+SSLKEYLOGFILE_FILES = ssl_keylog_unittest.cc
+else
+SSLKEYLOGFILE_FILES = $(NULL)
+endif
+
+#######################################################################
+# (5) Execute "global" rules. (OPTIONAL) #
+#######################################################################
+
+include $(CORE_DEPTH)/coreconf/rules.mk
+
+#######################################################################
+# (6) Execute "component" rules. (OPTIONAL) #
+#######################################################################
+
+
+#######################################################################
+# (7) Execute "local" rules. (OPTIONAL). #
+#######################################################################
diff --git a/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc b/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc
new file mode 100644
index 0000000000..ccb2cd88ef
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc
@@ -0,0 +1,108 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+extern "C" {
+#include "sslbloom.h"
+}
+
+#include "gtest_utils.h"
+
+namespace nss_test {
+
+// Some random-ish inputs to test with. These don't result in collisions in any
+// of the configurations that are tested below.
+static const uint8_t kHashes1[] = {
+ 0x79, 0x53, 0xb8, 0xdd, 0x6b, 0x98, 0xce, 0x00, 0xb7, 0xdc, 0xe8,
+ 0x03, 0x70, 0x8c, 0xe3, 0xac, 0x06, 0x8b, 0x22, 0xfd, 0x0e, 0x34,
+ 0x48, 0xe6, 0xe5, 0xe0, 0x8a, 0xd6, 0x16, 0x18, 0xe5, 0x48};
+static const uint8_t kHashes2[] = {
+ 0xc6, 0xdd, 0x6e, 0xc4, 0x76, 0xb8, 0x55, 0xf2, 0xa4, 0xfc, 0x59,
+ 0x04, 0xa4, 0x90, 0xdc, 0xa7, 0xa7, 0x0d, 0x94, 0x8f, 0xc2, 0xdc,
+ 0x15, 0x6d, 0x48, 0x93, 0x9d, 0x05, 0xbb, 0x9a, 0xbc, 0xc1};
+
+typedef struct {
+ unsigned int k;
+ unsigned int bits;
+} BloomFilterConfig;
+
+class BloomFilterTest
+ : public ::testing::Test,
+ public ::testing::WithParamInterface<BloomFilterConfig> {
+ public:
+ BloomFilterTest() : filter_() {}
+
+ void SetUp() { Init(); }
+
+ void TearDown() { sslBloom_Destroy(&filter_); }
+
+ protected:
+ void Init() {
+ if (filter_.filter) {
+ sslBloom_Destroy(&filter_);
+ }
+ ASSERT_EQ(SECSuccess,
+ sslBloom_Init(&filter_, GetParam().k, GetParam().bits));
+ }
+
+ bool Check(const uint8_t* hashes) {
+ return sslBloom_Check(&filter_, hashes) ? true : false;
+ }
+
+ void Add(const uint8_t* hashes, bool expect_collision = false) {
+ EXPECT_EQ(expect_collision, sslBloom_Add(&filter_, hashes) ? true : false);
+ EXPECT_TRUE(Check(hashes));
+ }
+
+ sslBloomFilter filter_;
+};
+
+TEST_P(BloomFilterTest, InitOnly) {}
+
+TEST_P(BloomFilterTest, AddToEmpty) {
+ EXPECT_FALSE(Check(kHashes1));
+ Add(kHashes1);
+}
+
+TEST_P(BloomFilterTest, AddTwo) {
+ Add(kHashes1);
+ Add(kHashes2);
+}
+
+TEST_P(BloomFilterTest, AddOneTwice) {
+ Add(kHashes1);
+ Add(kHashes1, true);
+}
+
+TEST_P(BloomFilterTest, Zero) {
+ Add(kHashes1);
+ sslBloom_Zero(&filter_);
+ EXPECT_FALSE(Check(kHashes1));
+ EXPECT_FALSE(Check(kHashes2));
+}
+
+TEST_P(BloomFilterTest, Fill) {
+ sslBloom_Fill(&filter_);
+ EXPECT_TRUE(Check(kHashes1));
+ EXPECT_TRUE(Check(kHashes2));
+}
+
+static const BloomFilterConfig kBloomFilterConfigurations[] = {
+ {1, 1}, // 1 hash, 1 bit input - high chance of collision.
+ {1, 2}, // 1 hash, 2 bits - smaller than the basic unit size.
+ {1, 3}, // 1 hash, 3 bits - same as basic unit size.
+ {1, 4}, // 1 hash, 4 bits - 2 octets each.
+ {3, 10}, // 3 hashes over a reasonable number of bits.
+ {3, 3}, // Test that we can read multiple bits.
+ {4, 15}, // A credible filter.
+ {2, 18}, // A moderately large allocation.
+ {16, 16}, // Insane, use all of the bits from the hashes.
+ {16, 9}, // This also uses all of the bits from the hashes.
+};
+
+INSTANTIATE_TEST_SUITE_P(BloomFilterConfigurations, BloomFilterTest,
+ ::testing::ValuesIn(kBloomFilterConfigurations));
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/gtest_utils.h b/security/nss/gtests/ssl_gtest/gtest_utils.h
new file mode 100644
index 0000000000..2344c3cea9
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/gtest_utils.h
@@ -0,0 +1,57 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef gtest_utils_h__
+#define gtest_utils_h__
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+#include "test_io.h"
+
+namespace nss_test {
+
+// Gtest utilities
+class Timeout : public PollTarget {
+ public:
+ Timeout(int32_t timer_ms) : handle_(nullptr) {
+ Poller::Instance()->SetTimer(timer_ms, this, &Timeout::ExpiredCallback,
+ &handle_);
+ }
+ ~Timeout() {
+ if (handle_) {
+ handle_->Cancel();
+ }
+ }
+
+ static void ExpiredCallback(PollTarget* target, Event event) {
+ Timeout* timeout = static_cast<Timeout*>(target);
+ timeout->handle_ = nullptr;
+ }
+
+ bool timed_out() const { return !handle_; }
+
+ private:
+ std::shared_ptr<Poller::Timer> handle_;
+};
+
+} // namespace nss_test
+
+#define WAIT_(expression, timeout) \
+ do { \
+ Timeout tm(timeout); \
+ while (!(expression)) { \
+ Poller::Instance()->Poll(); \
+ if (tm.timed_out()) break; \
+ } \
+ } while (0)
+
+#define ASSERT_TRUE_WAIT(expression, timeout) \
+ do { \
+ WAIT_(expression, timeout); \
+ ASSERT_TRUE(expression); \
+ } while (0)
+
+#endif
diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.c b/security/nss/gtests/ssl_gtest/libssl_internals.c
new file mode 100644
index 0000000000..c6b03c530f
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/libssl_internals.c
@@ -0,0 +1,501 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+/* This file contains functions for frobbing the internals of libssl */
+#include "libssl_internals.h"
+
+#include "nss.h"
+#include "pk11hpke.h"
+#include "pk11pub.h"
+#include "pk11priv.h"
+#include "tls13ech.h"
+#include "seccomon.h"
+#include "selfencrypt.h"
+#include "secmodti.h"
+#include "sslproto.h"
+
+SECStatus SSLInt_RemoveServerCertificates(PRFileDesc *fd) {
+ if (!fd) {
+ return SECFailure;
+ }
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ PRCList *cursor;
+ while (!PR_CLIST_IS_EMPTY(&ss->serverCerts)) {
+ cursor = PR_LIST_TAIL(&ss->serverCerts);
+ PR_REMOVE_LINK(cursor);
+ ssl_FreeServerCert((sslServerCert *)cursor);
+ }
+ return SECSuccess;
+}
+
+SECStatus SSLInt_SetDCAdvertisedSigSchemes(PRFileDesc *fd,
+ const SSLSignatureScheme *schemes,
+ uint32_t num_sig_schemes) {
+ if (!fd) {
+ return SECFailure;
+ }
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ // Alloc and copy, libssl will free.
+ SSLSignatureScheme *dc_schemes =
+ PORT_ZNewArray(SSLSignatureScheme, num_sig_schemes);
+ if (!dc_schemes) {
+ return SECFailure;
+ }
+ memcpy(dc_schemes, schemes, sizeof(SSLSignatureScheme) * num_sig_schemes);
+
+ if (ss->xtnData.delegCredSigSchemesAdvertised) {
+ PORT_Free(ss->xtnData.delegCredSigSchemesAdvertised);
+ }
+ ss->xtnData.delegCredSigSchemesAdvertised = dc_schemes;
+ ss->xtnData.numDelegCredSigSchemesAdvertised = num_sig_schemes;
+ return SECSuccess;
+}
+
+SECStatus SSLInt_TweakChannelInfoForDC(PRFileDesc *fd, PRBool changeAuthKeyBits,
+ PRBool changeScheme) {
+ if (!fd) {
+ return SECFailure;
+ }
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ // Just toggle so we'll always have a valid value.
+ if (changeScheme) {
+ ss->sec.signatureScheme = (ss->sec.signatureScheme == ssl_sig_ed25519)
+ ? ssl_sig_ecdsa_secp256r1_sha256
+ : ssl_sig_ed25519;
+ }
+ if (changeAuthKeyBits) {
+ ss->sec.authKeyBits = ss->sec.authKeyBits ? ss->sec.authKeyBits * 2 : 384;
+ }
+
+ return SECSuccess;
+}
+
+SECStatus SSLInt_GetHandshakeRandoms(PRFileDesc *fd, SSL3Random client_random,
+ SSL3Random server_random) {
+ if (!fd) {
+ return SECFailure;
+ }
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ if (client_random) {
+ memcpy(client_random, ss->ssl3.hs.client_random, sizeof(SSL3Random));
+ }
+ if (server_random) {
+ memcpy(server_random, ss->ssl3.hs.server_random, sizeof(SSL3Random));
+ }
+ return SECSuccess;
+}
+
+SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ ++ss->clientHelloVersion;
+
+ return SECSuccess;
+}
+
+/* Use this function to update the ClientRandom of a client's handshake state
+ * after replacing its ClientHello message. We for example need to do this
+ * when replacing an SSLv3 ClientHello with its SSLv2 equivalent. */
+SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd,
+ size_t rnd_len, uint8_t *msg,
+ size_t msg_len) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ ssl3_RestartHandshakeHashes(ss);
+
+ // Ensure we don't overrun hs.client_random.
+ rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len);
+
+ // Zero the client_random.
+ PORT_Memset(ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH);
+
+ // Copy over the challenge bytes.
+ size_t offset = SSL3_RANDOM_LENGTH - rnd_len;
+ PORT_Memcpy(ss->ssl3.hs.client_random + offset, rnd, rnd_len);
+
+ // Rehash the SSLv2 client hello message.
+ return ssl3_UpdateHandshakeHashes(ss, msg, msg_len);
+}
+
+PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ return (PRBool)(ss && ssl3_ExtensionNegotiated(ss, ext));
+}
+
+// Tests should not use this function directly, because the keys may
+// still be in cache. Instead, use TlsConnectTestBase::ClearServerCache.
+void SSLInt_ClearSelfEncryptKey() { ssl_ResetSelfEncryptKeys(); }
+
+sslSelfEncryptKeys *ssl_GetSelfEncryptKeysInt();
+
+void SSLInt_SetSelfEncryptMacKey(PK11SymKey *key) {
+ sslSelfEncryptKeys *keys = ssl_GetSelfEncryptKeysInt();
+
+ PK11_FreeSymKey(keys->macKey);
+ keys->macKey = key;
+}
+
+SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+ ss->ssl3.mtu = mtu;
+ ss->ssl3.hs.rtRetries = 0; /* Avoid DTLS shrinking the MTU any more. */
+ return SECSuccess;
+}
+
+PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd) {
+ PRCList *cur_p;
+ PRInt32 ct = 0;
+
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return -1;
+ }
+
+ for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs);
+ cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) {
+ ++ct;
+ }
+ return ct;
+}
+
+void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd) {
+ PRCList *cur_p;
+
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return;
+ }
+
+ fprintf(stderr, "Cipher specs for %s\n", label);
+ for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs);
+ cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) {
+ ssl3CipherSpec *spec = (ssl3CipherSpec *)cur_p;
+ fprintf(stderr, " %s spec epoch=%d (%s) refct=%d\n", SPEC_DIR(spec),
+ spec->epoch, spec->phase, spec->refCt);
+ }
+}
+
+/* DTLS timers are separate from the time that the rest of the stack uses.
+ * Force a timer expiry by backdating when all active timers were started.
+ * We could set the remaining time to 0 but then backoff would not work properly
+ * if we decide to test it. */
+SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift) {
+ size_t i;
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) {
+ if (ss->ssl3.hs.timers[i].cb) {
+ ss->ssl3.hs.timers[i].started -= shift;
+ }
+ }
+ return SECSuccess;
+}
+
+#define CHECK_SECRET(secret) \
+ if (ss->ssl3.hs.secret) { \
+ fprintf(stderr, "%s != NULL\n", #secret); \
+ return PR_FALSE; \
+ }
+
+PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return PR_FALSE;
+ }
+
+ CHECK_SECRET(currentSecret);
+ CHECK_SECRET(dheSecret);
+ CHECK_SECRET(clientEarlyTrafficSecret);
+ CHECK_SECRET(clientHsTrafficSecret);
+ CHECK_SECRET(serverHsTrafficSecret);
+
+ return PR_TRUE;
+}
+
+PRBool sslint_DamageTrafficSecret(PRFileDesc *fd, size_t offset) {
+ unsigned char data[32] = {0};
+ PK11SymKey **keyPtr;
+ PK11SlotInfo *slot = PK11_GetInternalSlot();
+ SECItem key_item = {siBuffer, data, sizeof(data)};
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return PR_FALSE;
+ }
+ if (!slot) {
+ return PR_FALSE;
+ }
+ keyPtr = (PK11SymKey **)((char *)&ss->ssl3.hs + offset);
+ if (!*keyPtr) {
+ return PR_FALSE;
+ }
+ PK11_FreeSymKey(*keyPtr);
+ *keyPtr = PK11_ImportSymKey(slot, CKM_NSS_HKDF_SHA256, PK11_OriginUnwrap,
+ CKA_DERIVE, &key_item, NULL);
+ PK11_FreeSlot(slot);
+ if (!*keyPtr) {
+ return PR_FALSE;
+ }
+
+ return PR_TRUE;
+}
+
+PRBool SSLInt_DamageClientHsTrafficSecret(PRFileDesc *fd) {
+ return sslint_DamageTrafficSecret(
+ fd, offsetof(SSL3HandshakeState, clientHsTrafficSecret));
+}
+
+PRBool SSLInt_DamageServerHsTrafficSecret(PRFileDesc *fd) {
+ return sslint_DamageTrafficSecret(
+ fd, offsetof(SSL3HandshakeState, serverHsTrafficSecret));
+}
+
+PRBool SSLInt_DamageEarlyTrafficSecret(PRFileDesc *fd) {
+ return sslint_DamageTrafficSecret(
+ fd, offsetof(SSL3HandshakeState, clientEarlyTrafficSecret));
+}
+
+SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ ss->xtnData.nextProtoState = SSL_NEXT_PROTO_EARLY_VALUE;
+ if (ss->xtnData.nextProto.data) {
+ SECITEM_FreeItem(&ss->xtnData.nextProto, PR_FALSE);
+ }
+ if (!SECITEM_AllocItem(NULL, &ss->xtnData.nextProto, len)) {
+ return SECFailure;
+ }
+ PORT_Memcpy(ss->xtnData.nextProto.data, data, len);
+
+ return SECSuccess;
+}
+
+PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return PR_FALSE;
+ }
+
+ return (PRBool)(!!ssl_FindServerCert(ss, authType, NULL));
+}
+
+PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return PR_FALSE;
+ }
+
+ SECStatus rv = SSL3_SendAlert(ss, level, type);
+ if (rv != SECSuccess) return PR_FALSE;
+
+ return PR_TRUE;
+}
+
+SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) {
+ sslSocket *ss;
+ ssl3CipherSpec *spec;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+ if (to > RECORD_SEQ_MAX) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ return SECFailure;
+ }
+ ssl_GetSpecWriteLock(ss);
+ spec = ss->ssl3.crSpec;
+ spec->nextSeqNum = to;
+
+ /* For DTLS, we need to fix the record sequence number. For this, we can just
+ * scrub the entire structure on the assumption that the new sequence number
+ * is far enough past the last received sequence number. */
+ if (spec->nextSeqNum <=
+ spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ return SECFailure;
+ }
+ dtls_RecordSetRecvd(&spec->recvdRecords, spec->nextSeqNum - 1);
+
+ ssl_ReleaseSpecWriteLock(ss);
+ return SECSuccess;
+}
+
+SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) {
+ sslSocket *ss;
+ ssl3CipherSpec *spec;
+ PK11Context *pk11ctxt;
+ const ssl3BulkCipherDef *cipher_def;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+ if (to >= RECORD_SEQ_MAX) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ return SECFailure;
+ }
+ ssl_GetSpecWriteLock(ss);
+ spec = ss->ssl3.cwSpec;
+ cipher_def = spec->cipherDef;
+ spec->nextSeqNum = to;
+ if (cipher_def->type != type_aead) {
+ ssl_ReleaseSpecWriteLock(ss);
+ return SECSuccess;
+ }
+ /* If we are using aead, we need to advance the counter in the
+ * internal IV generator as well.
+ * This could be in the token or software. */
+ pk11ctxt = spec->cipherContext;
+ /* If counter is in the token, we need to switch it to software,
+ * since we don't have access to the internal state of the token. We do
+ * that by turning on the simulated message interface, then setting up the
+ * software IV generator */
+ if (pk11ctxt->ivCounter == 0) {
+ _PK11_ContextSetAEADSimulation(pk11ctxt);
+ pk11ctxt->ivLen = cipher_def->iv_size + cipher_def->explicit_nonce_size;
+ pk11ctxt->ivMaxCount = PR_UINT64(0xffffffffffffffff);
+ if ((cipher_def->explicit_nonce_size == 0) ||
+ (spec->version >= SSL_LIBRARY_VERSION_TLS_1_3)) {
+ pk11ctxt->ivFixedBits =
+ (pk11ctxt->ivLen - sizeof(sslSequenceNumber)) * BPB;
+ pk11ctxt->ivGen = CKG_GENERATE_COUNTER_XOR;
+ } else {
+ pk11ctxt->ivFixedBits = cipher_def->iv_size * BPB;
+ pk11ctxt->ivGen = CKG_GENERATE_COUNTER;
+ }
+ /* DTLS included the epoch in the fixed portion of the IV */
+ if (IS_DTLS(ss)) {
+ pk11ctxt->ivFixedBits += 2 * BPB;
+ }
+ }
+ /* now we can update the internal counter (either we are already using
+ * the software IV generator, or we just switched to it above */
+ pk11ctxt->ivCounter = to;
+
+ ssl_ReleaseSpecWriteLock(ss);
+ return SECSuccess;
+}
+
+SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra) {
+ sslSocket *ss;
+ sslSequenceNumber to;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+ ssl_GetSpecReadLock(ss);
+ to = ss->ssl3.cwSpec->nextSeqNum + DTLS_RECVD_RECORDS_WINDOW + extra;
+ ssl_ReleaseSpecReadLock(ss);
+ return SSLInt_AdvanceWriteSeqNum(fd, to);
+}
+
+SECStatus SSLInt_AdvanceDtls13DecryptFailures(PRFileDesc *fd, PRUint64 to) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ ssl_GetSpecWriteLock(ss);
+ ssl3CipherSpec *spec = ss->ssl3.crSpec;
+ if (spec->cipherDef->type != type_aead) {
+ ssl_ReleaseSpecWriteLock(ss);
+ return SECFailure;
+ }
+
+ spec->deprotectionFailures = to;
+ ssl_ReleaseSpecWriteLock(ss);
+ return SECSuccess;
+}
+
+SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) {
+ const sslNamedGroupDef *groupDef = ssl_LookupNamedGroup(group);
+ if (!groupDef) return ssl_kea_null;
+
+ return groupDef->keaType;
+}
+
+SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) {
+ sslSocket *ss;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ /* This only works when resuming. */
+ if (!ss->statelessResume) {
+ PORT_SetError(SEC_INTERNAL_ONLY);
+ return SECFailure;
+ }
+
+ /* Modifying both specs allows this to be used on either peer. */
+ ssl_GetSpecWriteLock(ss);
+ ss->ssl3.crSpec->earlyDataRemaining = size;
+ ss->ssl3.cwSpec->earlyDataRemaining = size;
+ ssl_ReleaseSpecWriteLock(ss);
+
+ return SECSuccess;
+}
+
+SECStatus SSLInt_HasPendingHandshakeData(PRFileDesc *fd, PRBool *pending) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ ssl_GetSSL3HandshakeLock(ss);
+ *pending = ss->ssl3.hs.msg_body.len > 0;
+ ssl_ReleaseSSL3HandshakeLock(ss);
+ return SECSuccess;
+}
+
+SECStatus SSLInt_SetRawEchConfigForRetry(PRFileDesc *fd, const uint8_t *buf,
+ size_t len) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ sslEchConfig *cfg = (sslEchConfig *)PR_LIST_HEAD(&ss->echConfigs);
+ SECITEM_FreeItem(&cfg->raw, PR_FALSE);
+ SECITEM_AllocItem(NULL, &cfg->raw, len);
+ PORT_Memcpy(cfg->raw.data, buf, len);
+ return SECSuccess;
+}
+
+PRBool SSLInt_IsIp(PRUint8 *s, unsigned int len) { return tls13_IsIp(s, len); }
diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.h b/security/nss/gtests/ssl_gtest/libssl_internals.h
new file mode 100644
index 0000000000..70c6520c54
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/libssl_internals.h
@@ -0,0 +1,56 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef libssl_internals_h_
+#define libssl_internals_h_
+
+#include <stdint.h>
+
+#include "prio.h"
+#include "seccomon.h"
+#include "ssl.h"
+#include "sslimpl.h"
+#include "sslt.h"
+
+SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd);
+
+SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd,
+ size_t rnd_len, uint8_t *msg,
+ size_t msg_len);
+SECStatus SSLInt_GetHandshakeRandoms(PRFileDesc *fd, SSL3Random client_random,
+ SSL3Random server_random);
+PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext);
+void SSLInt_ClearSelfEncryptKey();
+void SSLInt_SetSelfEncryptMacKey(PK11SymKey *key);
+PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd);
+void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd);
+SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift);
+SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu);
+PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd);
+PRBool SSLInt_DamageClientHsTrafficSecret(PRFileDesc *fd);
+PRBool SSLInt_DamageServerHsTrafficSecret(PRFileDesc *fd);
+PRBool SSLInt_DamageEarlyTrafficSecret(PRFileDesc *fd);
+SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len);
+PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType);
+PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type);
+SECStatus SSLInt_AdvanceDtls13DecryptFailures(PRFileDesc *fd, PRUint64 to);
+SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to);
+SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to);
+SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra);
+SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group);
+SECStatus SSLInt_HasPendingHandshakeData(PRFileDesc *fd, PRBool *pending);
+SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size);
+SECStatus SSLInt_TweakChannelInfoForDC(PRFileDesc *fd, PRBool changeAuthKeyBits,
+ PRBool changeScheme);
+SECStatus SSLInt_SetDCAdvertisedSigSchemes(PRFileDesc *fd,
+ const SSLSignatureScheme *schemes,
+ uint32_t num_sig_schemes);
+SECStatus SSLInt_RemoveServerCertificates(PRFileDesc *fd);
+SECStatus SSLInt_SetRawEchConfigForRetry(PRFileDesc *fd, const uint8_t *buf,
+ size_t len);
+PRBool SSLInt_IsIp(PRUint8 *s, unsigned int len);
+
+#endif // ifndef libssl_internals_h_
diff --git a/security/nss/gtests/ssl_gtest/manifest.mn b/security/nss/gtests/ssl_gtest/manifest.mn
new file mode 100644
index 0000000000..af3081e8e4
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/manifest.mn
@@ -0,0 +1,77 @@
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+CORE_DEPTH = ../..
+DEPTH = ../..
+MODULE = nss
+
+# These sources have access to libssl internals
+CSRCS = \
+ libssl_internals.c \
+ $(NULL)
+
+CPPSRCS = \
+ bloomfilter_unittest.cc \
+ ssl_0rtt_unittest.cc \
+ ssl_aead_unittest.cc \
+ ssl_agent_unittest.cc \
+ ssl_auth_unittest.cc \
+ ssl_cert_ext_unittest.cc \
+ ssl_cipherorder_unittest.cc \
+ ssl_ciphersuite_unittest.cc \
+ ssl_custext_unittest.cc \
+ ssl_damage_unittest.cc \
+ ssl_debug_env_unittest.cc \
+ ssl_dhe_unittest.cc \
+ ssl_drop_unittest.cc \
+ ssl_ecdh_unittest.cc \
+ ssl_ems_unittest.cc \
+ ssl_exporter_unittest.cc \
+ ssl_extension_unittest.cc \
+ ssl_fragment_unittest.cc \
+ ssl_fuzz_unittest.cc \
+ ssl_gather_unittest.cc \
+ ssl_gtest.cc \
+ ssl_hrr_unittest.cc \
+ ssl_keyupdate_unittest.cc \
+ ssl_loopback_unittest.cc \
+ ssl_masking_unittest.cc \
+ ssl_misc_unittest.cc \
+ ssl_record_unittest.cc \
+ ssl_recordsep_unittest.cc \
+ ssl_recordsize_unittest.cc \
+ ssl_resumption_unittest.cc \
+ ssl_renegotiation_unittest.cc \
+ ssl_skip_unittest.cc \
+ ssl_staticrsa_unittest.cc \
+ ssl_tls13compat_unittest.cc \
+ ssl_v2_client_hello_unittest.cc \
+ ssl_version_unittest.cc \
+ ssl_versionpolicy_unittest.cc \
+ selfencrypt_unittest.cc \
+ test_io.cc \
+ tls_agent.cc \
+ tls_connect.cc \
+ tls_hkdf_unittest.cc \
+ tls_filter.cc \
+ tls_protect.cc \
+ tls_psk_unittest.cc \
+ tls_subcerts_unittest.cc \
+ tls_ech_unittest.cc \
+ $(SSLKEYLOGFILE_FILES) \
+ $(NULL)
+
+INCLUDES += -I$(CORE_DEPTH)/gtests/google_test/gtest/include \
+ -I$(CORE_DEPTH)/gtests/common \
+ -I$(CORE_DEPTH)/cpputil
+
+REQUIRES = nspr nss libdbm gtest cpputil
+
+PROGRAM = ssl_gtest
+EXTRA_LIBS += \
+ $(DIST)/lib/$(LIB_PREFIX)gtest.$(LIB_SUFFIX) \
+ $(DIST)/lib/$(LIB_PREFIX)cpputil.$(LIB_SUFFIX) \
+ $(NULL)
+
+USE_STATIC_LIBS = 1
diff --git a/security/nss/gtests/ssl_gtest/nss_policy.h b/security/nss/gtests/ssl_gtest/nss_policy.h
new file mode 100644
index 0000000000..ceab03becc
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/nss_policy.h
@@ -0,0 +1,107 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef nss_policy_h_
+#define nss_policy_h_
+
+#include "prtypes.h"
+#include "secoid.h"
+#include "nss.h"
+
+namespace nss_test {
+
+// container class to hold all a temp policy
+class NssPolicy {
+ public:
+ NssPolicy() : oid_(SEC_OID_UNKNOWN), set_(0), clear_(0) {}
+ NssPolicy(SECOidTag _oid, PRUint32 _set, PRUint32 _clear)
+ : oid_(_oid), set_(_set), clear_(_clear) {}
+ NssPolicy(const NssPolicy &p)
+ : oid_(p.oid_), set_(p.set_), clear_(p.clear_) {}
+ // clone the current policy for this oid
+ NssPolicy(SECOidTag _oid) : oid_(_oid), set_(0), clear_(0) {
+ NSS_GetAlgorithmPolicy(_oid, &set_);
+ clear_ = ~set_;
+ }
+ SECOidTag oid(void) const { return oid_; }
+ PRUint32 set(void) const { return set_; }
+ PRUint32 clear(void) const { return clear_; }
+ operator bool() const { return oid_ != SEC_OID_UNKNOWN; }
+
+ private:
+ SECOidTag oid_;
+ PRUint32 set_;
+ PRUint32 clear_;
+};
+
+// container class to hold a temp option
+class NssOption {
+ public:
+ NssOption() : id_(-1), value_(0) {}
+ NssOption(PRInt32 _id, PRInt32 _value) : id_(_id), value_(_value) {}
+ NssOption(const NssOption &o) : id_(o.id_), value_(o.value_) {}
+ // clone the current option for this id
+ NssOption(PRInt32 _id) : id_(_id), value_(0) { NSS_OptionGet(id_, &value_); }
+ PRInt32 id(void) const { return id_; }
+ PRInt32 value(void) const { return value_; }
+ operator bool() const { return id_ != -1; }
+
+ private:
+ PRInt32 id_;
+ PRInt32 value_;
+};
+
+// set the policy indicated in NssPolicy and restor the old policy
+// when we go out of scope
+class NssManagePolicy {
+ public:
+ NssManagePolicy(const NssPolicy &p, const NssOption &o)
+ : policy_(p), save_policy_(~(PRUint32)0), option_(o), save_option_(0) {
+ if (p) {
+ (void)NSS_GetAlgorithmPolicy(p.oid(), &save_policy_);
+ (void)NSS_SetAlgorithmPolicy(p.oid(), p.set(), p.clear());
+ }
+ if (o) {
+ (void)NSS_OptionGet(o.id(), &save_option_);
+ (void)NSS_OptionSet(o.id(), o.value());
+ }
+ }
+ ~NssManagePolicy() {
+ if (policy_) {
+ (void)NSS_SetAlgorithmPolicy(policy_.oid(), save_policy_, ~save_policy_);
+ }
+ if (option_) {
+ (void)NSS_OptionSet(option_.id(), save_option_);
+ }
+ }
+
+ private:
+ NssPolicy policy_;
+ PRUint32 save_policy_;
+ NssOption option_;
+ PRInt32 save_option_;
+};
+
+// wrapping PRFileDesc this way ensures that tests that attempt to access
+// PRFileDesc always correctly apply
+// the policy that was bound to that socket with TlsAgent::SetPolicy().
+class NssManagedFileDesc {
+ public:
+ NssManagedFileDesc(PRFileDesc *fd, const NssPolicy &policy,
+ const NssOption &option)
+ : fd_(fd), managed_policy_(policy, option) {}
+ PRFileDesc *get(void) const { return fd_; }
+ operator PRFileDesc *() const { return fd_; }
+ bool operator==(PRFileDesc *fd) const { return fd_ == fd; }
+
+ private:
+ PRFileDesc *fd_;
+ NssManagePolicy managed_policy_;
+};
+
+} // namespace nss_test
+
+#endif
diff --git a/security/nss/gtests/ssl_gtest/rsa8193.h b/security/nss/gtests/ssl_gtest/rsa8193.h
new file mode 100644
index 0000000000..1ac8503bc0
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/rsa8193.h
@@ -0,0 +1,209 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+// openssl req -nodes -x509 -newkey rsa:8193 -out cert.pem -days 365
+static const uint8_t rsa8193[] = {
+ 0x30, 0x82, 0x09, 0x61, 0x30, 0x82, 0x05, 0x48, 0xa0, 0x03, 0x02, 0x01,
+ 0x02, 0x02, 0x09, 0x00, 0xaf, 0xff, 0x37, 0x91, 0x3e, 0x44, 0xae, 0x57,
+ 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01,
+ 0x0b, 0x05, 0x00, 0x30, 0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55,
+ 0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03,
+ 0x55, 0x04, 0x08, 0x0c, 0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74,
+ 0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04, 0x0a,
+ 0x0c, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57,
+ 0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c,
+ 0x74, 0x64, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x38, 0x30, 0x35, 0x31, 0x37,
+ 0x30, 0x39, 0x34, 0x32, 0x32, 0x39, 0x5a, 0x17, 0x0d, 0x31, 0x39, 0x30,
+ 0x35, 0x31, 0x37, 0x30, 0x39, 0x34, 0x32, 0x32, 0x39, 0x5a, 0x30, 0x45,
+ 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x41,
+ 0x55, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x0c, 0x0a,
+ 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65, 0x31, 0x21,
+ 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x0c, 0x18, 0x49, 0x6e, 0x74,
+ 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67, 0x69, 0x74,
+ 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x30, 0x82, 0x04,
+ 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01,
+ 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x04, 0x0f, 0x00, 0x30, 0x82, 0x04,
+ 0x0a, 0x02, 0x82, 0x04, 0x01, 0x01, 0x77, 0xd6, 0xa9, 0x93, 0x4e, 0x15,
+ 0xb5, 0x67, 0x70, 0x8e, 0xc3, 0x77, 0x4f, 0xc9, 0x8a, 0x06, 0xd9, 0xb9,
+ 0xa6, 0x41, 0xb8, 0xfa, 0x4a, 0x13, 0x26, 0xdc, 0x2b, 0xc5, 0x82, 0xa0,
+ 0x74, 0x8c, 0x1e, 0xe9, 0xc0, 0x70, 0x15, 0x56, 0xec, 0x1f, 0x7e, 0x91,
+ 0x6e, 0x31, 0x42, 0x8b, 0xd5, 0xe2, 0x0e, 0x9c, 0xeb, 0xff, 0xbc, 0xf9,
+ 0x42, 0xd3, 0xb9, 0x1c, 0x5e, 0x46, 0x80, 0x90, 0x5f, 0xe1, 0x59, 0x22,
+ 0x13, 0x71, 0xd3, 0xd6, 0x66, 0x7a, 0xe0, 0x56, 0x04, 0x10, 0x59, 0x01,
+ 0xb3, 0xb6, 0xd2, 0xc7, 0xa7, 0x3b, 0xbc, 0xe6, 0x38, 0x44, 0xd5, 0x71,
+ 0x66, 0x1d, 0xb2, 0x63, 0x2f, 0xa9, 0x5e, 0x80, 0x92, 0x3c, 0x21, 0x0e,
+ 0xe1, 0xda, 0xd6, 0x1d, 0xcb, 0xce, 0xac, 0xe1, 0x5f, 0x97, 0x45, 0x8f,
+ 0xc1, 0x64, 0x16, 0xa6, 0x88, 0x2a, 0x36, 0x4a, 0x76, 0x64, 0x8f, 0x83,
+ 0x7a, 0x1d, 0xd8, 0x91, 0x90, 0x7b, 0x58, 0xb8, 0x1c, 0x7f, 0x56, 0x57,
+ 0x35, 0xfb, 0xf3, 0x1a, 0xcb, 0x7c, 0x66, 0x66, 0x04, 0x95, 0xee, 0x3a,
+ 0x80, 0xf0, 0xd4, 0x12, 0x3a, 0x7e, 0x7e, 0x5e, 0xb8, 0x55, 0x29, 0x23,
+ 0x06, 0xd3, 0x85, 0x0c, 0x99, 0x91, 0x42, 0xee, 0x5a, 0x30, 0x7f, 0x52,
+ 0x20, 0xb3, 0xe2, 0xe7, 0x39, 0x69, 0xb6, 0xfc, 0x42, 0x1e, 0x98, 0xd3,
+ 0x31, 0xa2, 0xfa, 0x81, 0x52, 0x69, 0x6d, 0x23, 0xf8, 0xc4, 0xc3, 0x3c,
+ 0x9b, 0x48, 0x75, 0xa8, 0xc7, 0xe7, 0x61, 0x81, 0x1f, 0xf7, 0xce, 0x10,
+ 0xaa, 0x13, 0xcb, 0x6e, 0x19, 0xc0, 0x4f, 0x6f, 0x90, 0xa8, 0x41, 0xea,
+ 0x49, 0xdf, 0xe4, 0xef, 0x84, 0x54, 0xb5, 0x37, 0xaf, 0x12, 0x75, 0x1a,
+ 0x11, 0x4b, 0x58, 0x7f, 0x63, 0x22, 0x33, 0xb1, 0xc8, 0x4d, 0xf2, 0x41,
+ 0x10, 0xbc, 0x37, 0xb5, 0xd5, 0xb2, 0x21, 0x32, 0x35, 0x9d, 0xf3, 0x8d,
+ 0xab, 0x66, 0x9d, 0x19, 0x12, 0x71, 0x45, 0xb3, 0x82, 0x5a, 0x5c, 0xff,
+ 0x2d, 0xcf, 0xf4, 0x5b, 0x56, 0xb8, 0x08, 0xb3, 0xd2, 0x43, 0x8c, 0xac,
+ 0xd2, 0xf8, 0xcc, 0x6d, 0x90, 0x97, 0xff, 0x12, 0x74, 0x97, 0xf8, 0xa4,
+ 0xe3, 0x95, 0xae, 0x92, 0xdc, 0x7e, 0x9d, 0x2b, 0xb4, 0x94, 0xc3, 0x8d,
+ 0x80, 0xe7, 0x77, 0x5c, 0x5b, 0xbb, 0x43, 0xdc, 0xa6, 0xe9, 0xbe, 0x20,
+ 0xcc, 0x9d, 0x8e, 0xa4, 0x2b, 0xf2, 0x72, 0xdc, 0x44, 0x61, 0x0f, 0xad,
+ 0x1a, 0x5e, 0xa5, 0x48, 0xe4, 0x42, 0xc5, 0xe4, 0xf1, 0x6d, 0x33, 0xdb,
+ 0xb2, 0x1b, 0x9f, 0xb2, 0xff, 0x18, 0x0e, 0x62, 0x35, 0x99, 0xed, 0x22,
+ 0x19, 0x4a, 0x5e, 0xb3, 0x3c, 0x07, 0x8f, 0x6e, 0x22, 0x5b, 0x16, 0x4a,
+ 0x9f, 0xef, 0xf3, 0xe7, 0xd6, 0x48, 0xe1, 0xb4, 0x3b, 0xab, 0x1b, 0x9e,
+ 0x53, 0xd7, 0x1b, 0xd9, 0x2d, 0x51, 0x8f, 0xe4, 0x1c, 0xab, 0xdd, 0xb9,
+ 0xe2, 0xee, 0xe4, 0xdd, 0x60, 0x04, 0x86, 0x6b, 0x4e, 0x7a, 0xc8, 0x09,
+ 0x51, 0xd1, 0x9b, 0x36, 0x9a, 0x36, 0x7f, 0xe8, 0x6b, 0x09, 0x6c, 0xee,
+ 0xad, 0x3a, 0x2f, 0xa8, 0x63, 0x92, 0x23, 0x2f, 0x7e, 0x00, 0xe2, 0xd1,
+ 0xbb, 0xd9, 0x5b, 0x5b, 0xfa, 0x4b, 0x83, 0x00, 0x19, 0x28, 0xfb, 0x7e,
+ 0xfe, 0x58, 0xab, 0xb7, 0x33, 0x45, 0x8f, 0x75, 0x9a, 0x54, 0x3d, 0x77,
+ 0x06, 0x75, 0x61, 0x4f, 0x5c, 0x93, 0xa0, 0xf9, 0xe8, 0xcf, 0xf6, 0x04,
+ 0x14, 0xda, 0x1b, 0x2e, 0x79, 0x35, 0xb8, 0xb4, 0xfa, 0x08, 0x27, 0x9a,
+ 0x03, 0x70, 0x78, 0x97, 0x8f, 0xae, 0x2e, 0xd5, 0x1c, 0xe0, 0x4d, 0x91,
+ 0x3a, 0xfe, 0x1a, 0x64, 0xd8, 0x49, 0xdf, 0x6c, 0x66, 0xac, 0xc9, 0x57,
+ 0x06, 0x72, 0xc0, 0xc0, 0x09, 0x71, 0x6a, 0xd0, 0xb0, 0x7d, 0x35, 0x3f,
+ 0x53, 0x17, 0x49, 0x38, 0x92, 0x22, 0x55, 0xf6, 0x58, 0x56, 0xa2, 0x42,
+ 0x77, 0x94, 0xb7, 0x28, 0x0a, 0xa0, 0xd2, 0xda, 0x25, 0xc1, 0xcc, 0x52,
+ 0x51, 0xd6, 0xba, 0x18, 0x0f, 0x0d, 0xe3, 0x7d, 0xd1, 0xda, 0xd9, 0x0c,
+ 0x5e, 0x3a, 0xca, 0xe9, 0xf1, 0xf5, 0x65, 0xfc, 0xc3, 0x99, 0x72, 0x25,
+ 0xf2, 0xc0, 0xa1, 0x8c, 0x43, 0x9d, 0xb2, 0xc9, 0xb1, 0x1a, 0x24, 0x34,
+ 0x57, 0xd8, 0xa7, 0x52, 0xa3, 0x39, 0x6e, 0x0b, 0xec, 0xbd, 0x5e, 0xc9,
+ 0x1f, 0x74, 0xed, 0xae, 0xe6, 0x4e, 0x49, 0xe8, 0x87, 0x3e, 0x46, 0x0d,
+ 0x40, 0x30, 0xda, 0x9d, 0xcf, 0xf5, 0x03, 0x1f, 0x38, 0x29, 0x3b, 0x66,
+ 0xe5, 0xc0, 0x89, 0x4c, 0xfc, 0x09, 0x62, 0x37, 0x01, 0xf9, 0x01, 0xab,
+ 0x8d, 0x53, 0x9c, 0x36, 0x5d, 0x36, 0x66, 0x8d, 0x87, 0xf4, 0xab, 0x37,
+ 0xb7, 0xf7, 0xe3, 0xdf, 0xc1, 0x52, 0xc0, 0x1d, 0x09, 0x92, 0x21, 0x47,
+ 0x49, 0x9a, 0x19, 0x38, 0x05, 0x62, 0xf3, 0x47, 0x80, 0x89, 0x1e, 0x70,
+ 0xa1, 0x57, 0xb7, 0x72, 0xd0, 0x41, 0x7a, 0x5c, 0x6a, 0x13, 0x8b, 0x6c,
+ 0xda, 0xdf, 0x6b, 0x01, 0x15, 0x20, 0xfa, 0xc8, 0x67, 0xee, 0xb2, 0x13,
+ 0xd8, 0x5f, 0x84, 0x30, 0x44, 0x8e, 0xf9, 0x2a, 0xae, 0x17, 0x53, 0x49,
+ 0xaa, 0x34, 0x31, 0x12, 0x31, 0xec, 0xf3, 0x25, 0x27, 0x53, 0x6b, 0xb5,
+ 0x63, 0xa6, 0xbc, 0xf1, 0x77, 0xd4, 0xb4, 0x77, 0xd1, 0xee, 0xad, 0x62,
+ 0x9d, 0x2c, 0x2e, 0x11, 0x0a, 0xd1, 0x87, 0xfe, 0xef, 0x77, 0x0e, 0xd1,
+ 0x38, 0xfe, 0xcc, 0x88, 0xaa, 0x1c, 0x06, 0x93, 0x25, 0x56, 0xfe, 0x0c,
+ 0x52, 0xe9, 0x7f, 0x4c, 0x3b, 0x2a, 0xfb, 0x40, 0x62, 0x29, 0x0a, 0x1d,
+ 0x58, 0x78, 0x8b, 0x09, 0x25, 0xaa, 0xc6, 0x8f, 0x66, 0x8f, 0xd1, 0x93,
+ 0x5a, 0xd6, 0x68, 0x35, 0x69, 0x13, 0x5d, 0x42, 0x35, 0x95, 0xcb, 0xc4,
+ 0xec, 0x17, 0x92, 0x96, 0xcb, 0x4a, 0xb9, 0x8f, 0xe5, 0xc4, 0x4a, 0xe7,
+ 0x54, 0x52, 0x4c, 0x64, 0x06, 0xac, 0x2f, 0x13, 0x32, 0x02, 0x47, 0x13,
+ 0x5c, 0xa2, 0x66, 0xdc, 0x36, 0x0c, 0x4f, 0xbb, 0x89, 0x58, 0x85, 0x16,
+ 0xf1, 0xf1, 0xff, 0xd2, 0x86, 0x54, 0x29, 0xb3, 0x7e, 0x2a, 0xbd, 0xf9,
+ 0x53, 0x8c, 0xa0, 0x60, 0x60, 0xb2, 0x90, 0x7f, 0x3a, 0x11, 0x5f, 0x2a,
+ 0x50, 0x74, 0x2a, 0xd1, 0x68, 0x78, 0xdb, 0x31, 0x1b, 0x8b, 0xee, 0xee,
+ 0x18, 0x97, 0xf3, 0x50, 0x84, 0xc1, 0x8f, 0xe1, 0xc6, 0x01, 0xb4, 0x16,
+ 0x65, 0x25, 0x0c, 0x03, 0xab, 0xed, 0x4f, 0xd6, 0xe6, 0x16, 0x23, 0xcc,
+ 0x42, 0x93, 0xff, 0xfa, 0x92, 0x63, 0x33, 0x9e, 0x36, 0xb0, 0xdc, 0x9a,
+ 0xb6, 0xaa, 0xd7, 0x48, 0xfe, 0x27, 0x01, 0xcf, 0x67, 0xc0, 0x75, 0xa0,
+ 0x86, 0x9a, 0xec, 0xa7, 0x2e, 0xb8, 0x7b, 0x00, 0x7f, 0xd4, 0xe3, 0xb3,
+ 0xfc, 0x48, 0xab, 0x50, 0x20, 0xd4, 0x0d, 0x58, 0x26, 0xc0, 0x3c, 0x09,
+ 0x0b, 0x80, 0x9e, 0xaf, 0x14, 0x3c, 0x0c, 0x6e, 0x69, 0xbc, 0x6c, 0x4e,
+ 0x50, 0x33, 0xb0, 0x07, 0x64, 0x6e, 0x77, 0x96, 0xc2, 0xe6, 0x3b, 0xd7,
+ 0xfe, 0xdc, 0xa4, 0x2f, 0x18, 0x5b, 0x53, 0xe5, 0xdd, 0xb6, 0xce, 0xeb,
+ 0x16, 0xb4, 0x25, 0xc6, 0xcb, 0xf2, 0x65, 0x3c, 0x4f, 0x94, 0xa5, 0x11,
+ 0x18, 0xeb, 0x7b, 0x62, 0x1d, 0xd5, 0x02, 0x35, 0x76, 0xf6, 0xb5, 0x20,
+ 0x27, 0x21, 0x9b, 0xab, 0xf4, 0xb6, 0x8f, 0x1a, 0x70, 0x1d, 0x12, 0xe3,
+ 0xb9, 0x8e, 0x29, 0x52, 0x25, 0xf4, 0xba, 0xb4, 0x25, 0x2c, 0x91, 0x11,
+ 0xf2, 0xae, 0x7b, 0xbe, 0xb6, 0x67, 0xd6, 0x08, 0xf8, 0x6f, 0xe7, 0xb0,
+ 0x16, 0xc5, 0xf6, 0xd5, 0xfb, 0x07, 0x71, 0x5b, 0x0e, 0xe1, 0x02, 0x03,
+ 0x01, 0x00, 0x01, 0xa3, 0x53, 0x30, 0x51, 0x30, 0x1d, 0x06, 0x03, 0x55,
+ 0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, 0xaa, 0xe7, 0x7f, 0xcf, 0xf8, 0xb4,
+ 0xe0, 0x8d, 0x39, 0x9a, 0x1d, 0x4f, 0x86, 0xa2, 0xac, 0x56, 0x32, 0xd9,
+ 0x58, 0xe3, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30,
+ 0x16, 0x80, 0x14, 0xaa, 0xe7, 0x7f, 0xcf, 0xf8, 0xb4, 0xe0, 0x8d, 0x39,
+ 0x9a, 0x1d, 0x4f, 0x86, 0xa2, 0xac, 0x56, 0x32, 0xd9, 0x58, 0xe3, 0x30,
+ 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x05, 0x30,
+ 0x03, 0x01, 0x01, 0xff, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86,
+ 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x03, 0x82, 0x04, 0x02, 0x00,
+ 0x00, 0x0a, 0x0a, 0x81, 0xb5, 0x2e, 0xac, 0x52, 0xab, 0x0f, 0xeb, 0xad,
+ 0x96, 0xd6, 0xd6, 0x59, 0x8f, 0x55, 0x15, 0x56, 0x70, 0xda, 0xd5, 0x75,
+ 0x47, 0x12, 0x9a, 0x0e, 0xd1, 0x65, 0x68, 0xe0, 0x51, 0x89, 0x59, 0xcc,
+ 0xe3, 0x5a, 0x1b, 0x85, 0x14, 0xa3, 0x1d, 0x9b, 0x3f, 0xd1, 0xa4, 0x42,
+ 0xb0, 0x89, 0x12, 0x93, 0xd3, 0x54, 0x19, 0x04, 0xa2, 0xaf, 0xaa, 0x60,
+ 0xca, 0x03, 0xc2, 0xae, 0x62, 0x8c, 0xb6, 0x31, 0x03, 0xd6, 0xa5, 0xf3,
+ 0x5e, 0x8d, 0x5c, 0x69, 0x4c, 0x7d, 0x81, 0x49, 0x20, 0x25, 0x41, 0xa4,
+ 0x2a, 0x95, 0x87, 0x36, 0xa3, 0x9b, 0x9e, 0x9f, 0xed, 0x85, 0xf3, 0xb1,
+ 0xf1, 0xe9, 0x1b, 0xbb, 0xe3, 0xbc, 0x3b, 0x11, 0x36, 0xca, 0xb9, 0x5f,
+ 0xee, 0x64, 0xde, 0x2a, 0x99, 0x27, 0x91, 0xc0, 0x54, 0x9e, 0x7a, 0xd4,
+ 0x89, 0x8c, 0xa0, 0xe3, 0xfd, 0x44, 0x6f, 0x02, 0x38, 0x3c, 0xee, 0x52,
+ 0x48, 0x1b, 0xd4, 0x25, 0x2b, 0xcb, 0x8e, 0xa8, 0x1b, 0x09, 0xd6, 0x30,
+ 0x51, 0x15, 0x6c, 0x5c, 0x03, 0x76, 0xad, 0x64, 0x45, 0x50, 0xa2, 0xe1,
+ 0x3c, 0x5a, 0x67, 0x87, 0xff, 0x8c, 0xed, 0x9a, 0x8d, 0x04, 0xc1, 0xac,
+ 0xf9, 0xca, 0xf5, 0x2a, 0x05, 0x9c, 0xdd, 0x78, 0xce, 0x99, 0x78, 0x7b,
+ 0xcd, 0x43, 0x10, 0x40, 0xf7, 0xb5, 0x27, 0x12, 0xec, 0xe9, 0xb2, 0x3f,
+ 0xf4, 0x5d, 0xd9, 0xbb, 0xf8, 0xc4, 0xc9, 0xa4, 0x46, 0x20, 0x41, 0x7f,
+ 0xeb, 0x79, 0xb0, 0x51, 0x8c, 0xf7, 0xc3, 0x2c, 0x16, 0xfe, 0x42, 0x59,
+ 0x77, 0xfe, 0x53, 0xfe, 0x19, 0x57, 0x58, 0x44, 0x6d, 0x12, 0xe2, 0x95,
+ 0xd0, 0xd3, 0x5a, 0xb5, 0x2d, 0xe5, 0x7e, 0xb4, 0xb3, 0xa9, 0xcc, 0x7d,
+ 0x53, 0x77, 0x81, 0x01, 0x0f, 0x0a, 0xf6, 0x86, 0x3c, 0x7d, 0xb5, 0x2c,
+ 0xbf, 0x62, 0xc3, 0xf5, 0x38, 0x89, 0x13, 0x84, 0x1f, 0x44, 0x2d, 0x87,
+ 0x5c, 0x23, 0x9e, 0x05, 0x62, 0x56, 0x3d, 0x71, 0x4d, 0xd0, 0xe3, 0x15,
+ 0xe9, 0x09, 0x9c, 0x1a, 0xc0, 0x9a, 0x19, 0x8b, 0x9c, 0xe9, 0xae, 0xde,
+ 0x62, 0x05, 0x23, 0xe2, 0xd0, 0x3f, 0xf5, 0xef, 0x04, 0x96, 0x4c, 0x87,
+ 0x34, 0x2f, 0xd5, 0x90, 0xde, 0xbf, 0x4b, 0x56, 0x12, 0x5f, 0xc6, 0xdc,
+ 0xa4, 0x1c, 0xc4, 0x53, 0x0c, 0xf9, 0xb4, 0xe4, 0x2c, 0xe7, 0x48, 0xbd,
+ 0xb1, 0xac, 0xf1, 0xc1, 0x8d, 0x53, 0x47, 0x84, 0xc0, 0x78, 0x0a, 0x5e,
+ 0xc2, 0x16, 0xff, 0xef, 0x97, 0x5b, 0x33, 0x85, 0x92, 0xcd, 0xd4, 0xbb,
+ 0x64, 0xee, 0xed, 0x17, 0x18, 0x43, 0x32, 0x99, 0x32, 0x36, 0x25, 0xf4,
+ 0x21, 0x3c, 0x2f, 0x55, 0xdc, 0x16, 0x06, 0x4d, 0x86, 0xa3, 0xa9, 0x34,
+ 0x22, 0xd5, 0xc3, 0xc8, 0x64, 0x3c, 0x4e, 0x3a, 0x69, 0xbd, 0xcf, 0xd7,
+ 0xee, 0x3f, 0x0d, 0x15, 0xeb, 0xfb, 0xbd, 0x91, 0x7f, 0xef, 0x48, 0xec,
+ 0x86, 0xb2, 0x78, 0xf7, 0x53, 0x90, 0x38, 0xb5, 0x04, 0x9c, 0xb7, 0xd7,
+ 0x9e, 0xaa, 0x15, 0xf7, 0xcd, 0xc2, 0x17, 0xd5, 0x8f, 0x82, 0x98, 0xa3,
+ 0xaf, 0x59, 0xf1, 0x71, 0xda, 0x6e, 0xaf, 0x97, 0x6d, 0x77, 0x72, 0xfd,
+ 0xa8, 0x80, 0x25, 0xce, 0x46, 0x04, 0x6e, 0x40, 0x15, 0x24, 0xc0, 0xf9,
+ 0xbf, 0x13, 0x16, 0x72, 0xcb, 0xb7, 0x10, 0xc7, 0x0a, 0xd6, 0x66, 0x96,
+ 0x5b, 0x27, 0x4d, 0x66, 0xc4, 0x2f, 0x21, 0x90, 0x9f, 0x8c, 0x24, 0xa0,
+ 0x0e, 0xa2, 0x89, 0x92, 0xd2, 0x44, 0x63, 0x06, 0xb2, 0xab, 0x07, 0x26,
+ 0xde, 0x03, 0x1d, 0xdb, 0x2a, 0x42, 0x5b, 0x4c, 0xf6, 0xfe, 0x53, 0xfa,
+ 0x80, 0x45, 0x8d, 0x75, 0xf6, 0x0e, 0x1d, 0xcc, 0x4c, 0x3b, 0xb0, 0x80,
+ 0x6d, 0x4c, 0xed, 0x7c, 0xe0, 0xd2, 0xe7, 0x62, 0x59, 0xb1, 0x5a, 0x5d,
+ 0x3a, 0xec, 0x86, 0x04, 0xfe, 0x26, 0xd1, 0x18, 0xed, 0x56, 0x7d, 0x67,
+ 0x56, 0x24, 0x6d, 0x7c, 0x6e, 0x8f, 0xc8, 0xa0, 0xba, 0x42, 0x0a, 0x33,
+ 0x38, 0x7a, 0x09, 0x03, 0xc2, 0xbf, 0x9b, 0x01, 0xdd, 0x03, 0x5a, 0xba,
+ 0x76, 0x04, 0xb1, 0xc3, 0x40, 0x23, 0x53, 0xbd, 0x64, 0x4e, 0x0f, 0xe7,
+ 0xc3, 0x4e, 0x48, 0xea, 0x19, 0x2b, 0x1c, 0xe4, 0x3d, 0x93, 0xd8, 0xf6,
+ 0xfb, 0xda, 0x3d, 0xeb, 0xed, 0xc2, 0xbd, 0x14, 0x57, 0x40, 0xde, 0xd1,
+ 0x74, 0x54, 0x1b, 0xa8, 0x39, 0xda, 0x73, 0x56, 0xd4, 0xbe, 0xab, 0xec,
+ 0xc7, 0x17, 0x4f, 0x91, 0xb6, 0xf6, 0xcb, 0x24, 0xc6, 0x1c, 0x07, 0xc4,
+ 0xf3, 0xd0, 0x5e, 0x8d, 0xfa, 0x44, 0x98, 0x5c, 0x87, 0x36, 0x75, 0xb6,
+ 0xa5, 0x31, 0xaa, 0xab, 0x7d, 0x38, 0x66, 0xb3, 0x18, 0x58, 0x65, 0x97,
+ 0x06, 0xfd, 0x61, 0x81, 0x71, 0xc5, 0x17, 0x8b, 0x19, 0x03, 0xc8, 0x58,
+ 0xec, 0x05, 0xca, 0x7b, 0x0f, 0xec, 0x9d, 0xb4, 0xbc, 0xa3, 0x20, 0x2e,
+ 0xf8, 0xe4, 0xb1, 0x82, 0xdc, 0x5a, 0xd2, 0x92, 0x9c, 0x43, 0x5d, 0x16,
+ 0x5b, 0x90, 0x80, 0xe4, 0xfb, 0x6e, 0x24, 0x6b, 0x8c, 0x1a, 0x35, 0xab,
+ 0xbd, 0x77, 0x7f, 0xf9, 0x61, 0x80, 0xa5, 0xab, 0xa3, 0x39, 0xc2, 0xc9,
+ 0x69, 0x3c, 0xfc, 0xb3, 0x9a, 0x05, 0x45, 0x03, 0x88, 0x8f, 0x8e, 0x23,
+ 0xf2, 0x0c, 0x4c, 0x54, 0xb9, 0x40, 0x3a, 0x31, 0x1a, 0x22, 0x67, 0x43,
+ 0x4a, 0x3e, 0xa0, 0x8c, 0x2d, 0x4d, 0x4f, 0xfc, 0xb5, 0x9b, 0x1f, 0xe1,
+ 0xef, 0x02, 0x54, 0xab, 0x8d, 0x75, 0x4d, 0x93, 0xba, 0x76, 0xe1, 0xbc,
+ 0x42, 0x7f, 0x6c, 0xcb, 0xf5, 0x47, 0xd6, 0x8a, 0xac, 0x5d, 0xe9, 0xbb,
+ 0x3a, 0x65, 0x2c, 0x81, 0xe5, 0xff, 0x27, 0x7e, 0x60, 0x64, 0x80, 0x42,
+ 0x8d, 0x36, 0x6b, 0x07, 0x76, 0x6a, 0xf1, 0xdf, 0x96, 0x17, 0x93, 0x21,
+ 0x5d, 0xe4, 0x6c, 0xce, 0x1c, 0xb9, 0x82, 0x45, 0x05, 0x61, 0xe2, 0x41,
+ 0x96, 0x03, 0x7d, 0x10, 0x8b, 0x3e, 0xc7, 0xe5, 0xcf, 0x08, 0xeb, 0x81,
+ 0xd3, 0x82, 0x1b, 0x04, 0x96, 0x93, 0x5a, 0xe2, 0x8c, 0x8e, 0x50, 0x33,
+ 0xf6, 0xf9, 0xf0, 0xfb, 0xb1, 0xd7, 0xc6, 0x97, 0xaa, 0xef, 0x0b, 0x87,
+ 0xe1, 0x34, 0x97, 0x78, 0x2e, 0x7c, 0x46, 0x11, 0xd5, 0x3c, 0xec, 0x38,
+ 0x70, 0x59, 0x14, 0x65, 0x4d, 0x0e, 0xd1, 0xeb, 0x49, 0xb3, 0x99, 0x6f,
+ 0x87, 0xf1, 0x79, 0x21, 0xd9, 0x5c, 0x37, 0xb2, 0xfe, 0xc4, 0x7a, 0xc1,
+ 0x67, 0xbd, 0x02, 0xfc, 0x02, 0xab, 0x2f, 0xf5, 0x0f, 0xa7, 0xae, 0x90,
+ 0xc2, 0xaf, 0xdb, 0xd1, 0x96, 0xb2, 0x92, 0x5a, 0xfb, 0xca, 0x28, 0x74,
+ 0x17, 0xed, 0xda, 0x2c, 0x9f, 0xb4, 0x2d, 0xf5, 0x71, 0x20, 0x64, 0x2d,
+ 0x44, 0xe5, 0xa3, 0xa0, 0x94, 0x6f, 0x20, 0xb3, 0x73, 0x96, 0x40, 0x06,
+ 0x9b, 0x25, 0x47, 0x4b, 0xe0, 0x63, 0x91, 0xd9, 0xda, 0xf3, 0xc3, 0xe5,
+ 0x3a, 0x3c, 0xb7, 0x5f, 0xab, 0x1e, 0x51, 0x17, 0x4f, 0xec, 0xc1, 0x6d,
+ 0x82, 0x79, 0x8e, 0xba, 0x7c, 0x47, 0x8e, 0x99, 0x00, 0x17, 0x9e, 0xda,
+ 0x10, 0x42, 0x70, 0x25, 0x42, 0x84, 0xc8, 0xb1, 0x95, 0x56, 0xb2, 0x08,
+ 0xa0, 0x4f, 0xdc, 0xcd, 0x9e, 0x31, 0x4b, 0x0c, 0x0b, 0x03, 0x5d, 0x2c,
+ 0x26, 0xbc, 0xa9, 0x4b, 0x19, 0xdf, 0x90, 0x01, 0x9a, 0xe0, 0x06, 0x05,
+ 0x13, 0x34, 0x9d, 0x34, 0xb8, 0xef, 0x13, 0x3a, 0x20, 0xf5, 0x74, 0x02,
+ 0x70, 0x3b, 0x41, 0x60, 0x1f, 0x5e, 0x76, 0x0a, 0xb1, 0x17, 0xd5, 0xcf,
+ 0x79, 0xef, 0xf7, 0xab, 0xe7, 0xd6, 0x0f, 0xad, 0x85, 0x2c, 0x52, 0x67,
+ 0xb5, 0xa0, 0x4a, 0xfd, 0xaf};
diff --git a/security/nss/gtests/ssl_gtest/selfencrypt_unittest.cc b/security/nss/gtests/ssl_gtest/selfencrypt_unittest.cc
new file mode 100644
index 0000000000..24f000454b
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/selfencrypt_unittest.cc
@@ -0,0 +1,281 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "nss.h"
+#include "pk11pub.h"
+#include "prerror.h"
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+extern "C" {
+#include "sslimpl.h"
+#include "selfencrypt.h"
+}
+
+#include "databuffer.h"
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+
+namespace nss_test {
+
+static const uint8_t kAesKey1Buf[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
+ 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
+ 0x0c, 0x0d, 0x0e, 0x0f};
+static const DataBuffer kAesKey1(kAesKey1Buf, sizeof(kAesKey1Buf));
+
+static const uint8_t kAesKey2Buf[] = {0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
+ 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b,
+ 0x1c, 0x1d, 0x1e, 0x1f};
+static const DataBuffer kAesKey2(kAesKey2Buf, sizeof(kAesKey2Buf));
+
+static const uint8_t kHmacKey1Buf[] = {
+ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a,
+ 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
+ 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f};
+static const DataBuffer kHmacKey1(kHmacKey1Buf, sizeof(kHmacKey1Buf));
+
+static const uint8_t kHmacKey2Buf[] = {
+ 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a,
+ 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25,
+ 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f};
+static const DataBuffer kHmacKey2(kHmacKey2Buf, sizeof(kHmacKey2Buf));
+
+static const uint8_t* kKeyName1 =
+ reinterpret_cast<const unsigned char*>("KEY1KEY1KEY1KEY1");
+static const uint8_t* kKeyName2 =
+ reinterpret_cast<const uint8_t*>("KEY2KEY2KEY2KEY2");
+
+static void ImportKey(const DataBuffer& key, PK11SlotInfo* slot,
+ CK_MECHANISM_TYPE mech, CK_ATTRIBUTE_TYPE cka,
+ ScopedPK11SymKey* to) {
+ SECItem key_item = {siBuffer, const_cast<uint8_t*>(key.data()),
+ static_cast<unsigned int>(key.len())};
+
+ PK11SymKey* inner =
+ PK11_ImportSymKey(slot, mech, PK11_OriginUnwrap, cka, &key_item, nullptr);
+ ASSERT_NE(nullptr, inner);
+ to->reset(inner);
+}
+
+extern "C" {
+extern char ssl_trace;
+extern FILE* ssl_trace_iob;
+}
+
+class SelfEncryptTestBase : public ::testing::Test {
+ public:
+ SelfEncryptTestBase(size_t message_size)
+ : aes1_(),
+ aes2_(),
+ hmac1_(),
+ hmac2_(),
+ message_(),
+ slot_(PK11_GetInternalSlot()) {
+ EXPECT_NE(nullptr, slot_);
+ char* ev = getenv("SSLTRACE");
+ if (ev && ev[0]) {
+ ssl_trace = atoi(ev);
+ ssl_trace_iob = stderr;
+ }
+ message_.Allocate(message_size);
+ for (size_t i = 0; i < message_.len(); ++i) {
+ message_.data()[i] = i;
+ }
+ }
+
+ void SetUp() {
+ message_.Allocate(100);
+ for (size_t i = 0; i < 100; ++i) {
+ message_.data()[i] = i;
+ }
+ ImportKey(kAesKey1, slot_.get(), CKM_AES_CBC, CKA_ENCRYPT, &aes1_);
+ ImportKey(kAesKey2, slot_.get(), CKM_AES_CBC, CKA_ENCRYPT, &aes2_);
+ ImportKey(kHmacKey1, slot_.get(), CKM_SHA256_HMAC, CKA_SIGN, &hmac1_);
+ ImportKey(kHmacKey2, slot_.get(), CKM_SHA256_HMAC, CKA_SIGN, &hmac2_);
+ }
+
+ void SelfTest(
+ const uint8_t* writeKeyName, const ScopedPK11SymKey& writeAes,
+ const ScopedPK11SymKey& writeHmac, const uint8_t* readKeyName,
+ const ScopedPK11SymKey& readAes, const ScopedPK11SymKey& readHmac,
+ PRErrorCode protect_error_code = 0, PRErrorCode unprotect_error_code = 0,
+ std::function<void(uint8_t* ciphertext, unsigned int* ciphertext_len)>
+ mutate = nullptr) {
+ uint8_t ciphertext[1000];
+ unsigned int ciphertext_len;
+ uint8_t plaintext[1000];
+ unsigned int plaintext_len;
+
+ SECStatus rv = ssl_SelfEncryptProtectInt(
+ writeAes.get(), writeHmac.get(), writeKeyName, message_.data(),
+ message_.len(), ciphertext, &ciphertext_len, sizeof(ciphertext));
+ if (rv != SECSuccess) {
+ std::cerr << "Error: " << PORT_ErrorToName(PORT_GetError()) << std::endl;
+ }
+ if (protect_error_code) {
+ ASSERT_EQ(protect_error_code, PORT_GetError());
+ return;
+ }
+ ASSERT_EQ(SECSuccess, rv);
+
+ if (mutate) {
+ mutate(ciphertext, &ciphertext_len);
+ }
+ rv = ssl_SelfEncryptUnprotectInt(readAes.get(), readHmac.get(), readKeyName,
+ ciphertext, ciphertext_len, plaintext,
+ &plaintext_len, sizeof(plaintext));
+ if (rv != SECSuccess) {
+ std::cerr << "Error: " << PORT_ErrorToName(PORT_GetError()) << std::endl;
+ }
+ if (!unprotect_error_code) {
+ ASSERT_EQ(SECSuccess, rv);
+ EXPECT_EQ(message_.len(), plaintext_len);
+ EXPECT_EQ(0, memcmp(message_.data(), plaintext, message_.len()));
+ } else {
+ ASSERT_EQ(SECFailure, rv);
+ EXPECT_EQ(unprotect_error_code, PORT_GetError());
+ }
+ }
+
+ protected:
+ ScopedPK11SymKey aes1_;
+ ScopedPK11SymKey aes2_;
+ ScopedPK11SymKey hmac1_;
+ ScopedPK11SymKey hmac2_;
+ DataBuffer message_;
+
+ private:
+ ScopedPK11SlotInfo slot_;
+};
+
+class SelfEncryptTestVariable : public SelfEncryptTestBase,
+ public ::testing::WithParamInterface<size_t> {
+ public:
+ SelfEncryptTestVariable() : SelfEncryptTestBase(GetParam()) {}
+};
+
+class SelfEncryptTest128 : public SelfEncryptTestBase {
+ public:
+ SelfEncryptTest128() : SelfEncryptTestBase(128) {}
+};
+
+TEST_P(SelfEncryptTestVariable, SuccessCase) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_);
+}
+
+TEST_P(SelfEncryptTestVariable, WrongMacKey) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac2_, 0,
+ SEC_ERROR_BAD_DATA);
+}
+
+TEST_P(SelfEncryptTestVariable, WrongKeyName) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName2, aes1_, hmac1_, 0,
+ SEC_ERROR_NOT_A_RECIPIENT);
+}
+
+TEST_P(SelfEncryptTestVariable, AddAByte) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0,
+ SEC_ERROR_BAD_DATA,
+ [](uint8_t* ciphertext, unsigned int* ciphertext_len) {
+ (*ciphertext_len)++;
+ });
+}
+
+TEST_P(SelfEncryptTestVariable, SubtractAByte) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0,
+ SEC_ERROR_BAD_DATA,
+ [](uint8_t* ciphertext, unsigned int* ciphertext_len) {
+ (*ciphertext_len)--;
+ });
+}
+
+TEST_P(SelfEncryptTestVariable, BogusIv) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0,
+ SEC_ERROR_BAD_DATA,
+ [](uint8_t* ciphertext, unsigned int* ciphertext_len) {
+ ciphertext[16]++;
+ });
+}
+
+TEST_P(SelfEncryptTestVariable, BogusCiphertext) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0,
+ SEC_ERROR_BAD_DATA,
+ [](uint8_t* ciphertext, unsigned int* ciphertext_len) {
+ ciphertext[32]++;
+ });
+}
+
+TEST_P(SelfEncryptTestVariable, BadMac) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0,
+ SEC_ERROR_BAD_DATA,
+ [](uint8_t* ciphertext, unsigned int* ciphertext_len) {
+ ciphertext[*ciphertext_len - 1]++;
+ });
+}
+
+TEST_F(SelfEncryptTest128, DISABLED_BadPadding) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes2_, hmac1_, 0,
+ SEC_ERROR_BAD_DATA);
+}
+
+TEST_F(SelfEncryptTest128, ShortKeyName) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0,
+ SEC_ERROR_BAD_DATA,
+ [](uint8_t* ciphertext, unsigned int* ciphertext_len) {
+ *ciphertext_len = 15;
+ });
+}
+
+TEST_F(SelfEncryptTest128, ShortIv) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0,
+ SEC_ERROR_BAD_DATA,
+ [](uint8_t* ciphertext, unsigned int* ciphertext_len) {
+ *ciphertext_len = 31;
+ });
+}
+
+TEST_F(SelfEncryptTest128, ShortCiphertextLen) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0,
+ SEC_ERROR_BAD_DATA,
+ [](uint8_t* ciphertext, unsigned int* ciphertext_len) {
+ *ciphertext_len = 32;
+ });
+}
+
+TEST_F(SelfEncryptTest128, ShortCiphertext) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, hmac1_, 0,
+ SEC_ERROR_BAD_DATA,
+ [](uint8_t* ciphertext, unsigned int* ciphertext_len) {
+ *ciphertext_len -= 17;
+ });
+}
+
+TEST_F(SelfEncryptTest128, MacWithAESKeyEncrypt) {
+ SelfTest(kKeyName1, aes1_, aes1_, kKeyName1, aes1_, hmac1_,
+ SEC_ERROR_LIBRARY_FAILURE);
+}
+
+TEST_F(SelfEncryptTest128, AESWithMacKeyEncrypt) {
+ SelfTest(kKeyName1, hmac1_, hmac1_, kKeyName1, aes1_, hmac1_,
+ SEC_ERROR_INVALID_KEY);
+}
+
+TEST_F(SelfEncryptTest128, MacWithAESKeyDecrypt) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, aes1_, aes1_, 0,
+ SEC_ERROR_LIBRARY_FAILURE);
+}
+
+TEST_F(SelfEncryptTest128, AESWithMacKeyDecrypt) {
+ SelfTest(kKeyName1, aes1_, hmac1_, kKeyName1, hmac1_, hmac1_, 0,
+ SEC_ERROR_INVALID_KEY);
+}
+
+INSTANTIATE_TEST_SUITE_P(VariousSizes, SelfEncryptTestVariable,
+ ::testing::Values(0, 15, 16, 31, 255, 256, 257));
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc
new file mode 100644
index 0000000000..51ec9d3ee5
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc
@@ -0,0 +1,1183 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslexp.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "cpputil.h"
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectTls13, ZeroRtt) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, ZeroRttServerRejectByOption) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false);
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, ZeroRttApplicationReject) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+
+ auto reject_0rtt = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<bool*>(arg);
+ *called = true;
+
+ EXPECT_TRUE(firstHello);
+ EXPECT_EQ(0U, clientTokenLen);
+ return ssl_hello_retry_reject_0rtt;
+ };
+
+ bool cb_run = false;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ reject_0rtt, &cb_run));
+ ZeroRttSendReceive(true, false);
+ Handshake();
+ EXPECT_TRUE(cb_run);
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, ZeroRttApparentReplayAfterRestart) {
+ // The test fixtures enable anti-replay in SetUp(). This results in 0-RTT
+ // being rejected until at least one window passes. SetupFor0Rtt() forces a
+ // rollover of the anti-replay filters, which clears that state and allows
+ // 0-RTT to work. Make the first connection manually to avoid that rollover
+ // and cause 0-RTT to be rejected.
+
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT.
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ StartConnect();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false);
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+class TlsZeroRttReplayTest : public TlsConnectTls13 {
+ private:
+ class SaveFirstPacket : public PacketFilter {
+ public:
+ PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) override {
+ if (!packet_.len() && input.len()) {
+ packet_ = input;
+ }
+ return KEEP;
+ }
+
+ const DataBuffer& packet() const { return packet_; }
+
+ private:
+ DataBuffer packet_;
+ };
+
+ protected:
+ void RunTest(bool rollover, const ScopedPK11SymKey& epsk) {
+ // Now run a true 0-RTT handshake, but capture the first packet.
+ auto first_packet = std::make_shared<SaveFirstPacket>();
+ client_->SetFilter(first_packet);
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ EXPECT_LT(0U, first_packet->packet().len());
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+
+ if (rollover) {
+ RolloverAntiReplay();
+ }
+
+ // Now replay that packet against the server.
+ Reset();
+ server_->StartConnect();
+ server_->Set0RttEnabled(true);
+ server_->SetAntiReplayContext(anti_replay_);
+ if (epsk) {
+ AddPsk(epsk, std::string("foo"), ssl_hash_sha256,
+ TLS_CHACHA20_POLY1305_SHA256);
+ }
+
+ // Capture the early_data extension, which should not appear.
+ auto early_data_ext =
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_early_data_xtn);
+
+ // Finally, replay the ClientHello and force the server to consume it. Stop
+ // after the server sends its first flight; the client will not be able to
+ // complete this handshake.
+ server_->adapter()->PacketReceived(first_packet->packet());
+ server_->Handshake();
+ EXPECT_FALSE(early_data_ext->captured());
+ }
+
+ void RunResPskTest(bool rollover) {
+ // Run the initial handshake
+ SetupForZeroRtt();
+ ExpectResumption(RESUME_TICKET);
+ RunTest(rollover, ScopedPK11SymKey(nullptr));
+ }
+
+ void RunExtPskTest(bool rollover) {
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_NE(nullptr, slot);
+
+ const std::vector<uint8_t> kPskDummyVal(16, 0xFF);
+ SECItem psk_item = {siBuffer, toUcharPtr(kPskDummyVal.data()),
+ static_cast<unsigned int>(kPskDummyVal.size())};
+ PK11SymKey* key =
+ PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN, PK11_OriginUnwrap,
+ CKA_DERIVE, &psk_item, NULL);
+ ASSERT_NE(nullptr, key);
+ ScopedPK11SymKey scoped_psk(key);
+ RolloverAntiReplay();
+ AddPsk(scoped_psk, std::string("foo"), ssl_hash_sha256,
+ TLS_CHACHA20_POLY1305_SHA256);
+ StartConnect();
+ RunTest(rollover, scoped_psk);
+ }
+};
+
+TEST_P(TlsZeroRttReplayTest, ResPskZeroRttReplay) { RunResPskTest(false); }
+
+TEST_P(TlsZeroRttReplayTest, ExtPskZeroRttReplay) { RunExtPskTest(false); }
+
+TEST_P(TlsZeroRttReplayTest, ZeroRttReplayAfterRollover) {
+ RunResPskTest(true);
+}
+
+// Test that we don't try to send 0-RTT data when the server sent
+// us a ticket without the 0-RTT flags.
+TEST_P(TlsConnectTls13, ZeroRttOptionsSetLate) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+ Reset();
+ StartConnect();
+ // Now turn on 0-RTT but too late for the ticket.
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(false, false);
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+// Make sure that a session ticket sent well after the original handshake
+// can be used for 0-RTT.
+// Stream because DTLS doesn't support SSL_SendSessionTicket.
+TEST_F(TlsConnectStreamTls13, ZeroRttUsingLateTicket) {
+ // Use a small-ish anti-replay window.
+ ResetAntiReplay(100 * PR_USEC_PER_MSEC);
+ RolloverAntiReplay();
+
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ server_->Set0RttEnabled(true);
+ Connect();
+ CheckKeys();
+
+ // Now move time forward 30s and send a ticket.
+ AdvanceTime(30 * PR_USEC_PER_SEC);
+ EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0));
+ SendReceive();
+ Reset();
+ StartConnect();
+
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+// Check that post-handshake authentication with a long RTT doesn't
+// make things worse.
+TEST_F(TlsConnectStreamTls13, ZeroRttUsingLateTicketPha) {
+ // Use a small-ish anti-replay window.
+ ResetAntiReplay(100 * PR_USEC_PER_MSEC);
+ RolloverAntiReplay();
+
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ server_->Set0RttEnabled(true);
+ client_->SetupClientAuth();
+ client_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE);
+ Connect();
+ CheckKeys();
+
+ // Add post-handshake authentication, with some added delays.
+ AdvanceTime(10 * PR_USEC_PER_SEC);
+ EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()));
+ AdvanceTime(10 * PR_USEC_PER_SEC);
+ server_->SendData(50);
+ client_->ReadBytes(50);
+ client_->SendData(50);
+ server_->ReadBytes(50);
+
+ AdvanceTime(10 * PR_USEC_PER_SEC);
+ EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0));
+ server_->SendData(100);
+ client_->ReadBytes(100);
+ Reset();
+ StartConnect();
+
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+// Same, but with client authentication on the first connection.
+TEST_F(TlsConnectStreamTls13, ZeroRttUsingLateTicketClientAuth) {
+ // Use a small-ish anti-replay window.
+ ResetAntiReplay(100 * PR_USEC_PER_MSEC);
+ RolloverAntiReplay();
+
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ server_->Set0RttEnabled(true);
+ Connect();
+ CheckKeys();
+
+ // Now move time forward 30s and send a ticket.
+ AdvanceTime(30 * PR_USEC_PER_SEC);
+ EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0));
+ SendReceive();
+ Reset();
+ StartConnect();
+
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, ZeroRttServerForgetTicket) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ClearServerCache();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_NONE);
+ ZeroRttSendReceive(true, false);
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, ZeroRttServerOnly) {
+ ExpectResumption(RESUME_NONE);
+ server_->Set0RttEnabled(true);
+ StartConnect();
+
+ // Client sends ordinary ClientHello.
+ client_->Handshake();
+
+ // Verify that the server doesn't get data.
+ uint8_t buf[100];
+ PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Now make sure that things complete.
+ Handshake();
+ CheckConnected();
+ SendReceive();
+ CheckKeys();
+}
+
+// Advancing time after sending the ClientHello means that the ticket age that
+// arrives at the server is too low. The server then rejects early data if this
+// delay exceeds half the anti-replay window.
+TEST_P(TlsConnectTls13, ZeroRttRejectOldTicket) {
+ static const PRTime kWindow = 10 * PR_USEC_PER_SEC;
+ ResetAntiReplay(kWindow);
+ SetupForZeroRtt();
+
+ Reset();
+ StartConnect();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false, [this]() {
+ AdvanceTime(1 + kWindow / 2);
+ return true;
+ });
+ Handshake();
+ ExpectEarlyDataAccepted(false);
+ CheckConnected();
+ SendReceive();
+}
+
+// In this test, we falsely inflate the estimate of the RTT by delaying the
+// ServerHello on the first handshake. This results in the server estimating a
+// higher value of the ticket age than the client ultimately provides. Add a
+// small tolerance for variation in ticket age and the ticket will appear to
+// arrive prematurely, causing the server to reject early data.
+TEST_P(TlsConnectTls13, ZeroRttRejectPrematureTicket) {
+ static const PRTime kWindow = 10 * PR_USEC_PER_SEC;
+ ResetAntiReplay(kWindow);
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->Set0RttEnabled(true);
+ StartConnect();
+ client_->Handshake(); // ClientHello
+ server_->Handshake(); // ServerHello
+ AdvanceTime(1 + kWindow / 2);
+ Handshake(); // Remainder of handshake
+ CheckConnected();
+ SendReceive();
+ CheckKeys();
+
+ Reset();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ExpectEarlyDataAccepted(false);
+ StartConnect();
+ ZeroRttSendReceive(true, false);
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpn) {
+ EnableAlpn();
+ SetupForZeroRtt();
+ EnableAlpn();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ExpectEarlyDataAccepted(true);
+ ZeroRttSendReceive(true, true, [this]() {
+ client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a");
+ return true;
+ });
+ Handshake();
+ CheckConnected();
+ SendReceive();
+ CheckAlpn("a");
+}
+
+// NOTE: In this test and those below, the client always sends
+// post-ServerHello alerts with the handshake keys, even if the server
+// has accepted 0-RTT. In some cases, as with errors in
+// EncryptedExtensions, the client can't know the server's behavior,
+// and in others it's just simpler. What the server is expecting
+// depends on whether it accepted 0-RTT or not. Eventually, we may
+// make the server trial decrypt.
+//
+// Have the server negotiate a different ALPN value, and therefore
+// reject 0-RTT.
+TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeServer) {
+ EnableAlpn();
+ SetupForZeroRtt();
+ static const uint8_t client_alpn[] = {0x01, 0x61, 0x01, 0x62}; // "a", "b"
+ static const uint8_t server_alpn[] = {0x01, 0x62}; // "b"
+ client_->EnableAlpn(client_alpn, sizeof(client_alpn));
+ server_->EnableAlpn(server_alpn, sizeof(server_alpn));
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false, [this]() {
+ client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a");
+ return true;
+ });
+ Handshake();
+ CheckConnected();
+ SendReceive();
+ CheckAlpn("b");
+}
+
+// Check that the client validates the ALPN selection of the server.
+// Stomp the ALPN on the client after sending the ClientHello so
+// that the server selection appears to be incorrect. The client
+// should then fail the connection.
+TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnServer) {
+ EnableAlpn();
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ EnableAlpn();
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true, [this]() {
+ PRUint8 b[] = {'b'};
+ client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a");
+ EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, sizeof(b)));
+ client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b");
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ return true;
+ });
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ } else {
+ client_->Handshake();
+ }
+ client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID);
+}
+
+// Set up with no ALPN and then set the client so it thinks it has ALPN.
+// The server responds without the extension and the client returns an
+// error.
+TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnClient) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true, [this]() {
+ PRUint8 b[] = {'b'};
+ EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, 1));
+ client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b");
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ return true;
+ });
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ } else {
+ client_->Handshake();
+ }
+ client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID);
+}
+
+// Remove the old ALPN value and so the client will not offer early data.
+TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeBoth) {
+ EnableAlpn();
+ SetupForZeroRtt();
+ static const std::vector<uint8_t> alpn({0x01, 0x62}); // "b"
+ EnableAlpn(alpn);
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false, [this]() {
+ client_->CheckAlpn(SSL_NEXT_PROTO_NO_SUPPORT);
+ return false;
+ });
+ Handshake();
+ CheckConnected();
+ SendReceive();
+ CheckAlpn("b");
+}
+
+// The client should abort the connection when sending a 0-rtt handshake but
+// the servers responds with a TLS 1.2 ServerHello. (no app data sent)
+TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngrade) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ server_->Set0RttEnabled(true); // set ticket_allow_early_data
+ Connect();
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ StartConnect();
+ // We will send the early data xtn without sending actual early data. Thus
+ // a 1.2 server shouldn't fail until the client sends an alert because the
+ // client sends end_of_early_data only after reading the server's flight.
+ client_->Set0RttEnabled(true);
+
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ }
+ client_->Handshake();
+ server_->Handshake();
+ ASSERT_TRUE_WAIT(
+ (client_->error_code() == SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA), 2000);
+
+ // DTLS will timeout as we bump the epoch when installing the early app data
+ // cipher suite. Thus the encrypted alert will be ignored.
+ if (variant_ == ssl_variant_stream) {
+ // The client sends an encrypted alert message.
+ ASSERT_TRUE_WAIT(
+ (server_->error_code() == SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA),
+ 2000);
+ }
+}
+
+// The client should abort the connection when sending a 0-rtt handshake but
+// the servers responds with a TLS 1.2 ServerHello. (with app data)
+TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) {
+ const char* k0RttData = "ABCDEF";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ server_->Set0RttEnabled(true); // set ticket_allow_early_data
+ Connect();
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ StartConnect();
+ // Send the early data xtn in the CH, followed by early app data. The server
+ // will fail right after sending its flight, when receiving the early data.
+ client_->Set0RttEnabled(true);
+ client_->Handshake(); // Send ClientHello.
+ PRInt32 rv =
+ PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write.
+ EXPECT_EQ(k0RttDataLen, rv);
+
+ if (variant_ == ssl_variant_stream) {
+ // When the server receives the early data, it will fail.
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ server_->Handshake(); // Consume ClientHello
+ EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state());
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA);
+ } else {
+ // If it's datagram, we just discard the early data.
+ server_->Handshake(); // Consume ClientHello
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
+ }
+
+ // The client now reads the ServerHello and fails.
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ client_->Handshake();
+ client_->CheckErrorCode(SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA);
+}
+
+TEST_P(TlsConnectTls13, SendTooMuchEarlyData) {
+ EnsureTlsSetup();
+ const char* big_message = "0123456789abcdef";
+ const size_t short_size = strlen(big_message) - 1;
+ const PRInt32 short_length = static_cast<PRInt32>(short_size);
+ EXPECT_EQ(SECSuccess,
+ SSL_SetMaxEarlyDataSize(server_->ssl_fd(),
+ static_cast<PRUint32>(short_size)));
+ SetupForZeroRtt();
+
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+
+ client_->Handshake();
+ CheckEarlyDataLimit(client_, short_size);
+
+ PRInt32 sent;
+ // Writing more than the limit will succeed in TLS, but fail in DTLS.
+ if (variant_ == ssl_variant_stream) {
+ sent = PR_Write(client_->ssl_fd(), big_message,
+ static_cast<PRInt32>(strlen(big_message)));
+ } else {
+ sent = PR_Write(client_->ssl_fd(), big_message,
+ static_cast<PRInt32>(strlen(big_message)));
+ EXPECT_GE(0, sent);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Try an exact-sized write now.
+ sent = PR_Write(client_->ssl_fd(), big_message, short_length);
+ }
+ EXPECT_EQ(short_length, sent);
+
+ // Even a single octet write should now fail.
+ sent = PR_Write(client_->ssl_fd(), big_message, 1);
+ EXPECT_GE(0, sent);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Process the ClientHello and read 0-RTT.
+ server_->Handshake();
+ CheckEarlyDataLimit(server_, short_size);
+
+ std::vector<uint8_t> buf(short_size + 1);
+ PRInt32 read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity());
+ EXPECT_EQ(short_length, read);
+ EXPECT_EQ(0, memcmp(big_message, buf.data(), short_size));
+
+ // Second read fails.
+ read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity());
+ EXPECT_EQ(SECFailure, read);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) {
+ EnsureTlsSetup();
+
+ const size_t limit = 5;
+ EXPECT_EQ(SECSuccess, SSL_SetMaxEarlyDataSize(server_->ssl_fd(), limit));
+ SetupForZeroRtt();
+
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+
+ client_->Handshake(); // Send ClientHello
+ CheckEarlyDataLimit(client_, limit);
+
+ server_->Handshake(); // Process ClientHello, send server flight.
+
+ // Lift the limit on the client.
+ EXPECT_EQ(SECSuccess,
+ SSLInt_SetSocketMaxEarlyDataSize(client_->ssl_fd(), 1000));
+
+ // Send message
+ const char* message = "0123456789abcdef";
+ const PRInt32 message_len = static_cast<PRInt32>(strlen(message));
+ EXPECT_EQ(message_len, PR_Write(client_->ssl_fd(), message, message_len));
+
+ if (variant_ == ssl_variant_stream) {
+ // This error isn't fatal for DTLS.
+ ExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ }
+
+ server_->Handshake(); // This reads the early data and maybe throws an error.
+ if (variant_ == ssl_variant_stream) {
+ server_->CheckErrorCode(SSL_ERROR_TOO_MUCH_EARLY_DATA);
+ } else {
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
+ }
+ CheckEarlyDataLimit(server_, limit);
+
+ // Attempt to read early data. This will get an error.
+ std::vector<uint8_t> buf(strlen(message) + 1);
+ EXPECT_GT(0, PR_Read(server_->ssl_fd(), buf.data(), buf.capacity()));
+ if (variant_ == ssl_variant_stream) {
+ EXPECT_EQ(SSL_ERROR_HANDSHAKE_FAILED, PORT_GetError());
+ } else {
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+ }
+
+ client_->Handshake(); // Process the server's first flight.
+ if (variant_ == ssl_variant_stream) {
+ client_->Handshake(); // Process the alert.
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ } else {
+ server_->Handshake(); // Finish connecting.
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+ }
+}
+
+class PacketCoalesceFilter : public PacketFilter {
+ public:
+ PacketCoalesceFilter() : packet_data_() {}
+
+ void SendCoalesced(std::shared_ptr<TlsAgent> agent) {
+ agent->SendDirect(packet_data_);
+ }
+
+ protected:
+ PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) override {
+ packet_data_.Write(packet_data_.len(), input);
+ return DROP;
+ }
+
+ private:
+ DataBuffer packet_data_;
+};
+
+TEST_P(TlsConnectTls13, ZeroRttOrdering) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+
+ // Send out the ClientHello.
+ client_->Handshake();
+
+ // Now, coalesce the next three things from the client: early data, second
+ // flight and 1-RTT data.
+ auto coalesce = std::make_shared<PacketCoalesceFilter>();
+ client_->SetFilter(coalesce);
+
+ // Send (and hold) early data.
+ static const std::vector<uint8_t> early_data = {3, 2, 1};
+ EXPECT_EQ(static_cast<PRInt32>(early_data.size()),
+ PR_Write(client_->ssl_fd(), early_data.data(), early_data.size()));
+
+ // Send (and hold) the second client handshake flight.
+ // The client sends EndOfEarlyData after seeing the server Finished.
+ server_->Handshake();
+ client_->Handshake();
+
+ // Send (and hold) 1-RTT data.
+ static const std::vector<uint8_t> late_data = {7, 8, 9, 10};
+ EXPECT_EQ(static_cast<PRInt32>(late_data.size()),
+ PR_Write(client_->ssl_fd(), late_data.data(), late_data.size()));
+
+ // Now release them all at once.
+ coalesce->SendCoalesced(client_);
+
+ // Now ensure that the three steps are exposed in the right order on the
+ // server: delivery of early data, handshake callback, delivery of 1-RTT.
+ size_t step = 0;
+ server_->SetHandshakeCallback([&step](TlsAgent*) {
+ EXPECT_EQ(1U, step);
+ ++step;
+ });
+
+ std::vector<uint8_t> buf(10);
+ PRInt32 read = PR_Read(server_->ssl_fd(), buf.data(), buf.size());
+ ASSERT_EQ(static_cast<PRInt32>(early_data.size()), read);
+ buf.resize(read);
+ EXPECT_EQ(early_data, buf);
+ EXPECT_EQ(0U, step);
+ ++step;
+
+ // The third read should be after the handshake callback and should return the
+ // data that was sent after the handshake completed.
+ buf.resize(10);
+ read = PR_Read(server_->ssl_fd(), buf.data(), buf.size());
+ ASSERT_EQ(static_cast<PRInt32>(late_data.size()), read);
+ buf.resize(read);
+ EXPECT_EQ(late_data, buf);
+ EXPECT_EQ(2U, step);
+}
+
+// Early data remains available after the handshake completes for TLS.
+TEST_F(TlsConnectStreamTls13, ZeroRttLateReadTls) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ client_->Handshake(); // ClientHello
+
+ // Write some early data.
+ const uint8_t data[] = {1, 2, 3, 4, 5, 6, 7, 8};
+ PRInt32 rv = PR_Write(client_->ssl_fd(), data, sizeof(data));
+ EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), rv);
+
+ // Consume the ClientHello and generate ServerHello..Finished.
+ server_->Handshake();
+
+ // Read some of the data.
+ std::vector<uint8_t> small_buffer(1 + sizeof(data) / 2);
+ rv = PR_Read(server_->ssl_fd(), small_buffer.data(), small_buffer.size());
+ EXPECT_EQ(static_cast<PRInt32>(small_buffer.size()), rv);
+ EXPECT_EQ(0, memcmp(data, small_buffer.data(), small_buffer.size()));
+
+ Handshake(); // Complete the handshake.
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+
+ // After the handshake, it should be possible to read the remainder.
+ uint8_t big_buf[100];
+ rv = PR_Read(server_->ssl_fd(), big_buf, sizeof(big_buf));
+ EXPECT_EQ(static_cast<PRInt32>(sizeof(data) - small_buffer.size()), rv);
+ EXPECT_EQ(0, memcmp(&data[small_buffer.size()], big_buf,
+ sizeof(data) - small_buffer.size()));
+
+ // And that's all there is to read.
+ rv = PR_Read(server_->ssl_fd(), big_buf, sizeof(big_buf));
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+}
+
+// Early data that arrives before the handshake can be read after the handshake
+// is complete.
+TEST_F(TlsConnectDatagram13, ZeroRttLateReadDtls) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ client_->Handshake(); // ClientHello
+
+ // Write some early data.
+ const uint8_t data[] = {1, 2, 3};
+ PRInt32 written = PR_Write(client_->ssl_fd(), data, sizeof(data));
+ EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), written);
+
+ Handshake(); // Complete the handshake.
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+
+ // Reading at the server should return the early data, which was buffered.
+ uint8_t buf[sizeof(data) + 1] = {0};
+ PRInt32 read = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), read);
+ EXPECT_EQ(0, memcmp(data, buf, sizeof(data)));
+}
+
+class PacketHolder : public PacketFilter {
+ public:
+ PacketHolder() = default;
+
+ virtual Action Filter(const DataBuffer& input, DataBuffer* output) {
+ packet_ = input;
+ Disable();
+ return DROP;
+ }
+
+ const DataBuffer& packet() const { return packet_; }
+
+ private:
+ DataBuffer packet_;
+};
+
+// Early data that arrives late is discarded for DTLS.
+TEST_F(TlsConnectDatagram13, ZeroRttLateArrivalDtls) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ client_->Handshake(); // ClientHello
+
+ // Write some early data. Twice, so that we can read bits of it.
+ const uint8_t data[] = {1, 2, 3};
+ PRInt32 written = PR_Write(client_->ssl_fd(), data, sizeof(data));
+ EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), written);
+
+ // Block and capture the next packet.
+ auto holder = std::make_shared<PacketHolder>();
+ client_->SetFilter(holder);
+ written = PR_Write(client_->ssl_fd(), data, sizeof(data));
+ EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), written);
+ EXPECT_FALSE(holder->enabled()) << "the filter should disable itself";
+
+ // Consume the ClientHello and generate ServerHello..Finished.
+ server_->Handshake();
+
+ // Read some of the data.
+ std::vector<uint8_t> small_buffer(sizeof(data));
+ PRInt32 read =
+ PR_Read(server_->ssl_fd(), small_buffer.data(), small_buffer.size());
+
+ EXPECT_EQ(static_cast<PRInt32>(small_buffer.size()), read);
+ EXPECT_EQ(0, memcmp(data, small_buffer.data(), small_buffer.size()));
+
+ Handshake(); // Complete the handshake.
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+
+ server_->SendDirect(holder->packet());
+
+ // Reading now should return nothing, even though a valid packet was
+ // delivered.
+ read = PR_Read(server_->ssl_fd(), small_buffer.data(), small_buffer.size());
+ EXPECT_GT(0, read);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+}
+
+// Early data reads in TLS should be coalesced.
+TEST_F(TlsConnectStreamTls13, ZeroRttCoalesceReadTls) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ client_->Handshake(); // ClientHello
+
+ // Write some early data. In two writes.
+ const uint8_t data[] = {1, 2, 3, 4, 5, 6};
+ PRInt32 written = PR_Write(client_->ssl_fd(), data, 1);
+ EXPECT_EQ(1, written);
+
+ written = PR_Write(client_->ssl_fd(), data + 1, sizeof(data) - 1);
+ EXPECT_EQ(static_cast<PRInt32>(sizeof(data) - 1), written);
+
+ // Consume the ClientHello and generate ServerHello..Finished.
+ server_->Handshake();
+
+ // Read all of the data.
+ std::vector<uint8_t> buffer(sizeof(data));
+ PRInt32 read = PR_Read(server_->ssl_fd(), buffer.data(), buffer.size());
+ EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), read);
+ EXPECT_EQ(0, memcmp(data, buffer.data(), sizeof(data)));
+
+ Handshake(); // Complete the handshake.
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+}
+
+// Early data reads in DTLS should not be coalesced.
+TEST_F(TlsConnectDatagram13, ZeroRttNoCoalesceReadDtls) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ client_->Handshake(); // ClientHello
+
+ // Write some early data. In two writes.
+ const uint8_t data[] = {1, 2, 3, 4, 5, 6};
+ PRInt32 written = PR_Write(client_->ssl_fd(), data, 1);
+ EXPECT_EQ(1, written);
+
+ written = PR_Write(client_->ssl_fd(), data + 1, sizeof(data) - 1);
+ EXPECT_EQ(static_cast<PRInt32>(sizeof(data) - 1), written);
+
+ // Consume the ClientHello and generate ServerHello..Finished.
+ server_->Handshake();
+
+ // Try to read all of the data.
+ std::vector<uint8_t> buffer(sizeof(data));
+ PRInt32 read = PR_Read(server_->ssl_fd(), buffer.data(), buffer.size());
+ EXPECT_EQ(1, read);
+ EXPECT_EQ(0, memcmp(data, buffer.data(), 1));
+
+ // Read the remainder.
+ read = PR_Read(server_->ssl_fd(), buffer.data(), buffer.size());
+ EXPECT_EQ(static_cast<PRInt32>(sizeof(data) - 1), read);
+ EXPECT_EQ(0, memcmp(data + 1, buffer.data(), sizeof(data) - 1));
+
+ Handshake(); // Complete the handshake.
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+}
+
+// Early data reads in DTLS should fail if the buffer is too small.
+TEST_F(TlsConnectDatagram13, ZeroRttShortReadDtls) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ client_->Handshake(); // ClientHello
+
+ // Write some early data. In two writes.
+ const uint8_t data[] = {1, 2, 3, 4, 5, 6};
+ PRInt32 written = PR_Write(client_->ssl_fd(), data, sizeof(data));
+ EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), written);
+
+ // Consume the ClientHello and generate ServerHello..Finished.
+ server_->Handshake();
+
+ // Try to read all of the data into a small buffer.
+ std::vector<uint8_t> buffer(sizeof(data));
+ PRInt32 read = PR_Read(server_->ssl_fd(), buffer.data(), 1);
+ EXPECT_GT(0, read);
+ EXPECT_EQ(SSL_ERROR_RX_SHORT_DTLS_READ, PORT_GetError());
+
+ // Read again with more space.
+ read = PR_Read(server_->ssl_fd(), buffer.data(), buffer.size());
+ EXPECT_EQ(static_cast<PRInt32>(sizeof(data)), read);
+ EXPECT_EQ(0, memcmp(data, buffer.data(), sizeof(data)));
+
+ Handshake(); // Complete the handshake.
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+}
+
+// There are few ways in which TLS uses the clock and most of those operate on
+// timescales that would be ridiculous to wait for in a test. This is the one
+// test we have that uses the real clock. It tests that time passes by checking
+// that a small sleep results in rejection of early data. 0-RTT has a
+// configurable timer, which makes it ideal for this.
+TEST_F(TlsConnectStreamTls13, TimePassesByDefault) {
+ // Calling EnsureTlsSetup() replaces the time function on client and server,
+ // and sets up anti-replay, which we don't want, so initialize each directly.
+ client_->EnsureTlsSetup();
+ server_->EnsureTlsSetup();
+ // StartConnect() calls EnsureTlsSetup(), so avoid that too.
+ client_->StartConnect();
+ server_->StartConnect();
+
+ // Set a tiny anti-replay window. This has to be at least 2 milliseconds to
+ // have any chance of being relevant as that is the smallest window that we
+ // can detect. Anything smaller rounds to zero.
+ static const unsigned int kTinyWindowMs = 5;
+ ResetAntiReplay(static_cast<PRTime>(kTinyWindowMs * PR_USEC_PER_MSEC));
+ server_->SetAntiReplayContext(anti_replay_);
+
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->Set0RttEnabled(true);
+ Handshake();
+ CheckConnected();
+ SendReceive(); // Absorb a session ticket.
+ CheckKeys();
+
+ // Clear the first window.
+ PR_Sleep(PR_MillisecondsToInterval(kTinyWindowMs));
+
+ Reset();
+ client_->EnsureTlsSetup();
+ server_->EnsureTlsSetup();
+ client_->StartConnect();
+ server_->StartConnect();
+
+ // Early data is rejected by the server only if time passes for it as well.
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false, []() {
+ // Sleep long enough that we minimize the risk of our RTT estimation being
+ // duped by stutters in test execution. This is very long to allow for
+ // flaky and low-end hardware, especially what our CI runs on.
+ PR_Sleep(PR_MillisecondsToInterval(1000));
+ return true;
+ });
+ Handshake();
+ ExpectEarlyDataAccepted(false);
+ CheckConnected();
+}
+
+// Test that SSL_CreateAntiReplayContext doesn't pass bad inputs.
+TEST_F(TlsConnectStreamTls13, BadAntiReplayArgs) {
+ SSLAntiReplayContext* p;
+ // Zero or negative window.
+ EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, -1, 1, 1, &p));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, 0, 1, 1, &p));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ // Zero k.
+ EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, 1, 0, 1, &p));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ // Zero bits.
+ EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, 1, 1, 0, &p));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, 1, 1, 1, nullptr));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ // Prove that these parameters do work, even if they are useless..
+ EXPECT_EQ(SECSuccess, SSL_CreateAntiReplayContext(0, 1, 1, 1, &p));
+ ASSERT_NE(nullptr, p);
+ ScopedSSLAntiReplayContext ctx(p);
+
+ // The socket isn't a client or server until later, so configuring a client
+ // should work OK.
+ client_->EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(client_->ssl_fd(), ctx.get()));
+ EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(client_->ssl_fd(), nullptr));
+}
+
+// See also TlsConnectGenericResumption.ResumeServerIncompatibleCipher
+TEST_P(TlsConnectTls13, ZeroRttDifferentCompatibleCipher) {
+ EnsureTlsSetup();
+ server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ // Change the ciphersuite. Resumption is OK because the hash is the same, but
+ // early data will be rejected.
+ server_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256);
+ ExpectResumption(RESUME_TICKET);
+
+ StartConnect();
+ ZeroRttSendReceive(true, false);
+
+ Handshake();
+ ExpectEarlyDataAccepted(false);
+ CheckConnected();
+ SendReceive();
+}
+
+// See also TlsConnectGenericResumption.ResumeServerIncompatibleCipher
+TEST_P(TlsConnectTls13, ZeroRttDifferentIncompatibleCipher) {
+ EnsureTlsSetup();
+ server_->EnableSingleCipher(TLS_AES_256_GCM_SHA384);
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ // Resumption is rejected because the hash is different.
+ server_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256);
+ ExpectResumption(RESUME_NONE);
+
+ StartConnect();
+ ZeroRttSendReceive(true, false);
+
+ Handshake();
+ ExpectEarlyDataAccepted(false);
+ CheckConnected();
+ SendReceive();
+}
+
+// The client failing to provide EndOfEarlyData results in failure.
+// After 0-RTT working perfectly, things fall apart later.
+// The server is unable to detect the change in keys, so it fails decryption.
+// The client thinks everything has worked until it gets the alert.
+TEST_F(TlsConnectStreamTls13, SuppressEndOfEarlyDataClientOnly) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ client_->SetOption(SSL_SUPPRESS_END_OF_EARLY_DATA, true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ ExpectAlert(server_, kTlsAlertBadRecordMac);
+ Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state());
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_ALERT);
+}
+
+TEST_P(TlsConnectGeneric, SuppressEndOfEarlyDataNoZeroRtt) {
+ EnsureTlsSetup();
+ client_->SetOption(SSL_SUPPRESS_END_OF_EARLY_DATA, true);
+ server_->SetOption(SSL_SUPPRESS_END_OF_EARLY_DATA, true);
+ Connect();
+ SendReceive();
+}
+
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_TEST_SUITE_P(Tls13ZeroRttReplayTest, TlsZeroRttReplayTest,
+ TlsConnectTestBase::kTlsVariantsAll);
+#endif
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_aead_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_aead_unittest.cc
new file mode 100644
index 0000000000..d94683be30
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_aead_unittest.cc
@@ -0,0 +1,218 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <memory>
+
+#include "keyhi.h"
+#include "pk11pub.h"
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslexp.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "scoped_ptrs_ssl.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+// From tls_hkdf_unittest.cc:
+extern size_t GetHashLength(SSLHashType ht);
+
+class AeadTest : public ::testing::Test {
+ public:
+ AeadTest() : slot_(PK11_GetInternalSlot()) {}
+
+ void InitSecret(SSLHashType hash_type) {
+ static const uint8_t kData[64] = {'s', 'e', 'c', 'r', 'e', 't'};
+ SECItem key_item = {siBuffer, const_cast<uint8_t *>(kData),
+ static_cast<unsigned int>(GetHashLength(hash_type))};
+ PK11SymKey *s =
+ PK11_ImportSymKey(slot_.get(), CKM_SSL3_MASTER_KEY_DERIVE,
+ PK11_OriginUnwrap, CKA_DERIVE, &key_item, NULL);
+ ASSERT_NE(nullptr, s);
+ secret_.reset(s);
+ }
+
+ void SetUp() override {
+ InitSecret(ssl_hash_sha256);
+ PORT_SetError(0);
+ }
+
+ protected:
+ static void EncryptDecrypt(const ScopedSSLAeadContext &ctx,
+ const uint8_t *ciphertext, size_t ciphertext_len) {
+ static const uint8_t kAad[] = {'a', 'a', 'd'};
+ static const uint8_t kPlaintext[] = {'t', 'e', 'x', 't'};
+ static const size_t kMaxSize = 32;
+
+ ASSERT_GE(kMaxSize, ciphertext_len);
+ ASSERT_LT(0U, ciphertext_len);
+
+ uint8_t output[kMaxSize] = {0};
+ unsigned int output_len = 0;
+ EXPECT_EQ(SECSuccess, SSL_AeadEncrypt(ctx.get(), 0, kAad, sizeof(kAad),
+ kPlaintext, sizeof(kPlaintext),
+ output, &output_len, sizeof(output)));
+ ASSERT_EQ(ciphertext_len, static_cast<size_t>(output_len));
+ EXPECT_EQ(0, memcmp(ciphertext, output, ciphertext_len));
+
+ memset(output, 0, sizeof(output));
+ EXPECT_EQ(SECSuccess, SSL_AeadDecrypt(ctx.get(), 0, kAad, sizeof(kAad),
+ ciphertext, ciphertext_len, output,
+ &output_len, sizeof(output)));
+ ASSERT_EQ(sizeof(kPlaintext), static_cast<size_t>(output_len));
+ EXPECT_EQ(0, memcmp(kPlaintext, output, sizeof(kPlaintext)));
+
+ // Now for some tests of decryption failure.
+ // Truncate the input.
+ EXPECT_EQ(SECFailure, SSL_AeadDecrypt(ctx.get(), 0, kAad, sizeof(kAad),
+ ciphertext, ciphertext_len - 1,
+ output, &output_len, sizeof(output)));
+ EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError());
+
+ // Skip the first byte of the AAD.
+ EXPECT_EQ(
+ SECFailure,
+ SSL_AeadDecrypt(ctx.get(), 0, kAad + 1, sizeof(kAad) - 1, ciphertext,
+ ciphertext_len, output, &output_len, sizeof(output)));
+ EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError());
+
+ uint8_t input[kMaxSize] = {0};
+ // Toggle a byte of the input.
+ memcpy(input, ciphertext, ciphertext_len);
+ input[0] ^= 9;
+ EXPECT_EQ(SECFailure, SSL_AeadDecrypt(ctx.get(), 0, kAad, sizeof(kAad),
+ input, ciphertext_len, output,
+ &output_len, sizeof(output)));
+ EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError());
+
+ // Toggle the last byte (the auth tag).
+ memcpy(input, ciphertext, ciphertext_len);
+ input[ciphertext_len - 1] ^= 77;
+ EXPECT_EQ(SECFailure, SSL_AeadDecrypt(ctx.get(), 0, kAad, sizeof(kAad),
+ input, ciphertext_len, output,
+ &output_len, sizeof(output)));
+ EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError());
+
+ // Toggle some of the AAD.
+ memcpy(input, kAad, sizeof(kAad));
+ input[1] ^= 23;
+ EXPECT_EQ(SECFailure, SSL_AeadDecrypt(ctx.get(), 0, input, sizeof(kAad),
+ ciphertext, ciphertext_len, output,
+ &output_len, sizeof(output)));
+ EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError());
+ }
+
+ protected:
+ ScopedPK11SymKey secret_;
+
+ private:
+ ScopedPK11SlotInfo slot_;
+};
+
+// These tests all use fixed inputs: a fixed secret, a fixed label, and fixed
+// inputs. So they have fixed outputs.
+static const char *kLabel = "test ";
+static const uint8_t kCiphertextAes128Gcm[] = {
+ 0x11, 0x14, 0xfc, 0x58, 0x4f, 0x44, 0xff, 0x8c, 0xb6, 0xd8,
+ 0x20, 0xb3, 0xfb, 0x50, 0xd9, 0x3b, 0xd4, 0xc6, 0xe1, 0x14};
+static const uint8_t kCiphertextAes256Gcm[] = {
+ 0xf7, 0x27, 0x35, 0x80, 0x88, 0xaf, 0x99, 0x85, 0xf2, 0x83,
+ 0xca, 0xbb, 0x95, 0x42, 0x09, 0x3f, 0x9c, 0xf3, 0x29, 0xf0};
+static const uint8_t kCiphertextChaCha20Poly1305[] = {
+ 0x4e, 0x89, 0x2c, 0xfa, 0xfc, 0x8c, 0x40, 0x55, 0x6d, 0x7e,
+ 0x99, 0xac, 0x8e, 0x54, 0x58, 0xb1, 0x18, 0xd2, 0x66, 0x22};
+
+TEST_F(AeadTest, AeadBadVersion) {
+ SSLAeadContext *ctx = nullptr;
+ ASSERT_EQ(SECFailure,
+ SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256,
+ secret_.get(), kLabel, strlen(kLabel), &ctx));
+ EXPECT_EQ(nullptr, ctx);
+}
+
+TEST_F(AeadTest, AeadUnsupportedCipher) {
+ SSLAeadContext *ctx = nullptr;
+ ASSERT_EQ(SECFailure,
+ SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_NULL_MD5,
+ secret_.get(), kLabel, strlen(kLabel), &ctx));
+ EXPECT_EQ(nullptr, ctx);
+}
+
+TEST_F(AeadTest, AeadOlderCipher) {
+ SSLAeadContext *ctx = nullptr;
+ ASSERT_EQ(
+ SECFailure,
+ SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_AES_128_CBC_SHA,
+ secret_.get(), kLabel, strlen(kLabel), &ctx));
+ EXPECT_EQ(nullptr, ctx);
+}
+
+TEST_F(AeadTest, AeadNoLabel) {
+ SSLAeadContext *ctx = nullptr;
+ ASSERT_EQ(SECFailure,
+ SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256,
+ secret_.get(), nullptr, 12, &ctx));
+ EXPECT_EQ(nullptr, ctx);
+}
+
+TEST_F(AeadTest, AeadLongLabel) {
+ SSLAeadContext *ctx = nullptr;
+ ASSERT_EQ(SECFailure,
+ SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256,
+ secret_.get(), "", 254, &ctx));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ EXPECT_EQ(nullptr, ctx);
+}
+
+TEST_F(AeadTest, AeadNoPointer) {
+ SSLAeadContext *ctx = nullptr;
+ ASSERT_EQ(SECFailure,
+ SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256,
+ secret_.get(), kLabel, strlen(kLabel), nullptr));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ EXPECT_EQ(nullptr, ctx);
+}
+
+TEST_F(AeadTest, AeadAes128Gcm) {
+ SSLAeadContext *ctxInit = nullptr;
+ ASSERT_EQ(SECSuccess,
+ SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256,
+ secret_.get(), kLabel, strlen(kLabel), &ctxInit));
+ ScopedSSLAeadContext ctx(ctxInit);
+ EXPECT_NE(nullptr, ctx);
+
+ EncryptDecrypt(ctx, kCiphertextAes128Gcm, sizeof(kCiphertextAes128Gcm));
+}
+
+TEST_F(AeadTest, AeadAes256Gcm) {
+ SSLAeadContext *ctxInit = nullptr;
+ ASSERT_EQ(SECSuccess,
+ SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_256_GCM_SHA384,
+ secret_.get(), kLabel, strlen(kLabel), &ctxInit));
+ ScopedSSLAeadContext ctx(ctxInit);
+ EXPECT_NE(nullptr, ctx);
+
+ EncryptDecrypt(ctx, kCiphertextAes256Gcm, sizeof(kCiphertextAes256Gcm));
+}
+
+TEST_F(AeadTest, AeadChaCha20Poly1305) {
+ SSLAeadContext *ctxInit = nullptr;
+ ASSERT_EQ(
+ SECSuccess,
+ SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_CHACHA20_POLY1305_SHA256,
+ secret_.get(), kLabel, strlen(kLabel), &ctxInit));
+ ScopedSSLAeadContext ctx(ctxInit);
+ EXPECT_NE(nullptr, ctx);
+
+ EncryptDecrypt(ctx, kCiphertextChaCha20Poly1305,
+ sizeof(kCiphertextChaCha20Poly1305));
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc
new file mode 100644
index 0000000000..283bfec169
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc
@@ -0,0 +1,235 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include <memory>
+
+#include "databuffer.h"
+#include "tls_agent.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+// This is a 1-RTT ClientHello with ECDHE.
+const static uint8_t kCannedTls13ClientHello[] = {
+ 0x01, 0x00, 0x00, 0xcf, 0x03, 0x03, 0x6c, 0xb3, 0x46, 0x81, 0xc8, 0x1a,
+ 0xf9, 0xd2, 0x05, 0x97, 0x48, 0x7c, 0xa8, 0x31, 0x03, 0x1c, 0x06, 0xa8,
+ 0x62, 0xb1, 0x90, 0xd6, 0x21, 0x44, 0x7f, 0xc1, 0x9b, 0x87, 0x3e, 0xad,
+ 0x91, 0x85, 0x00, 0x00, 0x06, 0x13, 0x01, 0x13, 0x03, 0x13, 0x02, 0x01,
+ 0x00, 0x00, 0xa0, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x09, 0x00, 0x00, 0x06,
+ 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00,
+ 0x0a, 0x00, 0x12, 0x00, 0x10, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x01,
+ 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x33, 0x00,
+ 0x47, 0x00, 0x45, 0x00, 0x17, 0x00, 0x41, 0x04, 0x86, 0x4a, 0xb9, 0xdc,
+ 0x6a, 0x38, 0xa7, 0xce, 0xe7, 0xc2, 0x4f, 0xa6, 0x28, 0xb9, 0xdc, 0x65,
+ 0xbf, 0x73, 0x47, 0x3c, 0x9c, 0x65, 0x8c, 0x47, 0x6d, 0x57, 0x22, 0x8a,
+ 0xc2, 0xb3, 0xc6, 0x80, 0x72, 0x86, 0x08, 0x86, 0x8f, 0x52, 0xc5, 0xcb,
+ 0xbf, 0x2a, 0xb5, 0x59, 0x64, 0xcc, 0x0c, 0x49, 0x95, 0x36, 0xe4, 0xd9,
+ 0x2f, 0xd4, 0x24, 0x66, 0x71, 0x6f, 0x5d, 0x70, 0xe2, 0xa0, 0xea, 0x26,
+ 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x0d, 0x00, 0x20, 0x00,
+ 0x1e, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x02, 0x03, 0x08, 0x04, 0x08,
+ 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x02, 0x01, 0x04,
+ 0x02, 0x05, 0x02, 0x06, 0x02, 0x02, 0x02};
+static const size_t kFirstFragmentSize = 20;
+static const char *k0RttData = "ABCDEF";
+
+TEST_P(TlsAgentTest, EarlyFinished) {
+ DataBuffer buffer;
+ MakeTrivialHandshakeRecord(kTlsHandshakeFinished, 0, &buffer);
+ ExpectAlert(kTlsAlertUnexpectedMessage);
+ ProcessMessage(buffer, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
+}
+
+TEST_P(TlsAgentTest, EarlyCertificateVerify) {
+ DataBuffer buffer;
+ MakeTrivialHandshakeRecord(kTlsHandshakeCertificateVerify, 0, &buffer);
+ ExpectAlert(kTlsAlertUnexpectedMessage);
+ ProcessMessage(buffer, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+}
+
+TEST_P(TlsAgentTestClient13, CannedHello) {
+ DataBuffer buffer;
+ EnsureInit();
+ DataBuffer server_hello;
+ auto sh = MakeCannedTls13ServerHello();
+ MakeHandshakeMessage(kTlsHandshakeServerHello, sh.data(), sh.len(),
+ &server_hello);
+ MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3, server_hello.data(),
+ server_hello.len(), &buffer);
+ ProcessMessage(buffer, TlsAgent::STATE_CONNECTING);
+}
+
+TEST_P(TlsAgentTestClient13, EncryptedExtensionsInClear) {
+ DataBuffer server_hello;
+ auto sh = MakeCannedTls13ServerHello();
+ MakeHandshakeMessage(kTlsHandshakeServerHello, sh.data(), sh.len(),
+ &server_hello);
+ DataBuffer encrypted_extensions;
+ MakeHandshakeMessage(kTlsHandshakeEncryptedExtensions, nullptr, 0,
+ &encrypted_extensions, 1);
+ server_hello.Append(encrypted_extensions);
+ DataBuffer buffer;
+ MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3, server_hello.data(),
+ server_hello.len(), &buffer);
+ EnsureInit();
+ ExpectAlert(kTlsAlertUnexpectedMessage);
+ ProcessMessage(buffer, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_UNEXPECTED_HANDSHAKE);
+}
+
+TEST_F(TlsAgentStreamTestClient, EncryptedExtensionsInClearTwoPieces) {
+ DataBuffer server_hello;
+ auto sh = MakeCannedTls13ServerHello();
+ MakeHandshakeMessage(kTlsHandshakeServerHello, sh.data(), sh.len(),
+ &server_hello);
+ DataBuffer encrypted_extensions;
+ MakeHandshakeMessage(kTlsHandshakeEncryptedExtensions, nullptr, 0,
+ &encrypted_extensions, 1);
+ server_hello.Append(encrypted_extensions);
+ DataBuffer buffer;
+ MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3, server_hello.data(),
+ kFirstFragmentSize, &buffer);
+
+ DataBuffer buffer2;
+ MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3,
+ server_hello.data() + kFirstFragmentSize,
+ server_hello.len() - kFirstFragmentSize, &buffer2);
+
+ EnsureInit();
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ProcessMessage(buffer, TlsAgent::STATE_CONNECTING);
+ ExpectAlert(kTlsAlertUnexpectedMessage);
+ ProcessMessage(buffer2, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_UNEXPECTED_HANDSHAKE);
+}
+
+TEST_F(TlsAgentDgramTestClient, EncryptedExtensionsInClearTwoPieces) {
+ auto sh = MakeCannedTls13ServerHello();
+ DataBuffer server_hello_frag1;
+ MakeHandshakeMessageFragment(kTlsHandshakeServerHello, sh.data(), sh.len(),
+ &server_hello_frag1, 0, 0, kFirstFragmentSize);
+ DataBuffer server_hello_frag2;
+ MakeHandshakeMessageFragment(kTlsHandshakeServerHello,
+ sh.data() + kFirstFragmentSize, sh.len(),
+ &server_hello_frag2, 0, kFirstFragmentSize,
+ sh.len() - kFirstFragmentSize);
+ DataBuffer encrypted_extensions;
+ MakeHandshakeMessage(kTlsHandshakeEncryptedExtensions, nullptr, 0,
+ &encrypted_extensions, 1);
+ server_hello_frag2.Append(encrypted_extensions);
+ DataBuffer buffer;
+ MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3,
+ server_hello_frag1.data(), server_hello_frag1.len(), &buffer);
+
+ DataBuffer buffer2;
+ MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3,
+ server_hello_frag2.data(), server_hello_frag2.len(), &buffer2, 1);
+
+ EnsureInit();
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ProcessMessage(buffer, TlsAgent::STATE_CONNECTING);
+ ExpectAlert(kTlsAlertUnexpectedMessage);
+ ProcessMessage(buffer2, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_UNEXPECTED_HANDSHAKE);
+}
+
+TEST_F(TlsAgentDgramTestClient, AckWithBogusLengthField) {
+ EnsureInit();
+ // Length doesn't match
+ const uint8_t ackBuf[] = {0x00, 0x08, 0x00};
+ DataBuffer record;
+ MakeRecord(variant_, ssl_ct_ack, SSL_LIBRARY_VERSION_TLS_1_2, ackBuf,
+ sizeof(ackBuf), &record, 0);
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ProcessMessage(record, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_MALFORMED_DTLS_ACK);
+}
+
+TEST_F(TlsAgentDgramTestClient, AckWithNonEvenLength) {
+ EnsureInit();
+ // Length isn't a multiple of 8
+ const uint8_t ackBuf[] = {0x00, 0x01, 0x00};
+ DataBuffer record;
+ MakeRecord(variant_, ssl_ct_ack, SSL_LIBRARY_VERSION_TLS_1_2, ackBuf,
+ sizeof(ackBuf), &record, 0);
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ // Because we haven't negotiated the version,
+ // ssl3_DecodeError() sends an older (pre-TLS error).
+ ExpectAlert(kTlsAlertIllegalParameter);
+ ProcessMessage(record, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_MALFORMED_DTLS_ACK);
+}
+
+TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenWrite) {
+ EnsureInit();
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ agent_->StartConnect();
+ agent_->Set0RttEnabled(true);
+ auto filter =
+ MakeTlsFilter<TlsHandshakeRecorder>(agent_, kTlsHandshakeClientHello);
+ PRInt32 rv = PR_Write(agent_->ssl_fd(), k0RttData, strlen(k0RttData));
+ EXPECT_EQ(-1, rv);
+ int32_t err = PORT_GetError();
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, err);
+ EXPECT_LT(0UL, filter->buffer().len());
+}
+
+TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenRead) {
+ EnsureInit();
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ agent_->StartConnect();
+ agent_->Set0RttEnabled(true);
+ DataBuffer buffer;
+ MakeRecord(ssl_ct_application_data, SSL_LIBRARY_VERSION_TLS_1_3,
+ reinterpret_cast<const uint8_t *>(k0RttData), strlen(k0RttData),
+ &buffer);
+ ExpectAlert(kTlsAlertUnexpectedMessage);
+ ProcessMessage(buffer, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA);
+}
+
+// The server is allowing 0-RTT but the client doesn't offer it,
+// so trial decryption isn't engaged and 0-RTT messages cause
+// an error.
+TEST_F(TlsAgentStreamTestServer, Set0RttOptionClientHelloThenRead) {
+ EnsureInit();
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ agent_->StartConnect();
+ agent_->Set0RttEnabled(true);
+ DataBuffer buffer;
+ MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3,
+ kCannedTls13ClientHello, sizeof(kCannedTls13ClientHello), &buffer);
+ ProcessMessage(buffer, TlsAgent::STATE_CONNECTING);
+ MakeRecord(ssl_ct_application_data, SSL_LIBRARY_VERSION_TLS_1_3,
+ reinterpret_cast<const uint8_t *>(k0RttData), strlen(k0RttData),
+ &buffer);
+ ExpectAlert(kTlsAlertBadRecordMac);
+ ProcessMessage(buffer, TlsAgent::STATE_ERROR, SSL_ERROR_BAD_MAC_READ);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AgentTests, TlsAgentTest,
+ ::testing::Combine(TlsAgentTestBase::kTlsRolesAll,
+ TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll));
+INSTANTIATE_TEST_SUITE_P(ClientTests13, TlsAgentTestClient13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV13));
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc
new file mode 100644
index 0000000000..edd2479a78
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc
@@ -0,0 +1,2261 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectGeneric, ServerAuthBigRsa) {
+ Reset(TlsAgent::kRsa2048);
+ Connect();
+ CheckKeys();
+}
+
+TEST_P(TlsConnectGeneric, ServerAuthRsaChain) {
+ Reset("rsa_chain");
+ Connect();
+ CheckKeys();
+ size_t chain_length;
+ EXPECT_TRUE(client_->GetPeerChainLength(&chain_length));
+ EXPECT_EQ(2UL, chain_length);
+}
+
+TEST_P(TlsConnectTls12Plus, ServerAuthRsaPss) {
+ static const SSLSignatureScheme kSignatureSchemePss[] = {
+ ssl_sig_rsa_pss_pss_sha256};
+
+ Reset(TlsAgent::kServerRsaPss);
+ client_->SetSignatureSchemes(kSignatureSchemePss,
+ PR_ARRAY_SIZE(kSignatureSchemePss));
+ server_->SetSignatureSchemes(kSignatureSchemePss,
+ PR_ARRAY_SIZE(kSignatureSchemePss));
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_pss,
+ ssl_sig_rsa_pss_pss_sha256);
+}
+
+// PSS doesn't work with TLS 1.0 or 1.1 because we can't signal it.
+TEST_P(TlsConnectPre12, ServerAuthRsaPssFails) {
+ static const SSLSignatureScheme kSignatureSchemePss[] = {
+ ssl_sig_rsa_pss_pss_sha256};
+
+ Reset(TlsAgent::kServerRsaPss);
+ client_->SetSignatureSchemes(kSignatureSchemePss,
+ PR_ARRAY_SIZE(kSignatureSchemePss));
+ server_->SetSignatureSchemes(kSignatureSchemePss,
+ PR_ARRAY_SIZE(kSignatureSchemePss));
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+// Check that a PSS certificate with no parameters works.
+TEST_P(TlsConnectTls12Plus, ServerAuthRsaPssNoParameters) {
+ static const SSLSignatureScheme kSignatureSchemePss[] = {
+ ssl_sig_rsa_pss_pss_sha256};
+
+ Reset("rsa_pss_noparam");
+ client_->SetSignatureSchemes(kSignatureSchemePss,
+ PR_ARRAY_SIZE(kSignatureSchemePss));
+ server_->SetSignatureSchemes(kSignatureSchemePss,
+ PR_ARRAY_SIZE(kSignatureSchemePss));
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_pss,
+ ssl_sig_rsa_pss_pss_sha256);
+}
+
+TEST_P(TlsConnectGeneric, ServerAuthRsaPssChain) {
+ Reset("rsa_pss_chain");
+ Connect();
+ CheckKeys();
+ size_t chain_length;
+ EXPECT_TRUE(client_->GetPeerChainLength(&chain_length));
+ EXPECT_EQ(2UL, chain_length);
+}
+
+TEST_P(TlsConnectGeneric, ServerAuthRsaCARsaPssChain) {
+ Reset("rsa_ca_rsa_pss_chain");
+ Connect();
+ CheckKeys();
+ size_t chain_length;
+ EXPECT_TRUE(client_->GetPeerChainLength(&chain_length));
+ EXPECT_EQ(2UL, chain_length);
+}
+
+TEST_P(TlsConnectGeneric, ServerAuthRejected) {
+ EnsureTlsSetup();
+ client_->SetAuthCertificateCallback(
+ [](TlsAgent*, PRBool, PRBool) -> SECStatus { return SECFailure; });
+ ConnectExpectAlert(client_, kTlsAlertBadCertificate);
+ client_->CheckErrorCode(SSL_ERROR_BAD_CERTIFICATE);
+ server_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT);
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+}
+
+struct AuthCompleteArgs : public PollTarget {
+ AuthCompleteArgs(const std::shared_ptr<TlsAgent>& a, PRErrorCode c)
+ : agent(a), code(c) {}
+
+ std::shared_ptr<TlsAgent> agent;
+ PRErrorCode code;
+};
+
+static void CallAuthComplete(PollTarget* target, Event event) {
+ EXPECT_EQ(TIMER_EVENT, event);
+ auto args = reinterpret_cast<AuthCompleteArgs*>(target);
+ std::cerr << args->agent->role_str() << ": call SSL_AuthCertificateComplete "
+ << (args->code ? PR_ErrorToName(args->code) : "no error")
+ << std::endl;
+ EXPECT_EQ(SECSuccess,
+ SSL_AuthCertificateComplete(args->agent->ssl_fd(), args->code));
+ args->agent->Handshake(); // Make the TlsAgent aware of the error.
+ delete args;
+}
+
+// Install an AuthCertificateCallback that blocks when called. Then
+// SSL_AuthCertificateComplete is called on a very short timer. This allows any
+// processing that might follow the callback to complete.
+static void SetDeferredAuthCertificateCallback(std::shared_ptr<TlsAgent> agent,
+ PRErrorCode code) {
+ auto args = new AuthCompleteArgs(agent, code);
+ agent->SetAuthCertificateCallback(
+ [args](TlsAgent*, PRBool, PRBool) -> SECStatus {
+ // This can't be 0 or we race the message from the client to the server,
+ // and tests assume that we lose that race.
+ std::shared_ptr<Poller::Timer> timer_handle;
+ Poller::Instance()->SetTimer(1U, args, CallAuthComplete, &timer_handle);
+ return SECWouldBlock;
+ });
+}
+
+TEST_P(TlsConnectTls13, ServerAuthRejectAsync) {
+ SetDeferredAuthCertificateCallback(client_, SEC_ERROR_REVOKED_CERTIFICATE);
+ ConnectExpectAlert(client_, kTlsAlertCertificateRevoked);
+ // We only detect the error here when we attempt to handshake, so all the
+ // client learns is that the handshake has already failed.
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILED);
+ server_->CheckErrorCode(SSL_ERROR_REVOKED_CERT_ALERT);
+}
+
+// In TLS 1.2 and earlier, this will result in the client sending its Finished
+// before learning that the server certificate is bad. That means that the
+// server will believe that the handshake is complete.
+TEST_P(TlsConnectGenericPre13, ServerAuthRejectAsync) {
+ SetDeferredAuthCertificateCallback(client_, SEC_ERROR_EXPIRED_CERTIFICATE);
+ client_->ExpectSendAlert(kTlsAlertCertificateExpired);
+ server_->ExpectReceiveAlert(kTlsAlertCertificateExpired);
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILED);
+
+ // The server might not receive the alert that the client sends, which would
+ // cause the test to fail when it cleans up. Reset expectations.
+ server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning);
+}
+
+class TlsCertificateRequestContextRecorder : public TlsHandshakeFilter {
+ public:
+ TlsCertificateRequestContextRecorder(const std::shared_ptr<TlsAgent>& a,
+ uint8_t handshake_type)
+ : TlsHandshakeFilter(a, {handshake_type}), buffer_(), filtered_(false) {
+ EnableDecryption();
+ }
+
+ bool filtered() const { return filtered_; }
+ const DataBuffer& buffer() const { return buffer_; }
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ assert(1 < input.len());
+ size_t len = input.data()[0];
+ assert(len + 1 < input.len());
+ buffer_.Assign(input.data() + 1, len);
+ filtered_ = true;
+ return KEEP;
+ }
+
+ private:
+ DataBuffer buffer_;
+ bool filtered_;
+};
+
+using ClientAuthParam =
+ std::tuple<SSLProtocolVariant, uint16_t, ClientAuthCallbackType>;
+
+class TlsConnectClientAuth
+ : public TlsConnectTestBase,
+ public testing::WithParamInterface<ClientAuthParam> {
+ public:
+ TlsConnectClientAuth()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+};
+
+// Wrapper classes for tests that target specific versions
+
+class TlsConnectClientAuth13 : public TlsConnectClientAuth {};
+
+class TlsConnectClientAuth12 : public TlsConnectClientAuth {};
+
+class TlsConnectClientAuthStream13 : public TlsConnectClientAuth {};
+
+class TlsConnectClientAuthPre13 : public TlsConnectClientAuth {};
+
+class TlsConnectClientAuth12Plus : public TlsConnectClientAuth {};
+
+std::string getClientAuthTestName(
+ testing::TestParamInfo<ClientAuthParam> info) {
+ auto param = info.param;
+ auto variant = std::get<0>(param);
+ auto version = std::get<1>(param);
+ auto callback_type = std::get<2>(param);
+
+ std::string output = std::string();
+ switch (variant) {
+ case ssl_variant_stream:
+ output.append("TLS");
+ break;
+ case ssl_variant_datagram:
+ output.append("DTLS");
+ break;
+ }
+ output.append(VersionString(version).replace(1, 1, ""));
+ switch (callback_type) {
+ case ClientAuthCallbackType::kAsyncImmediate:
+ output.append("AsyncImmediate");
+ break;
+ case ClientAuthCallbackType::kAsyncDelay:
+ output.append("AsyncDelay");
+ break;
+ case ClientAuthCallbackType::kSync:
+ output.append("Sync");
+ break;
+ case ClientAuthCallbackType::kNone:
+ output.append("None");
+ break;
+ }
+ return output;
+}
+
+auto kClientAuthCallbacks = testing::Values(
+ ClientAuthCallbackType::kAsyncImmediate,
+ ClientAuthCallbackType::kAsyncDelay, ClientAuthCallbackType::kSync,
+ ClientAuthCallbackType::kNone);
+
+INSTANTIATE_TEST_SUITE_P(
+ ClientAuthGenericStream, TlsConnectClientAuth,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll, kClientAuthCallbacks),
+ getClientAuthTestName);
+
+INSTANTIATE_TEST_SUITE_P(
+ ClientAuthGenericDatagram, TlsConnectClientAuth,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus, kClientAuthCallbacks),
+ getClientAuthTestName);
+
+INSTANTIATE_TEST_SUITE_P(ClientAuth13, TlsConnectClientAuth13,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV13,
+ kClientAuthCallbacks),
+ getClientAuthTestName);
+
+INSTANTIATE_TEST_SUITE_P(
+ ClientAuth13, TlsConnectClientAuthStream13,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV13, kClientAuthCallbacks),
+ getClientAuthTestName);
+
+INSTANTIATE_TEST_SUITE_P(ClientAuth12, TlsConnectClientAuth12,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12,
+ kClientAuthCallbacks),
+ getClientAuthTestName);
+
+INSTANTIATE_TEST_SUITE_P(
+ ClientAuthPre13Stream, TlsConnectClientAuthPre13,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12, kClientAuthCallbacks),
+ getClientAuthTestName);
+
+INSTANTIATE_TEST_SUITE_P(
+ ClientAuthPre13Datagram, TlsConnectClientAuthPre13,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11V12, kClientAuthCallbacks),
+ getClientAuthTestName);
+
+INSTANTIATE_TEST_SUITE_P(ClientAuth12Plus, TlsConnectClientAuth12Plus,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12Plus,
+ kClientAuthCallbacks),
+ getClientAuthTestName);
+
+TEST_P(TlsConnectClientAuth, ClientAuth) {
+ EnsureTlsSetup();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys();
+ client_->CheckClientAuthCompleted();
+}
+
+// All stream only tests; PostHandshakeAuth isn't supported for DTLS.
+
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuth) {
+ EnsureTlsSetup();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ auto capture_cert_req = MakeTlsFilter<TlsCertificateRequestContextRecorder>(
+ server_, kTlsHandshakeCertificateRequest);
+ auto capture_certificate =
+ MakeTlsFilter<TlsCertificateRequestContextRecorder>(
+ client_, kTlsHandshakeCertificate);
+ client_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE);
+ size_t called = 0;
+ server_->SetAuthCertificateCallback(
+ [&called](TlsAgent*, PRBool, PRBool) -> SECStatus {
+ called++;
+ return SECSuccess;
+ });
+ Connect();
+ EXPECT_EQ(0U, called);
+ EXPECT_FALSE(capture_cert_req->filtered());
+ EXPECT_FALSE(capture_certificate->filtered());
+ // Send CertificateRequest.
+ EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+ // Need to do a round-trip so that the post-handshake message is
+ // handled on both client and server.
+ server_->SendData(50);
+ client_->ReadBytes(50);
+ client_->ClientAuthCallbackComplete();
+ client_->SendData(50);
+ server_->ReadBytes(50);
+
+ EXPECT_EQ(1U, called);
+ ASSERT_TRUE(capture_cert_req->filtered());
+ ASSERT_TRUE(capture_certificate->filtered());
+
+ client_->CheckClientAuthCompleted();
+ // Check if a non-empty request context is generated and it is
+ // properly sent back.
+ EXPECT_LT(0U, capture_cert_req->buffer().len());
+ EXPECT_EQ(capture_cert_req->buffer().len(),
+ capture_certificate->buffer().len());
+ EXPECT_EQ(0, memcmp(capture_cert_req->buffer().data(),
+ capture_certificate->buffer().data(),
+ capture_cert_req->buffer().len()));
+ ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd()));
+ ASSERT_NE(nullptr, cert1.get());
+ ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert2.get());
+ EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert));
+}
+
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthAfterResumption) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+
+ // Resume the connection.
+ Reset();
+
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_TICKET);
+
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ client_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE);
+ Connect();
+ SendReceive();
+
+ size_t called = 0;
+ server_->SetAuthCertificateCallback(
+ [&called](TlsAgent*, PRBool, PRBool) -> SECStatus {
+ called++;
+ return SECSuccess;
+ });
+ EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+
+ server_->SendData(50);
+ client_->ReadBytes(50);
+ client_->ClientAuthCallbackComplete();
+ client_->SendData(50);
+ server_->ReadBytes(50);
+
+ client_->CheckClientAuthCompleted();
+ EXPECT_EQ(1U, called);
+
+ ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd()));
+ ASSERT_NE(nullptr, cert1.get());
+ ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert2.get());
+ EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert));
+}
+
+static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd,
+ CERTDistNames* caNames,
+ CERTCertificate** clientCert,
+ SECKEYPrivateKey** clientKey) {
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ // use a different certificate than TlsAgent::kClient
+ if (!TlsAgent::LoadCertificate(TlsAgent::kRsa2048, &cert, &priv)) {
+ return SECFailure;
+ }
+
+ *clientCert = cert.release();
+ *clientKey = priv.release();
+ return SECSuccess;
+}
+
+typedef struct AutoClientTestStr {
+ SECStatus result;
+ const std::string cert;
+} AutoClientTest;
+
+typedef struct AutoClientResultsStr {
+ AutoClientTest isRsa2048;
+ AutoClientTest isClient;
+ AutoClientTest isNull;
+ bool hookCalled;
+} AutoClientResults;
+
+void VerifyClientCertMatch(CERTCertificate* clientCert,
+ const std::string expectedName) {
+ const char* name = clientCert->nickname;
+ std::cout << "Match name=\"" << name << "\" expected=\"" << expectedName
+ << "\"" << std::endl;
+ EXPECT_TRUE(PORT_Strcmp(name, expectedName.c_str()) == 0)
+ << " Certmismatch: \"" << name << "\" != \"" << expectedName << "\"";
+}
+
+static SECStatus GetAutoClientAuthDataHook(void* expectResults, PRFileDesc* fd,
+ CERTDistNames* caNames,
+ CERTCertificate** clientCert,
+ SECKEYPrivateKey** clientKey) {
+ AutoClientResults& results = *(AutoClientResults*)expectResults;
+ SECStatus rv;
+
+ results.hookCalled = true;
+ *clientCert = NULL;
+ *clientKey = NULL;
+ rv = NSS_GetClientAuthData((void*)TlsAgent::kRsa2048.c_str(), fd, caNames,
+ clientCert, clientKey);
+ if (rv == SECSuccess) {
+ VerifyClientCertMatch(*clientCert, results.isRsa2048.cert);
+ CERT_DestroyCertificate(*clientCert);
+ SECKEY_DestroyPrivateKey(*clientKey);
+ *clientCert = NULL;
+ *clientKey = NULL;
+ }
+ EXPECT_EQ(results.isRsa2048.result, rv);
+
+ rv = NSS_GetClientAuthData((void*)TlsAgent::kClient.c_str(), fd, caNames,
+ clientCert, clientKey);
+ if (rv == SECSuccess) {
+ VerifyClientCertMatch(*clientCert, results.isClient.cert);
+ CERT_DestroyCertificate(*clientCert);
+ SECKEY_DestroyPrivateKey(*clientKey);
+ *clientCert = NULL;
+ *clientKey = NULL;
+ }
+ EXPECT_EQ(results.isClient.result, rv);
+ EXPECT_EQ(*clientCert, nullptr);
+ EXPECT_EQ(*clientKey, nullptr);
+ rv = NSS_GetClientAuthData(NULL, fd, caNames, clientCert, clientKey);
+ if (rv == SECSuccess) {
+ VerifyClientCertMatch(*clientCert, results.isNull.cert);
+ // return this result
+ }
+ EXPECT_EQ(results.isNull.result, rv);
+ return rv;
+}
+
+// while I would have liked to use a new INSTANTIATE macro the
+// generates the following three tests, figuring out how to make that
+// work on top of the existing TlsConnect* plumbing hurts my head.
+TEST_P(TlsConnectTls12, AutoClientSelectRsaPss) {
+ AutoClientResults rsa = {{SECSuccess, TlsAgent::kRsa2048},
+ {SECSuccess, TlsAgent::kClient},
+ {SECSuccess, TlsAgent::kDelegatorRsaPss2048},
+ false};
+ static const SSLSignatureScheme kSchemes[] = {ssl_sig_rsa_pss_pss_sha256,
+ ssl_sig_rsa_pkcs1_sha256,
+ ssl_sig_rsa_pkcs1_sha1};
+ Reset("rsa_pss_noparam");
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ EXPECT_EQ(SECSuccess,
+ SSL_GetClientAuthDataHook(client_->ssl_fd(),
+ GetAutoClientAuthDataHook, (void*)&rsa));
+ server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ Connect();
+ EXPECT_TRUE(rsa.hookCalled);
+}
+
+TEST_P(TlsConnectTls12, AutoClientSelectEcc) {
+ AutoClientResults ecc = {{SECFailure, TlsAgent::kClient},
+ {SECFailure, TlsAgent::kClient},
+ {SECSuccess, TlsAgent::kDelegatorEcdsa256},
+ false};
+ static const SSLSignatureScheme kSchemes[] = {ssl_sig_ecdsa_secp256r1_sha256};
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ EXPECT_EQ(SECSuccess,
+ SSL_GetClientAuthDataHook(client_->ssl_fd(),
+ GetAutoClientAuthDataHook, (void*)&ecc));
+ server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ Connect();
+ EXPECT_TRUE(ecc.hookCalled);
+}
+
+TEST_P(TlsConnectTls12, AutoClientSelectDsa) {
+ AutoClientResults dsa = {{SECFailure, TlsAgent::kClient},
+ {SECFailure, TlsAgent::kClient},
+ {SECSuccess, TlsAgent::kServerDsa},
+ false};
+ static const SSLSignatureScheme kSchemes[] = {ssl_sig_dsa_sha256};
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ EXPECT_EQ(SECSuccess,
+ SSL_GetClientAuthDataHook(client_->ssl_fd(),
+ GetAutoClientAuthDataHook, (void*)&dsa));
+ server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ Connect();
+ EXPECT_TRUE(dsa.hookCalled);
+}
+
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthMultiple) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
+ size_t called = 0;
+ server_->SetAuthCertificateCallback(
+ [&called](TlsAgent*, PRBool, PRBool) -> SECStatus {
+ called++;
+ return SECSuccess;
+ });
+ Connect();
+ EXPECT_EQ(0U, called);
+ EXPECT_EQ(nullptr, SSL_PeerCertificate(server_->ssl_fd()));
+ // Send 1st CertificateRequest.
+ EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+
+ server_->SendData(50);
+ client_->ReadBytes(50);
+ client_->ClientAuthCallbackComplete();
+ client_->ReadBytes(50);
+ client_->SendData(50);
+ server_->ReadBytes(50);
+ EXPECT_EQ(1U, called);
+ client_->CheckClientAuthCompleted(1);
+ ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd()));
+ ASSERT_NE(nullptr, cert1.get());
+ ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert2.get());
+ EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert));
+ // Send 2nd CertificateRequest.
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+
+ server_->SendData(50);
+ client_->ReadBytes(50);
+ client_->ClientAuthCallbackComplete();
+ client_->ReadBytes(50);
+ client_->SendData(50);
+ server_->ReadBytes(50);
+ client_->CheckClientAuthCompleted(2);
+ EXPECT_EQ(2U, called);
+ ScopedCERTCertificate cert3(SSL_PeerCertificate(server_->ssl_fd()));
+ ASSERT_NE(nullptr, cert3.get());
+ ScopedCERTCertificate cert4(SSL_LocalCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert4.get());
+ EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert3->derCert, &cert4->derCert));
+}
+
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthConcurrent) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
+ Connect();
+ // Send 1st CertificateRequest.
+ EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+ // Send 2nd CertificateRequest.
+ EXPECT_EQ(SECFailure, SSL_SendCertificateRequest(server_->ssl_fd()));
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+}
+
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthBeforeKeyUpdate) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
+ Connect();
+ // Send CertificateRequest.
+ EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+ // Send KeyUpdate.
+ EXPECT_EQ(SECFailure, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+}
+
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthDuringClientKeyUpdate) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ ;
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
+ Connect();
+ CheckEpochs(3, 3);
+ // Send CertificateRequest from server.
+ EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+ // Send KeyUpdate from client.
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE));
+ server_->SendData(50); // server sends CertificateRequest
+ client_->SendData(50); // client sends KeyUpdate
+ server_->ReadBytes(50); // server receives KeyUpdate and defers response
+ CheckEpochs(4, 3);
+ client_->ReadBytes(60); // client receives CertificateRequest
+ client_->ClientAuthCallbackComplete();
+ client_->ReadBytes(50); // Finish reading the remaining bytes
+ client_->SendData(
+ 50); // client sends Certificate, CertificateVerify, Finished
+ server_->ReadBytes(
+ 50); // server receives Certificate, CertificateVerify, Finished
+ client_->CheckClientAuthCompleted();
+ client_->CheckEpochs(3, 4);
+ server_->CheckEpochs(4, 4);
+ server_->SendData(50); // server sends KeyUpdate
+ client_->ReadBytes(50); // client receives KeyUpdate
+ client_->CheckEpochs(4, 4);
+}
+
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthMissingExtension) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ Connect();
+ // Send CertificateRequest, should fail due to missing
+ // post_handshake_auth extension.
+ EXPECT_EQ(SECFailure, SSL_SendCertificateRequest(server_->ssl_fd()));
+ EXPECT_EQ(SSL_ERROR_MISSING_POST_HANDSHAKE_AUTH_EXTENSION, PORT_GetError());
+}
+
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthAfterClientAuth) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ server_->RequestClientAuth(true);
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
+ size_t called = 0;
+ server_->SetAuthCertificateCallback(
+ [&called](TlsAgent*, PRBool, PRBool) -> SECStatus {
+ called++;
+ return SECSuccess;
+ });
+ Connect();
+ EXPECT_EQ(1U, called);
+ ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd()));
+ ASSERT_NE(nullptr, cert1.get());
+ ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert2.get());
+ EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert));
+ // Send CertificateRequest.
+ EXPECT_EQ(SECSuccess, SSL_GetClientAuthDataHook(
+ client_->ssl_fd(), GetClientAuthDataHook, nullptr));
+ EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+ server_->SendData(50);
+ client_->ReadBytes(50);
+ client_->SendData(50);
+ server_->ReadBytes(50);
+ EXPECT_EQ(2U, called);
+ ScopedCERTCertificate cert3(SSL_PeerCertificate(server_->ssl_fd()));
+ ASSERT_NE(nullptr, cert3.get());
+ ScopedCERTCertificate cert4(SSL_LocalCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert4.get());
+ EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert3->derCert, &cert4->derCert));
+ EXPECT_FALSE(SECITEM_ItemsAreEqual(&cert3->derCert, &cert1->derCert));
+}
+
+// Damages the request context in a CertificateRequest message.
+// We don't modify a Certificate message instead, so that the client
+// can compute CertificateVerify correctly.
+class TlsDamageCertificateRequestContextFilter : public TlsHandshakeFilter {
+ public:
+ TlsDamageCertificateRequestContextFilter(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeCertificateRequest}) {
+ EnableDecryption();
+ }
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+ assert(1 < output->len());
+ // The request context has a 1 octet length.
+ output->data()[1] ^= 73;
+ return CHANGE;
+ }
+};
+
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthContextMismatch) {
+ EnsureTlsSetup();
+ MakeTlsFilter<TlsDamageCertificateRequestContextFilter>(server_);
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
+ Connect();
+ // Send CertificateRequest.
+ EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+ server_->SendData(50);
+ client_->ReadBytes(50);
+ client_->ClientAuthCallbackComplete();
+ client_->ReadBytes(50);
+ client_->SendData(50);
+ server_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ReadBytes(50);
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CERTIFICATE, PORT_GetError());
+ server_->ExpectReadWriteError();
+ server_->SendData(50);
+ client_->ExpectReceiveAlert(kTlsAlertIllegalParameter);
+ client_->ReadBytes(50);
+ EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, PORT_GetError());
+}
+
+// Replaces signature in a CertificateVerify message.
+class TlsDamageSignatureFilter : public TlsHandshakeFilter {
+ public:
+ TlsDamageSignatureFilter(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeCertificateVerify}) {
+ EnableDecryption();
+ }
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+ assert(2 < output->len());
+ // The signature follows a 2-octet signature scheme.
+ output->data()[2] ^= 73;
+ return CHANGE;
+ }
+};
+
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthBadSignature) {
+ EnsureTlsSetup();
+ MakeTlsFilter<TlsDamageSignatureFilter>(client_);
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
+ Connect();
+ // Send CertificateRequest.
+ EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+ server_->SendData(50);
+ client_->ReadBytes(50);
+ client_->ClientAuthCallbackComplete();
+ client_->SendData(50);
+ client_->CheckClientAuthCompleted();
+ server_->ExpectSendAlert(kTlsAlertDecodeError);
+ server_->ReadBytes(50);
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CERT_VERIFY, PORT_GetError());
+}
+
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthDecline) {
+ EnsureTlsSetup();
+ auto capture_cert_req = MakeTlsFilter<TlsCertificateRequestContextRecorder>(
+ server_, kTlsHandshakeCertificateRequest);
+ auto capture_certificate =
+ MakeTlsFilter<TlsCertificateRequestContextRecorder>(
+ client_, kTlsHandshakeCertificate);
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
+ EXPECT_EQ(SECSuccess,
+ SSL_OptionSet(server_->ssl_fd(), SSL_REQUIRE_CERTIFICATE,
+ SSL_REQUIRE_ALWAYS));
+ // Client to decline the certificate request.
+ EXPECT_EQ(SECSuccess,
+ SSL_GetClientAuthDataHook(
+ client_->ssl_fd(),
+ [](void*, PRFileDesc*, CERTDistNames*, CERTCertificate**,
+ SECKEYPrivateKey**) -> SECStatus { return SECFailure; },
+ nullptr));
+ size_t called = 0;
+ server_->SetAuthCertificateCallback(
+ [&called](TlsAgent*, PRBool, PRBool) -> SECStatus {
+ called++;
+ return SECSuccess;
+ });
+ Connect();
+ EXPECT_EQ(0U, called);
+ // Send CertificateRequest.
+ EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+ server_->SendData(50); // send Certificate Request
+ client_->ReadBytes(50); // read Certificate Request
+ client_->SendData(50); // send empty Certificate+Finished
+ server_->ExpectSendAlert(kTlsAlertCertificateRequired);
+ server_->ReadBytes(50); // read empty Certificate+Finished
+ server_->ExpectReadWriteError();
+ server_->SendData(50); // send alert
+ // AuthCertificateCallback is not called, because the client sends
+ // an empty certificate_list.
+ EXPECT_EQ(0U, called);
+ EXPECT_TRUE(capture_cert_req->filtered());
+ EXPECT_TRUE(capture_certificate->filtered());
+ // Check if a non-empty request context is generated and it is
+ // properly sent back.
+ EXPECT_LT(0U, capture_cert_req->buffer().len());
+ EXPECT_EQ(capture_cert_req->buffer().len(),
+ capture_certificate->buffer().len());
+ EXPECT_EQ(0, memcmp(capture_cert_req->buffer().data(),
+ capture_certificate->buffer().data(),
+ capture_cert_req->buffer().len()));
+}
+
+// Check if post-handshake auth still works when session tickets are enabled:
+// https://bugzilla.mozilla.org/show_bug.cgi?id=1553443
+TEST_P(TlsConnectClientAuthStream13,
+ PostHandshakeAuthWithSessionTicketsEnabled) {
+ EnsureTlsSetup();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_SESSION_TICKETS, PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
+ SSL_ENABLE_SESSION_TICKETS, PR_TRUE));
+ size_t called = 0;
+ server_->SetAuthCertificateCallback(
+ [&called](TlsAgent*, PRBool, PRBool) -> SECStatus {
+ called++;
+ return SECSuccess;
+ });
+ Connect();
+ EXPECT_EQ(0U, called);
+ // Send CertificateRequest.
+ EXPECT_EQ(SECSuccess, SSL_GetClientAuthDataHook(
+ client_->ssl_fd(), GetClientAuthDataHook, nullptr));
+ EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+ server_->SendData(50);
+ client_->ReadBytes(50);
+ client_->SendData(50);
+ server_->ReadBytes(50);
+ EXPECT_EQ(1U, called);
+ ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd()));
+ ASSERT_NE(nullptr, cert1.get());
+ ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert2.get());
+ EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert));
+}
+
+TEST_P(TlsConnectClientAuthPre13, ClientAuthRequiredRejected) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), false);
+ server_->RequestClientAuth(true);
+ ConnectExpectAlert(server_, kTlsAlertBadCertificate);
+ client_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE);
+}
+
+// In TLS 1.3, the client will claim that the connection is done and then
+// receive the alert afterwards. So drive the handshake manually.
+TEST_P(TlsConnectClientAuth13, ClientAuthRequiredRejected) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), false);
+ server_->RequestClientAuth(true);
+ StartConnect();
+ client_->Handshake(); // CH
+ server_->Handshake(); // SH.. (no resumption)
+
+ client_->Handshake(); // Next message
+ ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ client_->CheckClientAuthCompleted();
+ ExpectAlert(server_, kTlsAlertCertificateRequired);
+ server_->Handshake(); // Alert
+ server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE);
+ client_->Handshake(); // Receive Alert
+ client_->CheckErrorCode(SSL_ERROR_RX_CERTIFICATE_REQUIRED_ALERT);
+}
+
+TEST_P(TlsConnectClientAuth, ClientAuthRequestedRejected) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), false);
+ server_->RequestClientAuth(false);
+ Connect();
+ CheckKeys();
+}
+
+TEST_P(TlsConnectClientAuth, ClientAuthEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
+}
+
+TEST_P(TlsConnectClientAuth, ClientAuthWithEch) {
+ Reset(TlsAgent::kServerEcdsa256);
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
+}
+
+TEST_P(TlsConnectClientAuth, ClientAuthBigRsa) {
+ Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048);
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys();
+}
+
+// Offset is the position in the captured buffer where the signature sits.
+static void CheckSigScheme(std::shared_ptr<TlsHandshakeRecorder>& capture,
+ size_t offset, std::shared_ptr<TlsAgent>& peer,
+ uint16_t expected_scheme, size_t expected_size) {
+ EXPECT_LT(offset + 2U, capture->buffer().len());
+
+ uint32_t scheme = 0;
+ capture->buffer().Read(offset, 2, &scheme);
+ EXPECT_EQ(expected_scheme, static_cast<uint16_t>(scheme));
+
+ ScopedCERTCertificate remote_cert(SSL_PeerCertificate(peer->ssl_fd()));
+ ASSERT_NE(nullptr, remote_cert.get());
+ ScopedSECKEYPublicKey remote_key(CERT_ExtractPublicKey(remote_cert.get()));
+ ASSERT_NE(nullptr, remote_key.get());
+ EXPECT_EQ(expected_size, SECKEY_PublicKeyStrengthInBits(remote_key.get()));
+}
+
+// The server should prefer SHA-256 by default, even for the small key size used
+// in the default certificate.
+TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) {
+ EnsureTlsSetup();
+ auto capture_ske = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ Connect();
+ CheckKeys();
+
+ const DataBuffer& buffer = capture_ske->buffer();
+ EXPECT_LT(3U, buffer.len());
+ EXPECT_EQ(3U, buffer.data()[0]) << "curve_type == named_curve";
+ uint32_t tmp;
+ EXPECT_TRUE(buffer.Read(1, 2, &tmp)) << "read NamedCurve";
+ EXPECT_EQ(ssl_grp_ec_curve25519, tmp);
+ EXPECT_TRUE(buffer.Read(3, 1, &tmp)) << " read ECPoint";
+ CheckSigScheme(capture_ske, 4 + tmp, client_, ssl_sig_rsa_pss_rsae_sha256,
+ 1024);
+}
+
+TEST_P(TlsConnectClientAuth12, ClientAuthCheckSigAlg) {
+ EnsureTlsSetup();
+ auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>(
+ client_, kTlsHandshakeCertificateVerify);
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys();
+
+ CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pkcs1_sha1, 1024);
+}
+
+TEST_P(TlsConnectClientAuth12, ClientAuthBigRsaCheckSigAlg) {
+ Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048);
+ auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>(
+ client_, kTlsHandshakeCertificateVerify);
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys();
+ CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_rsae_sha256,
+ 2048);
+}
+
+// Check if CertificateVerify signed with rsa_pss_rsae_* is properly
+// rejected when the certificate is RSA-PSS.
+//
+// This only works under TLS 1.2, because PSS doesn't work with TLS
+// 1.0 or TLS 1.1 and the TLS 1.3 1-RTT handshake is partially
+// successful at the client side.
+TEST_P(TlsConnectClientAuth12, ClientAuthInconsistentRsaeSignatureScheme) {
+ static const SSLSignatureScheme kSignatureSchemePss[] = {
+ ssl_sig_rsa_pss_pss_sha256, ssl_sig_rsa_pss_rsae_sha256};
+
+ Reset(TlsAgent::kServerRsa, "rsa_pss");
+ client_->SetSignatureSchemes(kSignatureSchemePss,
+ PR_ARRAY_SIZE(kSignatureSchemePss));
+ server_->SetSignatureSchemes(kSignatureSchemePss,
+ PR_ARRAY_SIZE(kSignatureSchemePss));
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ server_->RequestClientAuth(true);
+
+ EnsureTlsSetup();
+
+ MakeTlsFilter<TlsReplaceSignatureSchemeFilter>(client_,
+ ssl_sig_rsa_pss_rsae_sha256);
+
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+}
+
+// Check if CertificateVerify signed with rsa_pss_pss_* is properly
+// rejected when the certificate is RSA.
+//
+// This only works under TLS 1.2, because PSS doesn't work with TLS
+// 1.0 or TLS 1.1 and the TLS 1.3 1-RTT handshake is partially
+// successful at the client side.
+TEST_P(TlsConnectClientAuth12, ClientAuthInconsistentPssSignatureScheme) {
+ static const SSLSignatureScheme kSignatureSchemePss[] = {
+ ssl_sig_rsa_pss_rsae_sha256, ssl_sig_rsa_pss_pss_sha256};
+
+ Reset(TlsAgent::kServerRsa, "rsa");
+ client_->SetSignatureSchemes(kSignatureSchemePss,
+ PR_ARRAY_SIZE(kSignatureSchemePss));
+ server_->SetSignatureSchemes(kSignatureSchemePss,
+ PR_ARRAY_SIZE(kSignatureSchemePss));
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ server_->RequestClientAuth(true);
+
+ EnsureTlsSetup();
+
+ MakeTlsFilter<TlsReplaceSignatureSchemeFilter>(client_,
+ ssl_sig_rsa_pss_pss_sha256);
+
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+}
+
+TEST_P(TlsConnectClientAuth13, ClientAuthPkcs1SignatureScheme) {
+ static const SSLSignatureScheme kSignatureScheme[] = {
+ ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pss_rsae_sha256};
+
+ Reset(TlsAgent::kServerRsa, "rsa");
+ client_->SetSignatureSchemes(kSignatureScheme,
+ PR_ARRAY_SIZE(kSignatureScheme));
+ server_->SetSignatureSchemes(kSignatureScheme,
+ PR_ARRAY_SIZE(kSignatureScheme));
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ server_->RequestClientAuth(true);
+
+ auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>(
+ client_, kTlsHandshakeCertificateVerify);
+ capture_cert_verify->EnableDecryption();
+
+ Connect();
+ CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_rsae_sha256,
+ 1024);
+}
+
+// Client should refuse to connect without a usable signature scheme.
+TEST_P(TlsConnectClientAuth13, ClientAuthPkcs1SignatureSchemeOnly) {
+ static const SSLSignatureScheme kSignatureScheme[] = {
+ ssl_sig_rsa_pkcs1_sha256};
+
+ Reset(TlsAgent::kServerRsa, "rsa");
+ client_->SetSignatureSchemes(kSignatureScheme,
+ PR_ARRAY_SIZE(kSignatureScheme));
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ client_->StartConnect();
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ client_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM);
+}
+
+// Though the client has a usable signature scheme, when a certificate is
+// requested, it can't produce one.
+TEST_P(TlsConnectClientAuth13, ClientAuthPkcs1AndEcdsaScheme) {
+ static const SSLSignatureScheme kSignatureScheme[] = {
+ ssl_sig_rsa_pkcs1_sha256, ssl_sig_ecdsa_secp256r1_sha256};
+
+ Reset(TlsAgent::kServerRsa, "rsa");
+ client_->SetSignatureSchemes(kSignatureScheme,
+ PR_ARRAY_SIZE(kSignatureScheme));
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ server_->RequestClientAuth(true);
+
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter {
+ public:
+ TlsZeroCertificateRequestSigAlgsFilter(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeCertificateRequest}) {}
+ virtual PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) {
+ TlsParser parser(input);
+ std::cerr << "Zeroing CertReq.supported_signature_algorithms" << std::endl;
+
+ DataBuffer cert_types;
+ if (!parser.ReadVariable(&cert_types, 1)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+
+ if (!parser.SkipVariable(2)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+
+ DataBuffer cas;
+ if (!parser.ReadVariable(&cas, 2)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+
+ size_t idx = 0;
+
+ // Write certificate types.
+ idx = output->Write(idx, cert_types.len(), 1);
+ idx = output->Write(idx, cert_types);
+
+ // Write zero signature algorithms.
+ idx = output->Write(idx, 0U, 2);
+
+ // Write certificate authorities.
+ idx = output->Write(idx, cas.len(), 2);
+ idx = output->Write(idx, cas);
+
+ return CHANGE;
+ }
+};
+
+// Check that we send an alert when the server doesn't provide any
+// supported_signature_algorithms in the CertificateRequest message.
+TEST_P(TlsConnectClientAuth12, ClientAuthNoSigAlgs) {
+ EnsureTlsSetup();
+ MakeTlsFilter<TlsZeroCertificateRequestSigAlgsFilter>(server_);
+ auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>(
+ client_, kTlsHandshakeCertificateVerify);
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ server_->RequestClientAuth(true);
+
+ ConnectExpectAlert(client_, kTlsAlertHandshakeFailure);
+
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT);
+ client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+}
+
+static SECStatus GetEcClientAuthDataHook(void* self, PRFileDesc* fd,
+ CERTDistNames* caNames,
+ CERTCertificate** clientCert,
+ SECKEYPrivateKey** clientKey) {
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ // use a different certificate than TlsAgent::kClient
+ if (!TlsAgent::LoadCertificate(TlsAgent::kServerEcdsa256, &cert, &priv)) {
+ return SECFailure;
+ }
+
+ *clientCert = cert.release();
+ *clientKey = priv.release();
+ return SECSuccess;
+}
+
+TEST_P(TlsConnectClientAuth12Plus, ClientAuthDisjointSchemes) {
+ EnsureTlsSetup();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ server_->RequestClientAuth(true);
+
+ SSLSignatureScheme server_scheme = ssl_sig_rsa_pss_rsae_sha256;
+ std::vector<SSLSignatureScheme> client_schemes{
+ ssl_sig_rsa_pss_rsae_sha256, ssl_sig_ecdsa_secp256r1_sha256};
+ SECStatus rv =
+ SSL_SignatureSchemePrefSet(server_->ssl_fd(), &server_scheme, 1);
+ EXPECT_EQ(SECSuccess, rv);
+ rv = SSL_SignatureSchemePrefSet(
+ client_->ssl_fd(), client_schemes.data(),
+ static_cast<unsigned int>(client_schemes.size()));
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Select an EC cert that's incompatible with server schemes.
+ EXPECT_EQ(SECSuccess,
+ SSL_GetClientAuthDataHook(client_->ssl_fd(),
+ GetEcClientAuthDataHook, nullptr));
+
+ StartConnect();
+ client_->Handshake(); // CH
+ server_->Handshake(); // SH
+ client_->Handshake();
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ ExpectAlert(server_, kTlsAlertCertificateRequired);
+ server_->Handshake(); // Alert
+ server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE);
+ client_->Handshake(); // Receive Alert
+ client_->CheckErrorCode(SSL_ERROR_RX_CERTIFICATE_REQUIRED_ALERT);
+ } else {
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+ ExpectAlert(server_, kTlsAlertBadCertificate);
+ server_->Handshake(); // Alert
+ server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE);
+ client_->Handshake(); // Receive Alert
+ client_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT);
+ }
+}
+
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthDisjointSchemes) {
+ EnsureTlsSetup();
+ SSLSignatureScheme server_scheme = ssl_sig_rsa_pss_rsae_sha256;
+ std::vector<SSLSignatureScheme> client_schemes{
+ ssl_sig_rsa_pss_rsae_sha256, ssl_sig_ecdsa_secp256r1_sha256};
+ SECStatus rv =
+ SSL_SignatureSchemePrefSet(server_->ssl_fd(), &server_scheme, 1);
+ EXPECT_EQ(SECSuccess, rv);
+ rv = SSL_SignatureSchemePrefSet(
+ client_->ssl_fd(), client_schemes.data(),
+ static_cast<unsigned int>(client_schemes.size()));
+ EXPECT_EQ(SECSuccess, rv);
+
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ client_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE);
+
+ // Select an EC cert that's incompatible with server schemes.
+ EXPECT_EQ(SECSuccess,
+ SSL_GetClientAuthDataHook(client_->ssl_fd(),
+ GetEcClientAuthDataHook, nullptr));
+
+ Connect();
+
+ // Send CertificateRequest.
+ EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+
+ // Need to do a round-trip so that the post-handshake message is
+ // handled on both client and server.
+ server_->SendData(50);
+ client_->ReadBytes(50);
+ client_->SendData(50);
+ server_->ReadBytes(50);
+
+ ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd()));
+ ASSERT_EQ(nullptr, cert1.get());
+ ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd()));
+ ASSERT_EQ(nullptr, cert2.get());
+}
+
+static const SSLSignatureScheme kSignatureSchemeEcdsaSha384[] = {
+ ssl_sig_ecdsa_secp384r1_sha384};
+static const SSLSignatureScheme kSignatureSchemeEcdsaSha256[] = {
+ ssl_sig_ecdsa_secp256r1_sha256};
+static const SSLSignatureScheme kSignatureSchemeRsaSha384[] = {
+ ssl_sig_rsa_pkcs1_sha384};
+static const SSLSignatureScheme kSignatureSchemeRsaSha256[] = {
+ ssl_sig_rsa_pkcs1_sha256};
+
+static SSLNamedGroup NamedGroupForEcdsa384(uint16_t version) {
+ // NSS tries to match the group size to the symmetric cipher. In TLS 1.1 and
+ // 1.0, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA is the highest priority suite, so
+ // we use P-384. With TLS 1.2 on we pick AES-128 GCM so use x25519.
+ if (version <= SSL_LIBRARY_VERSION_TLS_1_1) {
+ return ssl_grp_ec_secp384r1;
+ }
+ return ssl_grp_ec_curve25519;
+}
+
+// When signature algorithms match up, this should connect successfully; even
+// for TLS 1.1 and 1.0, where they should be ignored.
+TEST_P(TlsConnectGeneric, SignatureAlgorithmServerAuth) {
+ Reset(TlsAgent::kServerEcdsa384);
+ client_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384));
+ server_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384));
+ Connect();
+ CheckKeys(ssl_kea_ecdh, NamedGroupForEcdsa384(version_), ssl_auth_ecdsa,
+ ssl_sig_ecdsa_secp384r1_sha384);
+}
+
+// Here the client picks a single option, which should work in all versions.
+// Defaults on the server include the first option.
+TEST_P(TlsConnectGeneric, SignatureAlgorithmClientOnly) {
+ const SSLSignatureAndHashAlg clientAlgorithms[] = {
+ {ssl_hash_sha384, ssl_sign_ecdsa},
+ {ssl_hash_sha384, ssl_sign_rsa}, // supported but unusable
+ {ssl_hash_md5, ssl_sign_ecdsa} // unsupported and ignored
+ };
+ Reset(TlsAgent::kServerEcdsa384);
+ EnsureTlsSetup();
+ // Use the old API for this function.
+ EXPECT_EQ(SECSuccess,
+ SSL_SignaturePrefSet(client_->ssl_fd(), clientAlgorithms,
+ PR_ARRAY_SIZE(clientAlgorithms)));
+ Connect();
+ CheckKeys(ssl_kea_ecdh, NamedGroupForEcdsa384(version_), ssl_auth_ecdsa,
+ ssl_sig_ecdsa_secp384r1_sha384);
+}
+
+// Here the server picks a single option, which should work in all versions.
+// Defaults on the client include the provided option.
+TEST_P(TlsConnectGeneric, SignatureAlgorithmServerOnly) {
+ Reset(TlsAgent::kServerEcdsa384);
+ server_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384));
+ Connect();
+ CheckKeys(ssl_kea_ecdh, NamedGroupForEcdsa384(version_), ssl_auth_ecdsa,
+ ssl_sig_ecdsa_secp384r1_sha384);
+}
+
+// In TLS 1.2, curve and hash aren't bound together.
+TEST_P(TlsConnectTls12, SignatureSchemeCurveMismatch) {
+ Reset(TlsAgent::kServerEcdsa256);
+ client_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384));
+ Connect();
+}
+
+// In TLS 1.3, curve and hash are coupled.
+TEST_P(TlsConnectTls13, SignatureSchemeCurveMismatch) {
+ Reset(TlsAgent::kServerEcdsa256);
+ client_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384));
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+// Configuring a P-256 cert with only SHA-384 signatures is OK in TLS 1.2.
+TEST_P(TlsConnectTls12, SignatureSchemeBadConfig) {
+ Reset(TlsAgent::kServerEcdsa256); // P-256 cert can't be used.
+ server_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384));
+ Connect();
+}
+
+// A P-256 certificate in TLS 1.3 needs a SHA-256 signature scheme.
+TEST_P(TlsConnectTls13, SignatureSchemeBadConfig) {
+ Reset(TlsAgent::kServerEcdsa256); // P-256 cert can't be used.
+ server_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384));
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+// Where there is no overlap on signature schemes, we still connect successfully
+// if we aren't going to use a signature.
+TEST_P(TlsConnectGenericPre13, SignatureAlgorithmNoOverlapStaticRsa) {
+ client_->SetSignatureSchemes(kSignatureSchemeRsaSha384,
+ PR_ARRAY_SIZE(kSignatureSchemeRsaSha384));
+ server_->SetSignatureSchemes(kSignatureSchemeRsaSha256,
+ PR_ARRAY_SIZE(kSignatureSchemeRsaSha256));
+ EnableOnlyStaticRsaCiphers();
+ Connect();
+ CheckKeys(ssl_kea_rsa, ssl_auth_rsa_decrypt);
+}
+
+TEST_P(TlsConnectTls12Plus, SignatureAlgorithmNoOverlapEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ client_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384));
+ server_->SetSignatureSchemes(kSignatureSchemeEcdsaSha256,
+ PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha256));
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+}
+
+// Pre 1.2, a mismatch on signature algorithms shouldn't affect anything.
+TEST_P(TlsConnectPre12, SignatureAlgorithmNoOverlapEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ client_->SetSignatureSchemes(kSignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha384));
+ server_->SetSignatureSchemes(kSignatureSchemeEcdsaSha256,
+ PR_ARRAY_SIZE(kSignatureSchemeEcdsaSha256));
+ Connect();
+}
+
+// The signature_algorithms extension is mandatory in TLS 1.3.
+TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) {
+ MakeTlsFilter<TlsExtensionDropper>(client_, ssl_signature_algorithms_xtn);
+ ConnectExpectAlert(server_, kTlsAlertMissingExtension);
+ client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION);
+}
+
+// TLS 1.2 has trouble detecting this sort of modification: it uses SHA1 and
+// only fails when the Finished is checked.
+TEST_P(TlsConnectTls12, SignatureAlgorithmDrop) {
+ MakeTlsFilter<TlsExtensionDropper>(client_, ssl_signature_algorithms_xtn);
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+}
+
+TEST_P(TlsConnectTls13, UnsupportedSignatureSchemeAlert) {
+ EnsureTlsSetup();
+ auto filter =
+ MakeTlsFilter<TlsReplaceSignatureSchemeFilter>(server_, ssl_sig_none);
+ filter->EnableDecryption();
+
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CERT_VERIFY);
+}
+
+TEST_P(TlsConnectTls13, InconsistentSignatureSchemeAlert) {
+ EnsureTlsSetup();
+
+ // This won't work because we use an RSA cert by default.
+ auto filter = MakeTlsFilter<TlsReplaceSignatureSchemeFilter>(
+ server_, ssl_sig_ecdsa_secp256r1_sha256);
+ filter->EnableDecryption();
+
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ client_->CheckErrorCode(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM);
+}
+
+TEST_P(TlsConnectTls12, RequestClientAuthWithSha384) {
+ server_->SetSignatureSchemes(kSignatureSchemeRsaSha384,
+ PR_ARRAY_SIZE(kSignatureSchemeRsaSha384));
+ server_->RequestClientAuth(false);
+ Connect();
+}
+
+class BeforeFinished : public TlsRecordFilter {
+ private:
+ enum HandshakeState { BEFORE_CCS, AFTER_CCS, DONE };
+
+ public:
+ BeforeFinished(const std::shared_ptr<TlsAgent>& server,
+ const std::shared_ptr<TlsAgent>& client,
+ VoidFunction before_ccs, VoidFunction before_finished)
+ : TlsRecordFilter(server),
+ client_(client),
+ before_ccs_(before_ccs),
+ before_finished_(before_finished),
+ state_(BEFORE_CCS) {}
+
+ protected:
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& body,
+ DataBuffer* out) {
+ switch (state_) {
+ case BEFORE_CCS:
+ // Awaken when we see the CCS.
+ if (header.content_type() == ssl_ct_change_cipher_spec) {
+ before_ccs_();
+
+ // Write the CCS out as a separate write, so that we can make
+ // progress. Ordinarily, libssl sends the CCS and Finished together,
+ // but that means that they both get processed together.
+ DataBuffer ccs;
+ header.Write(&ccs, 0, body);
+ agent()->SendDirect(ccs);
+ client_.lock()->Handshake();
+ state_ = AFTER_CCS;
+ // Request that the original record be dropped by the filter.
+ return DROP;
+ }
+ break;
+
+ case AFTER_CCS:
+ EXPECT_EQ(ssl_ct_handshake, header.content_type());
+ // This could check that data contains a Finished message, but it's
+ // encrypted, so that's too much extra work.
+
+ before_finished_();
+ state_ = DONE;
+ break;
+
+ case DONE:
+ break;
+ }
+ return KEEP;
+ }
+
+ private:
+ std::weak_ptr<TlsAgent> client_;
+ VoidFunction before_ccs_;
+ VoidFunction before_finished_;
+ HandshakeState state_;
+};
+
+// Running code after the client has started processing the encrypted part of
+// the server's first flight, but before the Finished is processed is very hard
+// in TLS 1.3. These encrypted messages are sent in a single encrypted blob.
+// The following test uses DTLS to make it possible to force the client to
+// process the handshake in pieces.
+//
+// The first encrypted message from the server is dropped, and the MTU is
+// reduced to just below the original message size so that the server sends two
+// messages. The Finished message is then processed separately.
+class BeforeFinished13 : public PacketFilter {
+ private:
+ enum HandshakeState {
+ INIT,
+ BEFORE_FIRST_FRAGMENT,
+ BEFORE_SECOND_FRAGMENT,
+ DONE
+ };
+
+ public:
+ BeforeFinished13(const std::shared_ptr<TlsAgent>& server,
+ const std::shared_ptr<TlsAgent>& client,
+ VoidFunction before_finished)
+ : server_(server),
+ client_(client),
+ before_finished_(before_finished),
+ records_(0) {}
+
+ protected:
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) {
+ switch (++records_) {
+ case 1:
+ // Packet 1 is the server's entire first flight. Drop it.
+ EXPECT_EQ(SECSuccess,
+ SSLInt_SetMTU(server_.lock()->ssl_fd(), input.len() - 1));
+ return DROP;
+
+ // Packet 2 is the first part of the server's retransmitted first
+ // flight. Keep that.
+
+ case 3:
+ // Packet 3 is the second part of the server's retransmitted first
+ // flight. Before passing that on, make sure that the client processes
+ // packet 2, then call the before_finished_() callback.
+ client_.lock()->Handshake();
+ before_finished_();
+ break;
+
+ default:
+ break;
+ }
+ return KEEP;
+ }
+
+ private:
+ std::weak_ptr<TlsAgent> server_;
+ std::weak_ptr<TlsAgent> client_;
+ VoidFunction before_finished_;
+ size_t records_;
+};
+
+static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) {
+ return SECWouldBlock;
+}
+
+// This test uses an AuthCertificateCallback that blocks. A filter is used to
+// split the server's first flight into two pieces. Before the second piece is
+// processed by the client, SSL_AuthCertificateComplete() is called.
+TEST_F(TlsConnectDatagram13, AuthCompleteBeforeFinished) {
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+ MakeTlsFilter<BeforeFinished13>(server_, client_, [this]() {
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+ });
+ Connect();
+}
+
+// This test uses a simple AuthCertificateCallback. Due to the way that the
+// entire server flight is processed, the call to SSL_AuthCertificateComplete
+// will trigger after the Finished message is processed.
+TEST_P(TlsConnectTls13, AuthCompleteAfterFinished) {
+ SetDeferredAuthCertificateCallback(client_, 0); // 0 = success.
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) {
+ client_->EnableFalseStart();
+ MakeTlsFilter<BeforeFinished>(
+ server_, client_,
+ [this]() { EXPECT_TRUE(client_->can_falsestart_hook_called()); },
+ [this]() {
+ // Write something, which used to fail: bug 1235366.
+ client_->SendData(10);
+ });
+
+ Connect();
+ server_->SendData(10);
+ Receive(10);
+}
+
+TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) {
+ client_->EnableFalseStart();
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+ MakeTlsFilter<BeforeFinished>(
+ server_, client_,
+ []() {
+ // Do nothing before CCS
+ },
+ [this]() {
+ EXPECT_FALSE(client_->can_falsestart_hook_called());
+ // AuthComplete before Finished still enables false start.
+ EXPECT_EQ(SECSuccess,
+ SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+ EXPECT_TRUE(client_->can_falsestart_hook_called());
+ client_->SendData(10);
+ });
+
+ Connect();
+ server_->SendData(10);
+ Receive(10);
+}
+
+class EnforceNoActivity : public PacketFilter {
+ protected:
+ PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) override {
+ std::cerr << "Unexpected packet: " << input << std::endl;
+ EXPECT_TRUE(false) << "should not send anything";
+ return KEEP;
+ }
+};
+
+// In this test, we want to make sure that the server completes its handshake,
+// but the client does not. Because the AuthCertificate callback blocks and we
+// never call SSL_AuthCertificateComplete(), the client should never report that
+// it has completed the handshake. Manually call Handshake(), alternating sides
+// between client and server, until the desired state is reached.
+TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Send ServerHello
+ client_->Handshake(); // Send ClientKeyExchange and Finished
+ server_->Handshake(); // Send Finished
+ // The server should now report that it is connected
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+
+ // The client should send nothing from here on.
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+
+ // This should allow the handshake to complete now.
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+ client_->Handshake(); // Transition to connected
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+
+ // Remove filter before closing or the close_notify alert will trigger it.
+ client_->ClearFilter();
+}
+
+TEST_P(TlsConnectGenericPre13, AuthCompleteFailDelayed) {
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Send ServerHello
+ client_->Handshake(); // Send ClientKeyExchange and Finished
+ server_->Handshake(); // Send Finished
+ // The server should now report that it is connected
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+
+ // The client should send nothing from here on.
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+
+ // Report failure.
+ client_->ClearFilter();
+ client_->ExpectSendAlert(kTlsAlertBadCertificate);
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(),
+ SSL_ERROR_BAD_CERTIFICATE));
+ client_->Handshake(); // Fail
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+}
+
+// TLS 1.3 handles a delayed AuthComplete callback differently since the
+// shape of the handshake is different.
+TEST_P(TlsConnectTls13, AuthCompleteDelayed) {
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Send ServerHello
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
+
+ // The client will send nothing until AuthCertificateComplete is called.
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+
+ // This should allow the handshake to complete now.
+ client_->ClearFilter();
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+ client_->Handshake(); // Send Finished
+ server_->Handshake(); // Transition to connected and send NewSessionTicket
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+}
+
+TEST_P(TlsConnectTls13, AuthCompleteFailDelayed) {
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Send ServerHello
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
+
+ // The client will send nothing until AuthCertificateComplete is called.
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+
+ // Report failure.
+ client_->ClearFilter();
+ ExpectAlert(client_, kTlsAlertBadCertificate);
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(),
+ SSL_ERROR_BAD_CERTIFICATE));
+ client_->Handshake(); // This should now fail.
+ server_->Handshake(); // Get the error.
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state());
+}
+
+static SECStatus AuthCompleteFail(TlsAgent*, PRBool, PRBool) {
+ PORT_SetError(SSL_ERROR_BAD_CERTIFICATE);
+ return SECFailure;
+}
+
+TEST_P(TlsConnectGeneric, AuthFailImmediate) {
+ client_->SetAuthCertificateCallback(AuthCompleteFail);
+
+ StartConnect();
+ ConnectExpectAlert(client_, kTlsAlertBadCertificate);
+ client_->CheckErrorCode(SSL_ERROR_BAD_CERTIFICATE);
+}
+
+static const SSLExtraServerCertData ServerCertDataRsaPkcs1Decrypt = {
+ ssl_auth_rsa_decrypt, nullptr, nullptr, nullptr, nullptr, nullptr};
+static const SSLExtraServerCertData ServerCertDataRsaPkcs1Sign = {
+ ssl_auth_rsa_sign, nullptr, nullptr, nullptr, nullptr, nullptr};
+static const SSLExtraServerCertData ServerCertDataRsaPss = {
+ ssl_auth_rsa_pss, nullptr, nullptr, nullptr, nullptr, nullptr};
+
+// Test RSA cert with usage=[signature, encipherment].
+TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPkcs1SignAndKEX) {
+ Reset(TlsAgent::kServerRsa);
+
+ PRFileDesc* ssl_fd = agent_->ssl_fd();
+ EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt));
+ EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign));
+ EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss));
+
+ // Configuring for only rsa_sign or rsa_decrypt should work.
+ EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsa, false,
+ &ServerCertDataRsaPkcs1Decrypt));
+ EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsa, false,
+ &ServerCertDataRsaPkcs1Sign));
+ EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsa, false,
+ &ServerCertDataRsaPss));
+}
+
+// Test RSA cert with usage=[signature].
+TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPkcs1Sign) {
+ Reset(TlsAgent::kServerRsaSign);
+
+ PRFileDesc* ssl_fd = agent_->ssl_fd();
+ EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt));
+ EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign));
+ EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss));
+
+ // Configuring for only rsa_decrypt should fail.
+ EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaSign, false,
+ &ServerCertDataRsaPkcs1Decrypt));
+
+ // Configuring for only rsa_sign should work.
+ EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsaSign, false,
+ &ServerCertDataRsaPkcs1Sign));
+ EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaSign, false,
+ &ServerCertDataRsaPss));
+}
+
+// Test RSA cert with usage=[encipherment].
+TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPkcs1KEX) {
+ Reset(TlsAgent::kServerRsaDecrypt);
+
+ PRFileDesc* ssl_fd = agent_->ssl_fd();
+ EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt));
+ EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign));
+ EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss));
+
+ // Configuring for only rsa_sign or rsa_pss should fail.
+ EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaDecrypt, false,
+ &ServerCertDataRsaPkcs1Sign));
+ EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaDecrypt, false,
+ &ServerCertDataRsaPss));
+
+ // Configuring for only rsa_decrypt should work.
+ EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsaDecrypt, false,
+ &ServerCertDataRsaPkcs1Decrypt));
+}
+
+// Test configuring an RSA-PSS cert.
+TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPss) {
+ Reset(TlsAgent::kServerRsaPss);
+
+ PRFileDesc* ssl_fd = agent_->ssl_fd();
+ EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt));
+ EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign));
+ EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss));
+
+ // Configuring for only rsa_sign or rsa_decrypt should fail.
+ EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaPss, false,
+ &ServerCertDataRsaPkcs1Sign));
+ EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaPss, false,
+ &ServerCertDataRsaPkcs1Decrypt));
+
+ // Configuring for only rsa_pss should work.
+ EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsaPss, false,
+ &ServerCertDataRsaPss));
+}
+
+// A server should refuse to even start a handshake with
+// misconfigured certificate and signature scheme.
+TEST_P(TlsConnectTls12Plus, MisconfiguredCertScheme) {
+ Reset(TlsAgent::kServerDsa);
+ static const SSLSignatureScheme kScheme[] = {ssl_sig_ecdsa_secp256r1_sha256};
+ server_->SetSignatureSchemes(kScheme, PR_ARRAY_SIZE(kScheme));
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ // TLS 1.2 disables cipher suites, which leads to a different error.
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ } else {
+ server_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM);
+ }
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+// In TLS 1.2, disabling an EC group causes ECDSA to be invalid.
+TEST_P(TlsConnectTls12, Tls12CertDisabledGroup) {
+ Reset(TlsAgent::kServerEcdsa256);
+ static const std::vector<SSLNamedGroup> k25519 = {ssl_grp_ec_curve25519};
+ server_->ConfigNamedGroups(k25519);
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+// In TLS 1.3, ECDSA configuration only depends on the signature scheme.
+TEST_P(TlsConnectTls13, Tls13CertDisabledGroup) {
+ Reset(TlsAgent::kServerEcdsa256);
+ static const std::vector<SSLNamedGroup> k25519 = {ssl_grp_ec_curve25519};
+ server_->ConfigNamedGroups(k25519);
+ Connect();
+}
+
+// A client should refuse to even start a handshake with only DSA.
+TEST_P(TlsConnectTls13, Tls13DsaOnlyClient) {
+ static const SSLSignatureScheme kDsa[] = {ssl_sig_dsa_sha256};
+ client_->SetSignatureSchemes(kDsa, PR_ARRAY_SIZE(kDsa));
+ client_->StartConnect();
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ client_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM);
+}
+
+TEST_P(TlsConnectTls13, Tls13DsaOnlyServer) {
+ Reset(TlsAgent::kServerDsa);
+ static const SSLSignatureScheme kDsa[] = {ssl_sig_dsa_sha256};
+ server_->SetSignatureSchemes(kDsa, PR_ARRAY_SIZE(kDsa));
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+TEST_P(TlsConnectTls13, Tls13Pkcs1OnlyClient) {
+ static const SSLSignatureScheme kPkcs1[] = {ssl_sig_rsa_pkcs1_sha256};
+ client_->SetSignatureSchemes(kPkcs1, PR_ARRAY_SIZE(kPkcs1));
+ client_->StartConnect();
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ client_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM);
+}
+
+TEST_P(TlsConnectTls13, Tls13Pkcs1OnlyServer) {
+ static const SSLSignatureScheme kPkcs1[] = {ssl_sig_rsa_pkcs1_sha256};
+ server_->SetSignatureSchemes(kPkcs1, PR_ARRAY_SIZE(kPkcs1));
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+TEST_P(TlsConnectTls13, Tls13DsaIsNotAdvertisedClient) {
+ EnsureTlsSetup();
+ static const SSLSignatureScheme kSchemes[] = {ssl_sig_dsa_sha256,
+ ssl_sig_rsa_pss_rsae_sha256};
+ client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ auto capture =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn);
+ Connect();
+ // We should only have the one signature algorithm advertised.
+ static const uint8_t kExpectedExt[] = {0, 2, ssl_sig_rsa_pss_rsae_sha256 >> 8,
+ ssl_sig_rsa_pss_rsae_sha256 & 0xff};
+ ASSERT_EQ(DataBuffer(kExpectedExt, sizeof(kExpectedExt)),
+ capture->extension());
+}
+
+TEST_P(TlsConnectTls13, Tls13DsaIsNotAdvertisedServer) {
+ EnsureTlsSetup();
+ static const SSLSignatureScheme kSchemes[] = {ssl_sig_dsa_sha256,
+ ssl_sig_rsa_pss_rsae_sha256};
+ server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ server_, ssl_signature_algorithms_xtn, true);
+ capture->SetHandshakeTypes({kTlsHandshakeCertificateRequest});
+ capture->EnableDecryption();
+ server_->RequestClientAuth(false); // So we get a CertificateRequest.
+ Connect();
+ // We should only have the one signature algorithm advertised.
+ static const uint8_t kExpectedExt[] = {0, 2, ssl_sig_rsa_pss_rsae_sha256 >> 8,
+ ssl_sig_rsa_pss_rsae_sha256 & 0xff};
+ ASSERT_EQ(DataBuffer(kExpectedExt, sizeof(kExpectedExt)),
+ capture->extension());
+}
+
+TEST_P(TlsConnectTls13, Tls13RsaPkcs1IsAdvertisedClient) {
+ EnsureTlsSetup();
+ static const SSLSignatureScheme kSchemes[] = {ssl_sig_rsa_pkcs1_sha256,
+ ssl_sig_rsa_pss_rsae_sha256};
+ client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ auto capture =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn);
+ Connect();
+ // We should only have the one signature algorithm advertised.
+ static const uint8_t kExpectedExt[] = {0,
+ 4,
+ ssl_sig_rsa_pss_rsae_sha256 >> 8,
+ ssl_sig_rsa_pss_rsae_sha256 & 0xff,
+ ssl_sig_rsa_pkcs1_sha256 >> 8,
+ ssl_sig_rsa_pkcs1_sha256 & 0xff};
+ ASSERT_EQ(DataBuffer(kExpectedExt, sizeof(kExpectedExt)),
+ capture->extension());
+}
+
+TEST_P(TlsConnectTls13, Tls13RsaPkcs1IsAdvertisedServer) {
+ EnsureTlsSetup();
+ static const SSLSignatureScheme kSchemes[] = {ssl_sig_rsa_pkcs1_sha256,
+ ssl_sig_rsa_pss_rsae_sha256};
+ server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ server_, ssl_signature_algorithms_xtn, true);
+ capture->SetHandshakeTypes({kTlsHandshakeCertificateRequest});
+ capture->EnableDecryption();
+ server_->RequestClientAuth(false); // So we get a CertificateRequest.
+ Connect();
+ // We should only have the one signature algorithm advertised.
+ static const uint8_t kExpectedExt[] = {0,
+ 4,
+ ssl_sig_rsa_pss_rsae_sha256 >> 8,
+ ssl_sig_rsa_pss_rsae_sha256 & 0xff,
+ ssl_sig_rsa_pkcs1_sha256 >> 8,
+ ssl_sig_rsa_pkcs1_sha256 & 0xff};
+ ASSERT_EQ(DataBuffer(kExpectedExt, sizeof(kExpectedExt)),
+ capture->extension());
+}
+
+// variant, version, certificate, auth type, signature scheme
+typedef std::tuple<SSLProtocolVariant, uint16_t, std::string, SSLAuthType,
+ SSLSignatureScheme>
+ SignatureSchemeProfile;
+
+class TlsSignatureSchemeConfiguration
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SignatureSchemeProfile> {
+ public:
+ TlsSignatureSchemeConfiguration()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())),
+ certificate_(std::get<2>(GetParam())),
+ auth_type_(std::get<3>(GetParam())),
+ signature_scheme_(std::get<4>(GetParam())) {}
+
+ protected:
+ void TestSignatureSchemeConfig(std::shared_ptr<TlsAgent>& configPeer) {
+ EnsureTlsSetup();
+ configPeer->SetSignatureSchemes(&signature_scheme_, 1);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, auth_type_,
+ signature_scheme_);
+ }
+
+ std::string certificate_;
+ SSLAuthType auth_type_;
+ SSLSignatureScheme signature_scheme_;
+};
+
+TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigServer) {
+ Reset(certificate_);
+ TestSignatureSchemeConfig(server_);
+}
+
+TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigClient) {
+ Reset(certificate_);
+ auto capture =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn);
+ TestSignatureSchemeConfig(client_);
+
+ const DataBuffer& ext = capture->extension();
+ ASSERT_EQ(2U + 2U, ext.len());
+ uint32_t v = 0;
+ ASSERT_TRUE(ext.Read(0, 2, &v));
+ EXPECT_EQ(2U, v);
+ ASSERT_TRUE(ext.Read(2, 2, &v));
+ EXPECT_EQ(signature_scheme_, static_cast<SSLSignatureScheme>(v));
+}
+
+TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigBoth) {
+ Reset(certificate_);
+ EnsureTlsSetup();
+ client_->SetSignatureSchemes(&signature_scheme_, 1);
+ server_->SetSignatureSchemes(&signature_scheme_, 1);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, auth_type_, signature_scheme_);
+}
+
+class Tls12CertificateRequestReplacer : public TlsHandshakeFilter {
+ public:
+ Tls12CertificateRequestReplacer(const std::shared_ptr<TlsAgent>& a,
+ SSLSignatureScheme scheme)
+ : TlsHandshakeFilter(a, {kTlsHandshakeCertificateRequest}),
+ scheme_(scheme) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ uint32_t offset = 0;
+
+ if (header.handshake_type() != ssl_hs_certificate_request) {
+ return KEEP;
+ }
+
+ *output = input;
+
+ uint32_t types_len = 0;
+ if (!output->Read(offset, 1, &types_len)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+ offset += 1 + types_len;
+ uint32_t scheme_len = 0;
+ if (!output->Read(offset, 2, &scheme_len)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+ DataBuffer schemes;
+ schemes.Write(0, 2, 2);
+ schemes.Write(2, scheme_, 2);
+ output->Write(offset, 2, schemes.len());
+ output->Splice(schemes, offset + 2, scheme_len);
+
+ return CHANGE;
+ }
+
+ private:
+ SSLSignatureScheme scheme_;
+};
+
+//
+// Test how policy interacts with client auth connections
+//
+
+// TLS/DTLS version algorithm policy
+typedef std::tuple<SSLProtocolVariant, uint16_t, SECOidTag, PRUint32>
+ PolicySignatureSchemeProfile;
+
+// Only TLS 1.2 handles client auth schemes inside
+// the certificate request packet, so our failure tests for
+// those kinds of connections only occur here.
+class TlsConnectAuthWithPolicyTls12
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<PolicySignatureSchemeProfile> {
+ public:
+ TlsConnectAuthWithPolicyTls12()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+ alg_ = std::get<2>(GetParam());
+ policy_ = std::get<3>(GetParam());
+ // use the algorithm to select which single scheme to deploy
+ // We use these schemes to force servers sending schemes the client
+ // didn't advertise to make sure the client will still filter these
+ // by policy and detect that no valid schemes were presented, rather
+ // than sending an empty client auth message.
+ switch (alg_) {
+ case SEC_OID_SHA256:
+ case SEC_OID_PKCS1_RSA_PSS_SIGNATURE:
+ scheme_ = ssl_sig_rsa_pss_pss_sha256;
+ break;
+ case SEC_OID_PKCS1_RSA_ENCRYPTION:
+ scheme_ = ssl_sig_rsa_pkcs1_sha256;
+ break;
+ case SEC_OID_ANSIX962_EC_PUBLIC_KEY:
+ scheme_ = ssl_sig_ecdsa_secp256r1_sha256;
+ break;
+ default:
+ ADD_FAILURE() << "need to update algorithm table in "
+ "TlsConnectAuthWithPolicyTls12";
+ scheme_ = ssl_sig_none;
+ break;
+ }
+ }
+
+ protected:
+ SECOidTag alg_;
+ PRUint32 policy_;
+ SSLSignatureScheme scheme_;
+};
+
+// Only TLS 1.2 and greater looks at schemes extensions on client auth
+class TlsConnectAuthWithPolicyTls12Plus
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<PolicySignatureSchemeProfile> {
+ public:
+ TlsConnectAuthWithPolicyTls12Plus()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+ alg_ = std::get<2>(GetParam());
+ policy_ = std::get<3>(GetParam());
+ }
+
+ protected:
+ SECOidTag alg_;
+ PRUint32 policy_;
+};
+
+// make sure we can turn single algorithms off by policy an still connect
+// this is basically testing that we are properly filtering our schemes
+// by policy before communicating them to the server, and that the
+// server is respecting our choices
+TEST_P(TlsConnectAuthWithPolicyTls12Plus, PolicySuccessTest) {
+ // in TLS 1.3, RSA PKCS1 is restricted. If we are also
+ // restricting RSA PSS by policy, we can't use the default
+ // RSA certificate as the server cert, switch to ECDSA
+ if ((version_ >= SSL_LIBRARY_VERSION_TLS_1_3) &&
+ (alg_ == SEC_OID_PKCS1_RSA_PSS_SIGNATURE)) {
+ Reset(TlsAgent::kServerEcdsa256);
+ }
+ client_->SetPolicy(alg_, 0, policy_); // Disable policy for client
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(false);
+ Connect();
+}
+
+// make sure we fail if the server ignores our policy preference and
+// requests client auth with a scheme we don't support
+TEST_P(TlsConnectAuthWithPolicyTls12, PolicyFailureTest) {
+ client_->SetPolicy(alg_, 0, policy_);
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(false);
+ MakeTlsFilter<Tls12CertificateRequestReplacer>(server_, scheme_);
+ ConnectExpectAlert(client_, kTlsAlertHandshakeFailure);
+ client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ SignaturesWithPolicyFail, TlsConnectAuthWithPolicyTls12,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12,
+ ::testing::Values(SEC_OID_SHA256,
+ SEC_OID_PKCS1_RSA_PSS_SIGNATURE,
+ SEC_OID_PKCS1_RSA_ENCRYPTION,
+ SEC_OID_ANSIX962_EC_PUBLIC_KEY),
+ ::testing::Values(NSS_USE_ALG_IN_SSL_KX,
+ NSS_USE_ALG_IN_ANY_SIGNATURE)));
+
+INSTANTIATE_TEST_SUITE_P(
+ SignaturesWithPolicySuccess, TlsConnectAuthWithPolicyTls12Plus,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12Plus,
+ ::testing::Values(SEC_OID_SHA256,
+ SEC_OID_PKCS1_RSA_PSS_SIGNATURE,
+ SEC_OID_PKCS1_RSA_ENCRYPTION,
+ SEC_OID_ANSIX962_EC_PUBLIC_KEY),
+ ::testing::Values(NSS_USE_ALG_IN_SSL_KX,
+ NSS_USE_ALG_IN_ANY_SIGNATURE)));
+
+INSTANTIATE_TEST_SUITE_P(
+ SignatureSchemeRsa, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(
+ TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12,
+ ::testing::Values(TlsAgent::kServerRsaSign),
+ ::testing::Values(ssl_auth_rsa_sign),
+ ::testing::Values(ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384,
+ ssl_sig_rsa_pkcs1_sha512, ssl_sig_rsa_pss_rsae_sha256,
+ ssl_sig_rsa_pss_rsae_sha384)));
+// RSASSA-PKCS1-v1_5 is not allowed to be used in TLS 1.3
+INSTANTIATE_TEST_SUITE_P(
+ SignatureSchemeRsaTls13, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV13,
+ ::testing::Values(TlsAgent::kServerRsaSign),
+ ::testing::Values(ssl_auth_rsa_sign),
+ ::testing::Values(ssl_sig_rsa_pss_rsae_sha256,
+ ssl_sig_rsa_pss_rsae_sha384)));
+// PSS with SHA-512 needs a bigger key to work.
+INSTANTIATE_TEST_SUITE_P(
+ SignatureSchemeBigRsa, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12Plus,
+ ::testing::Values(TlsAgent::kRsa2048),
+ ::testing::Values(ssl_auth_rsa_sign),
+ ::testing::Values(ssl_sig_rsa_pss_rsae_sha512)));
+INSTANTIATE_TEST_SUITE_P(
+ SignatureSchemeRsaSha1, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12,
+ ::testing::Values(TlsAgent::kServerRsa),
+ ::testing::Values(ssl_auth_rsa_sign),
+ ::testing::Values(ssl_sig_rsa_pkcs1_sha1)));
+INSTANTIATE_TEST_SUITE_P(
+ SignatureSchemeEcdsaP256, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12Plus,
+ ::testing::Values(TlsAgent::kServerEcdsa256),
+ ::testing::Values(ssl_auth_ecdsa),
+ ::testing::Values(ssl_sig_ecdsa_secp256r1_sha256)));
+INSTANTIATE_TEST_SUITE_P(
+ SignatureSchemeEcdsaP384, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12Plus,
+ ::testing::Values(TlsAgent::kServerEcdsa384),
+ ::testing::Values(ssl_auth_ecdsa),
+ ::testing::Values(ssl_sig_ecdsa_secp384r1_sha384)));
+INSTANTIATE_TEST_SUITE_P(
+ SignatureSchemeEcdsaP521, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12Plus,
+ ::testing::Values(TlsAgent::kServerEcdsa521),
+ ::testing::Values(ssl_auth_ecdsa),
+ ::testing::Values(ssl_sig_ecdsa_secp521r1_sha512)));
+INSTANTIATE_TEST_SUITE_P(
+ SignatureSchemeEcdsaSha1, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12,
+ ::testing::Values(TlsAgent::kServerEcdsa256,
+ TlsAgent::kServerEcdsa384),
+ ::testing::Values(ssl_auth_ecdsa),
+ ::testing::Values(ssl_sig_ecdsa_sha1)));
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
new file mode 100644
index 0000000000..26e5fb5028
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
@@ -0,0 +1,246 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include <memory>
+
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+// Tests for Certificate Transparency (RFC 6962)
+// These don't work with TLS 1.3: see bug 1252745.
+
+// Helper class - stores signed certificate timestamps as provided
+// by the relevant callbacks on the client.
+class SignedCertificateTimestampsExtractor {
+ public:
+ SignedCertificateTimestampsExtractor(std::shared_ptr<TlsAgent>& client)
+ : client_(client) {
+ client->SetAuthCertificateCallback(
+ [this](TlsAgent* agent, bool checksig, bool isServer) -> SECStatus {
+ const SECItem* scts = SSL_PeerSignedCertTimestamps(agent->ssl_fd());
+ EXPECT_TRUE(scts);
+ if (!scts) {
+ return SECFailure;
+ }
+ auth_timestamps_.reset(new DataBuffer(scts->data, scts->len));
+ return SECSuccess;
+ });
+ client->SetHandshakeCallback([this](TlsAgent* agent) {
+ const SECItem* scts = SSL_PeerSignedCertTimestamps(agent->ssl_fd());
+ ASSERT_TRUE(scts);
+ handshake_timestamps_.reset(new DataBuffer(scts->data, scts->len));
+ });
+ }
+
+ void assertTimestamps(const DataBuffer& timestamps) {
+ ASSERT_NE(nullptr, auth_timestamps_);
+ EXPECT_EQ(timestamps, *auth_timestamps_);
+
+ ASSERT_NE(nullptr, handshake_timestamps_);
+ EXPECT_EQ(timestamps, *handshake_timestamps_);
+
+ const SECItem* current =
+ SSL_PeerSignedCertTimestamps(client_.lock()->ssl_fd());
+ EXPECT_EQ(timestamps, DataBuffer(current->data, current->len));
+ }
+
+ private:
+ std::weak_ptr<TlsAgent> client_;
+ std::unique_ptr<DataBuffer> auth_timestamps_;
+ std::unique_ptr<DataBuffer> handshake_timestamps_;
+};
+
+static const uint8_t kSctValue[] = {0x01, 0x23, 0x45, 0x67, 0x89};
+static const SECItem kSctItem = {siBuffer, const_cast<uint8_t*>(kSctValue),
+ sizeof(kSctValue)};
+static const DataBuffer kSctBuffer(kSctValue, sizeof(kSctValue));
+static const SSLExtraServerCertData kExtraSctData = {
+ ssl_auth_null, nullptr, nullptr, &kSctItem, nullptr, nullptr};
+
+// Test timestamps extraction during a successful handshake.
+TEST_P(TlsConnectGenericPre13, SignedCertificateTimestampsLegacy) {
+ EnsureTlsSetup();
+
+ // We have to use the legacy API consistently here for configuring certs.
+ // Also, this doesn't work in TLS 1.3 because this only configures the SCT for
+ // RSA decrypt and PKCS#1 signing, not PSS.
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsa, &cert, &priv));
+ EXPECT_EQ(SECSuccess, SSL_ConfigSecureServerWithCertChain(
+ server_->ssl_fd(), cert.get(), nullptr, priv.get(),
+ ssl_kea_rsa));
+ EXPECT_EQ(SECSuccess, SSL_SetSignedCertTimestamps(server_->ssl_fd(),
+ &kSctItem, ssl_kea_rsa));
+
+ client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
+ SignedCertificateTimestampsExtractor timestamps_extractor(client_);
+
+ Connect();
+
+ timestamps_extractor.assertTimestamps(kSctBuffer);
+}
+
+TEST_P(TlsConnectGeneric, SignedCertificateTimestampsSuccess) {
+ EnsureTlsSetup();
+ EXPECT_TRUE(
+ server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraSctData));
+ client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
+ SignedCertificateTimestampsExtractor timestamps_extractor(client_);
+
+ Connect();
+
+ timestamps_extractor.assertTimestamps(kSctBuffer);
+}
+
+// Test SSL_PeerSignedCertTimestamps returning zero-length SECItem
+// when the client / the server / both have not enabled the feature.
+TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveClient) {
+ EnsureTlsSetup();
+ EXPECT_TRUE(
+ server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraSctData));
+ SignedCertificateTimestampsExtractor timestamps_extractor(client_);
+
+ Connect();
+ timestamps_extractor.assertTimestamps(DataBuffer());
+}
+
+TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveServer) {
+ EnsureTlsSetup();
+ client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
+ SignedCertificateTimestampsExtractor timestamps_extractor(client_);
+
+ Connect();
+ timestamps_extractor.assertTimestamps(DataBuffer());
+}
+
+TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveBoth) {
+ EnsureTlsSetup();
+ SignedCertificateTimestampsExtractor timestamps_extractor(client_);
+
+ Connect();
+ timestamps_extractor.assertTimestamps(DataBuffer());
+}
+
+// Check that the given agent doesn't have an OCSP response for its peer.
+static SECStatus CheckNoOCSP(TlsAgent* agent, bool checksig, bool isServer) {
+ const SECItemArray* ocsp = SSL_PeerStapledOCSPResponses(agent->ssl_fd());
+ EXPECT_TRUE(ocsp);
+ EXPECT_EQ(0U, ocsp->len);
+ return SECSuccess;
+}
+
+static const uint8_t kOcspValue1[] = {1, 2, 3, 4, 5, 6};
+static const uint8_t kOcspValue2[] = {7, 8, 9};
+static const SECItem kOcspItems[] = {
+ {siBuffer, const_cast<uint8_t*>(kOcspValue1), sizeof(kOcspValue1)},
+ {siBuffer, const_cast<uint8_t*>(kOcspValue2), sizeof(kOcspValue2)}};
+static const SECItemArray kOcspResponses = {const_cast<SECItem*>(kOcspItems),
+ PR_ARRAY_SIZE(kOcspItems)};
+const static SSLExtraServerCertData kOcspExtraData = {
+ ssl_auth_null, nullptr, &kOcspResponses, nullptr, nullptr, nullptr};
+
+TEST_P(TlsConnectGeneric, NoOcsp) {
+ EnsureTlsSetup();
+ client_->SetAuthCertificateCallback(CheckNoOCSP);
+ Connect();
+}
+
+// The client doesn't get OCSP stapling unless it asks.
+TEST_P(TlsConnectGeneric, OcspNotRequested) {
+ EnsureTlsSetup();
+ client_->SetAuthCertificateCallback(CheckNoOCSP);
+ EXPECT_TRUE(
+ server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData));
+ Connect();
+}
+
+// Even if the client asks, the server has nothing unless it is configured.
+TEST_P(TlsConnectGeneric, OcspNotProvided) {
+ EnsureTlsSetup();
+ client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
+ client_->SetAuthCertificateCallback(CheckNoOCSP);
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, OcspMangled) {
+ EnsureTlsSetup();
+ client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
+ EXPECT_TRUE(
+ server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData));
+
+ static const uint8_t val[] = {1};
+ auto replacer = MakeTlsFilter<TlsExtensionReplacer>(
+ server_, ssl_cert_status_xtn, DataBuffer(val, sizeof(val)));
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectGeneric, OcspSuccess) {
+ EnsureTlsSetup();
+ client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
+ auto capture_ocsp =
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_cert_status_xtn);
+
+ // The value should be available during the AuthCertificateCallback
+ client_->SetAuthCertificateCallback([](TlsAgent* agent, bool checksig,
+ bool isServer) -> SECStatus {
+ const SECItemArray* ocsp = SSL_PeerStapledOCSPResponses(agent->ssl_fd());
+ if (!ocsp) {
+ return SECFailure;
+ }
+ EXPECT_EQ(1U, ocsp->len) << "We only provide the first item";
+ EXPECT_EQ(0, SECITEM_CompareItem(&kOcspItems[0], &ocsp->items[0]));
+ return SECSuccess;
+ });
+ EXPECT_TRUE(
+ server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData));
+
+ Connect();
+ // In TLS 1.3, the server doesn't provide a visible ServerHello extension.
+ // For earlier versions, the extension is just empty.
+ EXPECT_EQ(0U, capture_ocsp->extension().len());
+}
+
+TEST_P(TlsConnectGeneric, OcspHugeSuccess) {
+ EnsureTlsSetup();
+ client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
+
+ uint8_t hugeOcspValue[16385];
+ memset(hugeOcspValue, 0xa1, sizeof(hugeOcspValue));
+ const SECItem hugeOcspItems[] = {
+ {siBuffer, const_cast<uint8_t*>(hugeOcspValue), sizeof(hugeOcspValue)}};
+ const SECItemArray hugeOcspResponses = {const_cast<SECItem*>(hugeOcspItems),
+ PR_ARRAY_SIZE(hugeOcspItems)};
+ const SSLExtraServerCertData hugeOcspExtraData = {
+ ssl_auth_null, nullptr, &hugeOcspResponses, nullptr, nullptr, nullptr};
+
+ // The value should be available during the AuthCertificateCallback
+ client_->SetAuthCertificateCallback([&](TlsAgent* agent, bool checksig,
+ bool isServer) -> SECStatus {
+ const SECItemArray* ocsp = SSL_PeerStapledOCSPResponses(agent->ssl_fd());
+ if (!ocsp) {
+ return SECFailure;
+ }
+ EXPECT_EQ(1U, ocsp->len) << "We only provide the first item";
+ EXPECT_EQ(0, SECITEM_CompareItem(&hugeOcspItems[0], &ocsp->items[0]));
+ return SECSuccess;
+ });
+ EXPECT_TRUE(server_->ConfigServerCert(TlsAgent::kServerRsa, true,
+ &hugeOcspExtraData));
+
+ Connect();
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_cipherorder_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_cipherorder_unittest.cc
new file mode 100644
index 0000000000..1e4f817e95
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_cipherorder_unittest.cc
@@ -0,0 +1,241 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include <memory>
+
+#include "tls_connect.h"
+#include "tls_filter.h"
+
+namespace nss_test {
+
+class TlsCipherOrderTest : public TlsConnectTestBase {
+ protected:
+ virtual void ConfigureTLS() {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ }
+
+ virtual SECStatus BuildTestLists(std::vector<uint16_t> &cs_initial_list,
+ std::vector<uint16_t> &cs_new_list) {
+ // This is the current CipherSuites order of enabled CipherSuites as defined
+ // in ssl3con.c
+ const PRUint16 *kCipherSuites = SSL_GetImplementedCiphers();
+
+ for (unsigned int i = 0; i < kNumImplementedCiphers; i++) {
+ PRBool pref = PR_FALSE, policy = PR_FALSE;
+ SECStatus rv;
+ rv = SSL_CipherPolicyGet(kCipherSuites[i], &policy);
+ if (rv != SECSuccess) {
+ return SECFailure;
+ }
+ rv = SSL_CipherPrefGetDefault(kCipherSuites[i], &pref);
+ if (rv != SECSuccess) {
+ return SECFailure;
+ }
+ if (pref && policy) {
+ cs_initial_list.push_back(kCipherSuites[i]);
+ }
+ }
+
+ // We will test set function with the first 15 enabled ciphers.
+ const PRUint16 kNumCiphersToSet = 15;
+ for (unsigned int i = 0; i < kNumCiphersToSet; i++) {
+ cs_new_list.push_back(cs_initial_list[i]);
+ }
+ cs_new_list[0] = cs_initial_list[1];
+ cs_new_list[1] = cs_initial_list[0];
+ return SECSuccess;
+ }
+
+ public:
+ TlsCipherOrderTest() : TlsConnectTestBase(ssl_variant_stream, 0) {}
+ const unsigned int kNumImplementedCiphers = SSL_GetNumImplementedCiphers();
+};
+
+const PRUint16 kCSUnsupported[] = {20196, 10101};
+const PRUint16 kNumCSUnsupported = PR_ARRAY_SIZE(kCSUnsupported);
+const PRUint16 kCSEmpty[] = {0};
+
+// Get the active CipherSuites odered as they were compiled
+TEST_F(TlsCipherOrderTest, CipherOrderGet) {
+ std::vector<uint16_t> initial_cs_order;
+ std::vector<uint16_t> new_cs_order;
+ SECStatus result = BuildTestLists(initial_cs_order, new_cs_order);
+ ASSERT_EQ(result, SECSuccess);
+ ConfigureTLS();
+
+ std::vector<uint16_t> current_cs_order(SSL_GetNumImplementedCiphers() + 1);
+ unsigned int current_num_active_cs = 0;
+ result = SSL_CipherSuiteOrderGet(client_->ssl_fd(), current_cs_order.data(),
+ &current_num_active_cs);
+ ASSERT_EQ(result, SECSuccess);
+ ASSERT_EQ(current_num_active_cs, initial_cs_order.size());
+ for (unsigned int i = 0; i < initial_cs_order.size(); i++) {
+ EXPECT_EQ(initial_cs_order[i], current_cs_order[i]);
+ }
+ // Get the chosen CipherSuite during the Handshake without any modification.
+ Connect();
+ SSLChannelInfo channel;
+ result = SSL_GetChannelInfo(client_->ssl_fd(), &channel, sizeof channel);
+ ASSERT_EQ(result, SECSuccess);
+ EXPECT_EQ(channel.cipherSuite, initial_cs_order[0]);
+}
+
+// The "server" used for gtests honor only its ciphersuites order.
+// So, we apply the new set for the server instead of client.
+// This is enough to test the effect of SSL_CipherSuiteOrderSet function.
+TEST_F(TlsCipherOrderTest, CipherOrderSet) {
+ std::vector<uint16_t> initial_cs_order;
+ std::vector<uint16_t> new_cs_order;
+ SECStatus result = BuildTestLists(initial_cs_order, new_cs_order);
+ ASSERT_EQ(result, SECSuccess);
+ ConfigureTLS();
+
+ // change the server_ ciphersuites order.
+ result = SSL_CipherSuiteOrderSet(server_->ssl_fd(), new_cs_order.data(),
+ new_cs_order.size());
+ ASSERT_EQ(result, SECSuccess);
+
+ // The function expect an array. We are using vector for VStudio
+ // compatibility.
+ std::vector<uint16_t> current_cs_order(SSL_GetNumImplementedCiphers() + 1);
+ unsigned int current_num_active_cs = 0;
+ result = SSL_CipherSuiteOrderGet(server_->ssl_fd(), current_cs_order.data(),
+ &current_num_active_cs);
+ ASSERT_EQ(result, SECSuccess);
+ ASSERT_EQ(current_num_active_cs, new_cs_order.size());
+ for (unsigned int i = 0; i < new_cs_order.size(); i++) {
+ ASSERT_EQ(new_cs_order[i], current_cs_order[i]);
+ }
+
+ Connect();
+ SSLChannelInfo channel;
+ // changes in server_ order reflect in client chosen ciphersuite.
+ result = SSL_GetChannelInfo(client_->ssl_fd(), &channel, sizeof channel);
+ ASSERT_EQ(result, SECSuccess);
+ EXPECT_EQ(channel.cipherSuite, new_cs_order[0]);
+}
+
+// Duplicate socket configuration from a model.
+TEST_F(TlsCipherOrderTest, CipherOrderCopySocket) {
+ std::vector<uint16_t> initial_cs_order;
+ std::vector<uint16_t> new_cs_order;
+ SECStatus result = BuildTestLists(initial_cs_order, new_cs_order);
+ ASSERT_EQ(result, SECSuccess);
+ ConfigureTLS();
+
+ // Use the existing sockets for this test.
+ result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), new_cs_order.data(),
+ new_cs_order.size());
+ ASSERT_EQ(result, SECSuccess);
+
+ std::vector<uint16_t> current_cs_order(SSL_GetNumImplementedCiphers() + 1);
+ unsigned int current_num_active_cs = 0;
+ result = SSL_CipherSuiteOrderGet(server_->ssl_fd(), current_cs_order.data(),
+ &current_num_active_cs);
+ ASSERT_EQ(result, SECSuccess);
+ ASSERT_EQ(current_num_active_cs, initial_cs_order.size());
+ for (unsigned int i = 0; i < current_num_active_cs; i++) {
+ ASSERT_EQ(initial_cs_order[i], current_cs_order[i]);
+ }
+
+ // Import/Duplicate configurations from client_ to server_
+ PRFileDesc *rv = SSL_ImportFD(client_->ssl_fd(), server_->ssl_fd());
+ EXPECT_NE(nullptr, rv);
+
+ result = SSL_CipherSuiteOrderGet(server_->ssl_fd(), current_cs_order.data(),
+ &current_num_active_cs);
+ ASSERT_EQ(result, SECSuccess);
+ ASSERT_EQ(current_num_active_cs, new_cs_order.size());
+ for (unsigned int i = 0; i < new_cs_order.size(); i++) {
+ EXPECT_EQ(new_cs_order.data()[i], current_cs_order[i]);
+ }
+}
+
+// If the infomed num of elements is lower than the actual list size, only the
+// first "informed num" elements will be considered. The rest is ignored.
+TEST_F(TlsCipherOrderTest, CipherOrderSetLower) {
+ std::vector<uint16_t> initial_cs_order;
+ std::vector<uint16_t> new_cs_order;
+ SECStatus result = BuildTestLists(initial_cs_order, new_cs_order);
+ ASSERT_EQ(result, SECSuccess);
+ ConfigureTLS();
+
+ result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), new_cs_order.data(),
+ new_cs_order.size() - 1);
+ ASSERT_EQ(result, SECSuccess);
+
+ std::vector<uint16_t> current_cs_order(SSL_GetNumImplementedCiphers() + 1);
+ unsigned int current_num_active_cs = 0;
+ result = SSL_CipherSuiteOrderGet(client_->ssl_fd(), current_cs_order.data(),
+ &current_num_active_cs);
+ ASSERT_EQ(result, SECSuccess);
+ ASSERT_EQ(current_num_active_cs, new_cs_order.size() - 1);
+ for (unsigned int i = 0; i < new_cs_order.size() - 1; i++) {
+ ASSERT_EQ(new_cs_order.data()[i], current_cs_order[i]);
+ }
+}
+
+// Testing Errors Controls
+TEST_F(TlsCipherOrderTest, CipherOrderSetControls) {
+ std::vector<uint16_t> initial_cs_order;
+ std::vector<uint16_t> new_cs_order;
+ SECStatus result = BuildTestLists(initial_cs_order, new_cs_order);
+ ASSERT_EQ(result, SECSuccess);
+ ConfigureTLS();
+
+ // Create a new vector with diplicated entries
+ std::vector<uint16_t> repeated_cs_order(SSL_GetNumImplementedCiphers() + 1);
+ std::copy(initial_cs_order.begin(), initial_cs_order.end(),
+ repeated_cs_order.begin());
+ repeated_cs_order[0] = repeated_cs_order[1];
+
+ // Repeated ciphersuites in the list
+ result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), repeated_cs_order.data(),
+ initial_cs_order.size());
+ EXPECT_EQ(result, SECFailure);
+
+ // Zero size for the sent list
+ result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), new_cs_order.data(), 0);
+ EXPECT_EQ(result, SECFailure);
+
+ // Wrong size, greater than actual
+ result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), new_cs_order.data(),
+ SSL_GetNumImplementedCiphers() + 1);
+ EXPECT_EQ(result, SECFailure);
+
+ // Wrong ciphersuites, not implemented
+ result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), kCSUnsupported,
+ kNumCSUnsupported);
+ EXPECT_EQ(result, SECFailure);
+
+ // Null list
+ result =
+ SSL_CipherSuiteOrderSet(client_->ssl_fd(), nullptr, new_cs_order.size());
+ EXPECT_EQ(result, SECFailure);
+
+ // Empty list
+ result =
+ SSL_CipherSuiteOrderSet(client_->ssl_fd(), kCSEmpty, new_cs_order.size());
+ EXPECT_EQ(result, SECFailure);
+
+ // Confirm that the controls are working, as the current ciphersuites
+ // remained untouched
+ std::vector<uint16_t> current_cs_order(SSL_GetNumImplementedCiphers() + 1);
+ unsigned int current_num_active_cs = 0;
+ result = SSL_CipherSuiteOrderGet(client_->ssl_fd(), current_cs_order.data(),
+ &current_num_active_cs);
+ ASSERT_EQ(result, SECSuccess);
+ ASSERT_EQ(current_num_active_cs, initial_cs_order.size());
+ for (unsigned int i = 0; i < initial_cs_order.size(); i++) {
+ ASSERT_EQ(initial_cs_order[i], current_cs_order[i]);
+ }
+}
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
new file mode 100644
index 0000000000..db0618e042
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
@@ -0,0 +1,531 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+// variant, version, cipher suite
+typedef std::tuple<SSLProtocolVariant, uint16_t, uint16_t, SSLNamedGroup,
+ SSLSignatureScheme>
+ CipherSuiteProfile;
+
+class TlsCipherSuiteTestBase : public TlsConnectTestBase {
+ public:
+ TlsCipherSuiteTestBase(SSLProtocolVariant variant, uint16_t version,
+ uint16_t cipher_suite, SSLNamedGroup group,
+ SSLSignatureScheme sig_scheme)
+ : TlsConnectTestBase(variant, version),
+ cipher_suite_(cipher_suite),
+ group_(group),
+ sig_scheme_(sig_scheme),
+ csinfo_({0}) {
+ SECStatus rv =
+ SSL_GetCipherSuiteInfo(cipher_suite_, &csinfo_, sizeof(csinfo_));
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv == SECSuccess) {
+ std::cerr << "Cipher suite: " << csinfo_.cipherSuiteName << std::endl;
+ }
+ auth_type_ = csinfo_.authType;
+ kea_type_ = csinfo_.keaType;
+ }
+
+ protected:
+ void EnableSingleCipher() {
+ EnsureTlsSetup();
+ // It doesn't matter which does this, but the test is better if both do it.
+ client_->EnableSingleCipher(cipher_suite_);
+ server_->EnableSingleCipher(cipher_suite_);
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ std::vector<SSLNamedGroup> groups = {group_};
+ if (cert_group_ != ssl_grp_none) {
+ groups.push_back(cert_group_);
+ }
+ client_->ConfigNamedGroups(groups);
+ server_->ConfigNamedGroups(groups);
+ kea_type_ = SSLInt_GetKEAType(group_);
+
+ client_->SetSignatureSchemes(&sig_scheme_, 1);
+ server_->SetSignatureSchemes(&sig_scheme_, 1);
+ }
+ }
+
+ virtual void SetupCertificate() {
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ switch (sig_scheme_) {
+ case ssl_sig_rsa_pss_rsae_sha256:
+ std::cerr << "Signature scheme: rsa_pss_rsae_sha256" << std::endl;
+ Reset(TlsAgent::kServerRsaSign);
+ auth_type_ = ssl_auth_rsa_sign;
+ break;
+ case ssl_sig_rsa_pss_rsae_sha384:
+ std::cerr << "Signature scheme: rsa_pss_rsae_sha384" << std::endl;
+ Reset(TlsAgent::kServerRsaSign);
+ auth_type_ = ssl_auth_rsa_sign;
+ break;
+ case ssl_sig_rsa_pss_rsae_sha512:
+ // You can't fit SHA-512 PSS in a 1024-bit key.
+ std::cerr << "Signature scheme: rsa_pss_rsae_sha512" << std::endl;
+ Reset(TlsAgent::kRsa2048);
+ auth_type_ = ssl_auth_rsa_sign;
+ break;
+ case ssl_sig_rsa_pss_pss_sha256:
+ std::cerr << "Signature scheme: rsa_pss_pss_sha256" << std::endl;
+ Reset(TlsAgent::kServerRsaPss);
+ auth_type_ = ssl_auth_rsa_pss;
+ break;
+ case ssl_sig_rsa_pss_pss_sha384:
+ std::cerr << "Signature scheme: rsa_pss_pss_sha384" << std::endl;
+ Reset("rsa_pss384");
+ auth_type_ = ssl_auth_rsa_pss;
+ break;
+ case ssl_sig_rsa_pss_pss_sha512:
+ std::cerr << "Signature scheme: rsa_pss_pss_sha512" << std::endl;
+ Reset("rsa_pss512");
+ auth_type_ = ssl_auth_rsa_pss;
+ break;
+ case ssl_sig_ecdsa_secp256r1_sha256:
+ std::cerr << "Signature scheme: ecdsa_secp256r1_sha256" << std::endl;
+ Reset(TlsAgent::kServerEcdsa256);
+ auth_type_ = ssl_auth_ecdsa;
+ cert_group_ = ssl_grp_ec_secp256r1;
+ break;
+ case ssl_sig_ecdsa_secp384r1_sha384:
+ std::cerr << "Signature scheme: ecdsa_secp384r1_sha384" << std::endl;
+ Reset(TlsAgent::kServerEcdsa384);
+ auth_type_ = ssl_auth_ecdsa;
+ cert_group_ = ssl_grp_ec_secp384r1;
+ break;
+ default:
+ ADD_FAILURE() << "Unsupported signature scheme: " << sig_scheme_;
+ break;
+ }
+ } else {
+ switch (csinfo_.authType) {
+ case ssl_auth_rsa_sign:
+ Reset(TlsAgent::kServerRsaSign);
+ break;
+ case ssl_auth_rsa_decrypt:
+ Reset(TlsAgent::kServerRsaDecrypt);
+ break;
+ case ssl_auth_ecdsa:
+ Reset(TlsAgent::kServerEcdsa256);
+ cert_group_ = ssl_grp_ec_secp256r1;
+ break;
+ case ssl_auth_ecdh_ecdsa:
+ Reset(TlsAgent::kServerEcdhEcdsa);
+ cert_group_ = ssl_grp_ec_secp256r1;
+ break;
+ case ssl_auth_ecdh_rsa:
+ Reset(TlsAgent::kServerEcdhRsa);
+ break;
+ case ssl_auth_dsa:
+ Reset(TlsAgent::kServerDsa);
+ break;
+ default:
+ ASSERT_TRUE(false) << "Unsupported cipher suite: " << cipher_suite_;
+ break;
+ }
+ }
+ }
+
+ void ConnectAndCheckCipherSuite() {
+ Connect();
+ SendReceive();
+
+ // Check that we used the right cipher suite, auth type and kea type.
+ uint16_t actual = TLS_NULL_WITH_NULL_NULL;
+ EXPECT_TRUE(client_->cipher_suite(&actual));
+ EXPECT_EQ(cipher_suite_, actual);
+ EXPECT_TRUE(server_->cipher_suite(&actual));
+ EXPECT_EQ(cipher_suite_, actual);
+ SSLAuthType auth = ssl_auth_size;
+ EXPECT_TRUE(client_->auth_type(&auth));
+ EXPECT_EQ(auth_type_, auth);
+ EXPECT_TRUE(server_->auth_type(&auth));
+ EXPECT_EQ(auth_type_, auth);
+ SSLKEAType kea = ssl_kea_size;
+ EXPECT_TRUE(client_->kea_type(&kea));
+ EXPECT_EQ(kea_type_, kea);
+ EXPECT_TRUE(server_->kea_type(&kea));
+ EXPECT_EQ(kea_type_, kea);
+ }
+
+ // Get the expected limit on the number of records that can be sent for the
+ // cipher suite.
+ uint64_t record_limit() const {
+ switch (csinfo_.symCipher) {
+ case ssl_calg_rc4:
+ case ssl_calg_3des:
+ return 1ULL << 20;
+ case ssl_calg_aes:
+ case ssl_calg_aes_gcm:
+ return 0x5aULL << 28;
+ case ssl_calg_null:
+ case ssl_calg_chacha20:
+ return (1ULL << 48) - 1;
+ case ssl_calg_rc2:
+ case ssl_calg_des:
+ case ssl_calg_idea:
+ case ssl_calg_fortezza:
+ case ssl_calg_camellia:
+ case ssl_calg_seed:
+ break;
+ }
+ ADD_FAILURE() << "No limit for " << csinfo_.cipherSuiteName;
+ return 0;
+ }
+
+ uint64_t last_safe_write() const {
+ uint64_t limit = record_limit() - 1;
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_1 &&
+ (csinfo_.symCipher == ssl_calg_3des ||
+ csinfo_.symCipher == ssl_calg_aes)) {
+ // 1/n-1 record splitting needs space for two records.
+ limit--;
+ }
+ return limit;
+ }
+
+ protected:
+ uint16_t cipher_suite_;
+ SSLAuthType auth_type_;
+ SSLKEAType kea_type_;
+ SSLNamedGroup group_;
+ SSLNamedGroup cert_group_ = ssl_grp_none;
+ SSLSignatureScheme sig_scheme_;
+ SSLCipherSuiteInfo csinfo_;
+};
+
+class TlsCipherSuiteTest
+ : public TlsCipherSuiteTestBase,
+ public ::testing::WithParamInterface<CipherSuiteProfile> {
+ public:
+ TlsCipherSuiteTest()
+ : TlsCipherSuiteTestBase(std::get<0>(GetParam()), std::get<1>(GetParam()),
+ std::get<2>(GetParam()), std::get<3>(GetParam()),
+ std::get<4>(GetParam())) {}
+
+ protected:
+ bool SkipIfCipherSuiteIsDSA() {
+ bool isDSA = csinfo_.authType == ssl_auth_dsa;
+ if (isDSA) {
+ std::cerr << "Skipping DSA suite: " << csinfo_.cipherSuiteName
+ << std::endl;
+ }
+ return isDSA;
+ }
+};
+
+TEST_P(TlsCipherSuiteTest, SingleCipherSuite) {
+ SetupCertificate();
+ EnableSingleCipher();
+ ConnectAndCheckCipherSuite();
+}
+
+TEST_P(TlsCipherSuiteTest, ResumeCipherSuite) {
+ if (SkipIfCipherSuiteIsDSA()) {
+ GTEST_SKIP() << "Tickets not supported with DSA (bug 1174677).";
+ }
+
+ SetupCertificate(); // This is only needed once.
+
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ EnableSingleCipher();
+
+ ConnectAndCheckCipherSuite();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ EnableSingleCipher();
+ ExpectResumption(RESUME_TICKET);
+ ConnectAndCheckCipherSuite();
+}
+
+TEST_P(TlsCipherSuiteTest, ReadLimit) {
+ SetupCertificate();
+ EnableSingleCipher();
+ TlsSendCipherSpecCapturer capturer(client_);
+ ConnectAndCheckCipherSuite();
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ uint64_t last = last_safe_write();
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last));
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last));
+
+ client_->SendData(10, 10);
+ server_->ReadBytes(); // This should be OK.
+ server_->ReadBytes(); // Read twice to flush any 1,N-1 record splitting.
+ } else {
+ // In TLS 1.3, reading or writing triggers a KeyUpdate. That would mean
+ // that the sequence numbers would reset and we wouldn't hit the limit. So
+ // move the sequence number to the limit directly and don't test sending and
+ // receiving just before the limit.
+ uint64_t last = record_limit();
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last));
+ }
+
+ // The payload needs to be big enough to pass for encrypted. The code checks
+ // the limit before it tries to decrypt.
+ static const uint8_t payload[32] = {6};
+ DataBuffer record;
+ uint64_t epoch;
+ if (variant_ == ssl_variant_datagram) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
+ epoch = 3; // Application traffic keys.
+ } else {
+ epoch = 1;
+ }
+ } else {
+ epoch = 0;
+ }
+
+ uint64_t seqno = (epoch << 48) | record_limit();
+
+ // DTLS 1.3 masks the sequence number
+ if (variant_ == ssl_variant_datagram &&
+ version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ auto spec = capturer.spec(1);
+ ASSERT_NE(nullptr, spec.get());
+ ASSERT_EQ(3, spec->epoch());
+
+ DataBuffer pt, ct;
+ uint8_t dtls13_ctype = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
+ kCtDtlsCiphertextLengthPresent;
+ TlsRecordHeader hdr(variant_, version_, dtls13_ctype, seqno);
+ pt.Assign(payload, sizeof(payload));
+ TlsRecordHeader out_hdr;
+ spec->Protect(hdr, pt, &ct, &out_hdr);
+
+ auto rv = out_hdr.Write(&record, 0, ct);
+ EXPECT_EQ(out_hdr.header_length() + ct.len(), rv);
+ } else {
+ TlsAgentTestBase::MakeRecord(variant_, ssl_ct_application_data, version_,
+ payload, sizeof(payload), &record, seqno);
+ }
+
+ client_->SendDirect(record);
+ server_->ExpectReadWriteError();
+ server_->ReadBytes();
+ EXPECT_EQ(SSL_ERROR_TOO_MANY_RECORDS, server_->error_code());
+}
+
+TEST_P(TlsCipherSuiteTest, WriteLimit) {
+ // This asserts in TLS 1.3 because we expect an automatic update.
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ GTEST_SKIP();
+ }
+ SetupCertificate();
+ EnableSingleCipher();
+ ConnectAndCheckCipherSuite();
+ EXPECT_EQ(SECSuccess,
+ SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last_safe_write()));
+ client_->SendData(10, 10);
+ client_->ExpectReadWriteError();
+ client_->SendData(10, 10);
+ EXPECT_EQ(SSL_ERROR_TOO_MANY_RECORDS, client_->error_code());
+}
+
+// This awful macro makes the test instantiations easier to read.
+#define INSTANTIATE_CIPHER_TEST_P(name, modes, versions, groups, sigalgs, ...) \
+ static const uint16_t k##name##CiphersArr[] = {__VA_ARGS__}; \
+ static const ::testing::internal::ParamGenerator<uint16_t> \
+ k##name##Ciphers = ::testing::ValuesIn(k##name##CiphersArr); \
+ INSTANTIATE_TEST_SUITE_P( \
+ CipherSuite##name, TlsCipherSuiteTest, \
+ ::testing::Combine(TlsConnectTestBase::kTlsVariants##modes, \
+ TlsConnectTestBase::kTls##versions, k##name##Ciphers, \
+ groups, sigalgs));
+
+static const auto kDummyNamedGroupParams = ::testing::Values(ssl_grp_none);
+static const auto kDummySignatureSchemesParams =
+ ::testing::Values(ssl_sig_none);
+
+static SSLSignatureScheme kSignatureSchemesParamsArr[] = {
+ ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384,
+ ssl_sig_rsa_pkcs1_sha512, ssl_sig_ecdsa_secp256r1_sha256,
+ ssl_sig_ecdsa_secp384r1_sha384, ssl_sig_rsa_pss_rsae_sha256,
+ ssl_sig_rsa_pss_rsae_sha384, ssl_sig_rsa_pss_rsae_sha512,
+ ssl_sig_rsa_pss_pss_sha256, ssl_sig_rsa_pss_pss_sha384,
+ ssl_sig_rsa_pss_pss_sha512};
+
+static SSLSignatureScheme kSignatureSchemesParamsArrTls13[] = {
+ ssl_sig_ecdsa_secp256r1_sha256, ssl_sig_ecdsa_secp384r1_sha384,
+ ssl_sig_rsa_pss_rsae_sha256, ssl_sig_rsa_pss_rsae_sha384,
+ ssl_sig_rsa_pss_rsae_sha512, ssl_sig_rsa_pss_pss_sha256,
+ ssl_sig_rsa_pss_pss_sha384, ssl_sig_rsa_pss_pss_sha512};
+
+INSTANTIATE_CIPHER_TEST_P(RC4, Stream, V10ToV12, kDummyNamedGroupParams,
+ kDummySignatureSchemesParams,
+ TLS_RSA_WITH_RC4_128_SHA,
+ TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
+ TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
+ TLS_ECDH_RSA_WITH_RC4_128_SHA,
+ TLS_ECDHE_RSA_WITH_RC4_128_SHA);
+INSTANTIATE_CIPHER_TEST_P(AEAD12, All, V12, kDummyNamedGroupParams,
+ kDummySignatureSchemesParams,
+ TLS_RSA_WITH_AES_128_GCM_SHA256,
+ TLS_RSA_WITH_AES_256_GCM_SHA384,
+ TLS_DHE_DSS_WITH_AES_128_GCM_SHA256,
+ TLS_DHE_DSS_WITH_AES_256_GCM_SHA384,
+ TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384,
+ TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384);
+INSTANTIATE_CIPHER_TEST_P(AEAD, All, V12, kDummyNamedGroupParams,
+ kDummySignatureSchemesParams,
+ TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+ TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
+ TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+ TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
+ TLS_DHE_RSA_WITH_AES_128_GCM_SHA256,
+ TLS_DHE_RSA_WITH_AES_256_GCM_SHA384,
+ TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
+ TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
+ TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256);
+INSTANTIATE_CIPHER_TEST_P(
+ CBC12, All, V12, kDummyNamedGroupParams, kDummySignatureSchemesParams,
+ TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, TLS_RSA_WITH_AES_256_CBC_SHA256,
+ TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
+ TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
+ TLS_RSA_WITH_AES_128_CBC_SHA256, TLS_DHE_DSS_WITH_AES_128_CBC_SHA256,
+ TLS_DHE_DSS_WITH_AES_256_CBC_SHA256);
+INSTANTIATE_CIPHER_TEST_P(
+ CBCStream, Stream, V10ToV12, kDummyNamedGroupParams,
+ kDummySignatureSchemesParams, TLS_ECDH_ECDSA_WITH_NULL_SHA,
+ TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_ECDSA_WITH_NULL_SHA,
+ TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDH_RSA_WITH_NULL_SHA,
+ TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDH_RSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDH_RSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_NULL_SHA,
+ TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA);
+INSTANTIATE_CIPHER_TEST_P(
+ CBCDatagram, Datagram, V11V12, kDummyNamedGroupParams,
+ kDummySignatureSchemesParams, TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA,
+ TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
+ TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA,
+ TLS_ECDH_RSA_WITH_AES_128_CBC_SHA, TLS_ECDH_RSA_WITH_AES_256_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA);
+INSTANTIATE_CIPHER_TEST_P(
+ TLS12SigSchemes, All, V12, ::testing::ValuesIn(kFasterDHEGroups),
+ ::testing::ValuesIn(kSignatureSchemesParamsArr),
+ TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, TLS_RSA_WITH_AES_256_CBC_SHA256,
+ TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
+ TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
+ TLS_RSA_WITH_AES_128_CBC_SHA256, TLS_DHE_DSS_WITH_AES_128_CBC_SHA256,
+ TLS_DHE_DSS_WITH_AES_256_CBC_SHA256);
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_CIPHER_TEST_P(TLS13, All, V13,
+ ::testing::ValuesIn(kFasterDHEGroups),
+ ::testing::ValuesIn(kSignatureSchemesParamsArrTls13),
+ TLS_AES_128_GCM_SHA256, TLS_CHACHA20_POLY1305_SHA256,
+ TLS_AES_256_GCM_SHA384);
+INSTANTIATE_CIPHER_TEST_P(TLS13AllGroups, All, V13,
+ ::testing::ValuesIn(kAllDHEGroups),
+ ::testing::Values(ssl_sig_ecdsa_secp384r1_sha384),
+ TLS_AES_256_GCM_SHA384);
+#endif
+
+// Fields are: version, cipher suite, bulk cipher name, secretKeySize
+struct SecStatusParams {
+ uint16_t version;
+ uint16_t cipher_suite;
+ std::string name;
+ int keySize;
+};
+
+inline std::ostream &operator<<(std::ostream &stream,
+ const SecStatusParams &vals) {
+ SSLCipherSuiteInfo csinfo;
+ SECStatus rv =
+ SSL_GetCipherSuiteInfo(vals.cipher_suite, &csinfo, sizeof(csinfo));
+ if (rv != SECSuccess) {
+ return stream << "Error invoking SSL_GetCipherSuiteInfo()";
+ }
+
+ return stream << "TLS " << VersionString(vals.version) << ", "
+ << csinfo.cipherSuiteName << ", name = \"" << vals.name
+ << "\", key size = " << vals.keySize;
+}
+
+class SecurityStatusTest
+ : public TlsCipherSuiteTestBase,
+ public ::testing::WithParamInterface<SecStatusParams> {
+ public:
+ SecurityStatusTest()
+ : TlsCipherSuiteTestBase(ssl_variant_stream, GetParam().version,
+ GetParam().cipher_suite, ssl_grp_none,
+ ssl_sig_none) {}
+};
+
+// SSL_SecurityStatus produces fairly useless output when compared to
+// SSL_GetCipherSuiteInfo and SSL_GetChannelInfo, but we can't break it, so we
+// need to check it.
+TEST_P(SecurityStatusTest, CheckSecurityStatus) {
+ SetupCertificate();
+ EnableSingleCipher();
+ ConnectAndCheckCipherSuite();
+
+ int on;
+ char *cipher;
+ int keySize;
+ int secretKeySize;
+ char *issuer;
+ char *subject;
+ EXPECT_EQ(SECSuccess,
+ SSL_SecurityStatus(client_->ssl_fd(), &on, &cipher, &keySize,
+ &secretKeySize, &issuer, &subject));
+ if (std::string(cipher) == "NULL") {
+ EXPECT_EQ(0, on);
+ } else {
+ EXPECT_NE(0, on);
+ }
+ EXPECT_EQ(GetParam().name, std::string(cipher));
+ // All the ciphers we support have secret key size == key size.
+ EXPECT_EQ(GetParam().keySize, keySize);
+ EXPECT_EQ(GetParam().keySize, secretKeySize);
+ EXPECT_LT(0U, strlen(issuer));
+ EXPECT_LT(0U, strlen(subject));
+
+ PORT_Free(cipher);
+ PORT_Free(issuer);
+ PORT_Free(subject);
+}
+
+static const SecStatusParams kSecStatusTestValuesArr[] = {
+ {SSL_LIBRARY_VERSION_TLS_1_0, TLS_ECDHE_RSA_WITH_NULL_SHA, "NULL", 0},
+ {SSL_LIBRARY_VERSION_TLS_1_0, TLS_RSA_WITH_RC4_128_SHA, "RC4", 128},
+ {SSL_LIBRARY_VERSION_TLS_1_0, TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
+ "3DES-EDE-CBC", 168},
+ {SSL_LIBRARY_VERSION_TLS_1_0, TLS_RSA_WITH_AES_128_CBC_SHA, "AES-128", 128},
+ {SSL_LIBRARY_VERSION_TLS_1_2, TLS_RSA_WITH_AES_256_CBC_SHA256, "AES-256",
+ 256},
+ {SSL_LIBRARY_VERSION_TLS_1_2, TLS_RSA_WITH_AES_128_GCM_SHA256,
+ "AES-128-GCM", 128},
+ {SSL_LIBRARY_VERSION_TLS_1_2, TLS_RSA_WITH_AES_256_GCM_SHA384,
+ "AES-256-GCM", 256},
+ {SSL_LIBRARY_VERSION_TLS_1_2, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
+ "ChaCha20-Poly1305", 256}};
+INSTANTIATE_TEST_SUITE_P(TestSecurityStatus, SecurityStatusTest,
+ ::testing::ValuesIn(kSecStatusTestValuesArr));
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc
new file mode 100644
index 0000000000..7ed0e5d934
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc
@@ -0,0 +1,499 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ssl.h"
+#include "ssl3prot.h"
+#include "sslerr.h"
+#include "sslproto.h"
+#include "sslexp.h"
+
+#include <memory>
+
+#include "tls_connect.h"
+
+namespace nss_test {
+
+static void IncrementCounterArg(void *arg) {
+ if (arg) {
+ auto *called = reinterpret_cast<size_t *>(arg);
+ ++*called;
+ }
+}
+
+static PRBool NoopExtensionWriter(PRFileDesc *fd, SSLHandshakeType message,
+ PRUint8 *data, unsigned int *len,
+ unsigned int maxLen, void *arg) {
+ IncrementCounterArg(arg);
+ return PR_FALSE;
+}
+
+static PRBool EmptyExtensionWriter(PRFileDesc *fd, SSLHandshakeType message,
+ PRUint8 *data, unsigned int *len,
+ unsigned int maxLen, void *arg) {
+ IncrementCounterArg(arg);
+ return PR_TRUE;
+}
+
+static SECStatus NoopExtensionHandler(PRFileDesc *fd, SSLHandshakeType message,
+ const PRUint8 *data, unsigned int len,
+ SSLAlertDescription *alert, void *arg) {
+ return SECSuccess;
+}
+
+// All of the (current) set of supported extensions, plus a few extra.
+static const uint16_t kManyExtensions[] = {
+ ssl_server_name_xtn,
+ ssl_cert_status_xtn,
+ ssl_supported_groups_xtn,
+ ssl_ec_point_formats_xtn,
+ ssl_signature_algorithms_xtn,
+ ssl_signature_algorithms_cert_xtn,
+ ssl_use_srtp_xtn,
+ ssl_app_layer_protocol_xtn,
+ ssl_signed_cert_timestamp_xtn,
+ ssl_padding_xtn,
+ ssl_extended_master_secret_xtn,
+ ssl_session_ticket_xtn,
+ ssl_tls13_key_share_xtn,
+ ssl_tls13_pre_shared_key_xtn,
+ ssl_tls13_early_data_xtn,
+ ssl_tls13_supported_versions_xtn,
+ ssl_tls13_cookie_xtn,
+ ssl_tls13_psk_key_exchange_modes_xtn,
+ ssl_tls13_ticket_early_data_info_xtn,
+ ssl_tls13_certificate_authorities_xtn,
+ ssl_next_proto_nego_xtn,
+ ssl_renegotiation_info_xtn,
+ ssl_record_size_limit_xtn,
+ ssl_tls13_encrypted_client_hello_xtn,
+ 1,
+ 0xffff};
+// The list here includes all extensions we expect to use (SSL_MAX_EXTENSIONS),
+// plus the deprecated values (see sslt.h), and two extra dummy values.
+PR_STATIC_ASSERT((SSL_MAX_EXTENSIONS + 5) == PR_ARRAY_SIZE(kManyExtensions));
+
+void InstallManyWriters(std::shared_ptr<TlsAgent> agent,
+ SSLExtensionWriter writer, size_t *installed = nullptr,
+ size_t *called = nullptr) {
+ for (size_t i = 0; i < PR_ARRAY_SIZE(kManyExtensions); ++i) {
+ SSLExtensionSupport support = ssl_ext_none;
+ SECStatus rv = SSL_GetExtensionSupport(kManyExtensions[i], &support);
+ ASSERT_EQ(SECSuccess, rv) << "SSL_GetExtensionSupport cannot fail";
+
+ rv = SSL_InstallExtensionHooks(agent->ssl_fd(), kManyExtensions[i], writer,
+ called, NoopExtensionHandler, nullptr);
+ if (support == ssl_ext_native_only) {
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ } else {
+ if (installed) {
+ ++*installed;
+ }
+ EXPECT_EQ(SECSuccess, rv);
+ }
+ }
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionAllNoopClient) {
+ EnsureTlsSetup();
+ size_t installed = 0;
+ size_t called = 0;
+ InstallManyWriters(client_, NoopExtensionWriter, &installed, &called);
+ EXPECT_LT(0U, installed);
+ Connect();
+ EXPECT_EQ(installed, called);
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionAllNoopServer) {
+ EnsureTlsSetup();
+ size_t installed = 0;
+ size_t called = 0;
+ InstallManyWriters(server_, NoopExtensionWriter, &installed, &called);
+ EXPECT_LT(0U, installed);
+ Connect();
+ // Extension writers are all called for each of ServerHello,
+ // EncryptedExtensions, and Certificate.
+ EXPECT_EQ(installed * 3, called);
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionEmptyWriterClient) {
+ EnsureTlsSetup();
+ InstallManyWriters(client_, EmptyExtensionWriter);
+ InstallManyWriters(server_, EmptyExtensionWriter);
+ Connect();
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionEmptyWriterServer) {
+ EnsureTlsSetup();
+ InstallManyWriters(server_, EmptyExtensionWriter);
+ // Sending extensions that the client doesn't expect leads to extensions
+ // appearing even if the client didn't send one, or in the wrong messages.
+ client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+}
+
+// Install an writer to disable sending of a natively-supported extension.
+TEST_F(TlsConnectStreamTls13, CustomExtensionWriterDisable) {
+ EnsureTlsSetup();
+
+ // This option enables sending the extension via the native support.
+ SECStatus rv = SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // This installs an override that doesn't do anything. You have to specify
+ // something; passing all nullptr values removes an existing handler.
+ rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NoopExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_signed_cert_timestamp_xtn);
+
+ Connect();
+ // So nothing will be sent.
+ EXPECT_FALSE(capture->captured());
+}
+
+// An extension that is unlikely to be parsed as valid.
+static uint8_t kNonsenseExtension[] = {91, 82, 73, 64, 55, 46, 37, 28, 19};
+
+static PRBool NonsenseExtensionWriter(PRFileDesc *fd, SSLHandshakeType message,
+ PRUint8 *data, unsigned int *len,
+ unsigned int maxLen, void *arg) {
+ TlsAgent *agent = reinterpret_cast<TlsAgent *>(arg);
+ EXPECT_NE(nullptr, agent);
+ EXPECT_NE(nullptr, data);
+ EXPECT_NE(nullptr, len);
+ EXPECT_EQ(0U, *len);
+ EXPECT_LT(0U, maxLen);
+ EXPECT_EQ(agent->ssl_fd(), fd);
+
+ if (message != ssl_hs_client_hello && message != ssl_hs_server_hello &&
+ message != ssl_hs_encrypted_extensions) {
+ return PR_FALSE;
+ }
+
+ *len = static_cast<unsigned int>(sizeof(kNonsenseExtension));
+ EXPECT_GE(maxLen, *len);
+ if (maxLen < *len) {
+ return PR_FALSE;
+ }
+ PORT_Memcpy(data, kNonsenseExtension, *len);
+ return PR_TRUE;
+}
+
+// Override the extension handler for an natively-supported and produce
+// nonsense, which results in a handshake failure.
+TEST_F(TlsConnectStreamTls13, CustomExtensionOverride) {
+ EnsureTlsSetup();
+
+ // This option enables sending the extension via the native support.
+ SECStatus rv = SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // This installs an override that sends nonsense.
+ rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NonsenseExtensionWriter,
+ client_.get(), NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture it to see what we got.
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_signed_cert_timestamp_xtn);
+
+ ConnectExpectAlert(server_, kTlsAlertDecodeError);
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+static SECStatus NonsenseExtensionHandler(PRFileDesc *fd,
+ SSLHandshakeType message,
+ const PRUint8 *data, unsigned int len,
+ SSLAlertDescription *alert,
+ void *arg) {
+ TlsAgent *agent = reinterpret_cast<TlsAgent *>(arg);
+ EXPECT_EQ(agent->ssl_fd(), fd);
+ if (agent->role() == TlsAgent::SERVER) {
+ EXPECT_EQ(ssl_hs_client_hello, message);
+ } else {
+ EXPECT_TRUE(message == ssl_hs_server_hello ||
+ message == ssl_hs_encrypted_extensions);
+ }
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ DataBuffer(data, len));
+ EXPECT_NE(nullptr, alert);
+ return SECSuccess;
+}
+
+// Send nonsense in an extension from client to server.
+TEST_F(TlsConnectStreamTls13, CustomExtensionClientToServer) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nonsense.
+ const uint16_t extension_code = 0xffe5;
+ SECStatus rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), extension_code, NonsenseExtensionWriter, client_.get(),
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture it to see what we got.
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, extension_code);
+
+ // Handle it so that the handshake completes.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NoopExtensionWriter, nullptr,
+ NonsenseExtensionHandler, server_.get());
+ EXPECT_EQ(SECSuccess, rv);
+
+ Connect();
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+static PRBool NonsenseExtensionWriterSH(PRFileDesc *fd,
+ SSLHandshakeType message, PRUint8 *data,
+ unsigned int *len, unsigned int maxLen,
+ void *arg) {
+ if (message == ssl_hs_server_hello) {
+ return NonsenseExtensionWriter(fd, message, data, len, maxLen, arg);
+ }
+ return PR_FALSE;
+}
+
+// Send nonsense in an extension from server to client, in ServerHello.
+TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientSH) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ const uint16_t extension_code = 0xff5e;
+ SECStatus rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), extension_code, EmptyExtensionWriter, nullptr,
+ NonsenseExtensionHandler, client_.get());
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Have the server send nonsense.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NonsenseExtensionWriterSH, server_.get(),
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture the extension from the ServerHello only and check it.
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(server_, extension_code);
+ capture->SetHandshakeTypes({kTlsHandshakeServerHello});
+
+ Connect();
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+static PRBool NonsenseExtensionWriterEE(PRFileDesc *fd,
+ SSLHandshakeType message, PRUint8 *data,
+ unsigned int *len, unsigned int maxLen,
+ void *arg) {
+ if (message == ssl_hs_encrypted_extensions) {
+ return NonsenseExtensionWriter(fd, message, data, len, maxLen, arg);
+ }
+ return PR_FALSE;
+}
+
+// Send nonsense in an extension from server to client, in EncryptedExtensions.
+TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientEE) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ const uint16_t extension_code = 0xff5e;
+ SECStatus rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), extension_code, EmptyExtensionWriter, nullptr,
+ NonsenseExtensionHandler, client_.get());
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Have the server send nonsense.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NonsenseExtensionWriterEE, server_.get(),
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture the extension from the EncryptedExtensions only and check it.
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(server_, extension_code);
+ capture->SetHandshakeTypes({kTlsHandshakeEncryptedExtensions});
+ capture->EnableDecryption();
+
+ Connect();
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionUnsolicitedServer) {
+ EnsureTlsSetup();
+
+ const uint16_t extension_code = 0xff5e;
+ SECStatus rv = SSL_InstallExtensionHooks(
+ server_->ssl_fd(), extension_code, NonsenseExtensionWriter, server_.get(),
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture it to see what we got.
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(server_, extension_code);
+
+ client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+SECStatus RejectExtensionHandler(PRFileDesc *fd, SSLHandshakeType message,
+ const PRUint8 *data, unsigned int len,
+ SSLAlertDescription *alert, void *arg) {
+ return SECFailure;
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionServerReject) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nonsense.
+ const uint16_t extension_code = 0xffe7;
+ SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Reject the extension for no good reason.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NoopExtensionWriter, nullptr,
+ RejectExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+}
+
+// Send nonsense in an extension from client to server.
+TEST_F(TlsConnectStreamTls13, CustomExtensionClientReject) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ const uint16_t extension_code = 0xff58;
+ SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ RejectExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Have the server send nonsense.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ client_->ExpectSendAlert(kTlsAlertHandshakeFailure);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+}
+
+static const uint8_t kCustomAlert = 0xf6;
+
+SECStatus AlertExtensionHandler(PRFileDesc *fd, SSLHandshakeType message,
+ const PRUint8 *data, unsigned int len,
+ SSLAlertDescription *alert, void *arg) {
+ *alert = kCustomAlert;
+ return SECFailure;
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionServerRejectAlert) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nonsense.
+ const uint16_t extension_code = 0xffea;
+ SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Reject the extension for no good reason.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NoopExtensionWriter, nullptr,
+ AlertExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ ConnectExpectAlert(server_, kCustomAlert);
+}
+
+// Send nonsense in an extension from client to server.
+TEST_F(TlsConnectStreamTls13, CustomExtensionClientRejectAlert) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ const uint16_t extension_code = 0xff5a;
+ SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ AlertExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Have the server send nonsense.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ client_->ExpectSendAlert(kCustomAlert);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+}
+
+// Configure a custom extension hook badly.
+TEST_F(TlsConnectStreamTls13, CustomExtensionOnlyWriter) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ SECStatus rv =
+ SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff6c, EmptyExtensionWriter,
+ nullptr, nullptr, nullptr);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionOnlyHandler) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ SECStatus rv =
+ SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff6d, nullptr, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionOverrunBuffer) {
+ EnsureTlsSetup();
+ // This doesn't actually overrun the buffer, but it says that it does.
+ auto overrun_writer = [](PRFileDesc *fd, SSLHandshakeType message,
+ PRUint8 *data, unsigned int *len,
+ unsigned int maxLen, void *arg) -> PRBool {
+ *len = maxLen + 1;
+ return PR_TRUE;
+ };
+ SECStatus rv =
+ SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff71, overrun_writer,
+ nullptr, NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+ client_->StartConnect();
+ client_->Handshake();
+ client_->CheckErrorCode(SEC_ERROR_APPLICATION_CALLBACK_ERROR);
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc
new file mode 100644
index 0000000000..9cbe9566f1
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc
@@ -0,0 +1,104 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_F(TlsConnectTest, DamageSecretHandleClientFinished) {
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ std::cerr << "Damaging HS secret" << std::endl;
+ SSLInt_DamageClientHsTrafficSecret(server_->ssl_fd());
+ client_->Handshake();
+ // The client thinks it has connected.
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+
+ ExpectAlert(server_, kTlsAlertDecryptError);
+ server_->Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+ client_->Handshake();
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+}
+
+TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) {
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ MakeTlsFilter<AfterRecordN>(
+ server_, client_,
+ 0, // ServerHello.
+ [this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); });
+ ConnectExpectAlert(client_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+}
+
+TEST_P(TlsConnectGenericPre13, DamageServerSignature) {
+ EnsureTlsSetup();
+ auto filter = MakeTlsFilter<TlsLastByteDamager>(
+ server_, kTlsHandshakeServerKeyExchange);
+ ExpectAlert(client_, kTlsAlertDecryptError);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
+ server_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+}
+
+TEST_P(TlsConnectTls13, DamageServerSignature) {
+ EnsureTlsSetup();
+ auto filter = MakeTlsFilter<TlsLastByteDamager>(
+ server_, kTlsHandshakeCertificateVerify);
+ filter->EnableDecryption();
+ ConnectExpectAlert(client_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
+}
+
+TEST_P(TlsConnectGeneric, DamageClientSignature) {
+ EnsureTlsSetup();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ auto filter = MakeTlsFilter<TlsLastByteDamager>(
+ client_, kTlsHandshakeCertificateVerify);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ filter->EnableDecryption();
+ }
+ server_->ExpectSendAlert(kTlsAlertDecryptError);
+ // Do these handshakes by hand to avoid race condition on
+ // the client processing the server's alert.
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ client_->Handshake();
+ server_->Handshake();
+ EXPECT_EQ(version_ >= SSL_LIBRARY_VERSION_TLS_1_3
+ ? TlsAgent::STATE_CONNECTED
+ : TlsAgent::STATE_CONNECTING,
+ client_->state());
+ server_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_debug_env_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_debug_env_unittest.cc
new file mode 100644
index 0000000000..77b4d69afc
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_debug_env_unittest.cc
@@ -0,0 +1,51 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <cstdlib>
+#include <fstream>
+#include <sstream>
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+extern "C" {
+extern FILE* ssl_trace_iob;
+
+#ifdef NSS_ALLOW_SSLKEYLOGFILE
+extern FILE* ssl_keylog_iob;
+#endif
+}
+
+// These tests ensure that when the associated environment variables are unset
+// that the lazily-initialized defaults are what they are supposed to be.
+
+#ifdef DEBUG
+TEST_P(TlsConnectGeneric, DebugEnvTraceFileNotSet) {
+ char* ev = PR_GetEnvSecure("SSLDEBUGFILE");
+ if (ev && ev[0]) {
+ GTEST_SKIP();
+ }
+
+ Connect();
+ EXPECT_EQ(stderr, ssl_trace_iob);
+}
+#endif
+
+#ifdef NSS_ALLOW_SSLKEYLOGFILE
+TEST_P(TlsConnectGeneric, DebugEnvKeylogFileNotSet) {
+ char* ev = PR_GetEnvSecure("SSLKEYLOGFILE");
+ if (ev && ev[0]) {
+ GTEST_SKIP();
+ }
+
+ Connect();
+ EXPECT_EQ(nullptr, ssl_keylog_iob);
+}
+#endif
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc
new file mode 100644
index 0000000000..09beb2a6d9
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc
@@ -0,0 +1,802 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include <set>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectGeneric, ConnectDhe) {
+ EnableOnlyDheCiphers();
+ Connect();
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+}
+
+TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) {
+ EnsureTlsSetup();
+ client_->ConfigNamedGroups(kAllDHEGroups);
+
+ auto groups_capture =
+ std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
+ auto shares_capture =
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
+ std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture,
+ shares_capture};
+ client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures));
+
+ Connect();
+
+ CheckKeys();
+
+ bool ec, dh;
+ auto track_group_type = [&ec, &dh](SSLNamedGroup group) {
+ if ((group & 0xff00U) == 0x100U) {
+ dh = true;
+ } else {
+ ec = true;
+ }
+ };
+ CheckGroups(groups_capture->extension(), track_group_type);
+ CheckShares(shares_capture->extension(), track_group_type);
+ EXPECT_TRUE(ec) << "Should include an EC group and share";
+ EXPECT_TRUE(dh) << "Should include an FFDHE group and share";
+}
+
+TEST_P(TlsConnectGeneric, ConnectFfdheClient) {
+ EnableOnlyDheCiphers();
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
+ auto groups_capture =
+ std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
+ auto shares_capture =
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
+ std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture,
+ shares_capture};
+ client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures));
+
+ Connect();
+
+ CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign);
+ auto is_ffdhe = [](SSLNamedGroup group) {
+ // The group has to be in this range.
+ EXPECT_LE(ssl_grp_ffdhe_2048, group);
+ EXPECT_GE(ssl_grp_ffdhe_8192, group);
+ };
+ CheckGroups(groups_capture->extension(), is_ffdhe);
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
+ CheckShares(shares_capture->extension(), is_ffdhe);
+ } else {
+ EXPECT_EQ(0U, shares_capture->extension().len());
+ }
+}
+
+// Requiring the FFDHE extension on the server alone means that clients won't be
+// able to connect using a DHE suite. They should still connect in TLS 1.3,
+// because the client automatically sends the supported groups extension.
+TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) {
+ EnableOnlyDheCiphers();
+ server_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ Connect();
+ CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign);
+ } else {
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ }
+}
+
+class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter {
+ public:
+ TlsDheServerKeyExchangeDamager(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeServerKeyExchange}) {}
+ virtual PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) {
+ // Damage the first octet of dh_p. Anything other than the known prime will
+ // be rejected as "weak" when we have SSL_REQUIRE_DH_NAMED_GROUPS enabled.
+ *output = input;
+ output->data()[3] ^= 73;
+ return CHANGE;
+ }
+};
+
+// Changing the prime in the server's key share results in an error. This will
+// invalidate the signature over the ServerKeyShare. That's ok, NSS won't check
+// the signature until everything else has been checked.
+TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) {
+ EnableOnlyDheCiphers();
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
+ MakeTlsFilter<TlsDheServerKeyExchangeDamager>(server_);
+
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+
+ client_->CheckErrorCode(SSL_ERROR_WEAK_SERVER_EPHEMERAL_DH_KEY);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+class TlsDheSkeChangeY : public TlsHandshakeFilter {
+ public:
+ enum ChangeYTo {
+ kYZero,
+ kYOne,
+ kYPMinusOne,
+ kYGreaterThanP,
+ kYTooLarge,
+ kYZeroPad
+ };
+
+ TlsDheSkeChangeY(const std::shared_ptr<TlsAgent>& a, uint8_t handshake_type,
+ ChangeYTo change)
+ : TlsHandshakeFilter(a, {handshake_type}), change_Y_(change) {}
+
+ protected:
+ void ChangeY(const DataBuffer& input, DataBuffer* output, size_t offset,
+ const DataBuffer& prime) {
+ static const uint8_t kExtraZero = 0;
+ static const uint8_t kTooLargeExtra = 1;
+
+ uint32_t dh_Ys_len;
+ EXPECT_TRUE(input.Read(offset, 2, &dh_Ys_len));
+ EXPECT_LT(offset + dh_Ys_len, input.len());
+ offset += 2;
+
+ // This isn't generally true, but our code pads.
+ EXPECT_EQ(prime.len(), dh_Ys_len)
+ << "Length of dh_Ys must equal length of dh_p";
+
+ *output = input;
+ switch (change_Y_) {
+ case kYZero:
+ memset(output->data() + offset, 0, prime.len());
+ break;
+
+ case kYOne:
+ memset(output->data() + offset, 0, prime.len() - 1);
+ output->Write(offset + prime.len() - 1, 1U, 1);
+ break;
+
+ case kYPMinusOne:
+ output->Write(offset, prime);
+ EXPECT_TRUE(output->data()[offset + prime.len() - 1] & 0x01)
+ << "P must at least be odd";
+ --output->data()[offset + prime.len() - 1];
+ break;
+
+ case kYGreaterThanP:
+ // Set the first 32 octets of Y to 0xff, except the first which we set
+ // to p[0]. This will make Y > p. That is, unless p is Mersenne, or
+ // improbably large (but still the same bit length). We currently only
+ // use a fixed prime that isn't a problem for this code.
+ EXPECT_LT(0, prime.data()[0]) << "dh_p should not be zero-padded";
+ offset = output->Write(offset, prime.data()[0], 1);
+ memset(output->data() + offset, 0xff, 31);
+ break;
+
+ case kYTooLarge:
+ // Increase the dh_Ys length.
+ output->Write(offset - 2, prime.len() + sizeof(kTooLargeExtra), 2);
+ // Then insert the octet.
+ output->Splice(&kTooLargeExtra, sizeof(kTooLargeExtra), offset);
+ break;
+
+ case kYZeroPad:
+ output->Write(offset - 2, prime.len() + sizeof(kExtraZero), 2);
+ output->Splice(&kExtraZero, sizeof(kExtraZero), offset);
+ break;
+ }
+ }
+
+ private:
+ ChangeYTo change_Y_;
+};
+
+class TlsDheSkeChangeYServer : public TlsDheSkeChangeY {
+ public:
+ TlsDheSkeChangeYServer(const std::shared_ptr<TlsAgent>& a, ChangeYTo change,
+ bool modify)
+ : TlsDheSkeChangeY(a, kTlsHandshakeServerKeyExchange, change),
+ modify_(modify),
+ p_() {}
+
+ const DataBuffer& prime() const { return p_; }
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) override {
+ size_t offset = 2;
+ // Read dh_p
+ uint32_t dh_len = 0;
+ EXPECT_TRUE(input.Read(0, 2, &dh_len));
+ EXPECT_GT(input.len(), offset + dh_len);
+ p_.Assign(input.data() + offset, dh_len);
+ offset += dh_len;
+
+ // Skip dh_g to find dh_Ys
+ EXPECT_TRUE(input.Read(offset, 2, &dh_len));
+ offset += 2 + dh_len;
+
+ if (modify_) {
+ ChangeY(input, output, offset, p_);
+ return CHANGE;
+ }
+ return KEEP;
+ }
+
+ private:
+ bool modify_;
+ DataBuffer p_;
+};
+
+class TlsDheSkeChangeYClient : public TlsDheSkeChangeY {
+ public:
+ TlsDheSkeChangeYClient(
+ const std::shared_ptr<TlsAgent>& a, ChangeYTo change,
+ std::shared_ptr<const TlsDheSkeChangeYServer> server_filter)
+ : TlsDheSkeChangeY(a, kTlsHandshakeClientKeyExchange, change),
+ server_filter_(server_filter) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) override {
+ ChangeY(input, output, 0, server_filter_->prime());
+ return CHANGE;
+ }
+
+ private:
+ std::shared_ptr<const TlsDheSkeChangeYServer> server_filter_;
+};
+
+/* This matrix includes: variant (stream/datagram), TLS version, what change to
+ * make to dh_Ys, whether the client will be configured to require DH named
+ * groups. Test all combinations. */
+typedef std::tuple<SSLProtocolVariant, uint16_t, TlsDheSkeChangeY::ChangeYTo,
+ bool>
+ DamageDHYProfile;
+class TlsDamageDHYTest
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<DamageDHYProfile> {
+ public:
+ TlsDamageDHYTest()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+};
+
+TEST_P(TlsDamageDHYTest, DamageServerY) {
+ EnableOnlyDheCiphers();
+ if (std::get<3>(GetParam())) {
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
+ }
+ TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
+ MakeTlsFilter<TlsDheSkeChangeYServer>(server_, change, true);
+
+ if (change == TlsDheSkeChangeY::kYZeroPad) {
+ ExpectAlert(client_, kTlsAlertDecryptError);
+ } else {
+ ExpectAlert(client_, kTlsAlertIllegalParameter);
+ }
+ ConnectExpectFail();
+ if (change == TlsDheSkeChangeY::kYZeroPad) {
+ // Zero padding Y only manifests in a signature failure.
+ // In TLS 1.0 and 1.1, the client reports a device error.
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_2) {
+ client_->CheckErrorCode(SEC_ERROR_PKCS11_DEVICE_ERROR);
+ } else {
+ client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
+ }
+ server_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ } else {
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ }
+}
+
+TEST_P(TlsDamageDHYTest, DamageClientY) {
+ EnableOnlyDheCiphers();
+ if (std::get<3>(GetParam())) {
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
+ }
+ // The filter on the server is required to capture the prime.
+ auto server_filter = MakeTlsFilter<TlsDheSkeChangeYServer>(
+ server_, TlsDheSkeChangeY::kYZero, false);
+
+ // The client filter does the damage.
+ TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
+ MakeTlsFilter<TlsDheSkeChangeYClient>(client_, change, server_filter);
+
+ if (change == TlsDheSkeChangeY::kYZeroPad) {
+ ExpectAlert(server_, kTlsAlertDecryptError);
+ } else {
+ ExpectAlert(server_, kTlsAlertHandshakeFailure);
+ }
+ ConnectExpectFail();
+ if (change == TlsDheSkeChangeY::kYZeroPad) {
+ // Zero padding Y only manifests in a finished error.
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+ } else {
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE);
+ }
+}
+
+static const TlsDheSkeChangeY::ChangeYTo kAllYArr[] = {
+ TlsDheSkeChangeY::kYZero, TlsDheSkeChangeY::kYOne,
+ TlsDheSkeChangeY::kYPMinusOne, TlsDheSkeChangeY::kYGreaterThanP,
+ TlsDheSkeChangeY::kYTooLarge, TlsDheSkeChangeY::kYZeroPad};
+static ::testing::internal::ParamGenerator<TlsDheSkeChangeY::ChangeYTo> kAllY =
+ ::testing::ValuesIn(kAllYArr);
+static const bool kTrueFalseArr[] = {true, false};
+static ::testing::internal::ParamGenerator<bool> kTrueFalse =
+ ::testing::ValuesIn(kTrueFalseArr);
+
+INSTANTIATE_TEST_SUITE_P(
+ DamageYStream, TlsDamageDHYTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12, kAllY, kTrueFalse));
+INSTANTIATE_TEST_SUITE_P(
+ DamageYDatagram, TlsDamageDHYTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11V12, kAllY, kTrueFalse));
+
+class TlsDheSkeMakePEven : public TlsHandshakeFilter {
+ public:
+ TlsDheSkeMakePEven(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeServerKeyExchange}) {}
+
+ virtual PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) {
+ // Find the end of dh_p
+ uint32_t dh_len = 0;
+ EXPECT_TRUE(input.Read(0, 2, &dh_len));
+ EXPECT_GT(input.len(), 2 + dh_len) << "enough space for dh_p";
+ size_t offset = 2 + dh_len - 1;
+ EXPECT_TRUE((input.data()[offset] & 0x01) == 0x01) << "p should be odd";
+
+ *output = input;
+ output->data()[offset] &= 0xfe;
+
+ return CHANGE;
+ }
+};
+
+// Even without requiring named groups, an even value for p is bad news.
+TEST_P(TlsConnectGenericPre13, MakeDhePEven) {
+ EnableOnlyDheCiphers();
+ MakeTlsFilter<TlsDheSkeMakePEven>(server_);
+
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+class TlsDheSkeZeroPadP : public TlsHandshakeFilter {
+ public:
+ TlsDheSkeZeroPadP(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeServerKeyExchange}) {}
+
+ virtual PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) {
+ *output = input;
+ uint32_t dh_len = 0;
+ EXPECT_TRUE(input.Read(0, 2, &dh_len));
+ static const uint8_t kZeroPad = 0;
+ output->Write(0, dh_len + sizeof(kZeroPad), 2); // increment the length
+ output->Splice(&kZeroPad, sizeof(kZeroPad), 2); // insert a zero
+
+ return CHANGE;
+ }
+};
+
+// Zero padding only causes signature failure.
+TEST_P(TlsConnectGenericPre13, PadDheP) {
+ EnableOnlyDheCiphers();
+ MakeTlsFilter<TlsDheSkeZeroPadP>(server_);
+
+ ConnectExpectAlert(client_, kTlsAlertDecryptError);
+
+ // In TLS 1.0 and 1.1, the client reports a device error.
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_2) {
+ client_->CheckErrorCode(SEC_ERROR_PKCS11_DEVICE_ERROR);
+ } else {
+ client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
+ }
+ server_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+}
+
+// The server should not pick the weak DH group if the client includes FFDHE
+// named groups in the supported_groups extension. The server then picks a
+// commonly-supported named DH group and this connects.
+//
+// Note: This test case can take ages to generate the weak DH key.
+TEST_P(TlsConnectGenericPre13, WeakDHGroup) {
+ EnableOnlyDheCiphers();
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
+ EXPECT_EQ(SECSuccess,
+ SSL_EnableWeakDHEPrimeGroup(server_->ssl_fd(), PR_TRUE));
+
+ Connect();
+}
+
+TEST_P(TlsConnectGeneric, Ffdhe3072) {
+ EnableOnlyDheCiphers();
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ffdhe_3072};
+ client_->ConfigNamedGroups(groups);
+
+ Connect();
+}
+
+// Even though the client doesn't have DHE groups enabled the server assumes it
+// does. Because the client doesn't require named groups it accepts FF3072 as
+// custom group.
+TEST_P(TlsConnectGenericPre13, NamedGroupMismatchPre13) {
+ EnableOnlyDheCiphers();
+ static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072};
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp256r1};
+ server_->ConfigNamedGroups(server_groups);
+ client_->ConfigNamedGroups(client_groups);
+
+ Connect();
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_custom, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+}
+
+// Same test but for TLS 1.3. This has to fail.
+TEST_P(TlsConnectTls13, NamedGroupMismatch13) {
+ EnableOnlyDheCiphers();
+ static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072};
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp256r1};
+ server_->ConfigNamedGroups(server_groups);
+ client_->ConfigNamedGroups(client_groups);
+
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+// Replace the key share in the server key exchange message with one that's
+// larger than 8192 bits.
+class TooLongDHEServerKEXFilter : public TlsHandshakeFilter {
+ public:
+ TooLongDHEServerKEXFilter(const std::shared_ptr<TlsAgent>& server)
+ : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ // Replace the server key exchange message very large DH shares that are
+ // not supported by NSS.
+ const uint32_t share_len = 0x401;
+ const uint8_t zero_share[share_len] = {0x80};
+ size_t offset = 0;
+ // Write dh_p.
+ offset = output->Write(offset, share_len, 2);
+ offset = output->Write(offset, zero_share, share_len);
+ // Write dh_g.
+ offset = output->Write(offset, share_len, 2);
+ offset = output->Write(offset, zero_share, share_len);
+ // Write dh_Y.
+ offset = output->Write(offset, share_len, 2);
+ offset = output->Write(offset, zero_share, share_len);
+
+ return CHANGE;
+ }
+};
+
+TEST_P(TlsConnectGenericPre13, TooBigDHGroup) {
+ EnableOnlyDheCiphers();
+ MakeTlsFilter<TooLongDHEServerKEXFilter>(server_);
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_FALSE);
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ client_->CheckErrorCode(SSL_ERROR_DH_KEY_TOO_LONG);
+}
+
+// Even though the client doesn't have DHE groups enabled the server assumes it
+// does. The client requires named groups and thus does not accept FF3072 as
+// custom group in contrast to the previous test.
+TEST_P(TlsConnectGenericPre13, RequireNamedGroupsMismatchPre13) {
+ EnableOnlyDheCiphers();
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
+ static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072};
+ static const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ffdhe_2048};
+ server_->ConfigNamedGroups(server_groups);
+ client_->ConfigNamedGroups(client_groups);
+
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+TEST_P(TlsConnectGenericPre13, PreferredFfdhe) {
+ EnableOnlyDheCiphers();
+ static const SSLDHEGroupType groups[] = {ssl_ff_dhe_3072_group,
+ ssl_ff_dhe_2048_group};
+ EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), groups,
+ PR_ARRAY_SIZE(groups)));
+
+ Connect();
+ client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072);
+ server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072);
+ client_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256);
+ server_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256);
+}
+
+TEST_P(TlsConnectGenericPre13, MismatchDHE) {
+ EnableOnlyDheCiphers();
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
+ static const SSLDHEGroupType serverGroups[] = {ssl_ff_dhe_3072_group};
+ EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), serverGroups,
+ PR_ARRAY_SIZE(serverGroups)));
+ static const SSLDHEGroupType clientGroups[] = {ssl_ff_dhe_2048_group};
+ EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(client_->ssl_fd(), clientGroups,
+ PR_ARRAY_SIZE(clientGroups)));
+
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+TEST_P(TlsConnectTls13, ResumeFfdhe) {
+ EnableOnlyDheCiphers();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ EnableOnlyDheCiphers();
+ auto clientCapture =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn);
+ auto serverCapture =
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_pre_shared_key_xtn);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ ASSERT_LT(0UL, clientCapture->extension().len());
+ ASSERT_LT(0UL, serverCapture->extension().len());
+}
+
+class TlsDheSkeChangeSignature : public TlsHandshakeFilter {
+ public:
+ TlsDheSkeChangeSignature(const std::shared_ptr<TlsAgent>& a, uint16_t version,
+ const uint8_t* data, size_t len)
+ : TlsHandshakeFilter(a, {kTlsHandshakeServerKeyExchange}),
+ version_(version),
+ data_(data),
+ len_(len) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ TlsParser parser(input);
+ EXPECT_TRUE(parser.SkipVariable(2)); // dh_p
+ EXPECT_TRUE(parser.SkipVariable(2)); // dh_g
+ EXPECT_TRUE(parser.SkipVariable(2)); // dh_Ys
+
+ // Copy DH params to output.
+ size_t offset = output->Write(0, input.data(), parser.consumed());
+
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_2) {
+ // Write signature algorithm.
+ offset = output->Write(offset, ssl_sig_dsa_sha256, 2);
+ }
+
+ // Write new signature.
+ offset = output->Write(offset, len_, 2);
+ offset = output->Write(offset, data_, len_);
+
+ return CHANGE;
+ }
+
+ private:
+ uint16_t version_;
+ const uint8_t* data_;
+ size_t len_;
+};
+
+TEST_P(TlsConnectGenericPre13, InvalidDERSignatureFfdhe) {
+ const uint8_t kBogusDheSignature[] = {
+ 0x30, 0x69, 0x3c, 0x02, 0x1c, 0x7d, 0x0b, 0x2f, 0x64, 0x00, 0x27,
+ 0xae, 0xcf, 0x1e, 0x28, 0x08, 0x6a, 0x7f, 0xb1, 0xbd, 0x78, 0xb5,
+ 0x3b, 0x8c, 0x8f, 0x59, 0xed, 0x8f, 0xee, 0x78, 0xeb, 0x2c, 0xe9,
+ 0x02, 0x1c, 0x6d, 0x7f, 0x3c, 0x0f, 0xf4, 0x44, 0x35, 0x0b, 0xb2,
+ 0x6d, 0xdc, 0xb8, 0x21, 0x87, 0xdd, 0x0d, 0xb9, 0x46, 0x09, 0x3e,
+ 0xef, 0x81, 0x5b, 0x37, 0x09, 0x39, 0xeb};
+
+ Reset(TlsAgent::kServerDsa);
+
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ffdhe_2048};
+ client_->ConfigNamedGroups(client_groups);
+
+ MakeTlsFilter<TlsDheSkeChangeSignature>(server_, version_, kBogusDheSignature,
+ sizeof(kBogusDheSignature));
+
+ ConnectExpectAlert(client_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+}
+
+TEST_P(TlsConnectTls12, ConnectInconsistentSigAlgDHE) {
+ EnableOnlyDheCiphers();
+
+ MakeTlsFilter<DHEServerKEXSigAlgReplacer>(server_,
+ ssl_sig_ecdsa_secp256r1_sha256);
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+}
+
+static void CheckSkeSigScheme(
+ std::shared_ptr<TlsHandshakeRecorder>& capture_ske,
+ uint16_t expected_scheme) {
+ TlsParser parser(capture_ske->buffer());
+ EXPECT_TRUE(parser.SkipVariable(2)) << " read dh_p";
+ EXPECT_TRUE(parser.SkipVariable(2)) << " read dh_q";
+ EXPECT_TRUE(parser.SkipVariable(2)) << " read dh_Ys";
+
+ uint32_t tmp;
+ EXPECT_TRUE(parser.Read(&tmp, 2)) << " read sig_scheme";
+ EXPECT_EQ(expected_scheme, static_cast<uint16_t>(tmp));
+}
+
+TEST_P(TlsConnectTls12, ConnectSigAlgEnabledByPolicyDhe) {
+ EnableOnlyDheCiphers();
+
+ const std::vector<SSLSignatureScheme> schemes = {ssl_sig_rsa_pkcs1_sha1,
+ ssl_sig_rsa_pkcs1_sha384};
+
+ EnsureTlsSetup();
+ client_->SetSignatureSchemes(schemes.data(), schemes.size());
+ server_->SetSignatureSchemes(schemes.data(), schemes.size());
+ auto capture_ske = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+
+ // Enable SHA-1 by policy.
+ SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, NSS_USE_ALG_IN_SSL_KX, 0);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL,
+ 0);
+ ASSERT_EQ(SECSuccess, rv);
+
+ Handshake(); // Remainder of handshake
+ // The server should now report that it is connected
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+
+ CheckSkeSigScheme(capture_ske, ssl_sig_rsa_pkcs1_sha1);
+}
+
+TEST_P(TlsConnectTls12, ConnectSigAlgDisabledByPolicyDhe) {
+ EnableOnlyDheCiphers();
+
+ const std::vector<SSLSignatureScheme> schemes = {ssl_sig_rsa_pkcs1_sha1,
+ ssl_sig_rsa_pkcs1_sha384};
+
+ EnsureTlsSetup();
+ client_->SetSignatureSchemes(schemes.data(), schemes.size());
+ server_->SetSignatureSchemes(schemes.data(), schemes.size());
+ auto capture_ske = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+
+ // Disable SHA-1 by policy after sending ClientHello so that CH
+ // includes SHA-1 signature scheme.
+ SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, 0, NSS_USE_ALG_IN_SSL_KX);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL,
+ 0);
+ ASSERT_EQ(SECSuccess, rv);
+
+ Handshake(); // Remainder of handshake
+ // The server should now report that it is connected
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+
+ CheckSkeSigScheme(capture_ske, ssl_sig_rsa_pkcs1_sha384);
+}
+
+TEST_P(TlsConnectPre12, ConnectSigAlgDisabledWeakGroupByOption3072DhePre12) {
+ EnableOnlyDheCiphers();
+
+ // explicitly enable the weak groups
+ EXPECT_EQ(SECSuccess,
+ SSL_EnableWeakDHEPrimeGroup(server_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess,
+ SSL_EnableWeakDHEPrimeGroup(client_->ssl_fd(), PR_TRUE));
+ server_->SetNssOption(NSS_DH_MIN_KEY_SIZE, 3072);
+ Connect();
+ client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072);
+ server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072);
+}
+
+TEST_P(TlsConnectPre12, ConnectSigAlgDisabledWeakGroupByOption2048DhePre12) {
+ EnableOnlyDheCiphers();
+
+ // explicitly enable the weak groups
+ EXPECT_EQ(SECSuccess,
+ SSL_EnableWeakDHEPrimeGroup(server_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess,
+ SSL_EnableWeakDHEPrimeGroup(client_->ssl_fd(), PR_TRUE));
+ server_->SetNssOption(NSS_DH_MIN_KEY_SIZE, 2048);
+ Connect();
+ client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_2048, 2048);
+ server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_2048, 2048);
+}
+
+TEST_P(TlsConnectPre12, ConnectSigAlgDisabledByPolicyDhePre12) {
+ EnableOnlyDheCiphers();
+
+ EnsureTlsSetup();
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+
+ // Disable SHA-1 by policy. This will cause the connection fail as
+ // TLS 1.1 or earlier uses combined SHA-1 + MD5 signature.
+ SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, 0, NSS_USE_ALG_IN_SSL_KX);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL,
+ 0);
+ ASSERT_EQ(SECSuccess, rv);
+
+ server_->ExpectSendAlert(kTlsAlertHandshakeFailure);
+ client_->ExpectReceiveAlert(kTlsAlertHandshakeFailure);
+
+ // Remainder of handshake
+ Handshake();
+
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_HASH_ALGORITHM);
+}
+
+TEST_P(TlsConnectTls12, ConnectSigAlgDisablePreferredGroupByOption3072Dhe) {
+ EnableOnlyDheCiphers();
+ static const SSLDHEGroupType dhe_groups[] = {
+ ssl_ff_dhe_2048_group, // first in the lists is the preferred group
+ ssl_ff_dhe_3072_group};
+
+ server_->SetNssOption(NSS_DH_MIN_KEY_SIZE, 3072);
+ EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), &dhe_groups[0],
+ PR_ARRAY_SIZE(dhe_groups)));
+ Connect();
+ // our option size should override the preferred group
+ client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072);
+ server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072);
+}
+
+TEST_P(TlsConnectTls12, ConnectSigAlgDisableGroupByOption3072Dhe) {
+ EnableOnlyDheCiphers();
+
+ server_->SetNssOption(NSS_DH_MIN_KEY_SIZE, 3072);
+ Connect();
+ client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072);
+ server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072);
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc
new file mode 100644
index 0000000000..98b29921ea
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc
@@ -0,0 +1,914 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslexp.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectDatagramPre13, DropClientFirstFlightOnce) {
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x1));
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightOnce) {
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x1));
+ Connect();
+ SendReceive();
+}
+
+// This drops the first transmission from both the client and server of all
+// flights that they send. Note: In DTLS 1.3, the shorter handshake means that
+// this will also drop some application data, so we can't call SendReceive().
+TEST_P(TlsConnectDatagramPre13, DropAllFirstTransmissions) {
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x15));
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x5));
+ Connect();
+}
+
+// This drops the server's first flight three times.
+TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightThrice) {
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x7));
+ Connect();
+}
+
+// This drops the client's second flight once
+TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightOnce) {
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x2));
+ Connect();
+}
+
+// This drops the client's second flight three times.
+TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightThrice) {
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0xe));
+ Connect();
+}
+
+// This drops the server's second flight three times.
+TEST_P(TlsConnectDatagramPre13, DropServerSecondFlightThrice) {
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0xe));
+ Connect();
+}
+
+static void CheckAcks(const std::shared_ptr<TlsRecordRecorder>& acks,
+ size_t index, std::vector<uint64_t> expected) {
+ ASSERT_LT(index, acks->count());
+ const DataBuffer& buf = acks->record(index).buffer;
+ size_t offset = 2;
+ uint64_t len;
+
+ EXPECT_EQ(2 + expected.size() * 8, buf.len());
+ ASSERT_TRUE(buf.Read(0, 2, &len));
+ ASSERT_EQ(static_cast<size_t>(len + 2), buf.len());
+ if ((2 + expected.size() * 8) != buf.len()) {
+ while (offset < buf.len()) {
+ uint64_t ack;
+ ASSERT_TRUE(buf.Read(offset, 8, &ack));
+ offset += 8;
+ std::cerr << "Ack=0x" << std::hex << ack << std::dec << std::endl;
+ }
+ return;
+ }
+
+ for (size_t i = 0; i < expected.size(); ++i) {
+ uint64_t a = expected[i];
+ uint64_t ack;
+ ASSERT_TRUE(buf.Read(offset, 8, &ack));
+ offset += 8;
+ if (a != ack) {
+ ADD_FAILURE() << "Wrong ack " << i << " expected=0x" << std::hex << a
+ << " got=0x" << ack << std::dec;
+ }
+ }
+}
+
+class TlsDropDatagram13 : public TlsConnectDatagram13,
+ public ::testing::WithParamInterface<bool> {
+ public:
+ TlsDropDatagram13()
+ : client_filters_(),
+ server_filters_(),
+ expected_client_acks_(0),
+ expected_server_acks_(1) {}
+
+ void SetUp() override {
+ TlsConnectDatagram13::SetUp();
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ int short_header = GetParam() ? PR_TRUE : PR_FALSE;
+ client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, short_header);
+ server_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, short_header);
+ SetFilters();
+ }
+
+ void SetFilters() {
+ EnsureTlsSetup();
+ client_filters_.Init(client_);
+ server_filters_.Init(server_);
+ }
+
+ void HandshakeAndAck(const std::shared_ptr<TlsAgent>& agent) {
+ agent->Handshake(); // Read flight.
+ ShiftDtlsTimers();
+ agent->Handshake(); // Generate ACK.
+ }
+
+ void ShrinkPostServerHelloMtu() {
+ // Abuse the custom extension mechanism to modify the MTU so that the
+ // Certificate message is split into two pieces.
+ ASSERT_EQ(
+ SECSuccess,
+ SSL_InstallExtensionHooks(
+ server_->ssl_fd(), 1,
+ [](PRFileDesc* fd, SSLHandshakeType message, PRUint8* data,
+ unsigned int* len, unsigned int maxLen, void* arg) -> PRBool {
+ SSLInt_SetMTU(fd, 500); // Splits the certificate.
+ return PR_FALSE;
+ },
+ nullptr,
+ [](PRFileDesc* fd, SSLHandshakeType message, const PRUint8* data,
+ unsigned int len, SSLAlertDescription* alert,
+ void* arg) -> SECStatus { return SECSuccess; },
+ nullptr));
+ }
+
+ protected:
+ class DropAckChain {
+ public:
+ DropAckChain()
+ : records_(nullptr), ack_(nullptr), drop_(nullptr), chain_(nullptr) {}
+
+ void Init(const std::shared_ptr<TlsAgent>& agent) {
+ records_ = std::make_shared<TlsRecordRecorder>(agent);
+ ack_ = std::make_shared<TlsRecordRecorder>(agent, ssl_ct_ack);
+ ack_->EnableDecryption();
+ drop_ = std::make_shared<SelectiveRecordDropFilter>(agent, 0, false);
+ chain_ = std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit({records_, ack_, drop_}));
+ agent->SetFilter(chain_);
+ }
+
+ const TlsRecord& record(size_t i) const { return records_->record(i); }
+
+ std::shared_ptr<TlsRecordRecorder> records_;
+ std::shared_ptr<TlsRecordRecorder> ack_;
+ std::shared_ptr<SelectiveRecordDropFilter> drop_;
+ std::shared_ptr<PacketFilter> chain_;
+ };
+
+ void CheckedHandshakeSendReceive() {
+ Handshake();
+ CheckPostHandshake();
+ }
+
+ void CheckPostHandshake() {
+ CheckConnected();
+ SendReceive();
+ EXPECT_EQ(expected_client_acks_, client_filters_.ack_->count());
+ EXPECT_EQ(expected_server_acks_, server_filters_.ack_->count());
+ }
+
+ protected:
+ DropAckChain client_filters_;
+ DropAckChain server_filters_;
+ size_t expected_client_acks_;
+ size_t expected_server_acks_;
+};
+
+// All of these tests produce a minimum one ACK, from the server
+// to the client upon receiving the client Finished.
+// Dropping complete first and second flights does not produce
+// ACKs
+TEST_P(TlsDropDatagram13, DropClientFirstFlightOnce) {
+ client_filters_.drop_->Reset({0});
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ CheckedHandshakeSendReceive();
+ CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL});
+}
+
+TEST_P(TlsDropDatagram13, DropServerFirstFlightOnce) {
+ server_filters_.drop_->Reset(0xff);
+ StartConnect();
+ client_->Handshake();
+ // Send the first flight, all dropped.
+ server_->Handshake();
+ server_filters_.drop_->Disable();
+ CheckedHandshakeSendReceive();
+ CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL});
+}
+
+// Dropping the server's first record also does not produce
+// an ACK because the next record is ignored.
+// TODO(ekr@rtfm.com): We should generate an empty ACK.
+TEST_P(TlsDropDatagram13, DropServerFirstRecordOnce) {
+ server_filters_.drop_->Reset({0});
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ Handshake();
+ CheckedHandshakeSendReceive();
+ CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL});
+}
+
+// Dropping the second packet of the server's flight should
+// produce an ACK.
+TEST_P(TlsDropDatagram13, DropServerSecondRecordOnce) {
+ server_filters_.drop_->Reset({1});
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ HandshakeAndAck(client_);
+ expected_client_acks_ = 1;
+ CheckedHandshakeSendReceive();
+ CheckAcks(client_filters_.ack_, 0, {0}); // ServerHello
+ CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL});
+}
+
+// Drop the server ACK and verify that the client retransmits
+// the ClientHello.
+TEST_P(TlsDropDatagram13, DropServerAckOnce) {
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ // At this point the server has sent it's first flight,
+ // so make it drop the ACK.
+ server_filters_.drop_->Reset({0});
+ client_->Handshake(); // Send the client Finished.
+ server_->Handshake(); // Receive the Finished and send the ACK.
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+ // Wait for the DTLS timeout to make sure we retransmit the
+ // Finished.
+ ShiftDtlsTimers();
+ client_->Handshake(); // Retransmit the Finished.
+ server_->Handshake(); // Read the Finished and send an ACK.
+ uint8_t buf[1];
+ PRInt32 rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf));
+ expected_server_acks_ = 2;
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+ CheckPostHandshake();
+ // There should be two copies of the finished ACK
+ CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL});
+ CheckAcks(server_filters_.ack_, 1, {0x0002000000000000ULL});
+}
+
+// Drop the client certificate verify.
+TEST_P(TlsDropDatagram13, DropClientCertVerify) {
+ StartConnect();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ client_->Handshake();
+ server_->Handshake();
+ // Have the client drop Cert Verify
+ client_filters_.drop_->Reset({1});
+ expected_server_acks_ = 2;
+ CheckedHandshakeSendReceive();
+ // Ack of the Cert.
+ CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL});
+ // Ack of the whole client handshake.
+ CheckAcks(
+ server_filters_.ack_, 1,
+ {0x0002000000000000ULL, // CH (we drop everything after this on client)
+ 0x0002000000000003ULL, // CT (2)
+ 0x0002000000000004ULL}); // FIN (2)
+}
+
+// Shrink the MTU down so that certs get split and drop the first piece.
+TEST_P(TlsDropDatagram13, DropFirstHalfOfServerCertificate) {
+ server_filters_.drop_->Reset({2});
+ StartConnect();
+ ShrinkPostServerHelloMtu();
+ client_->Handshake();
+ server_->Handshake();
+ // Check that things got split.
+ EXPECT_EQ(6UL,
+ server_filters_.records_->count()); // SH, EE, CT1, CT2, CV, FIN
+ size_t ct1_size = server_filters_.record(2).buffer.len();
+ server_filters_.records_->Clear();
+ expected_client_acks_ = 1;
+ HandshakeAndAck(client_);
+ server_->Handshake(); // Retransmit
+ EXPECT_EQ(3UL, server_filters_.records_->count()); // CT2, CV, FIN
+ // Check that the first record is CT1 (which is identical to the same
+ // as the previous CT1).
+ EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len());
+ CheckedHandshakeSendReceive();
+ CheckAcks(client_filters_.ack_, 0,
+ {0, // SH
+ 0x0002000000000000ULL, // EE
+ 0x0002000000000002ULL}); // CT2
+ CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL});
+}
+
+// Shrink the MTU down so that certs get split and drop the second piece.
+TEST_P(TlsDropDatagram13, DropSecondHalfOfServerCertificate) {
+ server_filters_.drop_->Reset({3});
+ StartConnect();
+ ShrinkPostServerHelloMtu();
+ client_->Handshake();
+ server_->Handshake();
+ // Check that things got split.
+ EXPECT_EQ(6UL,
+ server_filters_.records_->count()); // SH, EE, CT1, CT2, CV, FIN
+ size_t ct1_size = server_filters_.record(3).buffer.len();
+ server_filters_.records_->Clear();
+ expected_client_acks_ = 1;
+ HandshakeAndAck(client_);
+ server_->Handshake(); // Retransmit
+ EXPECT_EQ(3UL, server_filters_.records_->count()); // CT1, CV, FIN
+ // Check that the first record is CT1
+ EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len());
+ CheckedHandshakeSendReceive();
+ CheckAcks(client_filters_.ack_, 0,
+ {
+ 0, // SH
+ 0x0002000000000000ULL, // EE
+ 0x0002000000000001ULL, // CT1
+ });
+ CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL});
+}
+
+// In this test, the Certificate message is sent four times, we drop all or part
+// of the first three attempts:
+// 1. Without fragmentation so that we can see how big it is - we drop that.
+// 2. In two pieces - we drop half AND the resulting ACK.
+// 3. In three pieces - we drop the middle piece.
+//
+// After that we let all the ACKs through and allow the handshake to complete
+// without further interference.
+//
+// This allows us to test that ranges of handshake messages are sent correctly
+// even when there are overlapping acknowledgments; that ACKs with duplicate or
+// overlapping message ranges are handled properly; and that extra
+// retransmissions are handled properly.
+class TlsFragmentationAndRecoveryTest : public TlsDropDatagram13 {
+ public:
+ TlsFragmentationAndRecoveryTest() : cert_len_(0) {}
+
+ protected:
+ void RunTest(size_t dropped_half) {
+ FirstFlightDropCertificate();
+
+ SecondAttemptDropHalf(dropped_half);
+ size_t dropped_half_size = server_record_len(dropped_half);
+ size_t second_flight_count = server_filters_.records_->count();
+
+ ThirdAttemptDropMiddle();
+ size_t repaired_third_size = server_record_len((dropped_half == 0) ? 0 : 2);
+ size_t third_flight_count = server_filters_.records_->count();
+
+ AckAndCompleteRetransmission();
+ size_t final_server_flight_count = server_filters_.records_->count();
+ EXPECT_LE(3U, final_server_flight_count); // CT(sixth), CV, Fin
+ CheckSizeOfSixth(dropped_half_size, repaired_third_size);
+
+ SendDelayedAck();
+ // Same number of messages as the last flight.
+ EXPECT_EQ(final_server_flight_count, server_filters_.records_->count());
+ // Double check that the Certificate size is still correct.
+ CheckSizeOfSixth(dropped_half_size, repaired_third_size);
+
+ CompleteHandshake(final_server_flight_count);
+
+ // This is the ACK for the first attempt to send a whole certificate.
+ std::vector<uint64_t> client_acks = {
+ 0, // SH
+ 0x0002000000000000ULL // EE
+ };
+ CheckAcks(client_filters_.ack_, 0, client_acks);
+ // And from the second attempt for the half was kept (we delayed this ACK).
+ client_acks.push_back(0x0002000000000000ULL + second_flight_count +
+ ~dropped_half % 2);
+ CheckAcks(client_filters_.ack_, 1, client_acks);
+ // And the third attempt where the first and last thirds got through.
+ client_acks.push_back(0x0002000000000000ULL + second_flight_count +
+ third_flight_count - 1);
+ client_acks.push_back(0x0002000000000000ULL + second_flight_count +
+ third_flight_count + 1);
+ CheckAcks(client_filters_.ack_, 2, client_acks);
+ CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL});
+ }
+
+ private:
+ void FirstFlightDropCertificate() {
+ StartConnect();
+ client_->Handshake();
+
+ // Note: 1 << N is the Nth packet, starting from zero.
+ server_filters_.drop_->Reset(1 << 2); // Drop Cert0.
+ server_->Handshake();
+ EXPECT_EQ(5U, server_filters_.records_->count()); // SH, EE, CT, CV, Fin
+ cert_len_ = server_filters_.records_->record(2).buffer.len();
+
+ HandshakeAndAck(client_);
+ EXPECT_EQ(2U, client_filters_.records_->count());
+ }
+
+ // Lower the MTU so that the server has to split the certificate in two
+ // pieces. The server resends Certificate (in two), plus CV and Fin.
+ void SecondAttemptDropHalf(size_t dropped_half) {
+ ASSERT_LE(0U, dropped_half);
+ ASSERT_GT(2U, dropped_half);
+ server_filters_.records_->Clear();
+ server_filters_.drop_->Reset({dropped_half}); // Drop Cert1[half]
+ SplitServerMtu(2);
+ server_->Handshake();
+ EXPECT_LE(4U, server_filters_.records_->count()); // CT x2, CV, Fin
+
+ // Generate and capture the ACK from the client.
+ client_filters_.drop_->Reset({0});
+ HandshakeAndAck(client_);
+ EXPECT_EQ(3U, client_filters_.records_->count());
+ }
+
+ // Lower the MTU again so that the server sends Certificate cut into three
+ // pieces. Drop the middle piece.
+ void ThirdAttemptDropMiddle() {
+ server_filters_.records_->Clear();
+ server_filters_.drop_->Reset({1}); // Drop Cert2[1] (of 3)
+ SplitServerMtu(3);
+ // Because we dropped the client ACK, the server retransmits on a timer.
+ ShiftDtlsTimers();
+ server_->Handshake();
+ EXPECT_LE(5U, server_filters_.records_->count()); // CT x3, CV, Fin
+ }
+
+ void AckAndCompleteRetransmission() {
+ // Generate ACKs.
+ HandshakeAndAck(client_);
+ // The server should send the final sixth of the certificate: the client has
+ // acknowledged the first half and the last third. Also send CV and Fin.
+ server_filters_.records_->Clear();
+ server_->Handshake();
+ }
+
+ void CheckSizeOfSixth(size_t size_of_half, size_t size_of_third) {
+ // Work out if the final sixth is the right size. We get the records with
+ // overheads added, which obscures the length of the payload. We want to
+ // ensure that the server only sent the missing sixth of the Certificate.
+ //
+ // We captured |size_of_half + overhead| and |size_of_third + overhead| and
+ // want to calculate |size_of_third - size_of_third + overhead|. We can't
+ // calculate |overhead|, but it is is (currently) always a handshake message
+ // header, a content type, and an authentication tag:
+ static const size_t record_overhead = 12 + 1 + 16;
+ EXPECT_EQ(size_of_half - size_of_third + record_overhead,
+ server_filters_.records_->record(0).buffer.len());
+ }
+
+ void SendDelayedAck() {
+ // Send the ACK we held back. The reordered ACK doesn't add new
+ // information,
+ // but triggers an extra retransmission of the missing records again (even
+ // though the client has all that it needs).
+ client_->SendRecordDirect(client_filters_.records_->record(2));
+ server_filters_.records_->Clear();
+ server_->Handshake();
+ }
+
+ void CompleteHandshake(size_t extra_retransmissions) {
+ // All this messing around shouldn't cause a failure...
+ Handshake();
+ // ...but it leaves a mess. Add an extra few calls to Handshake() for the
+ // client so that it absorbs the extra retransmissions.
+ for (size_t i = 0; i < extra_retransmissions; ++i) {
+ client_->Handshake();
+ }
+ CheckConnected();
+ }
+
+ // Split the server MTU so that the Certificate is split into |count| pieces.
+ // The calculation doesn't need to be perfect as long as the Certificate
+ // message is split into the right number of pieces.
+ void SplitServerMtu(size_t count) {
+ // Set the MTU based on the formula:
+ // bare_size = cert_len_ - actual_overhead
+ // MTU = ceil(bare_size / count) + pessimistic_overhead
+ //
+ // actual_overhead is the amount of actual overhead on the record we
+ // captured, which is (note that our length doesn't include the header):
+ static const size_t actual_overhead = 12 + // handshake message header
+ 1 + // content type
+ 16; // authentication tag
+ size_t bare_size = cert_len_ - actual_overhead;
+
+ // pessimistic_overhead is the amount of expansion that NSS assumes will be
+ // added to each handshake record. Right now, that is DTLS_MIN_FRAGMENT:
+ static const size_t pessimistic_overhead =
+ 12 + // handshake message header
+ 1 + // content type
+ 13 + // record header length
+ 64; // maximum record expansion: IV, MAC and block cipher expansion
+
+ size_t mtu = (bare_size + count - 1) / count + pessimistic_overhead;
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "server: set MTU to " << mtu << std::endl;
+ }
+ EXPECT_EQ(SECSuccess, SSLInt_SetMTU(server_->ssl_fd(), mtu));
+ }
+
+ size_t server_record_len(size_t index) const {
+ return server_filters_.records_->record(index).buffer.len();
+ }
+
+ size_t cert_len_;
+};
+
+TEST_P(TlsFragmentationAndRecoveryTest, DropFirstHalf) { RunTest(0); }
+
+TEST_P(TlsFragmentationAndRecoveryTest, DropSecondHalf) { RunTest(1); }
+
+TEST_P(TlsDropDatagram13, NoDropsDuringZeroRtt) {
+ SetupForZeroRtt();
+ SetFilters();
+ std::cerr << "Starting second handshake" << std::endl;
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+ EXPECT_EQ(0U, client_filters_.ack_->count());
+ CheckAcks(server_filters_.ack_, 0,
+ {0x0001000000000001ULL, // EOED
+ 0x0002000000000000ULL}); // Finished
+}
+
+TEST_P(TlsDropDatagram13, DropEEDuringZeroRtt) {
+ SetupForZeroRtt();
+ SetFilters();
+ std::cerr << "Starting second handshake" << std::endl;
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ server_filters_.drop_->Reset({1});
+ ZeroRttSendReceive(true, true);
+ HandshakeAndAck(client_);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+ CheckAcks(client_filters_.ack_, 0, {0});
+ CheckAcks(server_filters_.ack_, 0,
+ {0x0001000000000002ULL, // EOED
+ 0x0002000000000000ULL}); // Finished
+}
+
+class TlsReorderDatagram13 : public TlsDropDatagram13 {
+ public:
+ TlsReorderDatagram13() {}
+
+ // Send records from the records buffer in the given order.
+ void ReSend(TlsAgent::Role side, std::vector<size_t> indices) {
+ std::shared_ptr<TlsAgent> agent;
+ std::shared_ptr<TlsRecordRecorder> records;
+
+ if (side == TlsAgent::CLIENT) {
+ agent = client_;
+ records = client_filters_.records_;
+ } else {
+ agent = server_;
+ records = server_filters_.records_;
+ }
+
+ for (auto i : indices) {
+ agent->SendRecordDirect(records->record(i));
+ }
+ }
+};
+
+// Reorder the server records so that EE comes at the end
+// of the flight and will still produce an ACK.
+TEST_P(TlsDropDatagram13, ReorderServerEE) {
+ server_filters_.drop_->Reset({1});
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ // We dropped EE, now reinject.
+ server_->SendRecordDirect(server_filters_.record(1));
+ expected_client_acks_ = 1;
+ HandshakeAndAck(client_);
+ CheckedHandshakeSendReceive();
+ CheckAcks(client_filters_.ack_, 0,
+ {
+ 0, // SH
+ 0x0002000000000000, // EE
+ });
+ CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL});
+}
+
+// The client sends an out of order non-handshake message
+// but with the handshake key.
+TEST_F(TlsConnectDatagram13, SendOutOfOrderAppWithHandshakeKey) {
+ StartConnect();
+ // Capturing secrets means that we can't use decrypting filters on the client.
+ TlsSendCipherSpecCapturer capturer(client_);
+ client_->Handshake();
+ server_->Handshake();
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ server_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+ // After the client sends Finished, inject an app data record
+ // with the handshake key. This should produce an alert.
+ uint8_t buf[] = {'a', 'b', 'c'};
+ auto spec = capturer.spec(0);
+ ASSERT_NE(nullptr, spec.get());
+ ASSERT_EQ(2, spec->epoch());
+
+ uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
+ kCtDtlsCiphertextLengthPresent;
+ ASSERT_TRUE(client_->SendEncryptedRecord(spec, 0x0002000000000002, dtls13_ct,
+ DataBuffer(buf, sizeof(buf))));
+
+ // Now have the server consume the bogus message.
+ server_->ExpectSendAlert(illegal_parameter, kTlsAlertFatal);
+ server_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state());
+ EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, PORT_GetError());
+}
+
+TEST_F(TlsConnectDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) {
+ StartConnect();
+ TlsSendCipherSpecCapturer capturer(client_);
+ auto acks = MakeTlsFilter<TlsRecordRecorder>(server_, ssl_ct_ack);
+ acks->EnableDecryption();
+
+ client_->Handshake();
+ server_->Handshake();
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ server_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+ // Inject a new bogus handshake record, which the server responds
+ // to by just ACKing the original one (we ignore the contents).
+ uint8_t buf[] = {'a', 'b', 'c'};
+ auto spec = capturer.spec(0);
+ ASSERT_NE(nullptr, spec.get());
+ ASSERT_EQ(2, spec->epoch());
+ ASSERT_TRUE(client_->SendEncryptedRecord(spec, 0x0002000000000002,
+ ssl_ct_handshake,
+ DataBuffer(buf, sizeof(buf))));
+ server_->Handshake();
+ EXPECT_EQ(2UL, acks->count());
+ // The server acknowledges client Finished twice.
+ CheckAcks(acks, 0, {0x0002000000000000ULL});
+ CheckAcks(acks, 1, {0x0002000000000000ULL});
+}
+
+// Shrink the MTU down so that certs get split and then swap the first and
+// second pieces of the server certificate.
+TEST_P(TlsReorderDatagram13, ReorderServerCertificate) {
+ StartConnect();
+ ShrinkPostServerHelloMtu();
+ client_->Handshake();
+ // Drop the entire handshake flight so we can reorder.
+ server_filters_.drop_->Reset(0xff);
+ server_->Handshake();
+ // Check that things got split.
+ EXPECT_EQ(6UL,
+ server_filters_.records_->count()); // CH, EE, CT1, CT2, CV, FIN
+ // Now re-send things in a different order.
+ ReSend(TlsAgent::SERVER, std::vector<size_t>{0, 1, 3, 2, 4, 5});
+ // Clear.
+ server_filters_.drop_->Disable();
+ server_filters_.records_->Clear();
+ // Wait for client to send ACK.
+ ShiftDtlsTimers();
+ CheckedHandshakeSendReceive();
+ EXPECT_EQ(2UL, server_filters_.records_->count()); // ACK + Data
+ CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL});
+}
+
+TEST_P(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) {
+ SetupForZeroRtt();
+ SetFilters();
+ std::cerr << "Starting second handshake" << std::endl;
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ // Send the client's first flight of zero RTT data.
+ ZeroRttSendReceive(true, true);
+ // Now send another client application data record but
+ // capture it.
+ client_filters_.records_->Clear();
+ client_filters_.drop_->Reset(0xff);
+ const char* k0RttData = "123456";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+ PRInt32 rv =
+ PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write.
+ EXPECT_EQ(k0RttDataLen, rv);
+ EXPECT_EQ(1UL, client_filters_.records_->count()); // data
+ server_->Handshake();
+ client_->Handshake();
+ ExpectEarlyDataAccepted(true);
+ // The server still hasn't received anything at this point.
+ EXPECT_EQ(3UL, client_filters_.records_->count()); // data, EOED, FIN
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
+ // Now re-send the client's messages: EOED, data, FIN
+ ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 0, 2}));
+ server_->Handshake();
+ CheckConnected();
+ EXPECT_EQ(0U, client_filters_.ack_->count());
+ // Acknowledgements for EOED and Finished.
+ CheckAcks(server_filters_.ack_, 0,
+ {0x0001000000000002ULL, 0x0002000000000000ULL});
+ uint8_t buf[8];
+ rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_EQ(-1, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+}
+
+TEST_P(TlsReorderDatagram13, DataAfterFinDuringZeroRtt) {
+ SetupForZeroRtt();
+ SetFilters();
+ std::cerr << "Starting second handshake" << std::endl;
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ // Send the client's first flight of zero RTT data.
+ ZeroRttSendReceive(true, true);
+ // Now send another client application data record but
+ // capture it.
+ client_filters_.records_->Clear();
+ client_filters_.drop_->Reset(0xff);
+ const char* k0RttData = "123456";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+ PRInt32 rv =
+ PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write.
+ EXPECT_EQ(k0RttDataLen, rv);
+ EXPECT_EQ(1UL, client_filters_.records_->count()); // data
+ server_->Handshake();
+ client_->Handshake();
+ ExpectEarlyDataAccepted(true);
+ // The server still hasn't received anything at this point.
+ EXPECT_EQ(3UL, client_filters_.records_->count()); // EOED, FIN, Data
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
+ // Now re-send the client's messages: EOED, FIN, Data
+ ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 2, 0}));
+ server_->Handshake();
+ CheckConnected();
+ EXPECT_EQ(0U, client_filters_.ack_->count());
+ // Acknowledgements for EOED and Finished.
+ CheckAcks(server_filters_.ack_, 0,
+ {0x0001000000000002ULL, 0x0002000000000000ULL});
+ uint8_t buf[8];
+ rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_EQ(-1, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+}
+
+static void GetCipherAndLimit(uint16_t version, uint16_t* cipher,
+ uint64_t* limit = nullptr) {
+ uint64_t l;
+ if (!limit) limit = &l;
+
+ if (version < SSL_LIBRARY_VERSION_TLS_1_2) {
+ *cipher = TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA;
+ *limit = 0x5aULL << 28;
+ } else if (version == SSL_LIBRARY_VERSION_TLS_1_2) {
+ *cipher = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256;
+ *limit = (1ULL << 48) - 1;
+ } else {
+ // This test probably isn't especially useful for TLS 1.3, which has a much
+ // shorter sequence number encoding. That space can probably be searched in
+ // a reasonable amount of time.
+ *cipher = TLS_CHACHA20_POLY1305_SHA256;
+ // Assume that we are starting with an expected sequence number of 0.
+ *limit = (1ULL << 15) - 1;
+ }
+}
+
+// This simulates a huge number of drops on one side.
+// See Bug 12965514 where a large gap was handled very inefficiently.
+TEST_P(TlsConnectDatagram, MissLotsOfPackets) {
+ uint16_t cipher;
+ uint64_t limit;
+
+ GetCipherAndLimit(version_, &cipher, &limit);
+
+ EnsureTlsSetup();
+ server_->EnableSingleCipher(cipher);
+ Connect();
+
+ // Note that the limit for ChaCha is 2^48-1.
+ EXPECT_EQ(SECSuccess,
+ SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), limit - 10));
+ SendReceive();
+}
+
+// Send a sequence number of 0xfffd and it should be interpreted as that
+// (and not -3 or UINT64_MAX - 2).
+TEST_F(TlsConnectDatagram13, UnderflowSequenceNumber) {
+ Connect();
+ // This is only valid if short headers are disabled.
+ client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_FALSE);
+ EXPECT_EQ(SECSuccess,
+ SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), (1ULL << 16) - 3));
+ SendReceive();
+}
+
+class TlsConnectDatagram12Plus : public TlsConnectDatagram {
+ public:
+ TlsConnectDatagram12Plus() : TlsConnectDatagram() {}
+};
+
+// This simulates missing a window's worth of packets.
+TEST_P(TlsConnectDatagram12Plus, MissAWindow) {
+ EnsureTlsSetup();
+ uint16_t cipher;
+ GetCipherAndLimit(version_, &cipher);
+ server_->EnableSingleCipher(cipher);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqByAWindow(client_->ssl_fd(), 0));
+ SendReceive();
+}
+
+TEST_P(TlsConnectDatagram12Plus, MissAWindowAndOne) {
+ EnsureTlsSetup();
+ uint16_t cipher;
+ GetCipherAndLimit(version_, &cipher);
+ server_->EnableSingleCipher(cipher);
+ Connect();
+
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqByAWindow(client_->ssl_fd(), 1));
+ SendReceive();
+}
+
+// This filter replaces the first record it sees with junk application data.
+class TlsReplaceFirstRecordWithJunk : public TlsRecordFilter {
+ public:
+ TlsReplaceFirstRecordWithJunk(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a), replaced_(false) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& record, size_t* offset,
+ DataBuffer* output) override {
+ if (replaced_) {
+ return KEEP;
+ }
+ replaced_ = true;
+
+ uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
+ kCtDtlsCiphertextLengthPresent;
+ TlsRecordHeader out_header(
+ header.variant(), header.version(),
+ is_dtls13() ? dtls13_ct : ssl_ct_application_data,
+ header.sequence_number());
+
+ static const uint8_t junk[] = {1, 2, 3, 4};
+ *offset = out_header.Write(output, *offset, DataBuffer(junk, sizeof(junk)));
+ return CHANGE;
+ }
+
+ private:
+ bool replaced_;
+};
+
+// DTLS needs to discard application_data that it receives prior to handshake
+// completion, not generate an error.
+TEST_P(TlsConnectDatagram, ReplaceFirstServerRecordWithApplicationData) {
+ MakeTlsFilter<TlsReplaceFirstRecordWithJunk>(server_);
+ Connect();
+}
+
+TEST_P(TlsConnectDatagram, ReplaceFirstClientRecordWithApplicationData) {
+ MakeTlsFilter<TlsReplaceFirstRecordWithJunk>(client_);
+ Connect();
+}
+
+INSTANTIATE_TEST_SUITE_P(Datagram12Plus, TlsConnectDatagram12Plus,
+ TlsConnectTestBase::kTlsV12Plus);
+INSTANTIATE_TEST_SUITE_P(DatagramPre13, TlsConnectDatagramPre13,
+ TlsConnectTestBase::kTlsV11V12);
+INSTANTIATE_TEST_SUITE_P(DatagramDrop13, TlsDropDatagram13,
+ ::testing::Values(true, false));
+INSTANTIATE_TEST_SUITE_P(DatagramReorder13, TlsReorderDatagram13,
+ ::testing::Values(true, false));
+INSTANTIATE_TEST_SUITE_P(DatagramFragment13, TlsFragmentationAndRecoveryTest,
+ ::testing::Values(true, false));
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc
new file mode 100644
index 0000000000..89bd0c679a
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc
@@ -0,0 +1,728 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectGenericPre13, ConnectEcdh) {
+ SetExpectedVersion(std::get<1>(GetParam()));
+ Reset(TlsAgent::kServerEcdhEcdsa);
+ DisableAllCiphers();
+ EnableSomeEcdhCiphers();
+
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_ecdh_ecdsa,
+ ssl_sig_none);
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectEcdhWithoutDisablingSuites) {
+ SetExpectedVersion(std::get<1>(GetParam()));
+ Reset(TlsAgent::kServerEcdhEcdsa);
+ EnableSomeEcdhCiphers();
+
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_ecdh_ecdsa,
+ ssl_sig_none);
+}
+
+TEST_P(TlsConnectGeneric, ConnectEcdhe) {
+ Connect();
+ CheckKeys();
+}
+
+// If we pick a 256-bit cipher suite and use a P-384 certificate, the server
+// should choose P-384 for key exchange too. Only valid for TLS == 1.2 because
+// we don't have 256-bit ciphers before then and 1.3 doesn't try to couple
+// DHE size to symmetric size.
+TEST_P(TlsConnectTls12, ConnectEcdheP384) {
+ Reset(TlsAgent::kServerEcdsa384);
+ ConnectWithCipherSuite(TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256);
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_ecdsa,
+ ssl_sig_ecdsa_secp256r1_sha256);
+}
+
+TEST_P(TlsConnectGeneric, ConnectEcdheP384Client) {
+ EnsureTlsSetup();
+ const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ffdhe_2048};
+ client_->ConfigNamedGroups(groups);
+ server_->ConfigNamedGroups(groups);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+}
+
+// This causes a HelloRetryRequest in TLS 1.3. Earlier versions don't care.
+TEST_P(TlsConnectGeneric, ConnectEcdheP384Server) {
+ EnsureTlsSetup();
+ auto hrr_capture = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeHelloRetryRequest);
+ const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ EXPECT_EQ(version_ == SSL_LIBRARY_VERSION_TLS_1_3,
+ hrr_capture->buffer().len() != 0);
+}
+
+// This enables only P-256 on the client and disables it on the server.
+// This test will fail when we add other groups that identify as ECDHE.
+TEST_P(TlsConnectGeneric, ConnectEcdheGroupMismatch) {
+ EnsureTlsSetup();
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ffdhe_2048};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_2048};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+ CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign);
+}
+
+TEST_P(TlsKeyExchangeTest, P384Priority) {
+ // P256, P384 and P521 are enabled. Both prefer P384.
+ const std::vector<SSLNamedGroup> groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1};
+ EnsureKeyShareSetup();
+ ConfigNamedGroups(groups);
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+
+ std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1};
+ CheckKEXDetails(groups, shares);
+}
+
+TEST_P(TlsKeyExchangeTest, DuplicateGroupConfig) {
+ const std::vector<SSLNamedGroup> groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_secp384r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp256r1, ssl_grp_ec_secp256r1};
+ EnsureKeyShareSetup();
+ ConfigNamedGroups(groups);
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+
+ std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1};
+ std::vector<SSLNamedGroup> expectedGroups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp256r1};
+ CheckKEXDetails(expectedGroups, shares);
+}
+
+TEST_P(TlsKeyExchangeTest, P384PriorityDHEnabled) {
+ // P256, P384, P521, and FFDHE2048 are enabled. Both prefer P384.
+ const std::vector<SSLNamedGroup> groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ffdhe_2048, ssl_grp_ec_secp256r1,
+ ssl_grp_ec_secp521r1};
+ EnsureKeyShareSetup();
+ ConfigNamedGroups(groups);
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1};
+ CheckKEXDetails(groups, shares);
+ } else {
+ std::vector<SSLNamedGroup> oldtlsgroups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1};
+ CheckKEXDetails(oldtlsgroups, std::vector<SSLNamedGroup>());
+ }
+}
+
+TEST_P(TlsConnectGenericPre13, P384PriorityOnServer) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ // The server prefers P384. It has to win.
+ const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1};
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+}
+
+TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) {
+ EnsureModelSockets();
+
+ /* Both prefer P384, set on the model socket. */
+ const std::vector<SSLNamedGroup> groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1,
+ ssl_grp_ffdhe_2048};
+ client_model_->ConfigNamedGroups(groups);
+ server_model_->ConfigNamedGroups(groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+}
+
+class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter {
+ public:
+ TlsKeyExchangeGroupCapture(const std::shared_ptr<TlsAgent> &a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeServerKeyExchange}),
+ group_(ssl_grp_none) {}
+
+ SSLNamedGroup group() const { return group_; }
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
+ const DataBuffer &input,
+ DataBuffer *output) {
+ uint32_t value = 0;
+ EXPECT_TRUE(input.Read(0, 1, &value));
+ EXPECT_EQ(3U, value) << "curve type has to be 3";
+
+ EXPECT_TRUE(input.Read(1, 2, &value));
+ group_ = static_cast<SSLNamedGroup>(value);
+
+ return KEEP;
+ }
+
+ private:
+ SSLNamedGroup group_;
+};
+
+// If we strip the client's supported groups extension, the server should assume
+// P-256 is supported by the client (<= 1.2 only).
+TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) {
+ EnsureTlsSetup();
+ MakeTlsFilter<TlsExtensionDropper>(client_, ssl_supported_groups_xtn);
+ auto group_capture = MakeTlsFilter<TlsKeyExchangeGroupCapture>(server_);
+
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+
+ EXPECT_EQ(ssl_grp_ec_secp256r1, group_capture->group());
+}
+
+// Supported groups is mandatory in TLS 1.3.
+TEST_P(TlsConnectTls13, DropSupportedGroupExtension) {
+ EnsureTlsSetup();
+ MakeTlsFilter<TlsExtensionDropper>(client_, ssl_supported_groups_xtn);
+ ConnectExpectAlert(server_, kTlsAlertMissingExtension);
+ client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION);
+}
+
+// If we only have a lame group, we fall back to static RSA.
+TEST_P(TlsConnectGenericPre13, UseLameGroup) {
+ const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp192r1};
+ client_->ConfigNamedGroups(groups);
+ server_->ConfigNamedGroups(groups);
+ Connect();
+ CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none);
+}
+
+// In TLS 1.3, we can't generate the ClientHello.
+TEST_P(TlsConnectTls13, UseLameGroup) {
+ const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_sect283k1};
+ client_->ConfigNamedGroups(groups);
+ server_->ConfigNamedGroups(groups);
+ client_->StartConnect();
+ client_->Handshake();
+ client_->CheckErrorCode(SSL_ERROR_NO_CIPHERS_SUPPORTED);
+}
+
+TEST_P(TlsConnectStreamPre13, ConfiguredGroupsRenegotiate) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ec_secp256r1};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ CheckConnected();
+
+ // The renegotiation has to use the same preferences as the original session.
+ server_->PrepareForRenegotiate();
+ client_->StartRenegotiate();
+ Handshake();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+}
+
+TEST_P(TlsKeyExchangeTest, Curve25519) {
+ Reset(TlsAgent::kServerEcdsa256);
+ const std::vector<SSLNamedGroup> groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1};
+ EnsureKeyShareSetup();
+ ConfigNamedGroups(groups);
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_ecdsa,
+ ssl_sig_ecdsa_secp256r1_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
+ CheckKEXDetails(groups, shares);
+}
+
+TEST_P(TlsConnectGenericPre13, GroupPreferenceServerPriority) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ // The client prefers P256 while the server prefers 25519.
+ // The server's preference has to win.
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ec_curve25519};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_curve25519,
+ ssl_grp_ec_secp256r1};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+}
+
+#ifndef NSS_DISABLE_TLS_1_3
+TEST_P(TlsKeyExchangeTest13, Curve25519P256EqualPriorityClient13) {
+ EnsureKeyShareSetup();
+
+ // The client sends a P256 key share while the server prefers 25519.
+ // We have to accept P256 without retry.
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ec_curve25519};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_curve25519,
+ ssl_grp_ec_secp256r1};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp256r1};
+ CheckKEXDetails(client_groups, shares);
+}
+
+TEST_P(TlsKeyExchangeTest13, Curve25519P256EqualPriorityServer13) {
+ EnsureKeyShareSetup();
+
+ // The client sends a 25519 key share while the server prefers P256.
+ // We have to accept 25519 without retry.
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519,
+ ssl_grp_ec_secp256r1};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
+ CheckKEXDetails(client_groups, shares);
+}
+
+TEST_P(TlsKeyExchangeTest13, EqualPriorityTestRetryECServer13) {
+ EnsureKeyShareSetup();
+
+ // The client sends a 25519 key share while the server prefers P256.
+ // The server prefers P-384 over x25519, so it must not consider P-256 and
+ // x25519 to be equivalent. It will therefore request a P-256 share
+ // with a HelloRetryRequest.
+ const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1};
+ const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
+ CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1);
+}
+
+TEST_P(TlsKeyExchangeTest13, NotEqualPriorityWithIntermediateGroup13) {
+ EnsureKeyShareSetup();
+
+ // The client sends a 25519 key share while the server prefers P256.
+ // The server prefers ffdhe_2048 over x25519, so it must not consider the
+ // P-256 and x25519 to be equivalent. It will therefore request a P-256 share
+ // with a HelloRetryRequest.
+ const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048};
+ const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048, ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
+ CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1);
+}
+
+TEST_P(TlsKeyExchangeTest13,
+ NotEqualPriorityWithUnsupportedFFIntermediateGroup13) {
+ EnsureKeyShareSetup();
+
+ // As in the previous test, the server prefers ffdhe_2048. Thus, even though
+ // the client doesn't support this group, the server must not regard x25519 as
+ // equivalent to P-256.
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519,
+ ssl_grp_ec_secp256r1};
+ const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048, ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
+ CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1);
+}
+
+TEST_P(TlsKeyExchangeTest13,
+ NotEqualPriorityWithUnsupportedECIntermediateGroup13) {
+ EnsureKeyShareSetup();
+
+ // As in the previous test, the server prefers P-384. Thus, even though
+ // the client doesn't support this group, the server must not regard x25519 as
+ // equivalent to P-256. The server sends a HelloRetryRequest.
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519,
+ ssl_grp_ec_secp256r1};
+ const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
+ CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1);
+}
+
+TEST_P(TlsKeyExchangeTest13, EqualPriority13) {
+ EnsureKeyShareSetup();
+
+ // The client sends a 25519 key share while the server prefers P256.
+ // We have to accept 25519 without retry because it's considered equivalent to
+ // P256 by the server.
+ const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ffdhe_2048, ssl_grp_ec_secp256r1};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys();
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
+ CheckKEXDetails(client_groups, shares);
+}
+#endif
+
+TEST_P(TlsConnectGeneric, P256ClientAndCurve25519Server) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ // The client sends a P256 key share while the server prefers 25519.
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_curve25519};
+
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+TEST_P(TlsKeyExchangeTest13, MultipleClientShares) {
+ EnsureKeyShareSetup();
+
+ // The client sends 25519 and P256 key shares. The server prefers P256,
+ // which must be chosen here.
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519,
+ ssl_grp_ec_secp256r1};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ // Generate a key share on the client for both curves.
+ EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+
+ Connect();
+
+ // The server would accept 25519 but its preferred group (P256) has to win.
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519,
+ ssl_grp_ec_secp256r1};
+ CheckKEXDetails(client_groups, shares);
+}
+
+// Replace the point in the client key exchange message with an empty one
+class ECCClientKEXFilter : public TlsHandshakeFilter {
+ public:
+ ECCClientKEXFilter(const std::shared_ptr<TlsAgent> &client)
+ : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
+ const DataBuffer &input,
+ DataBuffer *output) {
+ // Replace the client key exchange message with an empty point
+ output->Allocate(1);
+ output->Write(0, 0U, 1); // set point length 0
+ return CHANGE;
+ }
+};
+
+// Replace the point in the server key exchange message with an empty one
+class ECCServerKEXFilter : public TlsHandshakeFilter {
+ public:
+ ECCServerKEXFilter(const std::shared_ptr<TlsAgent> &server)
+ : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
+ const DataBuffer &input,
+ DataBuffer *output) {
+ // Replace the server key exchange message with an empty point
+ output->Allocate(4);
+ output->Write(0, 3U, 1); // named curve
+ uint32_t curve = 0;
+ EXPECT_TRUE(input.Read(1, 2, &curve)); // get curve id
+ output->Write(1, curve, 2); // write curve id
+ output->Write(3, 0U, 1); // point length 0
+ return CHANGE;
+ }
+};
+
+TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyServerPoint) {
+ MakeTlsFilter<ECCServerKEXFilter>(server_);
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_KEY_EXCH);
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyClientPoint) {
+ MakeTlsFilter<ECCClientKEXFilter>(client_);
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_KEY_EXCH);
+}
+
+// Damage ECParams/ECPoint of a SKE.
+class ECCServerKEXDamager : public TlsHandshakeFilter {
+ public:
+ ECCServerKEXDamager(const std::shared_ptr<TlsAgent> &server, ECType ec_type,
+ SSLNamedGroup named_curve)
+ : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}),
+ ec_type_(ec_type),
+ named_curve_(named_curve) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
+ const DataBuffer &input,
+ DataBuffer *output) {
+ size_t offset = 0;
+ output->Allocate(5);
+ offset = output->Write(offset, ec_type_, 1);
+ offset = output->Write(offset, named_curve_, 2);
+ // Write a point with fmt != EC_POINT_FORM_UNCOMPRESSED.
+ offset = output->Write(offset, 1U, 1);
+ (void)output->Write(offset, 0x02, 1); // EC_POINT_FORM_COMPRESSED_Y0
+ return CHANGE;
+ }
+
+ private:
+ ECType ec_type_;
+ SSLNamedGroup named_curve_;
+};
+
+TEST_P(TlsConnectGenericPre13, ConnectUnsupportedCurveType) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ MakeTlsFilter<ECCServerKEXDamager>(server_, ec_type_explicitPrime,
+ ssl_grp_none);
+ ConnectExpectAlert(client_, kTlsAlertHandshakeFailure);
+ client_->CheckErrorCode(SEC_ERROR_UNSUPPORTED_ELLIPTIC_CURVE);
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectUnsupportedCurve) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ MakeTlsFilter<ECCServerKEXDamager>(server_, ec_type_named,
+ ssl_grp_ffdhe_2048);
+ ConnectExpectAlert(client_, kTlsAlertHandshakeFailure);
+ client_->CheckErrorCode(SEC_ERROR_UNSUPPORTED_ELLIPTIC_CURVE);
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectUnsupportedPointFormat) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ MakeTlsFilter<ECCServerKEXDamager>(server_, ec_type_named,
+ ssl_grp_ec_secp256r1);
+ ConnectExpectAlert(client_, kTlsAlertHandshakeFailure);
+ client_->CheckErrorCode(SEC_ERROR_UNSUPPORTED_EC_POINT_FORM);
+}
+
+TEST_P(TlsConnectTls12, ConnectUnsupportedSigAlg) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ MakeTlsFilter<ECCServerKEXSigAlgReplacer>(server_, ssl_sig_none);
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+}
+
+TEST_P(TlsConnectTls12, ConnectIncorrectSigAlg) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ MakeTlsFilter<ECCServerKEXSigAlgReplacer>(server_,
+ ssl_sig_ecdsa_secp256r1_sha256);
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM);
+}
+
+static void CheckSkeSigScheme(
+ std::shared_ptr<TlsHandshakeRecorder> &capture_ske,
+ uint16_t expected_scheme) {
+ TlsParser parser(capture_ske->buffer());
+ uint32_t tmp = 0;
+ EXPECT_TRUE(parser.Read(&tmp, 1)) << " read curve_type";
+ EXPECT_EQ(3U, tmp) << "curve type has to be 3";
+ EXPECT_TRUE(parser.Skip(2)) << " read namedcurve";
+ EXPECT_TRUE(parser.SkipVariable(1)) << " read public";
+
+ EXPECT_TRUE(parser.Read(&tmp, 2)) << " read sig_scheme";
+ EXPECT_EQ(expected_scheme, static_cast<uint16_t>(tmp));
+}
+
+TEST_P(TlsConnectTls12, ConnectSigAlgEnabledByPolicy) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ const std::vector<SSLSignatureScheme> schemes = {ssl_sig_rsa_pkcs1_sha1,
+ ssl_sig_rsa_pkcs1_sha384};
+
+ client_->SetSignatureSchemes(schemes.data(), schemes.size());
+ server_->SetSignatureSchemes(schemes.data(), schemes.size());
+ auto capture_ske = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+
+ // Enable SHA-1 by policy.
+ SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, NSS_USE_ALG_IN_SSL_KX, 0);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL,
+ 0);
+ ASSERT_EQ(SECSuccess, rv);
+
+ Handshake(); // Remainder of handshake
+ // The server should now report that it is connected
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+
+ CheckSkeSigScheme(capture_ske, ssl_sig_rsa_pkcs1_sha1);
+}
+
+TEST_P(TlsConnectTls12, ConnectSigAlgDisabledByPolicy) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ const std::vector<SSLSignatureScheme> schemes = {ssl_sig_rsa_pkcs1_sha1,
+ ssl_sig_rsa_pkcs1_sha384};
+
+ client_->SetSignatureSchemes(schemes.data(), schemes.size());
+ server_->SetSignatureSchemes(schemes.data(), schemes.size());
+ auto capture_ske = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+
+ // Disable SHA-1 by policy.
+ SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, 0, NSS_USE_ALG_IN_SSL_KX);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL,
+ 0);
+ ASSERT_EQ(SECSuccess, rv);
+
+ Handshake(); // Remainder of handshake
+ // The server should now report that it is connected
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+
+ CheckSkeSigScheme(capture_ske, ssl_sig_rsa_pkcs1_sha384);
+}
+
+INSTANTIATE_TEST_SUITE_P(KeyExchangeTest, TlsKeyExchangeTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV11Plus));
+
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_TEST_SUITE_P(KeyExchangeTest, TlsKeyExchangeTest13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV13));
+#endif
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_ems_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ems_unittest.cc
new file mode 100644
index 0000000000..39b2d58736
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_ems_unittest.cc
@@ -0,0 +1,96 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecret) {
+ EnableExtendedMasterSecret();
+ Connect();
+ Reset();
+ ExpectResumption(RESUME_SESSIONID);
+ EnableExtendedMasterSecret();
+ Connect();
+}
+
+TEST_P(TlsConnectTls12, ConnectExtendedMasterSecretSha384) {
+ EnableExtendedMasterSecret();
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384);
+ ConnectWithCipherSuite(TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384);
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretStaticRSA) {
+ EnableOnlyStaticRsaCiphers();
+ EnableExtendedMasterSecret();
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretECDHE) {
+ EnableExtendedMasterSecret();
+ Connect();
+
+ Reset();
+ EnableExtendedMasterSecret();
+ ExpectResumption(RESUME_SESSIONID);
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretTicket) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ EnableExtendedMasterSecret();
+ Connect();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+
+ EnableExtendedMasterSecret();
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretClientOnly) {
+ client_->EnableExtendedMasterSecret();
+ ExpectExtendedMasterSecret(false);
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretServerOnly) {
+ server_->EnableExtendedMasterSecret();
+ ExpectExtendedMasterSecret(false);
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretResumeWithout) {
+ EnableExtendedMasterSecret();
+ Connect();
+
+ Reset();
+ server_->EnableExtendedMasterSecret();
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectNormalResumeWithExtendedMasterSecret) {
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
+ ExpectExtendedMasterSecret(false);
+ Connect();
+
+ Reset();
+ EnableExtendedMasterSecret();
+ ExpectResumption(RESUME_NONE);
+ Connect();
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc
new file mode 100644
index 0000000000..26ed6bc0ed
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc
@@ -0,0 +1,188 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ssl.h"
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+static const char* kExporterLabel = "EXPORTER-duck";
+static const uint8_t kExporterContext[] = {0x12, 0x34, 0x56};
+
+static void ExportAndCompare(std::shared_ptr<TlsAgent>& client,
+ std::shared_ptr<TlsAgent>& server, bool context) {
+ static const size_t exporter_len = 10;
+ uint8_t client_value[exporter_len] = {0};
+ EXPECT_EQ(SECSuccess,
+ SSL_ExportKeyingMaterial(
+ client->ssl_fd(), kExporterLabel, strlen(kExporterLabel),
+ context ? PR_TRUE : PR_FALSE, kExporterContext,
+ sizeof(kExporterContext), client_value, sizeof(client_value)));
+ uint8_t server_value[exporter_len] = {0xff};
+ EXPECT_EQ(SECSuccess,
+ SSL_ExportKeyingMaterial(
+ server->ssl_fd(), kExporterLabel, strlen(kExporterLabel),
+ context ? PR_TRUE : PR_FALSE, kExporterContext,
+ sizeof(kExporterContext), server_value, sizeof(server_value)));
+ EXPECT_EQ(0, memcmp(client_value, server_value, sizeof(client_value)));
+}
+
+TEST_P(TlsConnectGeneric, ExporterBasic) {
+ EnsureTlsSetup();
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ } else {
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ }
+ Connect();
+ CheckKeys();
+ ExportAndCompare(client_, server_, false);
+}
+
+TEST_P(TlsConnectGeneric, ExporterContext) {
+ EnsureTlsSetup();
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ } else {
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ }
+ Connect();
+ CheckKeys();
+ ExportAndCompare(client_, server_, true);
+}
+
+// Bug 1312976 - SHA-384 doesn't work in 1.2 right now.
+TEST_P(TlsConnectTls13, ExporterSha384) {
+ EnsureTlsSetup();
+ client_->EnableSingleCipher(TLS_AES_256_GCM_SHA384);
+ Connect();
+ CheckKeys();
+ ExportAndCompare(client_, server_, false);
+}
+
+TEST_P(TlsConnectTls13, ExporterContextEmptyIsSameAsNone) {
+ EnsureTlsSetup();
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ } else {
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ }
+ Connect();
+ CheckKeys();
+ ExportAndCompare(client_, server_, false);
+}
+
+TEST_P(TlsConnectGenericPre13, ExporterContextLengthTooLong) {
+ static const uint8_t kExporterContextTooLong[PR_UINT16_MAX] = {
+ 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xFF};
+
+ EnsureTlsSetup();
+ Connect();
+ CheckKeys();
+
+ static const size_t exporter_len = 10;
+ uint8_t client_value[exporter_len] = {0};
+ EXPECT_EQ(SECFailure,
+ SSL_ExportKeyingMaterial(client_->ssl_fd(), kExporterLabel,
+ strlen(kExporterLabel), PR_TRUE,
+ kExporterContextTooLong,
+ sizeof(kExporterContextTooLong),
+ client_value, sizeof(client_value)));
+ EXPECT_EQ(PORT_GetError(), SEC_ERROR_INVALID_ARGS);
+ uint8_t server_value[exporter_len] = {0xff};
+ EXPECT_EQ(SECFailure,
+ SSL_ExportKeyingMaterial(server_->ssl_fd(), kExporterLabel,
+ strlen(kExporterLabel), PR_TRUE,
+ kExporterContextTooLong,
+ sizeof(kExporterContextTooLong),
+ server_value, sizeof(server_value)));
+ EXPECT_EQ(PORT_GetError(), SEC_ERROR_INVALID_ARGS);
+}
+
+// This has a weird signature so that it can be passed to the SNI callback.
+int32_t RegularExporterShouldFail(TlsAgent* agent, const SECItem* srvNameArr,
+ PRUint32 srvNameArrSize) {
+ uint8_t val[10];
+ EXPECT_EQ(SECFailure, SSL_ExportKeyingMaterial(
+ agent->ssl_fd(), kExporterLabel,
+ strlen(kExporterLabel), PR_TRUE, kExporterContext,
+ sizeof(kExporterContext), val, sizeof(val)))
+ << "regular exporter should fail";
+ return 0;
+}
+
+TEST_P(TlsConnectTls13, EarlyExporter) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+
+ client_->Handshake(); // Send ClientHello.
+ uint8_t client_value[10] = {0};
+ RegularExporterShouldFail(client_.get(), nullptr, 0);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_ExportEarlyKeyingMaterial(
+ client_->ssl_fd(), kExporterLabel, strlen(kExporterLabel),
+ kExporterContext, sizeof(kExporterContext), client_value,
+ sizeof(client_value)));
+
+ server_->SetSniCallback(RegularExporterShouldFail);
+ server_->Handshake(); // Handle ClientHello.
+ uint8_t server_value[10] = {0};
+ EXPECT_EQ(SECSuccess,
+ SSL_ExportEarlyKeyingMaterial(
+ server_->ssl_fd(), kExporterLabel, strlen(kExporterLabel),
+ kExporterContext, sizeof(kExporterContext), server_value,
+ sizeof(server_value)));
+ EXPECT_EQ(0, memcmp(client_value, server_value, sizeof(client_value)));
+
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, EarlyExporterExternalPsk) {
+ RolloverAntiReplay();
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_TRUE(!!slot);
+ ScopedPK11SymKey scoped_psk(
+ PK11_KeyGen(slot.get(), CKM_HKDF_KEY_GEN, nullptr, 16, nullptr));
+ AddPsk(scoped_psk, std::string("foo"), ssl_hash_sha256,
+ TLS_CHACHA20_POLY1305_SHA256);
+ StartConnect();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ client_->Handshake(); // Send ClientHello.
+ uint8_t client_value[10] = {0};
+ RegularExporterShouldFail(client_.get(), nullptr, 0);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_ExportEarlyKeyingMaterial(
+ client_->ssl_fd(), kExporterLabel, strlen(kExporterLabel),
+ kExporterContext, sizeof(kExporterContext), client_value,
+ sizeof(client_value)));
+
+ server_->SetSniCallback(RegularExporterShouldFail);
+ server_->Handshake(); // Handle ClientHello.
+ uint8_t server_value[10] = {0};
+ EXPECT_EQ(SECSuccess,
+ SSL_ExportEarlyKeyingMaterial(
+ server_->ssl_fd(), kExporterLabel, strlen(kExporterLabel),
+ kExporterContext, sizeof(kExporterContext), server_value,
+ sizeof(server_value)));
+ EXPECT_EQ(0, memcmp(client_value, server_value, sizeof(client_value)));
+
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
new file mode 100644
index 0000000000..eb45e71422
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
@@ -0,0 +1,1513 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ssl.h"
+#include "ssl3prot.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+// This is only to get DTLS_1_3_DRAFT_VERSION
+#include "ssl3prot.h"
+
+#include <memory>
+
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+class Dtls13LegacyCookieInjector : public TlsHandshakeFilter {
+ public:
+ Dtls13LegacyCookieInjector(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ const uint8_t cookie_bytes[] = {0x03, 0x0A, 0x0B, 0x0C};
+ uint32_t offset = 2 /* version */ + 32 /* random */;
+
+ if (agent()->variant() != ssl_variant_datagram) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+
+ if (header.handshake_type() != ssl_hs_client_hello) {
+ return KEEP;
+ }
+
+ DataBuffer cookie(cookie_bytes, sizeof(cookie_bytes));
+ *output = input;
+
+ // Add the SID length (if any) to locate the cookie.
+ uint32_t sid_len = 0;
+ if (!output->Read(offset, 1, &sid_len)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+ offset += 1 + sid_len;
+ output->Splice(cookie, offset, 1);
+
+ return CHANGE;
+ }
+
+ private:
+ DataBuffer cookie_;
+};
+
+class TlsExtensionTruncator : public TlsExtensionFilter {
+ public:
+ TlsExtensionTruncator(const std::shared_ptr<TlsAgent>& a, uint16_t extension,
+ size_t length)
+ : TlsExtensionFilter(a), extension_(extension), length_(length) {}
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (extension_type != extension_) {
+ return KEEP;
+ }
+ if (input.len() <= length_) {
+ return KEEP;
+ }
+
+ output->Assign(input.data(), length_);
+ return CHANGE;
+ }
+
+ private:
+ uint16_t extension_;
+ size_t length_;
+};
+
+class TlsExtensionTestBase : public TlsConnectTestBase {
+ protected:
+ TlsExtensionTestBase(SSLProtocolVariant variant, uint16_t version)
+ : TlsConnectTestBase(variant, version) {}
+
+ void ClientHelloErrorTest(std::shared_ptr<PacketFilter> filter,
+ uint8_t desc = kTlsAlertDecodeError) {
+ client_->SetFilter(filter);
+ ConnectExpectAlert(server_, desc);
+ }
+
+ void ServerHelloErrorTest(std::shared_ptr<PacketFilter> filter,
+ uint8_t desc = kTlsAlertDecodeError) {
+ server_->SetFilter(filter);
+ ConnectExpectAlert(client_, desc);
+ }
+
+ static void InitSimpleSni(DataBuffer* extension) {
+ const char* name = "host.name";
+ const size_t namelen = PL_strlen(name);
+ extension->Allocate(namelen + 5);
+ extension->Write(0, namelen + 3, 2);
+ extension->Write(2, static_cast<uint32_t>(0), 1); // 0 == hostname
+ extension->Write(3, namelen, 2);
+ extension->Write(5, reinterpret_cast<const uint8_t*>(name), namelen);
+ }
+
+ void HrrThenRemoveExtensionsTest(SSLExtensionType type, PRInt32 client_error,
+ PRInt32 server_error) {
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ static const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp384r1};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+ EnsureTlsSetup();
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Send HRR.
+ MakeTlsFilter<TlsExtensionDropper>(client_, type);
+ Handshake();
+ client_->CheckErrorCode(client_error);
+ server_->CheckErrorCode(server_error);
+ }
+};
+
+class TlsExtensionTestDtls : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<uint16_t> {
+ public:
+ TlsExtensionTestDtls()
+ : TlsExtensionTestBase(ssl_variant_datagram, GetParam()) {}
+};
+
+class TlsExtensionTest12Plus : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsExtensionTest12Plus()
+ : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+ }
+};
+
+class TlsExtensionTest12 : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsExtensionTest12()
+ : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+ }
+};
+
+class TlsExtensionTest13
+ : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
+ public:
+ TlsExtensionTest13()
+ : TlsExtensionTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
+
+ void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) {
+ DataBuffer versions_buf(buf, len);
+ MakeTlsFilter<TlsExtensionReplacer>(
+ client_, ssl_tls13_supported_versions_xtn, versions_buf);
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+ }
+
+ void ConnectWithReplacementVersionList(uint16_t version) {
+ // Convert the version encoding for DTLS, if needed.
+ if (variant_ == ssl_variant_datagram) {
+ switch (version) {
+ case SSL_LIBRARY_VERSION_TLS_1_3:
+#ifdef DTLS_1_3_DRAFT_VERSION
+ version = 0x7f00 | DTLS_1_3_DRAFT_VERSION;
+#else
+ version = SSL_LIBRARY_VERSION_DTLS_1_3_WIRE;
+#endif
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ version = SSL_LIBRARY_VERSION_DTLS_1_2_WIRE;
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ /* TLS_1_1 maps to DTLS_1_0, see sslproto.h. */
+ version = SSL_LIBRARY_VERSION_DTLS_1_0_WIRE;
+ break;
+ default:
+ PORT_Assert(0);
+ }
+ }
+
+ DataBuffer versions_buf;
+ size_t index = versions_buf.Write(0, 2, 1);
+ versions_buf.Write(index, version, 2);
+ MakeTlsFilter<TlsExtensionReplacer>(
+ client_, ssl_tls13_supported_versions_xtn, versions_buf);
+ ConnectExpectFail();
+ }
+};
+
+class TlsExtensionTest13Stream : public TlsExtensionTestBase {
+ public:
+ TlsExtensionTest13Stream()
+ : TlsExtensionTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {}
+};
+
+class TlsExtensionTestGeneric : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsExtensionTestGeneric()
+ : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+ }
+};
+
+class TlsExtensionTestPre13 : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsExtensionTestPre13()
+ : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+ }
+};
+
+TEST_P(TlsExtensionTestGeneric, DamageSniLength) {
+ ClientHelloErrorTest(
+ std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 1));
+}
+
+TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) {
+ ClientHelloErrorTest(
+ std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 4));
+}
+
+TEST_P(TlsExtensionTestGeneric, TruncateSni) {
+ ClientHelloErrorTest(
+ std::make_shared<TlsExtensionTruncator>(client_, ssl_server_name_xtn, 7));
+}
+
+// A valid extension that appears twice will be reported as unsupported.
+TEST_P(TlsExtensionTestGeneric, RepeatSni) {
+ DataBuffer extension;
+ InitSimpleSni(&extension);
+ ClientHelloErrorTest(std::make_shared<TlsExtensionInjector>(
+ client_, ssl_server_name_xtn, extension),
+ kTlsAlertIllegalParameter);
+}
+
+// An SNI entry with zero length is considered invalid (strangely, not if it is
+// the last entry, which is probably a bug).
+TEST_P(TlsExtensionTestGeneric, BadSni) {
+ DataBuffer simple;
+ InitSimpleSni(&simple);
+ DataBuffer extension;
+ extension.Allocate(simple.len() + 3);
+ extension.Write(0, static_cast<uint32_t>(0), 3);
+ extension.Write(3, simple);
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_server_name_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, EmptySni) {
+ DataBuffer extension;
+ extension.Allocate(2);
+ extension.Write(0, static_cast<uint32_t>(0), 2);
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_server_name_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) {
+ EnableAlpn();
+ DataBuffer extension;
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_app_layer_protocol_xtn, extension),
+ kTlsAlertIllegalParameter);
+}
+
+// An empty ALPN isn't considered bad, though it does lead to there being no
+// protocol for the server to select.
+TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_app_layer_protocol_xtn, extension),
+ kTlsAlertNoApplicationProtocol);
+}
+
+TEST_P(TlsExtensionTestGeneric, OneByteAlpn) {
+ EnableAlpn();
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_app_layer_protocol_xtn, 1));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) {
+ EnableAlpn();
+ // This will leave the length of the second entry, but no value.
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_app_layer_protocol_xtn, 5));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x03, 0x01, 0x61, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnLengthOverflow) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x03, 0x01, 0x61, 0x01};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnMismatch) {
+ const uint8_t client_alpn[] = {0x01, 0x61};
+ client_->EnableAlpn(client_alpn, sizeof(client_alpn));
+ const uint8_t server_alpn[] = {0x02, 0x61, 0x62};
+ server_->EnableAlpn(server_alpn, sizeof(server_alpn));
+
+ ClientHelloErrorTest(nullptr, kTlsAlertNoApplicationProtocol);
+ client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_NO_PROTOCOL);
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnDisabledServer) {
+ const uint8_t client_alpn[] = {0x01, 0x61};
+ client_->EnableAlpn(client_alpn, sizeof(client_alpn));
+ server_->EnableAlpn(nullptr, 0);
+
+ ClientHelloErrorTest(nullptr, kTlsAlertUnsupportedExtension);
+}
+
+TEST_P(TlsConnectGeneric, AlpnDisabled) {
+ server_->EnableAlpn(nullptr, 0);
+ Connect();
+
+ SSLNextProtoState state;
+ uint8_t buf[255] = {0};
+ unsigned int buf_len = 3;
+ EXPECT_EQ(SECSuccess, SSL_GetNextProto(client_->ssl_fd(), &state, buf,
+ &buf_len, sizeof(buf)));
+ EXPECT_EQ(SSL_NEXT_PROTO_NO_SUPPORT, state);
+ EXPECT_EQ(0U, buf_len);
+}
+
+// Many of these tests fail in TLS 1.3 because the extension is encrypted, which
+// prevents modification of the value from the ServerHello.
+TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ server_, ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x01, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ server_, ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ server_, ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ server_, ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ server_, ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x02, 0x99, 0x61};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ server_, ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x02, 0x01, 0x67};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ server_, ssl_app_layer_protocol_xtn, extension),
+ kTlsAlertIllegalParameter);
+}
+
+TEST_P(TlsExtensionTestDtls, SrtpShort) {
+ EnableSrtp();
+ ClientHelloErrorTest(
+ std::make_shared<TlsExtensionTruncator>(client_, ssl_use_srtp_xtn, 3));
+}
+
+TEST_P(TlsExtensionTestDtls, SrtpOdd) {
+ EnableSrtp();
+ const uint8_t val[] = {0x00, 0x01, 0xff, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_use_srtp_xtn, extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) {
+ const uint8_t val[] = {0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_signature_algorithms_xtn, extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) {
+ // make sure the test uses an algorithm that is legal for
+ // tls 1.3 (or tls 1.3 will throw a handshake failure alert
+ // instead of a decode error alert)
+ const uint8_t val[] = {0x00, 0x02, 0x08, 0x09, 0x00}; // sha-256, rsa-pss-pss
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_signature_algorithms_xtn, extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) {
+ const uint8_t val[] = {0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_signature_algorithms_xtn, extension),
+ kTlsAlertHandshakeFailure);
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsNoOverlap) {
+ const uint8_t val[] = {0x00, 0x02, 0xff, 0xff};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_signature_algorithms_xtn, extension),
+ kTlsAlertHandshakeFailure);
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) {
+ const uint8_t val[] = {0x00, 0x01, 0x04};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_signature_algorithms_xtn, extension));
+}
+
+TEST_F(TlsExtensionTest13Stream, SignatureAlgorithmsPrecedingGarbage) {
+ // 31 unknown signature algorithms followed by sha-256, rsa-pss
+ const uint8_t val[] = {
+ 0x00, 0x40, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x08, 0x04};
+ DataBuffer extension(val, sizeof(val));
+ MakeTlsFilter<TlsExtensionReplacer>(client_, ssl_signature_algorithms_xtn,
+ extension);
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+}
+
+TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) {
+ ClientHelloErrorTest(
+ std::make_shared<TlsExtensionDropper>(client_, ssl_supported_groups_xtn),
+ version_ < SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsAlertDecryptError
+ : kTlsAlertMissingExtension);
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) {
+ const uint8_t val[] = {0x00, 0x01, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_elliptic_curves_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) {
+ const uint8_t val[] = {0x09, 0x99, 0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_elliptic_curves_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) {
+ const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_elliptic_curves_xtn, extension));
+}
+
+TEST_P(TlsExtensionTest12, SupportedCurvesDisableX25519) {
+ // Disable session resumption.
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+
+ // Ensure that we can enable its use in the key exchange.
+ SECStatus rv =
+ NSS_SetAlgorithmPolicy(SEC_OID_CURVE25519, NSS_USE_ALG_IN_SSL_KX, 0);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL,
+ 0);
+ ASSERT_EQ(SECSuccess, rv);
+
+ auto capture1 =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_elliptic_curves_xtn);
+ Connect();
+
+ EXPECT_TRUE(capture1->captured());
+ const DataBuffer& ext1 = capture1->extension();
+
+ uint32_t count;
+ ASSERT_TRUE(ext1.Read(0, 2, &count));
+
+ // Whether or not we've seen x25519 offered in this handshake.
+ bool seen1_x25519 = false;
+ for (size_t offset = 2; offset <= count; offset++) {
+ uint32_t val;
+ ASSERT_TRUE(ext1.Read(offset, 2, &val));
+ if (val == ssl_grp_ec_curve25519) {
+ seen1_x25519 = true;
+ break;
+ }
+ }
+ ASSERT_TRUE(seen1_x25519);
+
+ // Ensure that we can disable its use in the key exchange.
+ rv = NSS_SetAlgorithmPolicy(SEC_OID_CURVE25519, 0, NSS_USE_ALG_IN_SSL_KX);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL,
+ 0);
+ ASSERT_EQ(SECSuccess, rv);
+
+ // Clean up after the last run.
+ Reset();
+ auto capture2 =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_elliptic_curves_xtn);
+ Connect();
+
+ EXPECT_TRUE(capture2->captured());
+ const DataBuffer& ext2 = capture2->extension();
+
+ ASSERT_TRUE(ext2.Read(0, 2, &count));
+
+ // Whether or not we've seen x25519 offered in this handshake.
+ bool seen2_x25519 = false;
+ for (size_t offset = 2; offset <= count; offset++) {
+ uint32_t val;
+ ASSERT_TRUE(ext2.Read(offset, 2, &val));
+
+ if (val == ssl_grp_ec_curve25519) {
+ seen2_x25519 = true;
+ break;
+ }
+ }
+
+ ASSERT_FALSE(seen2_x25519);
+}
+
+TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) {
+ const uint8_t val[] = {0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_ec_point_formats_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) {
+ const uint8_t val[] = {0x99, 0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_ec_point_formats_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) {
+ const uint8_t val[] = {0x01, 0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_ec_point_formats_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, SupportedPointsCompressed) {
+ const uint8_t val[] = {0x01, 0x02};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_ec_point_formats_xtn, extension),
+ kTlsAlertIllegalParameter);
+}
+
+TEST_P(TlsExtensionTestPre13, SupportedPointsUndefined) {
+ const uint8_t val[] = {0x01, 0xAA};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_ec_point_formats_xtn, extension),
+ kTlsAlertIllegalParameter);
+}
+
+TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) {
+ const uint8_t val[] = {0x99};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_renegotiation_info_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) {
+ const uint8_t val[] = {0x01, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_renegotiation_info_xtn, extension));
+}
+
+// The extension has to contain a length.
+TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) {
+ DataBuffer extension;
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_renegotiation_info_xtn, extension));
+}
+
+// This only works on TLS 1.2, since it relies on static RSA; otherwise libssl
+// picks the wrong cipher suite.
+TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) {
+ const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_rsae_sha512,
+ ssl_sig_rsa_pss_rsae_sha384};
+
+ auto capture =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn);
+ client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes));
+ EnableOnlyStaticRsaCiphers();
+ Connect();
+
+ const DataBuffer& ext = capture->extension();
+ EXPECT_EQ(2 + PR_ARRAY_SIZE(schemes) * 2, ext.len());
+ for (size_t i = 0, cursor = 2;
+ i < PR_ARRAY_SIZE(schemes) && cursor < ext.len(); ++i) {
+ uint32_t v = 0;
+ EXPECT_TRUE(ext.Read(cursor, 2, &v));
+ cursor += 2;
+ EXPECT_EQ(schemes[i], static_cast<SSLSignatureScheme>(v));
+ }
+}
+
+// This only works on TLS 1.2, since it relies on DSA.
+TEST_P(TlsExtensionTest12, SignatureAlgorithmDisableDSA) {
+ const std::vector<SSLSignatureScheme> schemes = {
+ ssl_sig_dsa_sha1, ssl_sig_dsa_sha256, ssl_sig_dsa_sha384,
+ ssl_sig_dsa_sha512, ssl_sig_rsa_pss_rsae_sha256};
+
+ // Connect with DSA enabled by policy.
+ SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_ANSIX9_DSA_SIGNATURE,
+ NSS_USE_ALG_IN_SSL_KX, 0);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL,
+ 0);
+ ASSERT_EQ(SECSuccess, rv);
+
+ Reset(TlsAgent::kServerDsa);
+ auto capture1 =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn);
+ client_->SetSignatureSchemes(schemes.data(), schemes.size());
+ Connect();
+
+ // Check if all the signature algorithms are advertised.
+ EXPECT_TRUE(capture1->captured());
+ const DataBuffer& ext1 = capture1->extension();
+ EXPECT_EQ(2U + 2U * schemes.size(), ext1.len());
+
+ // Connect with DSA disabled by policy.
+ rv = NSS_SetAlgorithmPolicy(SEC_OID_ANSIX9_DSA_SIGNATURE, 0,
+ NSS_USE_ALG_IN_SSL_KX);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL,
+ 0);
+ ASSERT_EQ(SECSuccess, rv);
+
+ Reset(TlsAgent::kServerDsa);
+ auto capture2 =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn);
+ client_->SetSignatureSchemes(schemes.data(), schemes.size());
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+
+ // Check if no DSA algorithms are advertised.
+ EXPECT_TRUE(capture2->captured());
+ const DataBuffer& ext2 = capture2->extension();
+ EXPECT_EQ(2U + 2U, ext2.len());
+ uint32_t v = 0;
+ EXPECT_TRUE(ext2.Read(2, 2, &v));
+ EXPECT_EQ(ssl_sig_rsa_pss_rsae_sha256, v);
+}
+
+// Temporary test to verify that we choke on an empty ClientKeyShare.
+// This test will fail when we implement HelloRetryRequest.
+TEST_P(TlsExtensionTest13, EmptyClientKeyShare) {
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_tls13_key_share_xtn, 2),
+ kTlsAlertHandshakeFailure);
+}
+
+// These tests only work in stream mode because the client sends a
+// cleartext alert which causes a MAC error on the server. With
+// stream this causes handshake failure but with datagram, the
+// packet gets dropped.
+TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) {
+ EnsureTlsSetup();
+ MakeTlsFilter<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn);
+ client_->ExpectSendAlert(kTlsAlertMissingExtension);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_MISSING_KEY_SHARE, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code());
+}
+
+TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) {
+ const uint16_t wrong_group = ssl_grp_ec_secp384r1;
+
+ static const uint8_t key_share[] = {
+ wrong_group >> 8,
+ wrong_group & 0xff, // Group we didn't offer.
+ 0x00,
+ 0x02, // length = 2
+ 0x01,
+ 0x02};
+ DataBuffer buf(key_share, sizeof(key_share));
+ EnsureTlsSetup();
+ MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_key_share_xtn, buf);
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_KEY_SHARE, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code());
+}
+
+TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) {
+ const uint16_t wrong_group = 0xffff;
+
+ static const uint8_t key_share[] = {
+ wrong_group >> 8,
+ wrong_group & 0xff, // Group we didn't offer.
+ 0x00,
+ 0x02, // length = 2
+ 0x01,
+ 0x02};
+ DataBuffer buf(key_share, sizeof(key_share));
+ EnsureTlsSetup();
+ MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_key_share_xtn, buf);
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_KEY_SHARE, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code());
+}
+
+TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) {
+ SetupForResume();
+ DataBuffer empty;
+ MakeTlsFilter<TlsExtensionInjector>(server_, ssl_signature_algorithms_xtn,
+ empty);
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_EXTENSION_DISALLOWED_FOR_VERSION, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code());
+}
+
+struct PskIdentity {
+ DataBuffer identity;
+ uint32_t obfuscated_ticket_age;
+};
+
+class TlsPreSharedKeyReplacer;
+
+typedef std::function<void(TlsPreSharedKeyReplacer*)>
+ TlsPreSharedKeyReplacerFunc;
+
+class TlsPreSharedKeyReplacer : public TlsExtensionFilter {
+ public:
+ TlsPreSharedKeyReplacer(const std::shared_ptr<TlsAgent>& a,
+ TlsPreSharedKeyReplacerFunc function)
+ : TlsExtensionFilter(a), identities_(), binders_(), function_(function) {}
+
+ static size_t CopyAndMaybeReplace(TlsParser* parser, size_t size,
+ const std::unique_ptr<DataBuffer>& replace,
+ size_t index, DataBuffer* output) {
+ DataBuffer tmp;
+ bool ret = parser->ReadVariable(&tmp, size);
+ EXPECT_EQ(true, ret);
+ if (!ret) return 0;
+ if (replace) {
+ tmp = *replace;
+ }
+
+ return WriteVariable(output, index, tmp, size);
+ }
+
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (extension_type != ssl_tls13_pre_shared_key_xtn) {
+ return KEEP;
+ }
+
+ if (!Decode(input)) {
+ return KEEP;
+ }
+
+ // Call the function.
+ function_(this);
+
+ Encode(output);
+
+ return CHANGE;
+ }
+
+ std::vector<PskIdentity> identities_;
+ std::vector<DataBuffer> binders_;
+
+ private:
+ bool Decode(const DataBuffer& input) {
+ std::unique_ptr<TlsParser> parser(new TlsParser(input));
+ DataBuffer identities;
+
+ if (!parser->ReadVariable(&identities, 2)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ DataBuffer binders;
+ if (!parser->ReadVariable(&binders, 2)) {
+ ADD_FAILURE();
+ return false;
+ }
+ EXPECT_EQ(0UL, parser->remaining());
+
+ // Now parse the inner sections.
+ parser.reset(new TlsParser(identities));
+ while (parser->remaining()) {
+ PskIdentity identity;
+
+ if (!parser->ReadVariable(&identity.identity, 2)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ if (!parser->Read(&identity.obfuscated_ticket_age, 4)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ identities_.push_back(identity);
+ }
+
+ parser.reset(new TlsParser(binders));
+ while (parser->remaining()) {
+ DataBuffer binder;
+
+ if (!parser->ReadVariable(&binder, 1)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ binders_.push_back(binder);
+ }
+
+ return true;
+ }
+
+ void Encode(DataBuffer* output) {
+ DataBuffer identities;
+ size_t index = 0;
+ for (auto id : identities_) {
+ index = WriteVariable(&identities, index, id.identity, 2);
+ index = identities.Write(index, id.obfuscated_ticket_age, 4);
+ }
+
+ DataBuffer binders;
+ index = 0;
+ for (auto binder : binders_) {
+ index = WriteVariable(&binders, index, binder, 1);
+ }
+
+ output->Truncate(0);
+ index = 0;
+ index = WriteVariable(output, index, identities, 2);
+ index = WriteVariable(output, index, binders, 2);
+ }
+
+ TlsPreSharedKeyReplacerFunc function_;
+};
+
+TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) {
+ SetupForResume();
+
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->identities_[0].identity.Truncate(0);
+ });
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+// Flip the first byte of the binder.
+TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) {
+ SetupForResume();
+
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1);
+ });
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+}
+
+// Do the same with an External PSK.
+TEST_P(TlsConnectTls13, TestTls13PskInvalidBinderValue) {
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_TRUE(!!slot);
+ ScopedPK11SymKey key(
+ PK11_KeyGen(slot.get(), CKM_HKDF_KEY_GEN, nullptr, 16, nullptr));
+ ASSERT_TRUE(!!key);
+ AddPsk(key, std::string("foo"), ssl_hash_sha256);
+ StartConnect();
+ ASSERT_TRUE(client_->MaybeSetResumptionToken());
+
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1);
+ });
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+}
+
+// Extend the binder by one.
+TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) {
+ SetupForResume();
+
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->binders_[0].Write(r->binders_[0].len(), 0xff, 1);
+ });
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+// Binders must be at least 32 bytes.
+TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) {
+ SetupForResume();
+
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); });
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+// Duplicate the identity and binder. This will fail with an error
+// processing the binder (because we extended the identity list.)
+TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) {
+ SetupForResume();
+
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->identities_.push_back(r->identities_[0]);
+ r->binders_.push_back(r->binders_[0]);
+ });
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+}
+
+// The next two tests have mismatches in the number of identities
+// and binders. This generates an illegal parameter alert.
+TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) {
+ SetupForResume();
+
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->identities_.push_back(r->identities_[0]);
+ });
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) {
+ SetupForResume();
+
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->binders_.push_back(r->binders_[0]);
+ });
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) {
+ SetupForResume();
+
+ const uint8_t empty_buf[] = {0};
+ DataBuffer empty(empty_buf, 0);
+ // Inject an unused extension after the PSK extension.
+ MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello, 0xffff,
+ empty);
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) {
+ SetupForResume();
+
+ DataBuffer empty;
+ MakeTlsFilter<TlsExtensionDropper>(client_,
+ ssl_tls13_psk_key_exchange_modes_xtn);
+ ConnectExpectAlert(server_, kTlsAlertMissingExtension);
+ client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES);
+}
+
+// The following test contains valid but unacceptable PreSharedKey
+// modes and therefore produces non-resumption followed by MAC
+// errors.
+TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) {
+ SetupForResume();
+ const static uint8_t ke_modes[] = {1, // Length
+ kTls13PskKe};
+
+ DataBuffer modes(ke_modes, sizeof(ke_modes));
+ MakeTlsFilter<TlsExtensionReplacer>(
+ client_, ssl_tls13_psk_key_exchange_modes_xtn, modes);
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+}
+
+TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) {
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_psk_key_exchange_modes_xtn);
+ Connect();
+ EXPECT_FALSE(capture->captured());
+}
+
+// In these tests, we downgrade to TLS 1.2, causing the
+// server to negotiate TLS 1.2.
+// 1. Both sides only support TLS 1.3, so we get a cipher version
+// error.
+TEST_P(TlsExtensionTest13, RemoveTls13FromVersionList) {
+ ExpectAlert(server_, kTlsAlertProtocolVersion);
+ ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+}
+
+// 2. Server supports 1.2 and 1.3, client supports 1.2, so we
+// can't negotiate any ciphers.
+TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListServerV12) {
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectAlert(server_, kTlsAlertHandshakeFailure);
+ ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+// 3. Server supports 1.2 and 1.3, client supports 1.2 and 1.3
+// but advertises 1.2 (because we changed things).
+TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListBothV12) {
+ client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_TRUE);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+// The downgrade check is disabled in DTLS 1.3, so all that happens when we
+// tamper with the supported versions is that the Finished check fails.
+#ifdef DTLS_1_3_DRAFT_VERSION
+ if (variant_ == ssl_variant_datagram) {
+ ExpectAlert(server_, kTlsAlertDecryptError);
+ } else
+#endif
+ {
+ ExpectAlert(client_, kTlsAlertIllegalParameter);
+ }
+ ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2);
+#ifdef DTLS_1_3_DRAFT_VERSION
+ if (variant_ == ssl_variant_datagram) {
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+ } else
+#endif
+ {
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ }
+}
+
+TEST_P(TlsExtensionTest13, HrrThenRemoveSignatureAlgorithms) {
+ ExpectAlert(server_, kTlsAlertMissingExtension);
+ HrrThenRemoveExtensionsTest(ssl_signature_algorithms_xtn,
+ SSL_ERROR_MISSING_EXTENSION_ALERT,
+ SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION);
+}
+
+TEST_P(TlsExtensionTest13, HrrThenRemoveKeyShare) {
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ HrrThenRemoveExtensionsTest(ssl_tls13_key_share_xtn,
+ SSL_ERROR_ILLEGAL_PARAMETER_ALERT,
+ SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+}
+
+TEST_P(TlsExtensionTest13, HrrThenRemoveSupportedGroups) {
+ ExpectAlert(server_, kTlsAlertMissingExtension);
+ HrrThenRemoveExtensionsTest(ssl_supported_groups_xtn,
+ SSL_ERROR_MISSING_EXTENSION_ALERT,
+ SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION);
+}
+
+TEST_P(TlsExtensionTest13, EmptyVersionList) {
+ static const uint8_t kExt[] = {0x00, 0x00};
+ ConnectWithBogusVersionList(kExt, sizeof(kExt));
+}
+
+TEST_P(TlsExtensionTest13, OddVersionList) {
+ static const uint8_t kExt[] = {0x00, 0x01, 0x00};
+ ConnectWithBogusVersionList(kExt, sizeof(kExt));
+}
+
+TEST_P(TlsExtensionTest13, SignatureAlgorithmsInvalidTls13) {
+ // testing the case where we ask for a invalid parameter for tls13
+ const uint8_t val[] = {0x00, 0x02, 0x04, 0x01}; // sha-256, rsa-pkcs1
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_signature_algorithms_xtn, extension),
+ kTlsAlertHandshakeFailure);
+}
+
+// Use the stream version number for TLS 1.3 (0x0304) in DTLS.
+TEST_F(TlsConnectDatagram13, TlsVersionInDtls) {
+ static const uint8_t kExt[] = {0x02, 0x03, 0x04};
+
+ DataBuffer versions_buf(kExt, sizeof(kExt));
+ MakeTlsFilter<TlsExtensionReplacer>(client_, ssl_tls13_supported_versions_xtn,
+ versions_buf);
+ ConnectExpectAlert(server_, kTlsAlertProtocolVersion);
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+}
+
+// TODO: this only tests extensions in server messages. The client can extend
+// Certificate messages, which is not checked here.
+class TlsBogusExtensionTest : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsBogusExtensionTest()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+ protected:
+ virtual void ConnectAndFail(uint8_t message) = 0;
+
+ void AddFilter(uint8_t message, uint16_t extension) {
+ static uint8_t empty_buf[1] = {0};
+ DataBuffer empty(empty_buf, 0);
+ auto filter =
+ MakeTlsFilter<TlsExtensionAppender>(server_, message, extension, empty);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ filter->EnableDecryption();
+ }
+ }
+
+ void Run(uint8_t message, uint16_t extension = 0xff) {
+ EnsureTlsSetup();
+ AddFilter(message, extension);
+ ConnectAndFail(message);
+ }
+};
+
+class TlsBogusExtensionTestPre13 : public TlsBogusExtensionTest {
+ protected:
+ void ConnectAndFail(uint8_t) override {
+ ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension);
+ }
+};
+
+class TlsBogusExtensionTest13 : public TlsBogusExtensionTest {
+ protected:
+ void ConnectAndFail(uint8_t message) override {
+ if (message != kTlsHandshakeServerHello) {
+ ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension);
+ return;
+ }
+
+ FailWithAlert(kTlsAlertUnsupportedExtension);
+ }
+
+ void FailWithAlert(uint8_t alert) {
+ StartConnect();
+ client_->Handshake(); // ClientHello
+ server_->Handshake(); // ServerHello
+
+ client_->ExpectSendAlert(alert);
+ client_->Handshake();
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ }
+ server_->Handshake();
+ }
+};
+
+TEST_P(TlsBogusExtensionTestPre13, AddBogusExtensionServerHello) {
+ Run(kTlsHandshakeServerHello);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionServerHello) {
+ Run(kTlsHandshakeServerHello);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionEncryptedExtensions) {
+ Run(kTlsHandshakeEncryptedExtensions);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificate) {
+ Run(kTlsHandshakeCertificate);
+}
+
+// It's perfectly valid to set unknown extensions in CertificateRequest.
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificateRequest) {
+ server_->RequestClientAuth(false);
+ AddFilter(kTlsHandshakeCertificateRequest, 0xff);
+ ConnectExpectAlert(client_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) {
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ Run(kTlsHandshakeHelloRetryRequest);
+}
+
+// NewSessionTicket allows unknown extensions AND it isn't protected by the
+// Finished. So adding an unknown extension doesn't cause an error.
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionNewSessionTicket) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+
+ AddFilter(kTlsHandshakeNewSessionTicket, 0xff);
+ Connect();
+ SendReceive();
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+}
+
+class TlsDisallowedExtensionTest13 : public TlsBogusExtensionTest {
+ protected:
+ void ConnectAndFail(uint8_t message) override {
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ }
+};
+
+TEST_P(TlsDisallowedExtensionTest13, AddVersionExtensionEncryptedExtensions) {
+ Run(kTlsHandshakeEncryptedExtensions, ssl_tls13_supported_versions_xtn);
+}
+
+TEST_P(TlsDisallowedExtensionTest13, AddVersionExtensionCertificate) {
+ Run(kTlsHandshakeCertificate, ssl_tls13_supported_versions_xtn);
+}
+
+TEST_P(TlsDisallowedExtensionTest13, AddVersionExtensionCertificateRequest) {
+ server_->RequestClientAuth(false);
+ Run(kTlsHandshakeCertificateRequest, ssl_tls13_supported_versions_xtn);
+}
+
+/* For unadvertised disallowed extensions an unsupported_extension alert is
+ * thrown since NSS checks for unadvertised extensions before its disallowed
+ * extension check. */
+class TlsDisallowedUnadvertisedExtensionTest13 : public TlsBogusExtensionTest {
+ protected:
+ void ConnectAndFail(uint8_t message) override {
+ uint8_t alert = kTlsAlertUnsupportedExtension;
+ if (message == kTlsHandshakeCertificateRequest) {
+ alert = kTlsAlertIllegalParameter;
+ }
+ ConnectExpectAlert(client_, alert);
+ }
+};
+
+TEST_P(TlsDisallowedUnadvertisedExtensionTest13,
+ AddPSKExtensionEncryptedExtensions) {
+ Run(kTlsHandshakeEncryptedExtensions, ssl_tls13_pre_shared_key_xtn);
+}
+
+TEST_P(TlsDisallowedUnadvertisedExtensionTest13, AddPSKExtensionCertificate) {
+ Run(kTlsHandshakeCertificate, ssl_tls13_pre_shared_key_xtn);
+}
+
+TEST_P(TlsDisallowedUnadvertisedExtensionTest13,
+ AddPSKExtensionCertificateRequest) {
+ server_->RequestClientAuth(false);
+ Run(kTlsHandshakeCertificateRequest, ssl_tls13_pre_shared_key_xtn);
+}
+
+TEST_P(TlsConnectStream, IncludePadding) {
+ EnsureTlsSetup();
+ SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_FALSE); // Don't GREASE
+
+ // This needs to be long enough to push a TLS 1.0 ClientHello over 255, but
+ // short enough not to push a TLS 1.3 ClientHello over 511.
+ static const char* long_name =
+ "chickenchickenchickenchickenchickenchickenchickenchicken."
+ "chickenchickenchickenchickenchickenchickenchickenchicken."
+ "chickenchickenchickenchickenchicken.";
+ SECStatus rv = SSL_SetURL(client_->ssl_fd(), long_name);
+ EXPECT_EQ(SECSuccess, rv);
+
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, ssl_padding_xtn);
+ client_->StartConnect();
+ client_->Handshake();
+ EXPECT_TRUE(capture->captured());
+}
+
+TEST_F(TlsConnectDatagram13, Dtls13RejectLegacyCookie) {
+ EnsureTlsSetup();
+ MakeTlsFilter<Dtls13LegacyCookieInjector>(client_);
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectGeneric, ClientHelloExtensionPermutation) {
+ EnsureTlsSetup();
+ PR_ASSERT(SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_CH_EXTENSION_PERMUTATION,
+ PR_TRUE) == SECSuccess);
+ Connect();
+}
+
+TEST_F(TlsConnectStreamTls13, ClientHelloExtensionPermutationWithPSK) {
+ EnsureTlsSetup();
+
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ const uint8_t kPskDummyVal_[16] = {0x01, 0x02, 0x03, 0x04, 0x05,
+ 0x06, 0x07, 0x08, 0x09, 0x0a,
+ 0x0b, 0x0c, 0x0d, 0x0e, 0x0f};
+ SECItem psk_item;
+ psk_item.type = siBuffer;
+ psk_item.len = sizeof(kPskDummyVal_);
+ psk_item.data = const_cast<uint8_t*>(kPskDummyVal_);
+ PK11SymKey* key =
+ PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN, PK11_OriginUnwrap,
+ CKA_DERIVE, &psk_item, NULL);
+
+ ScopedPK11SymKey scoped_psk_(key);
+ const std::string kPskDummyLabel_ = "NSS PSK GTEST label";
+ const SSLHashType kPskHash_ = ssl_hash_sha384;
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+
+ PR_ASSERT(SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_CH_EXTENSION_PERMUTATION,
+ PR_TRUE) == SECSuccess);
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+}
+
+/* This test checks that the ClientHello extension order is actually permuted
+ * if ss->opt.chXtnPermutation is set. It is asserted that at least one out of
+ * 10 extension orders differs from the others.
+ *
+ * This is a probabilistic test: The default TLS 1.3 ClientHello contains 8
+ * extensions, leading to a 1/8! probability for any extension order and the
+ * same probability for two drawn extension orders to coincide.
+ * Since all sequences are compared against each other this leads to a false
+ * positive rate of (1/8!)^(n^2-n).
+ * To achieve a spurious failure rate << 1/2^64, we compare n=10 drawn orders.
+ *
+ * This test assures that randomisation is happening but does not check quality
+ * of the used Fisher-Yates shuffle. */
+TEST_F(TlsConnectStreamTls13,
+ ClientHelloExtensionPermutationProbabilisticTest) {
+ std::vector<std::vector<uint16_t>> orders;
+
+ /* Capture the extension order of 10 ClientHello messages. */
+ for (size_t i = 0; i < 10; i++) {
+ client_->StartConnect();
+ /* Enable ClientHello extension permutation. */
+ ASSERT_TRUE(SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_CH_EXTENSION_PERMUTATION,
+ PR_TRUE) == SECSuccess);
+ /* Capture extension order filter. */
+ auto filter = MakeTlsFilter<TlsExtensionOrderCapture>(
+ client_, kTlsHandshakeClientHello);
+ /* Send ClientHello. */
+ client_->Handshake();
+ /* Remember extension order. */
+ orders.push_back(filter->order);
+ /* Reset client / server state. */
+ Reset();
+ }
+
+ /* Check for extension order inequality. */
+ size_t inequal = 0;
+ for (auto& outerOrders : orders) {
+ for (auto& innerOrders : orders) {
+ if (outerOrders != innerOrders) {
+ inequal++;
+ }
+ }
+ }
+ ASSERT_TRUE(inequal >= 1);
+}
+
+// The certificate_authorities xtn can be included in a ClientHello [RFC 8446,
+// Section 4.2]
+TEST_F(TlsConnectStreamTls13, ClientHelloCertAuthXtnToleration) {
+ EnsureTlsSetup();
+ uint8_t bodyBuf[3] = {0x00, 0x01, 0xff};
+ DataBuffer body(bodyBuf, sizeof(bodyBuf));
+ auto ch = MakeTlsFilter<TlsExtensionAppender>(
+ client_, kTlsHandshakeClientHello, ssl_tls13_certificate_authorities_xtn,
+ body);
+ // The Connection will fail because the added extension isn't in the client's
+ // transcript not because the extension is unsupported (Bug 1815167).
+ server_->ExpectSendAlert(bad_record_mac);
+ client_->ExpectSendAlert(bad_record_mac);
+ ConnectExpectFail();
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ExtensionStream, TlsExtensionTestGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll));
+INSTANTIATE_TEST_SUITE_P(
+ ExtensionDatagram, TlsExtensionTestGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus));
+INSTANTIATE_TEST_SUITE_P(ExtensionDatagramOnly, TlsExtensionTestDtls,
+ TlsConnectTestBase::kTlsV11Plus);
+
+INSTANTIATE_TEST_SUITE_P(ExtensionTls12, TlsExtensionTest12,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12));
+
+INSTANTIATE_TEST_SUITE_P(ExtensionTls12Plus, TlsExtensionTest12Plus,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12Plus));
+
+INSTANTIATE_TEST_SUITE_P(
+ ExtensionPre13Stream, TlsExtensionTestPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12));
+INSTANTIATE_TEST_SUITE_P(ExtensionPre13Datagram, TlsExtensionTestPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV11V12));
+
+INSTANTIATE_TEST_SUITE_P(ExtensionTls13, TlsExtensionTest13,
+ TlsConnectTestBase::kTlsVariantsAll);
+
+INSTANTIATE_TEST_SUITE_P(
+ BogusExtensionStream, TlsBogusExtensionTestPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12));
+INSTANTIATE_TEST_SUITE_P(
+ BogusExtensionDatagram, TlsBogusExtensionTestPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11V12));
+
+INSTANTIATE_TEST_SUITE_P(BogusExtension13, TlsBogusExtensionTest13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV13));
+
+INSTANTIATE_TEST_SUITE_P(DisallowedExtension13, TlsDisallowedExtensionTest13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV13));
+
+INSTANTIATE_TEST_SUITE_P(DisallowedUnadvertisedExtension13,
+ TlsDisallowedUnadvertisedExtensionTest13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV13));
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc
new file mode 100644
index 0000000000..3752812633
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc
@@ -0,0 +1,169 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+// This class cuts every unencrypted handshake record into two parts.
+class RecordFragmenter : public PacketFilter {
+ public:
+ RecordFragmenter(bool is_dtls13)
+ : is_dtls13_(is_dtls13), sequence_number_(0), splitting_(true) {}
+
+ private:
+ class HandshakeSplitter {
+ public:
+ HandshakeSplitter(bool is_dtls13, const DataBuffer& input,
+ DataBuffer* output, uint64_t* sequence_number)
+ : is_dtls13_(is_dtls13),
+ input_(input),
+ output_(output),
+ cursor_(0),
+ sequence_number_(sequence_number) {}
+
+ private:
+ void WriteRecord(TlsRecordHeader& record_header,
+ DataBuffer& record_fragment) {
+ TlsRecordHeader fragment_header(
+ record_header.variant(), record_header.version(),
+ record_header.content_type(), *sequence_number_);
+ ++*sequence_number_;
+ if (::g_ssl_gtest_verbose) {
+ std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment
+ << std::endl;
+ }
+ cursor_ = fragment_header.Write(output_, cursor_, record_fragment);
+ }
+
+ bool SplitRecord(TlsRecordHeader& record_header, DataBuffer& record) {
+ TlsParser parser(record);
+ while (parser.remaining()) {
+ TlsHandshakeFilter::HandshakeHeader handshake_header;
+ DataBuffer handshake_body;
+ bool complete = false;
+ if (!handshake_header.Parse(&parser, record_header, DataBuffer(),
+ &handshake_body, &complete)) {
+ ADD_FAILURE() << "couldn't parse handshake header";
+ return false;
+ }
+ if (!complete) {
+ ADD_FAILURE() << "don't want to deal with fragmented messages";
+ return false;
+ }
+
+ DataBuffer record_fragment;
+ // We can't fragment handshake records that are too small.
+ if (handshake_body.len() < 2) {
+ handshake_header.Write(&record_fragment, 0U, handshake_body);
+ WriteRecord(record_header, record_fragment);
+ continue;
+ }
+
+ size_t cut = handshake_body.len() / 2;
+ handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, 0U,
+ cut);
+ WriteRecord(record_header, record_fragment);
+
+ handshake_header.WriteFragment(&record_fragment, 0U, handshake_body,
+ cut, handshake_body.len() - cut);
+ WriteRecord(record_header, record_fragment);
+ }
+ return true;
+ }
+
+ public:
+ bool Split() {
+ TlsParser parser(input_);
+ while (parser.remaining()) {
+ TlsRecordHeader header;
+ DataBuffer record;
+ if (!header.Parse(is_dtls13_, 0, &parser, &record)) {
+ ADD_FAILURE() << "bad record header";
+ return false;
+ }
+
+ if (::g_ssl_gtest_verbose) {
+ std::cerr << "Record: " << header << ' ' << record << std::endl;
+ }
+
+ // Don't touch packets from a non-zero epoch. Leave these unmodified.
+ if ((header.sequence_number() >> 48) != 0ULL) {
+ cursor_ = header.Write(output_, cursor_, record);
+ continue;
+ }
+
+ // Just rewrite the sequence number (CCS only).
+ if (header.content_type() != ssl_ct_handshake) {
+ EXPECT_EQ(ssl_ct_change_cipher_spec, header.content_type());
+ WriteRecord(header, record);
+ continue;
+ }
+
+ if (!SplitRecord(header, record)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private:
+ bool is_dtls13_;
+ const DataBuffer& input_;
+ DataBuffer* output_;
+ size_t cursor_;
+ uint64_t* sequence_number_;
+ };
+
+ protected:
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) override {
+ if (!splitting_) {
+ return KEEP;
+ }
+
+ output->Allocate(input.len());
+ HandshakeSplitter splitter(is_dtls13_, input, output, &sequence_number_);
+ if (!splitter.Split()) {
+ // If splitting fails, we obviously reached encrypted packets.
+ // Stop splitting from that point onward.
+ splitting_ = false;
+ return KEEP;
+ }
+
+ return CHANGE;
+ }
+
+ private:
+ bool is_dtls13_;
+ uint64_t sequence_number_;
+ bool splitting_;
+};
+
+TEST_P(TlsConnectDatagram, FragmentClientPackets) {
+ bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3;
+ client_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13));
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectDatagram, FragmentServerPackets) {
+ bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3;
+ server_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13));
+ Connect();
+ SendReceive();
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc
new file mode 100644
index 0000000000..ef6f7602cf
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc
@@ -0,0 +1,252 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "blapi.h"
+#include "ssl.h"
+#include "sslimpl.h"
+#include "tls_connect.h"
+
+#include "gtest/gtest.h"
+
+namespace nss_test {
+
+#ifdef UNSAFE_FUZZER_MODE
+#define FUZZ_F(c, f) TEST_F(c, Fuzz_##f)
+#define FUZZ_P(c, f) TEST_P(c, Fuzz_##f)
+#else
+#define FUZZ_F(c, f) TEST_F(c, DISABLED_Fuzz_##f)
+#define FUZZ_P(c, f) TEST_P(c, DISABLED_Fuzz_##f)
+#endif
+
+const uint8_t kShortEmptyFinished[8] = {0};
+const uint8_t kLongEmptyFinished[128] = {0};
+
+class TlsFuzzTest : public TlsConnectGeneric {};
+
+// Record the application data stream.
+class TlsApplicationDataRecorder : public TlsRecordFilter {
+ public:
+ TlsApplicationDataRecorder(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a), buffer_() {}
+
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (header.content_type() == ssl_ct_application_data) {
+ buffer_.Append(input);
+ }
+
+ return KEEP;
+ }
+
+ const DataBuffer& buffer() const { return buffer_; }
+
+ private:
+ DataBuffer buffer_;
+};
+
+// Check that due to the deterministic PRNG we derive
+// the same master secret in two consecutive TLS sessions.
+FUZZ_P(TlsFuzzTest, DeterministicExporter) {
+ const char kLabel[] = "label";
+ std::vector<unsigned char> out1(32), out2(32);
+
+ // Make sure we have RSA blinding params.
+ Connect();
+
+ Reset();
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+
+ // Reset the RNG state.
+ EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0));
+ Connect();
+
+ // Export a key derived from the MS and nonces.
+ SECStatus rv =
+ SSL_ExportKeyingMaterial(client_->ssl_fd(), kLabel, strlen(kLabel), false,
+ NULL, 0, out1.data(), out1.size());
+ EXPECT_EQ(SECSuccess, rv);
+
+ Reset();
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+
+ // Reset the RNG state.
+ EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0));
+ Connect();
+
+ // Export another key derived from the MS and nonces.
+ rv = SSL_ExportKeyingMaterial(client_->ssl_fd(), kLabel, strlen(kLabel),
+ false, NULL, 0, out2.data(), out2.size());
+ EXPECT_EQ(SECSuccess, rv);
+
+ // The two exported keys should be the same.
+ EXPECT_EQ(out1, out2);
+}
+
+// Check that due to the deterministic RNG two consecutive
+// TLS sessions will have the exact same transcript.
+FUZZ_P(TlsFuzzTest, DeterministicTranscript) {
+ // Make sure we have RSA blinding params.
+ Connect();
+
+ // Connect a few times and compare the transcripts byte-by-byte.
+ DataBuffer last;
+ for (size_t i = 0; i < 5; i++) {
+ Reset();
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+
+ DataBuffer buffer;
+ MakeTlsFilter<TlsConversationRecorder>(client_, buffer);
+ MakeTlsFilter<TlsConversationRecorder>(server_, buffer);
+
+ // Reset the RNG state.
+ EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0));
+ Connect();
+
+ // Ensure the filters go away before |buffer| does.
+ client_->ClearFilter();
+ server_->ClearFilter();
+
+ if (last.len() > 0) {
+ EXPECT_EQ(last, buffer);
+ }
+
+ last = buffer;
+ }
+}
+
+// Check that we can establish and use a connection
+// with all supported TLS versions, STREAM and DGRAM.
+// Check that records are NOT encrypted.
+// Check that records don't have a MAC.
+FUZZ_P(TlsFuzzTest, ConnectSendReceive_NullCipher) {
+ // Set up app data filters.
+ auto client_recorder = MakeTlsFilter<TlsApplicationDataRecorder>(client_);
+ auto server_recorder = MakeTlsFilter<TlsApplicationDataRecorder>(server_);
+
+ Connect();
+
+ // Construct the plaintext.
+ DataBuffer buf;
+ buf.Allocate(50);
+ for (size_t i = 0; i < buf.len(); ++i) {
+ buf.data()[i] = i & 0xff;
+ }
+
+ // Send/Receive data.
+ client_->SendBuffer(buf);
+ server_->SendBuffer(buf);
+ Receive(buf.len());
+
+ // Check for plaintext on the wire.
+ EXPECT_EQ(buf, client_recorder->buffer());
+ EXPECT_EQ(buf, server_recorder->buffer());
+}
+
+// Check that an invalid Finished message doesn't abort the connection.
+FUZZ_P(TlsFuzzTest, BogusClientFinished) {
+ EnsureTlsSetup();
+
+ MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>(
+ client_, kTlsHandshakeFinished,
+ DataBuffer(kShortEmptyFinished, sizeof(kShortEmptyFinished)));
+ Connect();
+ SendReceive();
+}
+
+// Check that an invalid Finished message doesn't abort the connection.
+FUZZ_P(TlsFuzzTest, BogusServerFinished) {
+ EnsureTlsSetup();
+
+ MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>(
+ server_, kTlsHandshakeFinished,
+ DataBuffer(kLongEmptyFinished, sizeof(kLongEmptyFinished)));
+ Connect();
+ SendReceive();
+}
+
+// Check that an invalid server auth signature doesn't abort the connection.
+FUZZ_P(TlsFuzzTest, BogusServerAuthSignature) {
+ EnsureTlsSetup();
+ uint8_t msg_type = version_ == SSL_LIBRARY_VERSION_TLS_1_3
+ ? kTlsHandshakeCertificateVerify
+ : kTlsHandshakeServerKeyExchange;
+ MakeTlsFilter<TlsLastByteDamager>(server_, msg_type);
+ Connect();
+ SendReceive();
+}
+
+// Check that an invalid client auth signature doesn't abort the connection.
+FUZZ_P(TlsFuzzTest, BogusClientAuthSignature) {
+ EnsureTlsSetup();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ MakeTlsFilter<TlsLastByteDamager>(client_, kTlsHandshakeCertificateVerify);
+ Connect();
+}
+
+// Check that session ticket resumption works.
+FUZZ_P(TlsFuzzTest, SessionTicketResumption) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+}
+
+// Check that session tickets are not encrypted.
+FUZZ_P(TlsFuzzTest, UnencryptedSessionTickets) {
+ ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
+
+ auto filter = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeNewSessionTicket);
+ Connect();
+
+ std::cerr << "ticket" << filter->buffer() << std::endl;
+ size_t offset = 4; // Skip lifetime.
+
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
+ offset += 4; // Skip ticket_age_add.
+ uint32_t nonce_len = 0;
+ EXPECT_TRUE(filter->buffer().Read(offset, 1, &nonce_len));
+ offset += 1 + nonce_len;
+ }
+
+ offset += 2; // Skip the ticket length.
+
+ // This bit parses the contents of the ticket, which would ordinarily be
+ // encrypted. Start by checking that we have the right version. This needs
+ // to be updated every time that TLS_EX_SESS_TICKET_VERSION is changed. But
+ // we don't use the #define. That way, any time that code is updated, this
+ // test will fail unless it is manually checked.
+ uint32_t ticket_version;
+ EXPECT_TRUE(filter->buffer().Read(offset, 2, &ticket_version));
+ EXPECT_EQ(0x010aU, ticket_version);
+ offset += 2;
+
+ // Check the protocol version number.
+ uint32_t tls_version = 0;
+ EXPECT_TRUE(filter->buffer().Read(offset, sizeof(version_), &tls_version));
+ EXPECT_EQ(version_, static_cast<decltype(version_)>(tls_version));
+ offset += sizeof(version_);
+
+ // Check the cipher suite.
+ uint32_t suite = 0;
+ EXPECT_TRUE(filter->buffer().Read(offset, 2, &suite));
+ client_->CheckCipherSuite(static_cast<uint16_t>(suite));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ FuzzStream, TlsFuzzTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll));
+INSTANTIATE_TEST_SUITE_P(
+ FuzzDatagram, TlsFuzzTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus));
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_gather_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_gather_unittest.cc
new file mode 100644
index 0000000000..2b0b722ae2
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_gather_unittest.cc
@@ -0,0 +1,156 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+class GatherV2ClientHelloTest : public TlsConnectTestBase {
+ public:
+ GatherV2ClientHelloTest() : TlsConnectTestBase(ssl_variant_stream, 0) {}
+
+ void ConnectExpectMalformedClientHello(const DataBuffer &data) {
+ EnsureTlsSetup();
+ server_->SetOption(SSL_ENABLE_V2_COMPATIBLE_HELLO, PR_TRUE);
+ server_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ client_->SendDirect(data);
+ server_->StartConnect();
+ server_->Handshake();
+ ASSERT_TRUE_WAIT(
+ (server_->error_code() == SSL_ERROR_RX_MALFORMED_CLIENT_HELLO), 2000);
+ }
+};
+
+// Gather a 5-byte v3 record, with a fragment length exceeding the maximum.
+TEST_F(TlsConnectTest, GatherExcessiveV3Record) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x16, 1); // handshake
+ idx = buffer.Write(idx, 0x0301, 2); // record_version
+ (void)buffer.Write(idx, MAX_FRAGMENT_LENGTH + 2048 + 1, 2); // length=max+1
+
+ EnsureTlsSetup();
+ server_->ExpectSendAlert(kTlsAlertRecordOverflow);
+ client_->SendDirect(buffer);
+ server_->StartConnect();
+ server_->Handshake();
+ ASSERT_TRUE_WAIT((server_->error_code() == SSL_ERROR_RX_RECORD_TOO_LONG),
+ 2000);
+}
+
+// Gather a 3-byte v2 header, with a fragment length of 2.
+TEST_F(GatherV2ClientHelloTest, GatherV2RecordLongHeader) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x0002, 2); // length=2 (long header)
+ idx = buffer.Write(idx, 0U, 1); // padding=0
+ (void)buffer.Write(idx, 0U, 2); // data
+
+ ConnectExpectMalformedClientHello(buffer);
+}
+
+// Gather a 3-byte v2 header, with a fragment length of 1.
+TEST_F(GatherV2ClientHelloTest, GatherV2RecordLongHeader2) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x0001, 2); // length=1 (long header)
+ idx = buffer.Write(idx, 0U, 1); // padding=0
+ idx = buffer.Write(idx, 0U, 1); // data
+ (void)buffer.Write(idx, 0U, 1); // surplus (need 5 bytes total)
+
+ ConnectExpectMalformedClientHello(buffer);
+}
+
+// Gather a 3-byte v2 header, with a zero fragment length.
+TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordLongHeader) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0U, 2); // length=0 (long header)
+ idx = buffer.Write(idx, 0U, 1); // padding=0
+ (void)buffer.Write(idx, 0U, 2); // surplus (need 5 bytes total)
+
+ ConnectExpectMalformedClientHello(buffer);
+}
+
+// Gather a 2-byte v2 header, with a fragment length of 3.
+TEST_F(GatherV2ClientHelloTest, GatherV2RecordShortHeader) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x8003, 2); // length=3 (short header)
+ (void)buffer.Write(idx, 0U, 3); // data
+
+ ConnectExpectMalformedClientHello(buffer);
+}
+
+// Gather a 2-byte v2 header, with a fragment length of 2.
+TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordShortHeader2) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x8002, 2); // length=2 (short header)
+ idx = buffer.Write(idx, 0U, 2); // data
+ (void)buffer.Write(idx, 0U, 1); // surplus (need 5 bytes total)
+
+ ConnectExpectMalformedClientHello(buffer);
+}
+
+// Gather a 2-byte v2 header, with a fragment length of 1.
+TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordShortHeader3) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x8001, 2); // length=1 (short header)
+ idx = buffer.Write(idx, 0U, 1); // data
+ (void)buffer.Write(idx, 0U, 2); // surplus (need 5 bytes total)
+
+ ConnectExpectMalformedClientHello(buffer);
+}
+
+// Gather a 2-byte v2 header, with a zero fragment length.
+TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordShortHeader) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x8000, 2); // length=0 (short header)
+ (void)buffer.Write(idx, 0U, 3); // surplus (need 5 bytes total)
+
+ ConnectExpectMalformedClientHello(buffer);
+}
+
+/* Test correct gather buffer clearing/freeing and (re-)allocation.
+ *
+ * Freeing and (re-)allocation of the gather buffers after reception of single
+ * records is only done in DEBUG builds. Normally they are created and
+ * destroyed with the SSL socket.
+ *
+ * TLS 1.0 record splitting leads to implicit complete read of the data.
+ *
+ * The NSS DTLS impelmentation does not allow partial reads
+ * (see sslsecur.c, line 535-543). */
+TEST_P(TlsConnectStream, GatherBufferPartialReadTest) {
+ EnsureTlsSetup();
+ Connect();
+
+ client_->SendData(1000);
+
+ if (version_ > SSL_LIBRARY_VERSION_TLS_1_0) {
+ for (unsigned i = 1; i <= 20; i++) {
+ server_->ReadBytes(50);
+ ASSERT_EQ(server_->received_bytes(), 50U * i);
+ }
+ } else {
+ server_->ReadBytes(50);
+ ASSERT_EQ(server_->received_bytes(), 1000U);
+ }
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.cc b/security/nss/gtests/ssl_gtest/ssl_gtest.cc
new file mode 100644
index 0000000000..2fff9d7cbb
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_gtest.cc
@@ -0,0 +1,52 @@
+#include "nspr.h"
+#include "nss.h"
+#include "prenv.h"
+#include "ssl.h"
+
+#include <cstdlib>
+
+#include "test_io.h"
+#include "databuffer.h"
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+
+std::string g_working_dir_path;
+bool g_ssl_gtest_verbose;
+
+int main(int argc, char** argv) {
+ // Start the tests
+ ::testing::InitGoogleTest(&argc, argv);
+ g_working_dir_path = ".";
+ g_ssl_gtest_verbose = false;
+
+ char* workdir = PR_GetEnvSecure("NSS_GTEST_WORKDIR");
+ if (workdir) g_working_dir_path = workdir;
+
+ for (int i = 0; i < argc; i++) {
+ if (!strcmp(argv[i], "-d")) {
+ g_working_dir_path = argv[i + 1];
+ ++i;
+ } else if (!strcmp(argv[i], "-v")) {
+ g_ssl_gtest_verbose = true;
+ nss_test::DataBuffer::SetLogLimit(16384);
+ }
+ }
+
+ if (NSS_Initialize(g_working_dir_path.c_str(), "", "", SECMOD_DB,
+ NSS_INIT_READONLY) != SECSuccess) {
+ return 1;
+ }
+ if (NSS_SetDomesticPolicy() != SECSuccess) {
+ return 1;
+ }
+ int rv = RUN_ALL_TESTS();
+
+ if (NSS_Shutdown() != SECSuccess) {
+ return 1;
+ }
+
+ nss_test::Poller::Shutdown();
+
+ return rv;
+}
diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.gyp b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp
new file mode 100644
index 0000000000..d078ce2303
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp
@@ -0,0 +1,135 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+{
+ 'includes': [
+ '../../coreconf/config.gypi',
+ '../common/gtest.gypi',
+ ],
+ 'targets': [
+ {
+ 'target_name': 'ssl_gtest',
+ 'type': 'executable',
+ 'sources': [
+ 'bloomfilter_unittest.cc',
+ 'libssl_internals.c',
+ 'selfencrypt_unittest.cc',
+ 'ssl_0rtt_unittest.cc',
+ 'ssl_aead_unittest.cc',
+ 'ssl_agent_unittest.cc',
+ 'ssl_auth_unittest.cc',
+ 'ssl_cert_ext_unittest.cc',
+ 'ssl_cipherorder_unittest.cc',
+ 'ssl_ciphersuite_unittest.cc',
+ 'ssl_custext_unittest.cc',
+ 'ssl_damage_unittest.cc',
+ 'ssl_debug_env_unittest.cc',
+ 'ssl_dhe_unittest.cc',
+ 'ssl_drop_unittest.cc',
+ 'ssl_ecdh_unittest.cc',
+ 'ssl_ems_unittest.cc',
+ 'ssl_exporter_unittest.cc',
+ 'ssl_extension_unittest.cc',
+ 'ssl_fuzz_unittest.cc',
+ 'ssl_fragment_unittest.cc',
+ 'ssl_gather_unittest.cc',
+ 'ssl_gtest.cc',
+ 'ssl_hrr_unittest.cc',
+ 'ssl_keyupdate_unittest.cc',
+ 'ssl_loopback_unittest.cc',
+ 'ssl_masking_unittest.cc',
+ 'ssl_misc_unittest.cc',
+ 'ssl_record_unittest.cc',
+ 'ssl_recordsep_unittest.cc',
+ 'ssl_recordsize_unittest.cc',
+ 'ssl_resumption_unittest.cc',
+ 'ssl_renegotiation_unittest.cc',
+ 'ssl_skip_unittest.cc',
+ 'ssl_staticrsa_unittest.cc',
+ 'ssl_tls13compat_unittest.cc',
+ 'ssl_v2_client_hello_unittest.cc',
+ 'ssl_version_unittest.cc',
+ 'ssl_versionpolicy_unittest.cc',
+ 'test_io.cc',
+ 'tls_agent.cc',
+ 'tls_connect.cc',
+ 'tls_filter.cc',
+ 'tls_hkdf_unittest.cc',
+ 'tls_ech_unittest.cc',
+ 'tls_protect.cc',
+ 'tls_psk_unittest.cc',
+ 'tls_subcerts_unittest.cc',
+ 'tls_grease_unittest.cc'
+ ],
+ 'dependencies': [
+ '<(DEPTH)/exports.gyp:nss_exports',
+ '<(DEPTH)/lib/util/util.gyp:nssutil3',
+ '<(DEPTH)/gtests/google_test/google_test.gyp:gtest',
+ '<(DEPTH)/lib/smime/smime.gyp:smime',
+ '<(DEPTH)/lib/ssl/ssl.gyp:ssl',
+ '<(DEPTH)/lib/nss/nss.gyp:nss_static',
+ '<(DEPTH)/lib/pkcs12/pkcs12.gyp:pkcs12',
+ '<(DEPTH)/lib/pkcs7/pkcs7.gyp:pkcs7',
+ '<(DEPTH)/lib/certhigh/certhigh.gyp:certhi',
+ '<(DEPTH)/lib/cryptohi/cryptohi.gyp:cryptohi',
+ '<(DEPTH)/lib/certdb/certdb.gyp:certdb',
+ '<(DEPTH)/lib/pki/pki.gyp:nsspki',
+ '<(DEPTH)/lib/dev/dev.gyp:nssdev',
+ '<(DEPTH)/lib/base/base.gyp:nssb',
+ '<(DEPTH)/lib/zlib/zlib.gyp:nss_zlib',
+ '<(DEPTH)/cpputil/cpputil.gyp:cpputil',
+ '<(DEPTH)/lib/libpkix/libpkix.gyp:libpkix',
+ ],
+ 'conditions': [
+ [ 'static_libs==1', {
+ 'dependencies': [
+ '<(DEPTH)/lib/pk11wrap/pk11wrap.gyp:pk11wrap_static',
+ ],
+ }, {
+ 'dependencies': [
+ '<(DEPTH)/lib/sqlite/sqlite.gyp:sqlite3',
+ '<(DEPTH)/lib/pk11wrap/pk11wrap.gyp:pk11wrap',
+ '<(DEPTH)/lib/softoken/softoken.gyp:softokn',
+ '<(DEPTH)/lib/freebl/freebl.gyp:freebl',
+ ],
+ }],
+ [ 'disable_dbm==0', {
+ 'dependencies': [
+ '<(DEPTH)/lib/dbm/src/src.gyp:dbm',
+ ],
+ }],
+ [ 'enable_sslkeylogfile==1 and sanitizer_flags==0', {
+ 'sources': [
+ 'ssl_keylog_unittest.cc',
+ ],
+ 'defines': [
+ 'NSS_ALLOW_SSLKEYLOGFILE',
+ ],
+ }],
+ # ssl_gtest fuzz defines should only be determined by the 'fuzz_tls'
+ # flag (so as to match lib/ssl). If gtest.gypi added the define due
+ # to '--fuzz' only, remove it.
+ ['fuzz_tls==1', {
+ 'defines': [
+ 'UNSAFE_FUZZER_MODE',
+ ],
+ }, {
+ 'defines!': [
+ 'UNSAFE_FUZZER_MODE',
+ ],
+ }],
+ ],
+ }
+ ],
+ 'target_defaults': {
+ 'include_dirs': [
+ '../../lib/ssl'
+ ],
+ 'defines': [
+ 'NSS_USE_STATIC_LIBS'
+ ],
+ },
+ 'variables': {
+ 'module': 'nss',
+ }
+}
diff --git a/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc
new file mode 100644
index 0000000000..3b81278f47
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc
@@ -0,0 +1,1364 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+// This is internal, just to get DTLS_1_3_DRAFT_VERSION.
+#include "ssl3prot.h"
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
+ const char* k0RttData = "Such is life";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+
+ SetupForZeroRtt(); // initial handshake as normal
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+ server_->ConfigNamedGroups(groups);
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+
+ // Send first ClientHello and send 0-RTT data
+ auto capture_early_data =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_early_data_xtn);
+ client_->Handshake();
+ EXPECT_EQ(k0RttDataLen, PR_Write(client_->ssl_fd(), k0RttData,
+ k0RttDataLen)); // 0-RTT write.
+ EXPECT_TRUE(capture_early_data->captured());
+
+ // Send the HelloRetryRequest
+ auto hrr_capture = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeHelloRetryRequest);
+ server_->Handshake();
+ EXPECT_LT(0U, hrr_capture->buffer().len());
+
+ // The server can't read
+ std::vector<uint8_t> buf(k0RttDataLen);
+ EXPECT_EQ(SECFailure, PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen));
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Make a new capture for the early data.
+ capture_early_data =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_early_data_xtn);
+
+ // Complete the handshake successfully
+ Handshake();
+ ExpectEarlyDataAccepted(false); // The server should reject 0-RTT
+ CheckConnected();
+ SendReceive();
+ EXPECT_FALSE(capture_early_data->captured());
+}
+
+// This filter only works for DTLS 1.3 where there is exactly one handshake
+// packet. If the record is split into two packets, or there are multiple
+// handshake packets, this will break.
+class CorrectMessageSeqAfterHrrFilter : public TlsRecordFilter {
+ public:
+ CorrectMessageSeqAfterHrrFilter(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& record, size_t* offset,
+ DataBuffer* output) {
+ if (filtered_packets() > 0 || header.content_type() != ssl_ct_handshake) {
+ return KEEP;
+ }
+
+ DataBuffer buffer(record);
+ TlsRecordHeader new_header(header.variant(), header.version(),
+ header.content_type(),
+ header.sequence_number() + 1);
+
+ // Correct message_seq.
+ buffer.Write(4, 1U, 2);
+
+ *offset = new_header.Write(output, *offset, buffer);
+ return CHANGE;
+ }
+};
+
+TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) {
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+
+ SetupForZeroRtt();
+ ExpectResumption(RESUME_TICKET);
+
+ client_->ConfigNamedGroups(groups);
+ server_->ConfigNamedGroups(groups);
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+
+ // A new client that tries to resume with 0-RTT but doesn't send the
+ // correct key share(s). The server will respond with an HRR.
+ auto orig_client =
+ std::make_shared<TlsAgent>(client_->name(), TlsAgent::CLIENT, variant_);
+ client_.swap(orig_client);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->ConfigureSessionCache(RESUME_BOTH);
+ client_->Set0RttEnabled(true);
+ client_->StartConnect();
+
+ // Swap in the new client.
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+
+ // Send the ClientHello.
+ client_->Handshake();
+ // Process the CH, send an HRR.
+ server_->Handshake();
+
+ // Swap the client we created manually with the one that successfully
+ // received a PSK, and try to resume with 0-RTT. The client doesn't know
+ // about the HRR so it will send the early_data xtn as well as 0-RTT data.
+ client_.swap(orig_client);
+ orig_client.reset();
+
+ // Correct the DTLS message sequence number after an HRR.
+ if (variant_ == ssl_variant_datagram) {
+ MakeTlsFilter<CorrectMessageSeqAfterHrrFilter>(client_);
+ }
+
+ server_->SetPeer(client_);
+ client_->Handshake();
+
+ // Send 0-RTT data.
+ const char* k0RttData = "ABCDEF";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+ PRInt32 rv = PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen);
+ EXPECT_EQ(k0RttDataLen, rv);
+
+ ExpectAlert(server_, kTlsAlertUnsupportedExtension);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_EXTENSION_ALERT);
+}
+
+class KeyShareReplayer : public TlsExtensionFilter {
+ public:
+ KeyShareReplayer(const std::shared_ptr<TlsAgent>& a)
+ : TlsExtensionFilter(a) {}
+
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (extension_type != ssl_tls13_key_share_xtn) {
+ return KEEP;
+ }
+
+ if (!data_.len()) {
+ data_ = input;
+ return KEEP;
+ }
+
+ *output = data_;
+ return CHANGE;
+ }
+
+ private:
+ DataBuffer data_;
+};
+
+// This forces a HelloRetryRequest by disabling P-256 on the server. However,
+// the second ClientHello is modified so that it omits the requested share. The
+// server should reject this.
+TEST_P(TlsConnectTls13, RetryWithSameKeyShare) {
+ EnsureTlsSetup();
+ MakeTlsFilter<KeyShareReplayer>(client_);
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+ server_->ConfigNamedGroups(groups);
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code());
+ EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code());
+}
+
+// Here we modify the second ClientHello so that the client retries with the
+// same shares, even though the server wanted something else.
+TEST_P(TlsConnectTls13, RetryWithTwoShares) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+ MakeTlsFilter<KeyShareReplayer>(client_);
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+ server_->ConfigNamedGroups(groups);
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code());
+ EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code());
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackAccept) {
+ EnsureTlsSetup();
+
+ auto accept_hello = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<bool*>(arg);
+ *called = true;
+
+ EXPECT_TRUE(firstHello);
+ EXPECT_EQ(0U, clientTokenLen);
+ return ssl_hello_retry_accept;
+ };
+
+ bool cb_run = false;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ accept_hello, &cb_run));
+ Connect();
+ EXPECT_TRUE(cb_run);
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackAcceptGroupMismatch) {
+ EnsureTlsSetup();
+
+ auto accept_hello_twice = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen,
+ unsigned int appTokenMax, void* arg) {
+ auto* called = reinterpret_cast<size_t*>(arg);
+ ++*called;
+
+ EXPECT_EQ(0U, clientTokenLen);
+ return ssl_hello_retry_accept;
+ };
+
+ auto capture =
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_cookie_xtn);
+ capture->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ size_t cb_run = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), accept_hello_twice, &cb_run));
+ Connect();
+ EXPECT_EQ(2U, cb_run);
+ EXPECT_TRUE(capture->captured()) << "expected a cookie in HelloRetryRequest";
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackFail) {
+ EnsureTlsSetup();
+
+ auto fail_hello = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<bool*>(arg);
+ *called = true;
+
+ EXPECT_TRUE(firstHello);
+ EXPECT_EQ(0U, clientTokenLen);
+ return ssl_hello_retry_fail;
+ };
+
+ bool cb_run = false;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ fail_hello, &cb_run));
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_APPLICATION_ABORT);
+ EXPECT_TRUE(cb_run);
+}
+
+// Asking for retry twice isn't allowed.
+TEST_P(TlsConnectTls13, RetryCallbackRequestHrrTwice) {
+ EnsureTlsSetup();
+
+ auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) -> SSLHelloRetryRequestAction {
+ return ssl_hello_retry_request;
+ };
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ bad_callback, NULL));
+ ConnectExpectAlert(server_, kTlsAlertInternalError);
+ server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR);
+}
+
+// Accepting the CH and modifying the token isn't allowed.
+TEST_P(TlsConnectTls13, RetryCallbackAcceptAndSetToken) {
+ EnsureTlsSetup();
+
+ auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) -> SSLHelloRetryRequestAction {
+ *appTokenLen = 1;
+ return ssl_hello_retry_accept;
+ };
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ bad_callback, NULL));
+ ConnectExpectAlert(server_, kTlsAlertInternalError);
+ server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR);
+}
+
+// As above, but with reject.
+TEST_P(TlsConnectTls13, RetryCallbackRejectAndSetToken) {
+ EnsureTlsSetup();
+
+ auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) -> SSLHelloRetryRequestAction {
+ *appTokenLen = 1;
+ return ssl_hello_retry_fail;
+ };
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ bad_callback, NULL));
+ ConnectExpectAlert(server_, kTlsAlertInternalError);
+ server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR);
+}
+
+// This is a (pretend) buffer overflow.
+TEST_P(TlsConnectTls13, RetryCallbackSetTooLargeToken) {
+ EnsureTlsSetup();
+
+ auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) -> SSLHelloRetryRequestAction {
+ *appTokenLen = appTokenMax + 1;
+ return ssl_hello_retry_accept;
+ };
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ bad_callback, NULL));
+ ConnectExpectAlert(server_, kTlsAlertInternalError);
+ server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR);
+}
+
+SSLHelloRetryRequestAction RetryHello(PRBool firstHello,
+ const PRUint8* clientToken,
+ unsigned int clientTokenLen,
+ PRUint8* appToken,
+ unsigned int* appTokenLen,
+ unsigned int appTokenMax, void* arg) {
+ auto* called = reinterpret_cast<size_t*>(arg);
+ ++*called;
+
+ EXPECT_EQ(0U, clientTokenLen);
+ return firstHello ? ssl_hello_retry_request : ssl_hello_retry_accept;
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackRetry) {
+ EnsureTlsSetup();
+
+ auto capture_hrr = std::make_shared<TlsHandshakeRecorder>(
+ server_, ssl_hs_hello_retry_request);
+ auto capture_key_share =
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
+ capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ std::vector<std::shared_ptr<PacketFilter>> chain = {capture_hrr,
+ capture_key_share};
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(chain));
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+
+ // Do the first message exchange.
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+
+ EXPECT_EQ(1U, cb_called) << "callback should be called once here";
+ EXPECT_LT(0U, capture_hrr->buffer().len()) << "HelloRetryRequest expected";
+ EXPECT_FALSE(capture_key_share->captured())
+ << "no key_share extension expected";
+
+ auto capture_cookie =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_cookie_xtn);
+
+ Handshake();
+ CheckConnected();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(capture_cookie->captured()) << "should have a cookie";
+}
+
+static size_t CountShares(const DataBuffer& key_share) {
+ size_t count = 0;
+ uint32_t len = 0;
+ size_t offset = 2;
+
+ EXPECT_TRUE(key_share.Read(0, 2, &len));
+ EXPECT_EQ(key_share.len() - 2, len);
+ while (offset < key_share.len()) {
+ offset += 2; // Skip KeyShareEntry.group
+ EXPECT_TRUE(key_share.Read(offset, 2, &len));
+ offset += 2 + len; // Skip KeyShareEntry.key_exchange
+ ++count;
+ }
+ return count;
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackRetryWithAdditionalShares) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+
+ auto capture_server =
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
+ capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+
+ // Do the first message exchange.
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+
+ EXPECT_EQ(1U, cb_called) << "callback should be called once here";
+ EXPECT_FALSE(capture_server->captured())
+ << "no key_share extension expected from server";
+
+ auto capture_client_2nd =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
+
+ Handshake();
+ CheckConnected();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(capture_client_2nd->captured()) << "client should send key_share";
+ EXPECT_EQ(2U, CountShares(capture_client_2nd->extension()))
+ << "client should still send two shares";
+}
+
+// The callback should be run even if we have another reason to send
+// HelloRetryRequest. In this case, the server sends HRR because the server
+// wants a P-384 key share and the client didn't offer one.
+TEST_P(TlsConnectTls13, RetryCallbackRetryWithGroupMismatch) {
+ EnsureTlsSetup();
+
+ auto capture_cookie =
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_cookie_xtn);
+ capture_cookie->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ auto capture_key_share =
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
+ capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit{capture_cookie, capture_key_share}));
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+ Connect();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(capture_cookie->captured()) << "cookie expected";
+ EXPECT_TRUE(capture_key_share->captured()) << "key_share expected";
+}
+
+static const uint8_t kApplicationToken[] = {0x92, 0x44, 0x00};
+
+SSLHelloRetryRequestAction RetryHelloWithToken(
+ PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen,
+ PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<size_t*>(arg);
+ ++*called;
+
+ if (firstHello) {
+ memcpy(appToken, kApplicationToken, sizeof(kApplicationToken));
+ *appTokenLen = sizeof(kApplicationToken);
+ return ssl_hello_retry_request;
+ }
+
+ EXPECT_EQ(DataBuffer(kApplicationToken, sizeof(kApplicationToken)),
+ DataBuffer(clientToken, static_cast<size_t>(clientTokenLen)));
+ return ssl_hello_retry_accept;
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackRetryWithToken) {
+ EnsureTlsSetup();
+
+ auto capture_key_share =
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
+ capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess,
+ SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHelloWithToken, &cb_called));
+ Connect();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_FALSE(capture_key_share->captured()) << "no key share expected";
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackRetryWithTokenAndGroupMismatch) {
+ EnsureTlsSetup();
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ auto capture_key_share =
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
+ capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess,
+ SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHelloWithToken, &cb_called));
+ Connect();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(capture_key_share->captured()) << "key share expected";
+}
+
+SSLHelloRetryRequestAction CheckTicketToken(
+ PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen,
+ PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<bool*>(arg);
+ *called = true;
+
+ EXPECT_TRUE(firstHello);
+ EXPECT_EQ(DataBuffer(kApplicationToken, sizeof(kApplicationToken)),
+ DataBuffer(clientToken, static_cast<size_t>(clientTokenLen)));
+ return ssl_hello_retry_accept;
+}
+
+// Stream because SSL_SendSessionTicket only supports that.
+TEST_F(TlsConnectStreamTls13, RetryCallbackWithSessionTicketToken) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ EXPECT_EQ(SECSuccess,
+ SSL_SendSessionTicket(server_->ssl_fd(), kApplicationToken,
+ sizeof(kApplicationToken)));
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+
+ bool cb_run = false;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), CheckTicketToken, &cb_run));
+ Connect();
+ EXPECT_TRUE(cb_run);
+}
+
+void TriggerHelloRetryRequest(std::shared_ptr<TlsAgent>& client,
+ std::shared_ptr<TlsAgent>& server) {
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server->ssl_fd(),
+ RetryHello, &cb_called));
+
+ // Start the handshake.
+ client->StartConnect();
+ server->StartConnect();
+ client->Handshake();
+ server->Handshake();
+ EXPECT_EQ(1U, cb_called);
+ // Stop the callback from being called in future handshakes.
+ EXPECT_EQ(SECSuccess,
+ SSL_HelloRetryRequestCallback(server->ssl_fd(), nullptr, nullptr));
+}
+
+TEST_P(TlsConnectTls13, VersionNumbersAfterRetry) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ auto r = MakeTlsFilter<TlsRecordRecorder>(client_);
+ TriggerHelloRetryRequest(client_, server_);
+ Handshake();
+ ASSERT_GT(r->count(), 1UL);
+ auto ch1 = r->record(0);
+ if (ch1.header.is_dtls()) {
+ ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_1, ch1.header.version());
+ } else {
+ ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_0, ch1.header.version());
+ }
+ auto ch2 = r->record(1);
+ ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, ch2.header.version());
+
+ CheckConnected();
+}
+
+TEST_P(TlsConnectTls13, RetryStateless) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, RetryStatefulDropCookie) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeTlsFilter<TlsExtensionDropper>(client_, ssl_tls13_cookie_xtn);
+
+ ExpectAlert(server_, kTlsAlertMissingExtension);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_MISSING_COOKIE_EXTENSION);
+}
+
+class TruncateHrrCookie : public TlsExtensionFilter {
+ public:
+ TruncateHrrCookie(const std::shared_ptr<TlsAgent>& a)
+ : TlsExtensionFilter(a) {}
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (extension_type != ssl_tls13_cookie_xtn) {
+ return KEEP;
+ }
+
+ // Claim a zero-length cookie.
+ output->Allocate(2);
+ output->Write(0, static_cast<uint32_t>(0), 2);
+ return CHANGE;
+ }
+};
+
+TEST_P(TlsConnectTls13, RetryCookieEmpty) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeTlsFilter<TruncateHrrCookie>(client_);
+
+ ExpectAlert(server_, kTlsAlertHandshakeFailure);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+class AddJunkToCookie : public TlsExtensionFilter {
+ public:
+ AddJunkToCookie(const std::shared_ptr<TlsAgent>& a) : TlsExtensionFilter(a) {}
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (extension_type != ssl_tls13_cookie_xtn) {
+ return KEEP;
+ }
+
+ *output = input;
+ // Add junk after the cookie.
+ static const uint8_t junk[2] = {1, 2};
+ output->Append(DataBuffer(junk, sizeof(junk)));
+ return CHANGE;
+ }
+};
+
+TEST_P(TlsConnectTls13, RetryCookieWithExtras) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeTlsFilter<AddJunkToCookie>(client_);
+
+ ExpectAlert(server_, kTlsAlertHandshakeFailure);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+// Stream only because DTLS drops bad packets.
+TEST_F(TlsConnectStreamTls13, RetryStatelessDamageFirstClientHello) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ auto damage_ch =
+ MakeTlsFilter<TlsExtensionInjector>(client_, 0xfff3, DataBuffer());
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ // Key exchange fails when the handshake continues because client and server
+ // disagree about the transcript.
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+}
+
+TEST_F(TlsConnectStreamTls13, RetryStatelessDamageSecondClientHello) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ auto damage_ch =
+ MakeTlsFilter<TlsExtensionInjector>(client_, 0xfff3, DataBuffer());
+
+ // Key exchange fails when the handshake continues because client and server
+ // disagree about the transcript.
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+}
+
+// Stream because SSL_SendSessionTicket only supports that.
+TEST_F(TlsConnectStreamTls13, SecondClientHelloSendSameTicket) {
+ // This simulates the scenario described at:
+ // https://bugzilla.mozilla.org/show_bug.cgi?id=1481271#c7
+ //
+ // Here two connections are interleaved. Tickets are issued on one
+ // connection. A HelloRetryRequest is triggered on the second connection,
+ // meaning that there are two ClientHellos. We need to check that both
+ // ClientHellos have the same ticket, even if a new ticket is issued on the
+ // other connection in the meantime.
+ //
+ // Connection 1: <handshake>
+ // Connection 1: S->C: NST=X
+ // Connection 2: C->S: CH [PSK_ID=X]
+ // Connection 1: S->C: NST=Y
+ // Connection 2: S->C: HRR
+ // Connection 2: C->S: CH [PSK_ID=Y]
+
+ // Connection 1, send a ticket after handshake is complete.
+ ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
+
+ Connect();
+
+ // Set this token so that RetryHelloWithToken() will check that this
+ // is the token that it receives in the HelloRetryRequest callback.
+ EXPECT_EQ(SECSuccess,
+ SSL_SendSessionTicket(server_->ssl_fd(), kApplicationToken,
+ sizeof(kApplicationToken)));
+ SendReceive(50);
+
+ // Connection 2, trigger HRR.
+ auto client2 =
+ std::make_shared<TlsAgent>(client_->name(), TlsAgent::CLIENT, variant_);
+ auto server2 =
+ std::make_shared<TlsAgent>(server_->name(), TlsAgent::SERVER, variant_);
+
+ client2->SetPeer(server2);
+ server2->SetPeer(client2);
+
+ client_.swap(client2);
+ server_.swap(server2);
+
+ ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
+
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ client_->StartConnect();
+ server_->StartConnect();
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess,
+ SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHelloWithToken, &cb_called));
+ client_->Handshake(); // Send ClientHello.
+ server_->Handshake(); // Process ClientHello, send HelloRetryRequest.
+
+ EXPECT_EQ(1U, cb_called) << "callback should be called once here";
+
+ // Connection 1, send another ticket.
+ client_.swap(client2);
+ server_.swap(server2);
+
+ // If the client uses this token, RetryHelloWithToken() will fail the test.
+ const uint8_t kAnotherApplicationToken[] = {0x92, 0x44, 0x01};
+ EXPECT_EQ(SECSuccess,
+ SSL_SendSessionTicket(server_->ssl_fd(), kAnotherApplicationToken,
+ sizeof(kAnotherApplicationToken)));
+ SendReceive(60);
+
+ // Connection 2, continue the handshake.
+ // The client should use kApplicationToken, not kAnotherApplicationToken.
+ client_.swap(client2);
+ server_.swap(server2);
+
+ client_->Handshake();
+ server_->Handshake();
+
+ EXPECT_EQ(2U, cb_called) << "callback should be called twice here";
+}
+
+// Read the cipher suite from the HRR and disable it on the identified agent.
+static void DisableSuiteFromHrr(
+ std::shared_ptr<TlsAgent>& agent,
+ std::shared_ptr<TlsHandshakeRecorder>& capture_hrr) {
+ uint32_t tmp;
+ size_t offset = 2 + 32; // skip version + server_random
+ ASSERT_TRUE(
+ capture_hrr->buffer().Read(offset, 1, &tmp)); // session_id length
+ EXPECT_EQ(0U, tmp);
+ offset += 1 + tmp;
+ ASSERT_TRUE(capture_hrr->buffer().Read(offset, 2, &tmp)); // suite
+ EXPECT_EQ(
+ SECSuccess,
+ SSL_CipherPrefSet(agent->ssl_fd(), static_cast<uint16_t>(tmp), PR_FALSE));
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteClient) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ auto capture_hrr =
+ MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_hello_retry_request);
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ DisableSuiteFromHrr(client_, capture_hrr);
+
+ // The client thinks that the HelloRetryRequest is bad, even though its
+ // because it changed its mind about the cipher suite.
+ ExpectAlert(client_, kTlsAlertIllegalParameter);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteServer) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ auto capture_hrr =
+ MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_hello_retry_request);
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ DisableSuiteFromHrr(server_, capture_hrr);
+
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessDisableGroupClient) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ client_->ConfigNamedGroups(groups);
+
+ // We're into undefined behavior on the client side, but - at the point this
+ // test was written - the client here doesn't amend its key shares because the
+ // server doesn't ask it to. The server notices that the key share (x25519)
+ // doesn't match the negotiated group (P-384) and objects.
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessDisableGroupServer) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessBadCookie) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+
+ // Now replace the self-encrypt MAC key with a garbage key.
+ static const uint8_t bad_hmac_key[32] = {0};
+ SECItem key_item = {siBuffer, const_cast<uint8_t*>(bad_hmac_key),
+ sizeof(bad_hmac_key)};
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ PK11SymKey* hmac_key =
+ PK11_ImportSymKey(slot.get(), CKM_SHA256_HMAC, PK11_OriginUnwrap,
+ CKA_SIGN, &key_item, nullptr);
+ ASSERT_NE(nullptr, hmac_key);
+ SSLInt_SetSelfEncryptMacKey(hmac_key); // Passes ownership.
+
+ MakeNewServer();
+
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Stream because the server doesn't consume the alert and terminate.
+TEST_F(TlsConnectStreamTls13, RetryWithDifferentCipherSuite) {
+ EnsureTlsSetup();
+ // Force a HelloRetryRequest.
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+ // Then switch out the default suite (TLS_AES_128_GCM_SHA256).
+ MakeTlsFilter<SelectedCipherSuiteReplacer>(server_,
+ TLS_CHACHA20_POLY1305_SHA256);
+
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code());
+}
+
+// This tests that the second attempt at sending a ClientHello (after receiving
+// a HelloRetryRequest) is correctly retransmitted.
+TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) {
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+ server_->ConfigNamedGroups(groups);
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x2));
+ Connect();
+}
+
+class TlsKeyExchange13 : public TlsKeyExchangeTest {};
+
+// This should work, with an HRR, because the server prefers x25519 and the
+// client generates a share for P-384 on the initial ClientHello.
+TEST_P(TlsKeyExchange13, ConnectEcdhePreferenceMismatchHrr) {
+ EnsureKeyShareSetup();
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ static const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp384r1};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+ Connect();
+ CheckKeys();
+ static const std::vector<SSLNamedGroup> expectedShares = {
+ ssl_grp_ec_secp384r1};
+ CheckKEXDetails(client_groups, expectedShares, ssl_grp_ec_curve25519);
+}
+
+TEST_P(TlsKeyExchange13, SecondClientHelloPreambleMatches) {
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+
+ ConfigureSelfEncrypt();
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ static const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ auto ch1 = MakeTlsFilter<ClientHelloPreambleCapture>(client_);
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+
+ MakeNewServer();
+ auto ch2 = MakeTlsFilter<ClientHelloPreambleCapture>(client_);
+ Handshake();
+
+ EXPECT_TRUE(ch1->captured());
+ EXPECT_TRUE(ch2->captured());
+ EXPECT_EQ(ch1->contents(), ch2->contents());
+}
+
+// This should work, but not use HRR because the key share for x25519 was
+// pre-generated by the client.
+TEST_P(TlsKeyExchange13, ConnectEcdhePreferenceMismatchHrrExtraShares) {
+ EnsureKeyShareSetup();
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ static const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp384r1};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+ EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+
+ Connect();
+ CheckKeys();
+ CheckKEXDetails(client_groups, client_groups);
+}
+
+// The callback should be run even if we have another reason to send
+// HelloRetryRequest. In this case, the server sends HRR because the server
+// wants an X25519 key share and the client didn't offer one.
+TEST_P(TlsKeyExchange13,
+ RetryCallbackRetryWithGroupMismatchAndAdditionalShares) {
+ EnsureKeyShareSetup();
+
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ static const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_curve25519};
+ server_->ConfigNamedGroups(server_groups);
+ EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+
+ auto capture_server =
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
+ capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit{capture_hrr_, capture_server}));
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+
+ // Do the first message exchange.
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+
+ EXPECT_EQ(1U, cb_called) << "callback should be called once here";
+ EXPECT_TRUE(capture_server->captured()) << "key_share extension expected";
+
+ uint32_t server_group = 0;
+ EXPECT_TRUE(capture_server->extension().Read(0, 2, &server_group));
+ EXPECT_EQ(ssl_grp_ec_curve25519, static_cast<SSLNamedGroup>(server_group));
+
+ Handshake();
+ CheckConnected();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(shares_capture2_->captured()) << "client should send shares";
+
+ CheckKeys();
+ static const std::vector<SSLNamedGroup> client_shares(
+ client_groups.begin(), client_groups.begin() + 2);
+ CheckKEXDetails(client_groups, client_shares, server_groups[0]);
+}
+
+TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) {
+ EnsureTlsSetup();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1};
+ client_->ConfigNamedGroups(client_groups);
+ static const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1};
+ server_->ConfigNamedGroups(server_groups);
+ StartConnect();
+
+ client_->Handshake();
+ server_->Handshake();
+
+ // Here we replace the TLS server with one that does TLS 1.2 only.
+ // This will happily send the client a TLS 1.2 ServerHello.
+ server_.reset(new TlsAgent(server_->name(), TlsAgent::SERVER, variant_));
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ server_->StartConnect();
+ ExpectAlert(client_, kTlsAlertIllegalParameter);
+ Handshake();
+ EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, server_->error_code());
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
+}
+
+// This class increments the low byte of the first Handshake.message_seq
+// field in every handshake record.
+class MessageSeqIncrementer : public TlsRecordFilter {
+ public:
+ MessageSeqIncrementer(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ if (header.content_type() != ssl_ct_handshake) {
+ return KEEP;
+ }
+
+ *changed = data;
+ // struct { uint8 msg_type; uint24 length; uint16 message_seq; ... }
+ // Handshake;
+ changed->data()[5]++;
+ EXPECT_NE(0, changed->data()[5]); // Check for overflow.
+ return CHANGE;
+ }
+};
+
+// A server that receives a ClientHello with message_seq == 1
+// assumes that this is after a stateless HelloRetryRequest.
+// However, it should reject the ClientHello if it lacks a cookie.
+TEST_F(TlsConnectDatagram13, MessageSeq1ClientHello) {
+ EnsureTlsSetup();
+ MakeTlsFilter<MessageSeqIncrementer>(client_);
+ ConnectExpectAlert(server_, kTlsAlertMissingExtension);
+ EXPECT_EQ(SSL_ERROR_MISSING_COOKIE_EXTENSION, server_->error_code());
+ EXPECT_EQ(SSL_ERROR_MISSING_EXTENSION_ALERT, client_->error_code());
+}
+
+class HelloRetryRequestAgentTest : public TlsAgentTestClient {
+ protected:
+ void SetUp() override {
+ TlsAgentTestClient::SetUp();
+ EnsureInit();
+ agent_->StartConnect();
+ }
+
+ void MakeCannedHrr(const uint8_t* body, size_t len, DataBuffer* hrr_record,
+ uint32_t seq_num = 0) const {
+ DataBuffer hrr_data;
+ const uint8_t ssl_hello_retry_random[] = {
+ 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C,
+ 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB,
+ 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C};
+
+ hrr_data.Allocate(len + 6);
+ size_t i = 0;
+ i = hrr_data.Write(i,
+ variant_ == ssl_variant_datagram
+ ? SSL_LIBRARY_VERSION_DTLS_1_2_WIRE
+ : SSL_LIBRARY_VERSION_TLS_1_2,
+ 2);
+ i = hrr_data.Write(i, ssl_hello_retry_random,
+ sizeof(ssl_hello_retry_random));
+ i = hrr_data.Write(i, static_cast<uint32_t>(0), 1); // session_id
+ i = hrr_data.Write(i, TLS_AES_128_GCM_SHA256, 2);
+ i = hrr_data.Write(i, ssl_compression_null, 1);
+ // Add extensions. First a length, which includes the supported version.
+ i = hrr_data.Write(i, static_cast<uint32_t>(len) + 6, 2);
+ // Now the supported version.
+ i = hrr_data.Write(i, ssl_tls13_supported_versions_xtn, 2);
+ i = hrr_data.Write(i, 2, 2);
+ i = hrr_data.Write(i,
+ (variant_ == ssl_variant_datagram)
+ ? (0x7f00 | DTLS_1_3_DRAFT_VERSION)
+ : SSL_LIBRARY_VERSION_TLS_1_3,
+ 2);
+ if (len) {
+ hrr_data.Write(i, body, len);
+ }
+ DataBuffer hrr;
+ MakeHandshakeMessage(kTlsHandshakeServerHello, hrr_data.data(),
+ hrr_data.len(), &hrr, seq_num);
+ MakeRecord(ssl_ct_handshake, SSL_LIBRARY_VERSION_TLS_1_3, hrr.data(),
+ hrr.len(), hrr_record, seq_num);
+ }
+
+ void MakeGroupHrr(SSLNamedGroup group, DataBuffer* hrr_record,
+ uint32_t seq_num = 0) const {
+ const uint8_t group_hrr[] = {
+ static_cast<uint8_t>(ssl_tls13_key_share_xtn >> 8),
+ static_cast<uint8_t>(ssl_tls13_key_share_xtn),
+ 0,
+ 2, // length of key share extension
+ static_cast<uint8_t>(group >> 8),
+ static_cast<uint8_t>(group)};
+ MakeCannedHrr(group_hrr, sizeof(group_hrr), hrr_record, seq_num);
+ }
+};
+
+// Send two HelloRetryRequest messages in response to the ClientHello. The are
+// constructed to appear legitimate by asking for a new share in each, so that
+// the client has to count to work out that the server is being unreasonable.
+TEST_P(HelloRetryRequestAgentTest, SendSecondHelloRetryRequest) {
+ DataBuffer hrr;
+ MakeGroupHrr(ssl_grp_ec_secp384r1, &hrr, 0);
+ ProcessMessage(hrr, TlsAgent::STATE_CONNECTING);
+ MakeGroupHrr(ssl_grp_ec_secp521r1, &hrr, 1);
+ ExpectAlert(kTlsAlertUnexpectedMessage);
+ ProcessMessage(hrr, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_UNEXPECTED_HELLO_RETRY_REQUEST);
+}
+
+// Here the client receives a HelloRetryRequest with a group that they already
+// provided a share for.
+TEST_P(HelloRetryRequestAgentTest, HandleBogusHelloRetryRequest) {
+ DataBuffer hrr;
+ MakeGroupHrr(ssl_grp_ec_curve25519, &hrr);
+ ExpectAlert(kTlsAlertIllegalParameter);
+ ProcessMessage(hrr, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST);
+}
+
+TEST_P(HelloRetryRequestAgentTest, HandleNoopHelloRetryRequest) {
+ DataBuffer hrr;
+ MakeCannedHrr(nullptr, 0U, &hrr);
+ ExpectAlert(kTlsAlertDecodeError);
+ ProcessMessage(hrr, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST);
+}
+
+class ReplaceRandom : public TlsHandshakeFilter {
+ public:
+ ReplaceRandom(const std::shared_ptr<TlsAgent>& a, const DataBuffer& r)
+ : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}), random_(r) {}
+
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ output->Assign(input);
+ output->Write(2, random_);
+ return CHANGE;
+ }
+
+ private:
+ DataBuffer random_;
+};
+
+// Make sure that the TLS 1.3 special value for the ServerHello.random
+// is rejected by earlier versions.
+TEST_P(TlsConnectStreamPre13, HrrRandomOnTls10) {
+ static const uint8_t hrr_random[] = {
+ 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C,
+ 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB,
+ 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C};
+
+ EnsureTlsSetup();
+ MakeTlsFilter<ReplaceRandom>(server_,
+ DataBuffer(hrr_random, sizeof(hrr_random)));
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13, HrrThenTls12) {
+ StartConnect();
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+
+ client_->Handshake(); // Send CH (1.3)
+ server_->Handshake(); // Send HRR.
+ EXPECT_EQ(1U, cb_called);
+
+ // Replace the client with a new TLS 1.2 client. Don't call Init(), since
+ // it will artifically limit the server's vrange.
+ client_.reset(
+ new TlsAgent(client_->name(), TlsAgent::CLIENT, ssl_variant_stream));
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+
+ client_->StartConnect();
+ client_->Handshake(); // Send CH (1.2)
+ ExpectAlert(server_, kTlsAlertProtocolVersion);
+ server_->Handshake();
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+ client_->Handshake();
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13, ZeroRttHrrThenTls12) {
+ SetupForZeroRtt();
+
+ client_->Set0RttEnabled(true);
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+
+ client_->Handshake(); // Send CH (1.3)
+ ZeroRttSendReceive(true, false);
+ server_->Handshake(); // Send HRR.
+ EXPECT_EQ(1U, cb_called);
+
+ // Replace the client with a new TLS 1.2 client. Don't call Init(), since
+ // it will artifically limit the server's vrange.
+ client_.reset(
+ new TlsAgent(client_->name(), TlsAgent::CLIENT, ssl_variant_stream));
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+
+ client_->StartConnect();
+ client_->Handshake(); // Send CH (1.2)
+ ExpectAlert(server_, kTlsAlertProtocolVersion);
+ server_->Handshake();
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+ client_->Handshake();
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+
+ // Try to write something
+ server_->Handshake();
+ client_->ExpectReadWriteError();
+ client_->SendData(1);
+ uint8_t buf[1];
+ EXPECT_EQ(-1, PR_Read(server_->ssl_fd(), buf, sizeof(buf)));
+ EXPECT_EQ(SSL_ERROR_HANDSHAKE_FAILED, PR_GetError());
+}
+
+TEST_F(TlsConnectStreamTls13, HrrThenTls12SupportedVersions) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+
+ client_->Handshake(); // Send CH (1.3)
+ ZeroRttSendReceive(true, false);
+ server_->Handshake(); // Send HRR.
+ EXPECT_EQ(1U, cb_called);
+
+ // Replace the client with a new TLS 1.2 client. Don't call Init(), since
+ // it will artifically limit the server's vrange.
+ client_.reset(
+ new TlsAgent(client_->name(), TlsAgent::CLIENT, ssl_variant_stream));
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ // Negotiate via supported_versions
+ static const uint8_t tls12[] = {0x02, 0x03, 0x03};
+ auto replacer = MakeTlsFilter<TlsExtensionInjector>(
+ client_, ssl_tls13_supported_versions_xtn,
+ DataBuffer(tls12, sizeof(tls12)));
+
+ client_->StartConnect();
+ client_->Handshake(); // Send CH (1.2)
+ ExpectAlert(server_, kTlsAlertProtocolVersion);
+ server_->Handshake();
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+ client_->Handshake();
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+}
+
+INSTANTIATE_TEST_SUITE_P(HelloRetryRequestAgentTests,
+ HelloRetryRequestAgentTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV13));
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_TEST_SUITE_P(HelloRetryRequestKeyExchangeTests, TlsKeyExchange13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV13));
+#endif
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc
new file mode 100644
index 0000000000..b7f0351d11
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc
@@ -0,0 +1,164 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <cstdlib>
+#include <fstream>
+#include <sstream>
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+static const std::string kKeylogFilePath = "keylog.txt";
+static const std::string kKeylogBlankEnv = "SSLKEYLOGFILE=";
+static const std::string kKeylogSetEnv = kKeylogBlankEnv + kKeylogFilePath;
+
+extern "C" {
+extern FILE* ssl_keylog_iob;
+}
+
+class KeyLogFileTestBase : public TlsConnectGeneric {
+ private:
+ std::string env_to_set_;
+
+ public:
+ virtual void CheckKeyLog() = 0;
+
+ KeyLogFileTestBase(std::string env) : env_to_set_(env) {}
+
+ void SetUp() override {
+ TlsConnectGeneric::SetUp();
+ // Remove previous results (if any).
+ (void)remove(kKeylogFilePath.c_str());
+ PR_SetEnv(env_to_set_.c_str());
+ }
+
+ void ConnectAndCheck() {
+ // This is a child process, ensure that error messages immediately
+ // propagate or else it will not be visible.
+ ::testing::GTEST_FLAG(throw_on_failure) = true;
+
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+ } else {
+ Connect();
+ }
+ CheckKeyLog();
+ _exit(0);
+ }
+};
+
+class KeyLogFileTest : public KeyLogFileTestBase {
+ public:
+ KeyLogFileTest() : KeyLogFileTestBase(kKeylogSetEnv) {}
+
+ void CheckKeyLog() override {
+ std::ifstream f(kKeylogFilePath);
+ std::map<std::string, size_t> labels;
+ std::set<std::string> client_randoms;
+ for (std::string line; std::getline(f, line);) {
+ if (line[0] == '#') {
+ continue;
+ }
+
+ std::istringstream iss(line);
+ std::string label, client_random, secret;
+ iss >> label >> client_random >> secret;
+
+ ASSERT_EQ(64U, client_random.size());
+ client_randoms.insert(client_random);
+ labels[label]++;
+ }
+
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ ASSERT_EQ(1U, client_randoms.size());
+ } else {
+ /* two handshakes for 0-RTT */
+ ASSERT_EQ(2U, client_randoms.size());
+ }
+
+ // Every entry occurs twice (one log from server, one from client).
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ ASSERT_EQ(2U, labels["CLIENT_RANDOM"]);
+ } else {
+ ASSERT_EQ(2U, labels["CLIENT_EARLY_TRAFFIC_SECRET"]);
+ ASSERT_EQ(2U, labels["EARLY_EXPORTER_SECRET"]);
+ ASSERT_EQ(4U, labels["CLIENT_HANDSHAKE_TRAFFIC_SECRET"]);
+ ASSERT_EQ(4U, labels["SERVER_HANDSHAKE_TRAFFIC_SECRET"]);
+ ASSERT_EQ(4U, labels["CLIENT_TRAFFIC_SECRET_0"]);
+ ASSERT_EQ(4U, labels["SERVER_TRAFFIC_SECRET_0"]);
+ ASSERT_EQ(4U, labels["EXPORTER_SECRET"]);
+ }
+ }
+};
+
+// Tests are run in a separate process to ensure that NSS is not initialized yet
+// and can process the SSLKEYLOGFILE environment variable.
+
+TEST_P(KeyLogFileTest, KeyLogFile) {
+ testing::GTEST_FLAG(death_test_style) = "threadsafe";
+
+ ASSERT_EXIT(ConnectAndCheck(), ::testing::ExitedWithCode(0), "");
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ KeyLogFileDTLS12, KeyLogFileTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11V12));
+INSTANTIATE_TEST_SUITE_P(
+ KeyLogFileTLS12, KeyLogFileTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12));
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_TEST_SUITE_P(
+ KeyLogFileTLS13, KeyLogFileTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV13));
+#endif
+
+class KeyLogFileUnsetTest : public KeyLogFileTestBase {
+ public:
+ KeyLogFileUnsetTest() : KeyLogFileTestBase(kKeylogBlankEnv) {}
+
+ void CheckKeyLog() override {
+ std::ifstream f(kKeylogFilePath);
+ EXPECT_FALSE(f.good());
+
+ EXPECT_EQ(nullptr, ssl_keylog_iob);
+ }
+};
+
+TEST_P(KeyLogFileUnsetTest, KeyLogFile) {
+ testing::GTEST_FLAG(death_test_style) = "threadsafe";
+
+ ASSERT_EXIT(ConnectAndCheck(), ::testing::ExitedWithCode(0), "");
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ KeyLogFileDTLS12, KeyLogFileUnsetTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11V12));
+INSTANTIATE_TEST_SUITE_P(
+ KeyLogFileTLS12, KeyLogFileUnsetTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12));
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_TEST_SUITE_P(
+ KeyLogFileTLS13, KeyLogFileUnsetTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV13));
+#endif
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc
new file mode 100644
index 0000000000..b921d2c1e6
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc
@@ -0,0 +1,209 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+// All stream only tests; DTLS isn't supported yet.
+
+TEST_F(TlsConnectTest, KeyUpdateClient) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(4, 3);
+}
+
+TEST_F(TlsConnectStreamTls13, KeyUpdateTooEarly_Client) {
+ StartConnect();
+ auto filter = MakeTlsFilter<TlsEncryptedHandshakeMessageReplacer>(
+ server_, kTlsHandshakeFinished, kTlsHandshakeKeyUpdate);
+ filter->EnableDecryption();
+
+ client_->Handshake();
+ server_->Handshake();
+ ExpectAlert(client_, kTlsAlertUnexpectedMessage);
+ client_->Handshake();
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_KEY_UPDATE);
+ server_->Handshake();
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13, KeyUpdateTooEarly_Server) {
+ StartConnect();
+ auto filter = MakeTlsFilter<TlsEncryptedHandshakeMessageReplacer>(
+ client_, kTlsHandshakeFinished, kTlsHandshakeKeyUpdate);
+ filter->EnableDecryption();
+
+ client_->Handshake();
+ server_->Handshake();
+ client_->Handshake();
+ ExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ server_->Handshake();
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_KEY_UPDATE);
+ client_->Handshake();
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateClientRequestUpdate) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE));
+ // SendReceive() only gives each peer one chance to read. This isn't enough
+ // when the read on one side generates another handshake message. A second
+ // read gives each peer an extra chance to consume the KeyUpdate.
+ SendReceive(50);
+ SendReceive(60); // Cumulative count.
+ CheckEpochs(4, 4);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateServer) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(3, 4);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateServerRequestUpdate) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(4, 4);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateConsecutiveRequests) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ SendReceive(50);
+ SendReceive(60);
+ // The server should have updated twice, but the client should have declined
+ // to respond to the second request from the server, since it doesn't send
+ // anything in between those two requests.
+ CheckEpochs(4, 5);
+}
+
+// Check that a local update can be immediately followed by a remotely triggered
+// update even if there is no use of the keys.
+TEST_F(TlsConnectTest, KeyUpdateLocalUpdateThenConsecutiveRequests) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ // This should trigger an update on the client.
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE));
+ // The client should update for the first request.
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ // ...but not the second.
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ SendReceive(50);
+ SendReceive(60);
+ // Both should have updated twice.
+ CheckEpochs(5, 5);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateMultiple) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(5, 6);
+}
+
+// Both ask the other for an update, and both should react.
+TEST_F(TlsConnectTest, KeyUpdateBothRequest) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(5, 5);
+}
+
+// If the sequence number exceeds the number of writes before an automatic
+// update (currently 3/4 of the max records for the cipher suite), then the
+// stack should send an update automatically (but not request one).
+TEST_F(TlsConnectTest, KeyUpdateAutomaticOnWrite) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectWithCipherSuite(TLS_AES_128_GCM_SHA256);
+
+ // Set this to one below the write threshold.
+ uint64_t threshold = (0x5aULL << 28) * 3 / 4;
+ EXPECT_EQ(SECSuccess,
+ SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), threshold));
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), threshold));
+
+ // This should be OK.
+ client_->SendData(10);
+ server_->ReadBytes();
+
+ // This should cause the client to update.
+ client_->SendData(10);
+ server_->ReadBytes();
+
+ SendReceive(100);
+ CheckEpochs(4, 3);
+}
+
+// If the sequence number exceeds a certain number of reads (currently 7/8 of
+// the max records for the cipher suite), then the stack should send AND request
+// an update automatically. However, the sender (client) will be above its
+// automatic update threshold, so the KeyUpdate - that it sends with the old
+// cipher spec - will exceed the receiver (server) automatic update threshold.
+// The receiver gets a packet with a sequence number over its automatic read
+// update threshold. Even though the sender has updated, the code that checks
+// the sequence numbers at the receiver doesn't know this and it will request an
+// update. This causes two updates: one from the sender (without requesting a
+// response) and one from the receiver (which does request a response).
+TEST_F(TlsConnectTest, KeyUpdateAutomaticOnRead) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectWithCipherSuite(TLS_AES_128_GCM_SHA256);
+
+ // Move to right at the read threshold. Unlike the write test, we can't send
+ // packets because that would cause the client to update, which would spoil
+ // the test.
+ uint64_t threshold = ((0x5aULL << 28) * 7 / 8) + 1;
+ EXPECT_EQ(SECSuccess,
+ SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), threshold));
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), threshold));
+
+ // This should cause the client to update, but not early enough to prevent the
+ // server from updating also.
+ client_->SendData(10);
+ server_->ReadBytes();
+
+ // Need two SendReceive() calls to ensure that the update that the server
+ // requested is properly generated and consumed.
+ SendReceive(70);
+ SendReceive(80);
+ CheckEpochs(5, 4);
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
new file mode 100644
index 0000000000..491f50921f
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
@@ -0,0 +1,801 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include <vector>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectGeneric, SetupOnly) {}
+
+TEST_P(TlsConnectGeneric, Connect) {
+ SetExpectedVersion(std::get<1>(GetParam()));
+ Connect();
+ CheckKeys();
+}
+
+TEST_P(TlsConnectGeneric, ConnectEcdsa) {
+ SetExpectedVersion(std::get<1>(GetParam()));
+ Reset(TlsAgent::kServerEcdsa256);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
+}
+
+TEST_P(TlsConnectGeneric, CipherSuiteMismatch) {
+ EnsureTlsSetup();
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ server_->EnableSingleCipher(TLS_AES_256_GCM_SHA384);
+ } else {
+ client_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA);
+ }
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+class TlsAlertRecorder : public TlsRecordFilter {
+ public:
+ TlsAlertRecorder(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a), level_(255), description_(255) {}
+
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ if (level_ != 255) { // Already captured.
+ return KEEP;
+ }
+ if (header.content_type() != ssl_ct_alert) {
+ return KEEP;
+ }
+
+ std::cerr << "Alert: " << input << std::endl;
+
+ TlsParser parser(input);
+ EXPECT_TRUE(parser.Read(&level_));
+ EXPECT_TRUE(parser.Read(&description_));
+ return KEEP;
+ }
+
+ uint8_t level() const { return level_; }
+ uint8_t description() const { return description_; }
+
+ private:
+ uint8_t level_;
+ uint8_t description_;
+};
+
+class HelloTruncator : public TlsHandshakeFilter {
+ public:
+ HelloTruncator(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(
+ a, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {}
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ output->Assign(input.data(), input.len() - 1);
+ return CHANGE;
+ }
+};
+
+// Verify that when NSS reports that an alert is sent, it is actually sent.
+TEST_P(TlsConnectGeneric, CaptureAlertServer) {
+ MakeTlsFilter<HelloTruncator>(client_);
+ auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(server_);
+
+ ConnectExpectAlert(server_, kTlsAlertDecodeError);
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
+}
+
+TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
+ MakeTlsFilter<HelloTruncator>(server_);
+ auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_);
+
+ ConnectExpectAlert(client_, kTlsAlertDecodeError);
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
+}
+
+// In TLS 1.3, the server can't read the client alert.
+TEST_P(TlsConnectTls13, CaptureAlertClient) {
+ MakeTlsFilter<HelloTruncator>(server_);
+ auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_);
+
+ StartConnect();
+
+ client_->Handshake();
+ client_->ExpectSendAlert(kTlsAlertDecodeError);
+ server_->Handshake();
+ client_->Handshake();
+ if (variant_ == ssl_variant_stream) {
+ // DTLS just drops the alert it can't decrypt.
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ }
+ server_->Handshake();
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectFalseStart) {
+ client_->EnableFalseStart();
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpn) {
+ EnableAlpn();
+ Connect();
+ CheckAlpn("a");
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpnPriorityA) {
+ // "alpn" "npn"
+ // alpn is the fallback here. npn has the highest priority and should be
+ // picked.
+ const std::vector<uint8_t> alpn = {0x04, 0x61, 0x6c, 0x70, 0x6e,
+ 0x03, 0x6e, 0x70, 0x6e};
+ EnableAlpn(alpn);
+ Connect();
+ CheckAlpn("npn");
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpnPriorityB) {
+ // "alpn" "npn" "http"
+ // npn has the highest priority and should be picked.
+ const std::vector<uint8_t> alpn = {0x04, 0x61, 0x6c, 0x70, 0x6e, 0x03, 0x6e,
+ 0x70, 0x6e, 0x04, 0x68, 0x74, 0x74, 0x70};
+ EnableAlpn(alpn);
+ Connect();
+ CheckAlpn("npn");
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpnClone) {
+ EnsureModelSockets();
+ client_model_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
+ server_model_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
+ Connect();
+ CheckAlpn("a");
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackA) {
+ // "ab" "alpn"
+ const std::vector<uint8_t> client_alpn = {0x02, 0x61, 0x62, 0x04,
+ 0x61, 0x6c, 0x70, 0x6e};
+ EnableAlpnWithCallback(client_alpn, "alpn");
+ Connect();
+ CheckAlpn("alpn");
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackB) {
+ // "ab" "alpn"
+ const std::vector<uint8_t> client_alpn = {0x02, 0x61, 0x62, 0x04,
+ 0x61, 0x6c, 0x70, 0x6e};
+ EnableAlpnWithCallback(client_alpn, "ab");
+ Connect();
+ CheckAlpn("ab");
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackC) {
+ // "cd" "npn" "alpn"
+ const std::vector<uint8_t> client_alpn = {0x02, 0x63, 0x64, 0x03, 0x6e, 0x70,
+ 0x6e, 0x04, 0x61, 0x6c, 0x70, 0x6e};
+ EnableAlpnWithCallback(client_alpn, "npn");
+ Connect();
+ CheckAlpn("npn");
+}
+
+TEST_P(TlsConnectDatagram, ConnectSrtp) {
+ EnableSrtp();
+ Connect();
+ CheckSrtp();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGeneric, ConnectSendReceive) {
+ Connect();
+ SendReceive();
+}
+
+class SaveTlsRecord : public TlsRecordFilter {
+ public:
+ SaveTlsRecord(const std::shared_ptr<TlsAgent>& a, size_t index)
+ : TlsRecordFilter(a), index_(index), count_(0), contents_() {}
+
+ const DataBuffer& contents() const { return contents_; }
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ if (count_++ == index_) {
+ contents_ = data;
+ }
+ return KEEP;
+ }
+
+ private:
+ const size_t index_;
+ size_t count_;
+ DataBuffer contents_;
+};
+
+// Check that decrypting filters work and can read any record.
+// This test (currently) only works in TLS 1.3 where we can decrypt.
+TEST_F(TlsConnectStreamTls13, DecryptRecordClient) {
+ EnsureTlsSetup();
+ // 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer
+ auto saved = MakeTlsFilter<SaveTlsRecord>(client_, 3);
+ saved->EnableDecryption();
+ Connect();
+ SendReceive();
+
+ static const uint8_t data[] = {0xde, 0xad, 0xdc};
+ DataBuffer buf(data, sizeof(data));
+ client_->SendBuffer(buf);
+ EXPECT_EQ(buf, saved->contents());
+}
+
+TEST_F(TlsConnectStreamTls13, DecryptRecordServer) {
+ EnsureTlsSetup();
+ // Disable tickets so that we are sure to not get NewSessionTicket.
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
+ SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
+ // 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer
+ auto saved = MakeTlsFilter<SaveTlsRecord>(server_, 3);
+ saved->EnableDecryption();
+ Connect();
+ SendReceive();
+
+ static const uint8_t data[] = {0xde, 0xad, 0xd5};
+ DataBuffer buf(data, sizeof(data));
+ server_->SendBuffer(buf);
+ EXPECT_EQ(buf, saved->contents());
+}
+
+class DropTlsRecord : public TlsRecordFilter {
+ public:
+ DropTlsRecord(const std::shared_ptr<TlsAgent>& a, size_t index)
+ : TlsRecordFilter(a), index_(index), count_(0) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ if (count_++ == index_) {
+ return DROP;
+ }
+ return KEEP;
+ }
+
+ private:
+ const size_t index_;
+ size_t count_;
+};
+
+// Test that decrypting filters work correctly and are able to drop records.
+TEST_F(TlsConnectStreamTls13, DropRecordServer) {
+ EnsureTlsSetup();
+ // Disable session tickets so that the server doesn't send an extra record.
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
+ SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
+
+ // 0 = ServerHello, 1 = other handshake, 2 = first write
+ auto filter = MakeTlsFilter<DropTlsRecord>(server_, 2);
+ filter->EnableDecryption();
+ Connect();
+ server_->SendData(23, 23); // This should be dropped, so it won't be counted.
+ server_->ResetSentBytes();
+ SendReceive();
+}
+
+TEST_F(TlsConnectStreamTls13, DropRecordClient) {
+ EnsureTlsSetup();
+ // 0 = ClientHello, 1 = Finished, 2 = first write
+ auto filter = MakeTlsFilter<DropTlsRecord>(client_, 2);
+ filter->EnableDecryption();
+ Connect();
+ client_->SendData(26, 26); // This should be dropped, so it won't be counted.
+ client_->ResetSentBytes();
+ SendReceive();
+}
+
+// Check that a server can use 0.5 RTT if client authentication isn't enabled.
+TEST_P(TlsConnectTls13, WriteBeforeClientFinished) {
+ EnsureTlsSetup();
+ StartConnect();
+ client_->Handshake(); // ClientHello
+ server_->Handshake(); // ServerHello
+
+ server_->SendData(10);
+ client_->ReadBytes(10); // Client should emit the Finished as a side-effect.
+ server_->Handshake(); // Server consumes the Finished.
+ CheckConnected();
+}
+
+// We don't allow 0.5 RTT if client authentication is requested.
+TEST_P(TlsConnectTls13, WriteBeforeClientFinishedClientAuth) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(false);
+ StartConnect();
+ client_->Handshake(); // ClientHello
+ server_->Handshake(); // ServerHello
+
+ static const uint8_t data[] = {1, 2, 3};
+ EXPECT_GT(0, PR_Write(server_->ssl_fd(), data, sizeof(data)));
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+// 0.5 RTT should fail with client authentication required.
+TEST_P(TlsConnectTls13, WriteBeforeClientFinishedClientAuthRequired) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ StartConnect();
+ client_->Handshake(); // ClientHello
+ server_->Handshake(); // ServerHello
+
+ static const uint8_t data[] = {1, 2, 3};
+ EXPECT_GT(0, PR_Write(server_->ssl_fd(), data, sizeof(data)));
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+// The next two tests takes advantage of the fact that we
+// automatically read the first 1024 bytes, so if
+// we provide 1200 bytes, they overrun the read buffer
+// provided by the calling test.
+
+// DTLS should return an error.
+TEST_P(TlsConnectDatagram, ShortRead) {
+ Connect();
+ client_->ExpectReadWriteError();
+ server_->SendData(50, 50);
+ client_->ReadBytes(20);
+ EXPECT_EQ(0U, client_->received_bytes());
+ EXPECT_EQ(SSL_ERROR_RX_SHORT_DTLS_READ, PORT_GetError());
+
+ // Now send and receive another packet.
+ server_->ResetSentBytes(); // Reset the counter.
+ SendReceive();
+}
+
+// TLS should get the write in two chunks.
+TEST_P(TlsConnectStream, ShortRead) {
+ // This test behaves oddly with TLS 1.0 because of 1/n+1 splitting,
+ // so skip in that case.
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_1) GTEST_SKIP();
+
+ Connect();
+ server_->SendData(50, 50);
+ // Read the first tranche.
+ client_->ReadBytes(20);
+ ASSERT_EQ(20U, client_->received_bytes());
+ // The second tranche should now immediately be available.
+ client_->ReadBytes();
+ ASSERT_EQ(50U, client_->received_bytes());
+}
+
+// We enable compression via the API but it's disabled internally,
+// so we should never get it.
+TEST_P(TlsConnectGeneric, ConnectWithCompressionEnabled) {
+ EnsureTlsSetup();
+ client_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE);
+ server_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE);
+ Connect();
+ EXPECT_FALSE(client_->is_compressed());
+ SendReceive();
+}
+
+class TlsHolddownTest : public TlsConnectDatagram {
+ protected:
+ // This causes all timers to run to completion. It advances the clock and
+ // handshakes on both peers until both peers have no more timers pending,
+ // which should happen at the end of a handshake. This is necessary to ensure
+ // that the relatively long holddown timer expires, but that any other timers
+ // also expire and run correctly.
+ void RunAllTimersDown() {
+ while (true) {
+ PRIntervalTime time;
+ SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time);
+ if (rv != SECSuccess) {
+ rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time);
+ if (rv != SECSuccess) {
+ break; // Neither peer has an outstanding timer.
+ }
+ }
+
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "Shifting timers" << std::endl;
+ }
+ ShiftDtlsTimers();
+ Handshake();
+ }
+ }
+};
+
+TEST_P(TlsHolddownTest, TestDtlsHolddownExpiry) {
+ Connect();
+ std::cerr << "Expiring holddown timer" << std::endl;
+ RunAllTimersDown();
+ SendReceive();
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // One for send, one for receive.
+ EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd()));
+ }
+}
+
+TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ RunAllTimersDown();
+ SendReceive();
+ // One for send, one for receive.
+ EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd()));
+}
+
+class TlsPreCCSHeaderInjector : public TlsRecordFilter {
+ public:
+ TlsPreCCSHeaderInjector(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a) {}
+ virtual PacketFilter::Action FilterRecord(
+ const TlsRecordHeader& record_header, const DataBuffer& input,
+ size_t* offset, DataBuffer* output) override {
+ if (record_header.content_type() != ssl_ct_change_cipher_spec) {
+ return KEEP;
+ }
+
+ std::cerr << "Injecting Finished header before CCS\n";
+ const uint8_t hhdr[] = {kTlsHandshakeFinished, 0x00, 0x00, 0x0c};
+ DataBuffer hhdr_buf(hhdr, sizeof(hhdr));
+ TlsRecordHeader nhdr(record_header.variant(), record_header.version(),
+ ssl_ct_handshake, 0);
+ *offset = nhdr.Write(output, *offset, hhdr_buf);
+ *offset = record_header.Write(output, *offset, input);
+ return CHANGE;
+ }
+};
+
+TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) {
+ MakeTlsFilter<TlsPreCCSHeaderInjector>(client_);
+ ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+}
+
+TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) {
+ MakeTlsFilter<TlsPreCCSHeaderInjector>(server_);
+ StartConnect();
+ ExpectAlert(client_, kTlsAlertUnexpectedMessage);
+ Handshake();
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+ server_->Handshake(); // Make sure alert is consumed.
+}
+
+TEST_P(TlsConnectTls13, UnknownAlert) {
+ Connect();
+ server_->ExpectSendAlert(0xff, kTlsAlertWarning);
+ client_->ExpectReceiveAlert(0xff, kTlsAlertWarning);
+ SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning,
+ 0xff); // Unknown value.
+ client_->ExpectReadWriteError();
+ client_->WaitForErrorCode(SSL_ERROR_RX_UNKNOWN_ALERT, 2000);
+}
+
+TEST_P(TlsConnectTls13, AlertWrongLevel) {
+ Connect();
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage, kTlsAlertWarning);
+ client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage, kTlsAlertWarning);
+ SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning,
+ kTlsAlertUnexpectedMessage);
+ client_->ExpectReadWriteError();
+ client_->WaitForErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT, 2000);
+}
+
+TEST_P(TlsConnectTls13, UnknownRecord) {
+ static const uint8_t kUknownRecord[] = {
+ 0xff, SSL_LIBRARY_VERSION_TLS_1_2 >> 8,
+ SSL_LIBRARY_VERSION_TLS_1_2 & 0xff, 0, 0};
+
+ Connect();
+ if (variant_ == ssl_variant_stream) {
+ // DTLS just drops the record with an invalid type.
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ }
+ client_->SendDirect(DataBuffer(kUknownRecord, sizeof(kUknownRecord)));
+ server_->ExpectReadWriteError();
+ server_->ReadBytes();
+ if (variant_ == ssl_variant_stream) {
+ EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code());
+ } else {
+ EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, server_->error_code());
+ }
+}
+
+TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) {
+ EnsureTlsSetup();
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake(); // Send first flight.
+ client_->adapter()->SetWriteError(PR_IO_ERROR);
+ client_->Handshake(); // This will get an error, but shouldn't crash.
+ client_->CheckErrorCode(SSL_ERROR_SOCKET_WRITE_FAILURE);
+}
+
+TEST_P(TlsConnectDatagram, BlockedWrite) {
+ Connect();
+
+ // Mark the socket as blocked.
+ client_->adapter()->SetWriteError(PR_WOULD_BLOCK_ERROR);
+ static const uint8_t data[] = {1, 2, 3};
+ int32_t rv = PR_Write(client_->ssl_fd(), data, sizeof(data));
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Remove the write error and though the previous write failed, future reads
+ // and writes should just work as if it never happened.
+ client_->adapter()->SetWriteError(0);
+ SendReceive();
+}
+
+TEST_F(TlsConnectTest, ConnectSSLv3) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_3_0);
+ EnableOnlyStaticRsaCiphers();
+ Connect();
+ CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none);
+}
+
+TEST_F(TlsConnectTest, ConnectSSLv3ClientAuth) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_3_0);
+ EnableOnlyStaticRsaCiphers();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none);
+}
+
+static size_t ExpectedCbcLen(size_t in, size_t hmac = 20, size_t block = 16) {
+ // MAC-then-Encrypt expansion formula:
+ return ((in + hmac + (block - 1)) / block) * block;
+}
+
+TEST_F(TlsConnectTest, OneNRecordSplitting) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0);
+ EnsureTlsSetup();
+ ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA);
+ auto records = MakeTlsFilter<TlsRecordRecorder>(server_);
+ // This should be split into 1, 16384 and 20.
+ DataBuffer big_buffer;
+ big_buffer.Allocate(1 + 16384 + 20);
+ server_->SendBuffer(big_buffer);
+ ASSERT_EQ(3U, records->count());
+ EXPECT_EQ(ExpectedCbcLen(1), records->record(0).buffer.len());
+ EXPECT_EQ(ExpectedCbcLen(16384), records->record(1).buffer.len());
+ EXPECT_EQ(ExpectedCbcLen(20), records->record(2).buffer.len());
+}
+
+// We can't test for randomness easily here, but we can test that we don't
+// produce a zero value, or produce the same value twice. There are 5 values
+// here: two ClientHello.random, two ServerHello.random, and one zero value.
+// Matrix them and fail if any are the same.
+TEST_P(TlsConnectGeneric, CheckRandoms) {
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+
+ static const size_t random_len = 32;
+ uint8_t crandom1[random_len], srandom1[random_len];
+ uint8_t z[random_len] = {0};
+
+ auto ch = MakeTlsFilter<TlsHandshakeRecorder>(client_, ssl_hs_client_hello);
+ auto sh = MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_server_hello);
+ Connect();
+ ASSERT_TRUE(ch->buffer().len() > (random_len + 2));
+ ASSERT_TRUE(sh->buffer().len() > (random_len + 2));
+ memcpy(crandom1, ch->buffer().data() + 2, random_len);
+ memcpy(srandom1, sh->buffer().data() + 2, random_len);
+ EXPECT_NE(0, memcmp(crandom1, srandom1, random_len));
+ EXPECT_NE(0, memcmp(crandom1, z, random_len));
+ EXPECT_NE(0, memcmp(srandom1, z, random_len));
+
+ Reset();
+ ch = MakeTlsFilter<TlsHandshakeRecorder>(client_, ssl_hs_client_hello);
+ sh = MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_server_hello);
+ Connect();
+ ASSERT_TRUE(ch->buffer().len() > (random_len + 2));
+ ASSERT_TRUE(sh->buffer().len() > (random_len + 2));
+ const uint8_t* crandom2 = ch->buffer().data() + 2;
+ const uint8_t* srandom2 = sh->buffer().data() + 2;
+
+ EXPECT_NE(0, memcmp(crandom2, srandom2, random_len));
+ EXPECT_NE(0, memcmp(crandom2, z, random_len));
+ EXPECT_NE(0, memcmp(srandom2, z, random_len));
+
+ EXPECT_NE(0, memcmp(crandom1, crandom2, random_len));
+ EXPECT_NE(0, memcmp(crandom1, srandom2, random_len));
+ EXPECT_NE(0, memcmp(srandom1, crandom2, random_len));
+ EXPECT_NE(0, memcmp(srandom1, srandom2, random_len));
+}
+
+void FailOnCloseNotify(const PRFileDesc* fd, void* arg, const SSLAlert* alert) {
+ ADD_FAILURE() << "received alert " << alert->description;
+}
+
+void CheckCloseNotify(const PRFileDesc* fd, void* arg, const SSLAlert* alert) {
+ *reinterpret_cast<bool*>(arg) = true;
+ EXPECT_EQ(close_notify, alert->description);
+ EXPECT_EQ(alert_warning, alert->level);
+}
+
+TEST_P(TlsConnectGeneric, ShutdownOneSide) {
+ Connect();
+
+ // Setup to check alerts.
+ EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(server_->ssl_fd(),
+ FailOnCloseNotify, nullptr));
+ EXPECT_EQ(SECSuccess, SSL_AlertReceivedCallback(client_->ssl_fd(),
+ FailOnCloseNotify, nullptr));
+
+ bool client_sent = false;
+ EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(client_->ssl_fd(),
+ CheckCloseNotify, &client_sent));
+ bool server_received = false;
+ EXPECT_EQ(SECSuccess,
+ SSL_AlertReceivedCallback(server_->ssl_fd(), CheckCloseNotify,
+ &server_received));
+ EXPECT_EQ(PR_SUCCESS, PR_Shutdown(client_->ssl_fd(), PR_SHUTDOWN_SEND));
+
+ // Make sure that the server reads out the close_notify.
+ uint8_t buf[10];
+ EXPECT_EQ(0, PR_Read(server_->ssl_fd(), buf, sizeof(buf)));
+
+ // Reading and writing should still work in the one open direction.
+ EXPECT_TRUE(client_sent);
+ EXPECT_TRUE(server_received);
+ server_->SendData(10, 10);
+ client_->ReadBytes(10);
+
+ // Now close the other side and do the same checks.
+ bool server_sent = false;
+ EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(server_->ssl_fd(),
+ CheckCloseNotify, &server_sent));
+ bool client_received = false;
+ EXPECT_EQ(SECSuccess,
+ SSL_AlertReceivedCallback(client_->ssl_fd(), CheckCloseNotify,
+ &client_received));
+ EXPECT_EQ(PR_SUCCESS, PR_Shutdown(server_->ssl_fd(), PR_SHUTDOWN_SEND));
+
+ EXPECT_EQ(0, PR_Read(client_->ssl_fd(), buf, sizeof(buf)));
+ EXPECT_TRUE(server_sent);
+ EXPECT_TRUE(client_received);
+}
+
+TEST_P(TlsConnectGeneric, ShutdownOneSideThenCloseTcp) {
+ Connect();
+
+ bool client_sent = false;
+ EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(client_->ssl_fd(),
+ CheckCloseNotify, &client_sent));
+ bool server_received = false;
+ EXPECT_EQ(SECSuccess,
+ SSL_AlertReceivedCallback(server_->ssl_fd(), CheckCloseNotify,
+ &server_received));
+ EXPECT_EQ(PR_SUCCESS, PR_Shutdown(client_->ssl_fd(), PR_SHUTDOWN_SEND));
+
+ // Make sure that the server reads out the close_notify.
+ uint8_t buf[10];
+ EXPECT_EQ(0, PR_Read(server_->ssl_fd(), buf, sizeof(buf)));
+
+ // Now simulate the underlying connection closing.
+ client_->adapter()->Reset();
+
+ // Now close the other side and see that things don't explode.
+ EXPECT_EQ(PR_SUCCESS, PR_Shutdown(server_->ssl_fd(), PR_SHUTDOWN_SEND));
+
+ EXPECT_GT(0, PR_Read(client_->ssl_fd(), buf, sizeof(buf)));
+ EXPECT_EQ(PR_NOT_CONNECTED_ERROR, PR_GetError());
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ GenericStream, TlsConnectGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll));
+INSTANTIATE_TEST_SUITE_P(
+ GenericDatagram, TlsConnectGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus));
+
+INSTANTIATE_TEST_SUITE_P(StreamOnly, TlsConnectStream,
+ TlsConnectTestBase::kTlsVAll);
+INSTANTIATE_TEST_SUITE_P(DatagramOnly, TlsConnectDatagram,
+ TlsConnectTestBase::kTlsV11Plus);
+INSTANTIATE_TEST_SUITE_P(DatagramHolddown, TlsHolddownTest,
+ TlsConnectTestBase::kTlsV11Plus);
+
+INSTANTIATE_TEST_SUITE_P(
+ Pre12Stream, TlsConnectPre12,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10V11));
+INSTANTIATE_TEST_SUITE_P(
+ Pre12Datagram, TlsConnectPre12,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11));
+
+INSTANTIATE_TEST_SUITE_P(Version12Only, TlsConnectTls12,
+ TlsConnectTestBase::kTlsVariantsAll);
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_TEST_SUITE_P(Version13Only, TlsConnectTls13,
+ TlsConnectTestBase::kTlsVariantsAll);
+#endif
+
+INSTANTIATE_TEST_SUITE_P(
+ Pre13Stream, TlsConnectGenericPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12));
+INSTANTIATE_TEST_SUITE_P(
+ Pre13Datagram, TlsConnectGenericPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11V12));
+INSTANTIATE_TEST_SUITE_P(Pre13StreamOnly, TlsConnectStreamPre13,
+ TlsConnectTestBase::kTlsV10ToV12);
+
+INSTANTIATE_TEST_SUITE_P(Version12Plus, TlsConnectTls12Plus,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12Plus));
+
+INSTANTIATE_TEST_SUITE_P(
+ GenericStream, TlsConnectGenericResumption,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll,
+ ::testing::Values(true, false)));
+INSTANTIATE_TEST_SUITE_P(
+ GenericDatagram, TlsConnectGenericResumption,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus,
+ ::testing::Values(true, false)));
+
+INSTANTIATE_TEST_SUITE_P(
+ GenericStream, TlsConnectGenericResumptionToken,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll));
+INSTANTIATE_TEST_SUITE_P(
+ GenericDatagram, TlsConnectGenericResumptionToken,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus));
+
+INSTANTIATE_TEST_SUITE_P(GenericDatagram, TlsConnectTls13ResumptionToken,
+ TlsConnectTestBase::kTlsVariantsAll);
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_masking_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_masking_unittest.cc
new file mode 100644
index 0000000000..8209a6e4e0
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_masking_unittest.cc
@@ -0,0 +1,350 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <memory>
+
+#include "keyhi.h"
+#include "pk11pub.h"
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslexp.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "scoped_ptrs_ssl.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+// From tls_hkdf_unittest.cc:
+extern size_t GetHashLength(SSLHashType ht);
+
+const std::string kLabel = "sn";
+
+class MaskingTest : public ::testing::Test {
+ public:
+ MaskingTest() : slot_(PK11_GetInternalSlot()) {}
+
+ void InitSecret(SSLHashType hash_type) {
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ PK11SymKey *s = PK11_KeyGen(slot_.get(), CKM_GENERIC_SECRET_KEY_GEN,
+ nullptr, AES_128_KEY_LENGTH, nullptr);
+ ASSERT_NE(nullptr, s);
+ secret_.reset(s);
+ }
+
+ void SetUp() override {
+ InitSecret(ssl_hash_sha256);
+ PORT_SetError(0);
+ }
+
+ protected:
+ ScopedPK11SymKey secret_;
+ ScopedPK11SlotInfo slot_;
+ // Should have 4B ctr, 12B nonce for ChaCha, or >=16B ciphertext for AES.
+ // Use the same default size for mask output.
+ static const int kSampleSize = 16;
+ static const int kMaskSize = 16;
+ void CreateMask(PRUint16 ciphersuite, SSLProtocolVariant variant,
+ std::string label, const std::vector<uint8_t> &sample,
+ std::vector<uint8_t> *out_mask) {
+ ASSERT_NE(nullptr, out_mask);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateVariantMaskingContext(
+ SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite, variant,
+ secret_.get(), label.c_str(), label.size(), &ctx_init));
+ ASSERT_NE(nullptr, ctx_init);
+ ScopedSSLMaskingContext ctx(ctx_init);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
+ out_mask->data(), out_mask->size()));
+ bool all_zeros = std::all_of(out_mask->begin(), out_mask->end(),
+ [](uint8_t v) { return v == 0; });
+
+ // If out_mask is short, |all_zeros| will be (expectedly) true often enough
+ // to fail tests.
+ // In this case, just retry to make sure we're not outputting zeros
+ // continuously.
+ if (all_zeros && out_mask->size() < 3) {
+ unsigned int tries = 2;
+ std::vector<uint8_t> tmp_sample = sample;
+ std::vector<uint8_t> tmp_mask(out_mask->size());
+ while (tries--) {
+ tmp_sample.data()[0]++; // Tweak something to get a new mask.
+ EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), tmp_sample.data(),
+ tmp_sample.size(), tmp_mask.data(),
+ tmp_mask.size()));
+ bool retry_zero = std::all_of(tmp_mask.begin(), tmp_mask.end(),
+ [](uint8_t v) { return v == 0; });
+ if (!retry_zero) {
+ all_zeros = false;
+ break;
+ }
+ }
+ }
+ EXPECT_FALSE(all_zeros);
+ }
+};
+
+class SuiteTest : public MaskingTest,
+ public ::testing::WithParamInterface<uint16_t> {
+ public:
+ SuiteTest() : ciphersuite_(GetParam()) {}
+ void CreateMask(std::string label, const std::vector<uint8_t> &sample,
+ std::vector<uint8_t> *out_mask) {
+ MaskingTest::CreateMask(ciphersuite_, ssl_variant_datagram, label, sample,
+ out_mask);
+ }
+
+ protected:
+ const uint16_t ciphersuite_;
+};
+
+class VariantTest : public MaskingTest,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
+ public:
+ VariantTest() : variant_(GetParam()) {}
+ void CreateMask(uint16_t ciphersuite, std::string label,
+ const std::vector<uint8_t> &sample,
+ std::vector<uint8_t> *out_mask) {
+ MaskingTest::CreateMask(ciphersuite, variant_, label, sample, out_mask);
+ }
+
+ protected:
+ const SSLProtocolVariant variant_;
+};
+
+class VariantSuiteTest : public MaskingTest,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ VariantSuiteTest()
+ : variant_(std::get<0>(GetParam())),
+ ciphersuite_(std::get<1>(GetParam())) {}
+ void CreateMask(std::string label, const std::vector<uint8_t> &sample,
+ std::vector<uint8_t> *out_mask) {
+ MaskingTest::CreateMask(ciphersuite_, variant_, label, sample, out_mask);
+ }
+
+ protected:
+ const SSLProtocolVariant variant_;
+ const uint16_t ciphersuite_;
+};
+
+TEST_P(VariantSuiteTest, MaskContextNoLabel) {
+ std::vector<uint8_t> sample(kSampleSize);
+ std::vector<uint8_t> mask(kMaskSize);
+ CreateMask(std::string(""), sample, &mask);
+}
+
+TEST_P(VariantSuiteTest, MaskNoSample) {
+ std::vector<uint8_t> mask(kMaskSize);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateVariantMaskingContext(
+ SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_,
+ secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init));
+ ASSERT_NE(nullptr, ctx_init);
+ ScopedSSLMaskingContext ctx(ctx_init);
+
+ EXPECT_EQ(SECFailure,
+ SSL_CreateMask(ctx.get(), nullptr, 0, mask.data(), mask.size()));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), nullptr, mask.size(),
+ mask.data(), mask.size()));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_P(VariantSuiteTest, MaskShortSample) {
+ std::vector<uint8_t> sample(kSampleSize);
+ std::vector<uint8_t> mask(kMaskSize);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateVariantMaskingContext(
+ SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_,
+ secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init));
+ ASSERT_NE(nullptr, ctx_init);
+ ScopedSSLMaskingContext ctx(ctx_init);
+
+ EXPECT_EQ(SECFailure,
+ SSL_CreateMask(ctx.get(), sample.data(), sample.size() - 1,
+ mask.data(), mask.size()));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_P(VariantSuiteTest, MaskContextUnsupportedMech) {
+ std::vector<uint8_t> sample(kSampleSize);
+ std::vector<uint8_t> mask(kMaskSize);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECFailure,
+ SSL_CreateVariantMaskingContext(
+ SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_AES_128_CBC_SHA256,
+ variant_, secret_.get(), nullptr, 0, &ctx_init));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ EXPECT_EQ(nullptr, ctx_init);
+}
+
+TEST_P(VariantSuiteTest, MaskContextUnsupportedVersion) {
+ std::vector<uint8_t> sample(kSampleSize);
+ std::vector<uint8_t> mask(kMaskSize);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECFailure, SSL_CreateVariantMaskingContext(
+ SSL_LIBRARY_VERSION_TLS_1_2, ciphersuite_, variant_,
+ secret_.get(), nullptr, 0, &ctx_init));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ EXPECT_EQ(nullptr, ctx_init);
+}
+
+TEST_P(VariantSuiteTest, MaskMaxLength) {
+ uint32_t max_mask_len = kMaskSize;
+ if (ciphersuite_ == TLS_CHACHA20_POLY1305_SHA256) {
+ // Internal limitation for ChaCha20 masks.
+ max_mask_len = 128;
+ }
+
+ std::vector<uint8_t> sample(kSampleSize);
+ std::vector<uint8_t> mask(max_mask_len + 1);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateVariantMaskingContext(
+ SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_,
+ secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init));
+ ASSERT_NE(nullptr, ctx_init);
+ ScopedSSLMaskingContext ctx(ctx_init);
+
+ EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
+ mask.data(), mask.size() - 1));
+ EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
+ mask.data(), mask.size()));
+ EXPECT_EQ(SEC_ERROR_OUTPUT_LEN, PORT_GetError());
+}
+
+TEST_P(VariantSuiteTest, MaskMinLength) {
+ std::vector<uint8_t> sample(kSampleSize);
+ std::vector<uint8_t> mask(1); // Don't pass a null
+
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateVariantMaskingContext(
+ SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_,
+ secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init));
+ ASSERT_NE(nullptr, ctx_init);
+ ScopedSSLMaskingContext ctx(ctx_init);
+ EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
+ mask.data(), 0));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
+ mask.data(), 1));
+}
+
+TEST_P(VariantSuiteTest, MaskRotateLabel) {
+ std::vector<uint8_t> sample(kSampleSize);
+ std::vector<uint8_t> mask1(kMaskSize);
+ std::vector<uint8_t> mask2(kMaskSize);
+ EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
+ sample.size()));
+
+ CreateMask(kLabel, sample, &mask1);
+ CreateMask(std::string("sn1"), sample, &mask2);
+ EXPECT_FALSE(mask1 == mask2);
+}
+
+TEST_P(VariantSuiteTest, MaskRotateSample) {
+ std::vector<uint8_t> sample(kSampleSize);
+ std::vector<uint8_t> mask1(kMaskSize);
+ std::vector<uint8_t> mask2(kMaskSize);
+
+ EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
+ sample.size()));
+ CreateMask(kLabel, sample, &mask1);
+
+ EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
+ sample.size()));
+ CreateMask(kLabel, sample, &mask2);
+ EXPECT_FALSE(mask1 == mask2);
+}
+
+TEST_P(VariantSuiteTest, MaskRederive) {
+ std::vector<uint8_t> sample(kSampleSize);
+ std::vector<uint8_t> mask1(kMaskSize);
+ std::vector<uint8_t> mask2(kMaskSize);
+
+ SECStatus rv =
+ PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), sample.size());
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Check that re-using inputs with a new context produces the same mask.
+ CreateMask(kLabel, sample, &mask1);
+ CreateMask(kLabel, sample, &mask2);
+ EXPECT_TRUE(mask1 == mask2);
+}
+
+TEST_P(SuiteTest, MaskTlsVariantKeySeparation) {
+ std::vector<uint8_t> sample(kSampleSize);
+ std::vector<uint8_t> tls_mask(kMaskSize);
+ std::vector<uint8_t> dtls_mask(kMaskSize);
+ SSLMaskingContext *stream_ctx_init = nullptr;
+ SSLMaskingContext *datagram_ctx_init = nullptr;
+
+ // Init
+ EXPECT_EQ(SECSuccess, SSL_CreateVariantMaskingContext(
+ SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_,
+ ssl_variant_stream, secret_.get(), kLabel.c_str(),
+ kLabel.size(), &stream_ctx_init));
+ ASSERT_NE(nullptr, stream_ctx_init);
+ EXPECT_EQ(SECSuccess, SSL_CreateVariantMaskingContext(
+ SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_,
+ ssl_variant_datagram, secret_.get(), kLabel.c_str(),
+ kLabel.size(), &datagram_ctx_init));
+ ASSERT_NE(nullptr, datagram_ctx_init);
+ ScopedSSLMaskingContext tls_ctx(stream_ctx_init);
+ ScopedSSLMaskingContext dtls_ctx(datagram_ctx_init);
+
+ // Derive
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateMask(tls_ctx.get(), sample.data(), sample.size(),
+ tls_mask.data(), tls_mask.size()));
+
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateMask(dtls_ctx.get(), sample.data(), sample.size(),
+ dtls_mask.data(), dtls_mask.size()));
+ EXPECT_NE(tls_mask, dtls_mask);
+}
+
+TEST_P(VariantTest, MaskChaChaRederiveOddSizes) {
+ // Non-block-aligned.
+ std::vector<uint8_t> sample(27);
+ std::vector<uint8_t> mask1(26);
+ std::vector<uint8_t> mask2(25);
+ EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
+ sample.size()));
+ CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask1);
+ CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask2);
+ mask1.pop_back();
+ EXPECT_TRUE(mask1 == mask2);
+}
+
+static const uint16_t kMaskingCiphersuites[] = {TLS_CHACHA20_POLY1305_SHA256,
+ TLS_AES_128_GCM_SHA256,
+ TLS_AES_256_GCM_SHA384};
+::testing::internal::ParamGenerator<uint16_t> kMaskingCiphersuiteParams =
+ ::testing::ValuesIn(kMaskingCiphersuites);
+
+INSTANTIATE_TEST_SUITE_P(GenericMasking, SuiteTest, kMaskingCiphersuiteParams);
+
+INSTANTIATE_TEST_SUITE_P(GenericMasking, VariantTest,
+ TlsConnectTestBase::kTlsVariantsAll);
+
+INSTANTIATE_TEST_SUITE_P(GenericMasking, VariantSuiteTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ kMaskingCiphersuiteParams));
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc
new file mode 100644
index 0000000000..2b1b92dcd8
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc
@@ -0,0 +1,20 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "sslexp.h"
+
+#include "gtest_utils.h"
+
+namespace nss_test {
+
+class MiscTest : public ::testing::Test {};
+
+TEST_F(MiscTest, NonExistentExperimentalAPI) {
+ EXPECT_EQ(nullptr, SSL_GetExperimentalAPI("blah"));
+ EXPECT_EQ(SSL_ERROR_UNSUPPORTED_EXPERIMENTAL_API, PORT_GetError());
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc
new file mode 100644
index 0000000000..5378d67af8
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc
@@ -0,0 +1,826 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "nss.h"
+#include "ssl.h"
+#include "sslimpl.h"
+
+#include "databuffer.h"
+#include "gtest_utils.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+
+namespace nss_test {
+
+const static size_t kMacSize = 20;
+
+class TlsPaddingTest
+ : public ::testing::Test,
+ public ::testing::WithParamInterface<std::tuple<size_t, bool>> {
+ public:
+ TlsPaddingTest() : plaintext_len_(std::get<0>(GetParam())) {
+ size_t extra =
+ (plaintext_len_ + 1) % 16; // Bytes past a block (1 == pad len)
+ // Minimal padding.
+ pad_len_ = extra ? 16 - extra : 0;
+ if (std::get<1>(GetParam())) {
+ // Maximal padding.
+ pad_len_ += 240;
+ }
+ MakePaddedPlaintext();
+ }
+
+ // Makes a plaintext record with correct padding.
+ void MakePaddedPlaintext() {
+ EXPECT_EQ(0UL, (plaintext_len_ + pad_len_ + 1) % 16);
+ size_t i = 0;
+ plaintext_.Allocate(plaintext_len_ + pad_len_ + 1);
+ for (; i < plaintext_len_; ++i) {
+ plaintext_.Write(i, 'A', 1);
+ }
+
+ for (; i < plaintext_len_ + pad_len_ + 1; ++i) {
+ plaintext_.Write(i, pad_len_, 1);
+ }
+ }
+
+ void Unpad(bool expect_success) {
+ std::cerr << "Content length=" << plaintext_len_
+ << " padding length=" << pad_len_
+ << " total length=" << plaintext_.len() << std::endl;
+ std::cerr << "Plaintext: " << plaintext_ << std::endl;
+ sslBuffer s;
+ s.buf = const_cast<unsigned char*>(
+ static_cast<const unsigned char*>(plaintext_.data()));
+ s.len = plaintext_.len();
+ SECStatus rv = ssl_RemoveTLSCBCPadding(&s, kMacSize);
+ if (expect_success) {
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(plaintext_len_, static_cast<size_t>(s.len));
+ } else {
+ EXPECT_EQ(SECFailure, rv);
+ }
+ }
+
+ protected:
+ size_t plaintext_len_;
+ size_t pad_len_;
+ DataBuffer plaintext_;
+};
+
+TEST_P(TlsPaddingTest, Correct) {
+ if (plaintext_len_ >= kMacSize) {
+ Unpad(true);
+ } else {
+ Unpad(false);
+ }
+}
+
+TEST_P(TlsPaddingTest, PadTooLong) {
+ if (plaintext_.len() < 255) {
+ plaintext_.Write(plaintext_.len() - 1, plaintext_.len(), 1);
+ Unpad(false);
+ }
+}
+
+TEST_P(TlsPaddingTest, FirstByteOfPadWrong) {
+ if (pad_len_) {
+ plaintext_.Write(plaintext_len_, plaintext_.data()[plaintext_len_] + 1, 1);
+ Unpad(false);
+ }
+}
+
+TEST_P(TlsPaddingTest, LastByteOfPadWrong) {
+ if (pad_len_) {
+ plaintext_.Write(plaintext_.len() - 2,
+ plaintext_.data()[plaintext_.len() - 1] + 1, 1);
+ Unpad(false);
+ }
+}
+
+class RecordReplacer : public TlsRecordFilter {
+ public:
+ RecordReplacer(const std::shared_ptr<TlsAgent>& a, size_t size)
+ : TlsRecordFilter(a), size_(size) {
+ Disable();
+ }
+
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ EXPECT_EQ(ssl_ct_application_data, header.content_type());
+ changed->Allocate(size_);
+
+ for (size_t i = 0; i < size_; ++i) {
+ changed->data()[i] = i & 0xff;
+ }
+
+ Disable();
+ return CHANGE;
+ }
+
+ private:
+ size_t size_;
+};
+
+TEST_P(TlsConnectStream, BadRecordMac) {
+ EnsureTlsSetup();
+ Connect();
+ client_->SetFilter(std::make_shared<TlsRecordLastByteDamager>(client_));
+ ExpectAlert(server_, kTlsAlertBadRecordMac);
+ client_->SendData(10);
+
+ // Read from the client, get error.
+ uint8_t buf[10];
+ PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, PORT_GetError());
+
+ // Read the server alert.
+ rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(SSL_ERROR_BAD_MAC_ALERT, PORT_GetError());
+}
+
+TEST_F(TlsConnectStreamTls13, LargeRecord) {
+ EnsureTlsSetup();
+
+ const size_t record_limit = 16384;
+ auto replacer = MakeTlsFilter<RecordReplacer>(client_, record_limit);
+ replacer->EnableDecryption();
+ Connect();
+
+ replacer->Enable();
+ client_->SendData(10);
+ WAIT_(server_->received_bytes() == record_limit, 2000);
+ ASSERT_EQ(record_limit, server_->received_bytes());
+}
+
+TEST_F(TlsConnectStreamTls13, TooLargeRecord) {
+ EnsureTlsSetup();
+
+ const size_t record_limit = 16384;
+ auto replacer = MakeTlsFilter<RecordReplacer>(client_, record_limit + 1);
+ replacer->EnableDecryption();
+ Connect();
+
+ replacer->Enable();
+ ExpectAlert(server_, kTlsAlertRecordOverflow);
+ client_->SendData(10); // This is expanded.
+
+ uint8_t buf[record_limit + 2];
+ PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, PORT_GetError());
+
+ // Read the server alert.
+ rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(SSL_ERROR_RECORD_OVERFLOW_ALERT, PORT_GetError());
+}
+
+class ShortHeaderChecker : public PacketFilter {
+ public:
+ PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output) {
+ // The first octet should be 0b001000xx.
+ EXPECT_EQ(kCtDtlsCiphertext, (input.data()[0] & ~0x3));
+ return KEEP;
+ }
+};
+
+TEST_F(TlsConnectDatagram13, AeadLimit) {
+ Connect();
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceDtls13DecryptFailures(server_->ssl_fd(),
+ (1ULL << 36) - 2));
+ SendReceive(50);
+
+ // Expect this to increment the counter. We should still be able to talk.
+ client_->SetFilter(std::make_shared<TlsRecordLastByteDamager>(client_));
+ client_->SendData(10);
+ server_->ReadBytes(10);
+ client_->ClearFilter();
+ client_->ResetSentBytes(50);
+ SendReceive(60);
+
+ // Expect alert when the limit is hit.
+ client_->SetFilter(std::make_shared<TlsRecordLastByteDamager>(client_));
+ client_->SendData(10);
+ ExpectAlert(server_, kTlsAlertBadRecordMac);
+
+ // Check the error on both endpoints.
+ uint8_t buf[10];
+ PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_EQ(-1, rv);
+ EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, PORT_GetError());
+
+ rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_EQ(-1, rv);
+ EXPECT_EQ(SSL_ERROR_BAD_MAC_ALERT, PORT_GetError());
+}
+
+TEST_F(TlsConnectDatagram13, ShortHeadersClient) {
+ Connect();
+ client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_TRUE);
+ client_->SetFilter(std::make_shared<ShortHeaderChecker>());
+ SendReceive();
+}
+
+TEST_F(TlsConnectDatagram13, ShortHeadersServer) {
+ Connect();
+ server_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_TRUE);
+ server_->SetFilter(std::make_shared<ShortHeaderChecker>());
+ SendReceive();
+}
+
+// Send a DTLSCiphertext header with a 2B sequence number, and no length.
+TEST_F(TlsConnectDatagram13, DtlsAlternateShortHeader) {
+ StartConnect();
+ TlsSendCipherSpecCapturer capturer(client_);
+ Connect();
+ SendReceive(50);
+
+ uint8_t buf[] = {0x32, 0x33, 0x34};
+ auto spec = capturer.spec(1);
+ ASSERT_NE(nullptr, spec.get());
+ ASSERT_EQ(3, spec->epoch());
+
+ uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno;
+ TlsRecordHeader header(variant_, SSL_LIBRARY_VERSION_TLS_1_3, dtls13_ct,
+ 0x0003000000000001);
+ TlsRecordHeader out_header(header);
+ DataBuffer msg(buf, sizeof(buf));
+ msg.Write(msg.len(), ssl_ct_application_data, 1);
+ DataBuffer ciphertext;
+ EXPECT_TRUE(spec->Protect(header, msg, &ciphertext, &out_header));
+
+ DataBuffer record;
+ auto rv = out_header.Write(&record, 0, ciphertext);
+ EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv);
+ client_->SendDirect(record);
+
+ server_->ReadBytes(3);
+}
+
+TEST_F(TlsConnectStreamTls13, UnencryptedFinishedMessage) {
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Send first server flight
+
+ // Record and drop the first record, which is the Finished.
+ auto recorder = std::make_shared<TlsRecordRecorder>(client_);
+ recorder->EnableDecryption();
+ auto dropper = std::make_shared<SelectiveDropFilter>(1);
+ client_->SetFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit({recorder, dropper})));
+ client_->Handshake(); // Save and drop CFIN.
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+
+ ASSERT_EQ(1U, recorder->count());
+ auto& finished = recorder->record(0);
+
+ DataBuffer d;
+ size_t offset = d.Write(0, ssl_ct_handshake, 1);
+ offset = d.Write(offset, SSL_LIBRARY_VERSION_TLS_1_2, 2);
+ offset = d.Write(offset, finished.buffer.len(), 2);
+ d.Append(finished.buffer);
+ client_->SendDirect(d);
+
+ // Now process the message.
+ ExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ // The server should generate an alert.
+ server_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state());
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE);
+ // Have the client consume the alert.
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+const static size_t kContentSizesArr[] = {
+ 1, kMacSize - 1, kMacSize, 30, 31, 32, 36, 256, 257, 287, 288};
+
+auto kContentSizes = ::testing::ValuesIn(kContentSizesArr);
+const static bool kTrueFalseArr[] = {true, false};
+auto kTrueFalse = ::testing::ValuesIn(kTrueFalseArr);
+
+INSTANTIATE_TEST_SUITE_P(TlsPadding, TlsPaddingTest,
+ ::testing::Combine(kContentSizes, kTrueFalse));
+
+/* Filter to modify record header and content */
+class Tls13RecordModifier : public TlsRecordFilter {
+ public:
+ Tls13RecordModifier(const std::shared_ptr<TlsAgent>& a,
+ uint8_t contentType = ssl_ct_handshake, size_t size = 0,
+ size_t padding = 0)
+ : TlsRecordFilter(a),
+ contentType_(contentType),
+ size_(size),
+ padding_(padding) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& record, size_t* offset,
+ DataBuffer* output) override {
+ if (!header.is_protected()) {
+ return KEEP;
+ }
+
+ uint16_t protection_epoch;
+ uint8_t inner_content_type;
+ DataBuffer plaintext;
+ TlsRecordHeader out_header;
+ if (!Unprotect(header, record, &protection_epoch, &inner_content_type,
+ &plaintext, &out_header)) {
+ return KEEP;
+ }
+
+ if (decrypting() && inner_content_type != ssl_ct_application_data) {
+ return KEEP;
+ }
+
+ DataBuffer ciphertext;
+ bool ok = Protect(spec(protection_epoch), out_header, contentType_,
+ DataBuffer(size_), &ciphertext, &out_header, padding_);
+ EXPECT_TRUE(ok);
+ if (!ok) {
+ return KEEP;
+ }
+
+ *offset = out_header.Write(output, *offset, ciphertext);
+ return CHANGE;
+ }
+
+ private:
+ uint8_t contentType_;
+ size_t size_;
+ size_t padding_;
+};
+
+/* Zero-length InnerPlaintext test class
+ *
+ * Parameter = Tuple of:
+ * - TLS variant (datagram/stream)
+ * - Content type to be set in zero-length inner plaintext record
+ * - Padding of record plaintext
+ */
+class ZeroLengthInnerPlaintextSetupTls13
+ : public TlsConnectTestBase,
+ public testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, SSLContentType, size_t>> {
+ public:
+ ZeroLengthInnerPlaintextSetupTls13()
+ : TlsConnectTestBase(std::get<0>(GetParam()),
+ SSL_LIBRARY_VERSION_TLS_1_3),
+ contentType_(std::get<1>(GetParam())),
+ padding_(std::get<2>(GetParam())){};
+
+ protected:
+ SSLContentType contentType_;
+ size_t padding_;
+};
+
+/* Test correct rejection of TLS 1.3 encrypted handshake/alert records with
+ * zero-length inner plaintext content length with and without padding.
+ *
+ * Implementations MUST NOT send Handshake and Alert records that have a
+ * zero-length TLSInnerPlaintext.content; if such a message is received,
+ * the receiving implementation MUST terminate the connection with an
+ * "unexpected_message" alert [RFC8446, Section 5.4]. */
+TEST_P(ZeroLengthInnerPlaintextSetupTls13, ZeroLengthInnerPlaintextRun) {
+ EnsureTlsSetup();
+
+ // Filter modifies record to be zero-length
+ auto filter =
+ MakeTlsFilter<Tls13RecordModifier>(client_, contentType_, 0, padding_);
+ filter->EnableDecryption();
+ filter->Disable();
+
+ Connect();
+
+ filter->Enable();
+
+ // Record will be overwritten
+ client_->SendData(0xf);
+
+ // Receive corrupt record
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ // 22B = 16B MAC + 1B innerContentType + 5B Header
+ server_->ReadBytes(22);
+ // Process alert at peer
+ client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
+ client_->Handshake();
+ } else { /* DTLS */
+ size_t received = server_->received_bytes();
+ // 22B = 16B MAC + 1B innerContentType + 5B Header
+ server_->ReadBytes(22);
+ // Check that no bytes were received => packet was dropped
+ ASSERT_EQ(received, server_->received_bytes());
+ // Check that we are still connected / not in error state
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+ }
+}
+
+// Test for TLS and DTLS
+const SSLProtocolVariant kZeroLengthInnerPlaintextVariants[] = {
+ ssl_variant_stream, ssl_variant_datagram};
+// Test for handshake and alert fragments
+const SSLContentType kZeroLengthInnerPlaintextContentTypes[] = {
+ ssl_ct_handshake, ssl_ct_alert};
+// Test with 0,1 and 100 octets of padding
+const size_t kZeroLengthInnerPlaintextPadding[] = {0, 1, 100};
+
+INSTANTIATE_TEST_SUITE_P(
+ ZeroLengthInnerPlaintextTest, ZeroLengthInnerPlaintextSetupTls13,
+ testing::Combine(testing::ValuesIn(kZeroLengthInnerPlaintextVariants),
+ testing::ValuesIn(kZeroLengthInnerPlaintextContentTypes),
+ testing::ValuesIn(kZeroLengthInnerPlaintextPadding)),
+ [](const testing::TestParamInfo<
+ ZeroLengthInnerPlaintextSetupTls13::ParamType>& inf) {
+ return std::string(std::get<0>(inf.param) == ssl_variant_stream
+ ? "Tls"
+ : "Dtls") +
+ "ZeroLengthInnerPlaintext" +
+ (std::get<1>(inf.param) == ssl_ct_handshake ? "Handshake"
+ : "Alert") +
+ (std::get<2>(inf.param)
+ ? "Padding" + std::to_string(std::get<2>(inf.param)) + "B"
+ : "") +
+ "Test";
+ });
+
+/* Zero-length record test class
+ *
+ * Parameter = Tuple of:
+ * - TLS variant (datagram/stream)
+ * - TLS version
+ * - Content type to be set in zero-length record
+ */
+class ZeroLengthRecordSetup
+ : public TlsConnectTestBase,
+ public testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t, SSLContentType>> {
+ public:
+ ZeroLengthRecordSetup()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())),
+ variant_(std::get<0>(GetParam())),
+ contentType_(std::get<2>(GetParam())){};
+
+ void createZeroLengthRecord(DataBuffer& buffer, unsigned epoch = 0,
+ unsigned seqn = 0) {
+ size_t idx = 0;
+ // Set header content type
+ idx = buffer.Write(idx, contentType_, 1);
+ // The record version is not checked during record layer handling
+ idx = buffer.Write(idx, 0xDEAD, 2);
+ // DTLS (version always < TLS 1.3)
+ if (variant_ == ssl_variant_datagram) {
+ // Set epoch (Should be 0 before handshake)
+ idx = buffer.Write(idx, 0U, 2);
+ // Set 6B sequence number (0 if send as first message)
+ idx = buffer.Write(idx, 0U, 2);
+ idx = buffer.Write(idx, 0U, 4);
+ }
+ // Set fragment to be of zero-length
+ (void)buffer.Write(idx, 0U, 2);
+ }
+
+ protected:
+ SSLProtocolVariant variant_;
+ SSLContentType contentType_;
+};
+
+/* Test handling of zero-length (ciphertext/fragment) records before handshake.
+ *
+ * This is only tested before the first handshake, since after it all of these
+ * messages are expected to be encrypted which is impossible for a content
+ * length of zero, always leading to a bad record mac. For TLS 1.3 only
+ * records of application data content type is legal after the handshake.
+ *
+ * Handshake records of length zero will be ignored in the record layer since
+ * the RFC does only specify that such records MUST NOT be sent but it does not
+ * state that an alert should be sent or the connection be terminated
+ * [RFC8446, Section 5.1].
+ *
+ * Even though only handshake messages are handled (ignored) in the record
+ * layer handling, this test covers zero-length records of all content types
+ * for complete coverage of cases.
+ *
+ * !!! Expected TLS (Stream) behavior !!!
+ * - Handshake records of zero length are ignored.
+ * - Alert and ChangeCipherSpec records of zero-length lead to illegal
+ * parameter alerts due to the malformed record content.
+ * - ApplicationData before the handshake leads to an unexpected message alert.
+ *
+ * !!! Expected DTLS (Datagram) behavior !!!
+ * - Handshake message of zero length are ignored.
+ * - Alert messages lead to an illegal parameter alert due to malformed record
+ * content.
+ * - ChangeCipherSpec records before the first handshake are not expected and
+ * ignored (see ssl3con.c, line 3276).
+ * - ApplicationData before the handshake is ignored since it could be a packet
+ * received in incorrect order (see ssl3con.c, line 13353).
+ */
+TEST_P(ZeroLengthRecordSetup, ZeroLengthRecordRun) {
+ EnsureTlsSetup();
+
+ // Send zero-length record
+ DataBuffer buffer;
+ createZeroLengthRecord(buffer);
+ client_->SendDirect(buffer);
+ // This must be set, otherwise handshake completness assertions might fail
+ server_->StartConnect();
+
+ SSLAlertDescription alert = close_notify;
+
+ switch (variant_) {
+ case ssl_variant_datagram:
+ switch (contentType_) {
+ case ssl_ct_alert:
+ // Should actually be ignored, see bug 1829391.
+ alert = illegal_parameter;
+ break;
+ case ssl_ct_ack:
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
+ // Skipped due to bug 1829391.
+ GTEST_SKIP();
+ }
+ // DTLS versions < 1.3 correctly ignore the invalid record
+ // so we fall through.
+ case ssl_ct_change_cipher_spec:
+ case ssl_ct_application_data:
+ case ssl_ct_handshake:
+ server_->Handshake();
+ Connect();
+ return;
+ }
+ break;
+ case ssl_variant_stream:
+ switch (contentType_) {
+ case ssl_ct_alert:
+ case ssl_ct_change_cipher_spec:
+ alert = illegal_parameter;
+ break;
+ case ssl_ct_application_data:
+ case ssl_ct_ack:
+ alert = unexpected_message;
+ break;
+ case ssl_ct_handshake:
+ // TLS ignores unprotected zero-length handshake records
+ server_->Handshake();
+ Connect();
+ return;
+ }
+ break;
+ }
+
+ // Assert alert is send for TLS and DTLS alert records
+ server_->ExpectSendAlert(alert);
+ server_->Handshake();
+
+ // Consume alert at peer, expect alert for TLS and DTLS alert records
+ client_->StartConnect();
+ client_->ExpectReceiveAlert(alert);
+ client_->Handshake();
+}
+
+// Test for handshake, alert, change_cipher_spec and application data fragments
+const SSLContentType kZeroLengthRecordContentTypes[] = {
+ ssl_ct_handshake, ssl_ct_alert, ssl_ct_change_cipher_spec,
+ ssl_ct_application_data, ssl_ct_ack};
+
+INSTANTIATE_TEST_SUITE_P(
+ ZeroLengthRecordTest, ZeroLengthRecordSetup,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV11Plus,
+ testing::ValuesIn(kZeroLengthRecordContentTypes)),
+ [](const testing::TestParamInfo<ZeroLengthRecordSetup::ParamType>& inf) {
+ std::string variant =
+ (std::get<0>(inf.param) == ssl_variant_stream) ? "Tls" : "Dtls";
+ std::string version = VersionString(std::get<1>(inf.param));
+ std::replace(version.begin(), version.end(), '.', '_');
+ std::string contentType;
+ switch (std::get<2>(inf.param)) {
+ case ssl_ct_handshake:
+ contentType = "Handshake";
+ break;
+ case ssl_ct_alert:
+ contentType = "Alert";
+ break;
+ case ssl_ct_application_data:
+ contentType = "ApplicationData";
+ break;
+ case ssl_ct_change_cipher_spec:
+ contentType = "ChangeCipherSpec";
+ break;
+ case ssl_ct_ack:
+ contentType = "Ack";
+ break;
+ }
+ return variant + version + "ZeroLength" + contentType + "Test";
+ });
+
+/* Test correct handling of records with invalid content types.
+ *
+ * TLS:
+ * If a TLS implementation receives an unexpected record type, it MUST
+ * terminate the connection with an "unexpected_message" alert
+ * [RFC8446, Section 5].
+ *
+ * DTLS:
+ * In general, invalid records SHOULD be silently discarded...
+ * [RFC6347, Section 4.1.2.7]. */
+class UndefinedContentTypeSetup : public TlsConnectGeneric {
+ public:
+ UndefinedContentTypeSetup() : TlsConnectGeneric() { StartConnect(); };
+
+ void createUndefinedContentTypeRecord(DataBuffer& buffer, unsigned epoch = 0,
+ unsigned seqn = 0) {
+ // dummy data
+ uint8_t data[] = {0xAA, 0xBB, 0xCC, 0xDD, 0xEE};
+
+ size_t idx = 0;
+ // Set undefined content type
+ idx = buffer.Write(idx, 0xFF, 1);
+ // The record version is not checked during record layer handling
+ idx = buffer.Write(idx, 0xDEAD, 2);
+ // DTLS (version always < TLS 1.3)
+ if (variant_ == ssl_variant_datagram) {
+ // Set epoch (Should be 0 before/during handshake)
+ idx = buffer.Write(idx, epoch, 2);
+ // Set 6B sequence number (0 if send as first message)
+ idx = buffer.Write(idx, 0U, 2);
+ idx = buffer.Write(idx, seqn, 4);
+ }
+ // Set fragment length
+ idx = buffer.Write(idx, 5U, 2);
+ // Add data to record
+ (void)buffer.Write(idx, data, 5);
+ }
+
+ void checkUndefinedContentTypeHandling(std::shared_ptr<TlsAgent> sender,
+ std::shared_ptr<TlsAgent> receiver) {
+ if (variant_ == ssl_variant_stream) {
+ // Handle record and expect alert to be sent
+ receiver->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ receiver->ReadBytes();
+ /* Digest and assert that the correct alert was received at peer
+ *
+ * The 1.3 server expects all messages other than the ClientHello to be
+ * encrypted and responds with an unexpected message alert to alerts. */
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3 && sender == server_) {
+ sender->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ } else {
+ sender->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
+ }
+ sender->ReadBytes();
+ } else { // DTLS drops invalid records silently
+ size_t received = receiver->received_bytes();
+ receiver->ReadBytes();
+ // Ensure no bytes were received/record was dropped
+ ASSERT_EQ(received, receiver->received_bytes());
+ }
+ }
+
+ protected:
+ DataBuffer buffer_;
+};
+
+INSTANTIATE_TEST_SUITE_P(
+ UndefinedContentTypePreHandshakeStream, UndefinedContentTypeSetup,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll));
+
+INSTANTIATE_TEST_SUITE_P(
+ UndefinedContentTypePreHandshakeDatagram, UndefinedContentTypeSetup,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus));
+
+TEST_P(UndefinedContentTypeSetup,
+ ServerReceiveUndefinedContentTypePreClientHello) {
+ createUndefinedContentTypeRecord(buffer_);
+
+ // Send undefined content type record
+ client_->SendDirect(buffer_);
+
+ checkUndefinedContentTypeHandling(client_, server_);
+}
+
+TEST_P(UndefinedContentTypeSetup,
+ ServerReceiveUndefinedContentTypePostClientHello) {
+ // Set epoch to 0 (handshake), and sequence number to 1 since hello is sent
+ createUndefinedContentTypeRecord(buffer_, 0, 1);
+
+ // Send ClientHello
+ client_->Handshake();
+ // Send undefined content type record
+ client_->SendDirect(buffer_);
+
+ checkUndefinedContentTypeHandling(client_, server_);
+}
+
+TEST_P(UndefinedContentTypeSetup,
+ ClientReceiveUndefinedContentTypePreClientHello) {
+ createUndefinedContentTypeRecord(buffer_);
+
+ // Send undefined content type record
+ server_->SendDirect(buffer_);
+
+ checkUndefinedContentTypeHandling(server_, client_);
+}
+
+TEST_P(UndefinedContentTypeSetup,
+ ClientReceiveUndefinedContentTypePostClientHello) {
+ // Set epoch to 0 (handshake), and sequence number to 1 since hello is sent
+ createUndefinedContentTypeRecord(buffer_, 0, 1);
+
+ // Send ClientHello
+ client_->Handshake();
+ // Send undefined content type record
+ server_->SendDirect(buffer_);
+
+ checkUndefinedContentTypeHandling(server_, client_);
+}
+
+class RecordOuterContentTypeSetter : public TlsRecordFilter {
+ public:
+ RecordOuterContentTypeSetter(const std::shared_ptr<TlsAgent>& a,
+ uint8_t contentType = ssl_ct_handshake)
+ : TlsRecordFilter(a), contentType_(contentType) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& record, size_t* offset,
+ DataBuffer* output) override {
+ TlsRecordHeader hdr(header.variant(), header.version(), contentType_,
+ header.sequence_number());
+
+ *offset = hdr.Write(output, *offset, record);
+ return CHANGE;
+ }
+
+ private:
+ uint8_t contentType_;
+};
+
+/* Test correct handling of invalid inner and outer record content type.
+ * This is only possible for TLS 1.3, since only for this version decryption
+ * and encryption of manipulated records is supported by the test suite. */
+TEST_P(TlsConnectTls13, UndefinedOuterContentType13) {
+ EnsureTlsSetup();
+ Connect();
+
+ // Manipulate record: set invalid content type 0xff
+ MakeTlsFilter<RecordOuterContentTypeSetter>(client_, 0xff);
+ client_->SendData(50);
+
+ if (variant_ == ssl_variant_stream) {
+ // Handle invalid record
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ server_->ReadBytes();
+ // Handle alert at peer
+ client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
+ client_->ReadBytes();
+ } else {
+ // Make sure DTLS drops invalid record silently
+ size_t received = server_->received_bytes();
+ server_->ReadBytes();
+ ASSERT_EQ(received, server_->received_bytes());
+ }
+}
+
+TEST_P(TlsConnectTls13, UndefinedInnerContentType13) {
+ EnsureTlsSetup();
+
+ // Manipulate record: set invalid content type 0xff and length to 50.
+ auto filter = MakeTlsFilter<Tls13RecordModifier>(client_, 0xff, 50, 0);
+ filter->EnableDecryption();
+ filter->Disable();
+
+ Connect();
+
+ filter->Enable();
+ // Send manipulate record with invalid content type
+ client_->SendData(50);
+
+ if (variant_ == ssl_variant_stream) {
+ // Handle invalid record
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ server_->ReadBytes();
+ // Handle alert at peer
+ client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
+ client_->ReadBytes();
+ } else {
+ // Make sure DTLS drops invalid record silently
+ size_t received = server_->received_bytes();
+ server_->ReadBytes();
+ ASSERT_EQ(received, server_->received_bytes());
+ }
+}
+
+} // namespace nss_test \ No newline at end of file
diff --git a/security/nss/gtests/ssl_gtest/ssl_recordsep_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_recordsep_unittest.cc
new file mode 100644
index 0000000000..8051b58d01
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_recordsep_unittest.cc
@@ -0,0 +1,679 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include <queue>
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+class HandshakeSecretTracker {
+ public:
+ HandshakeSecretTracker(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t first_read_epoch, uint16_t first_write_epoch)
+ : agent_(agent),
+ next_read_epoch_(first_read_epoch),
+ next_write_epoch_(first_write_epoch) {
+ EXPECT_EQ(SECSuccess,
+ SSL_SecretCallback(agent_->ssl_fd(),
+ HandshakeSecretTracker::SecretCb, this));
+ }
+
+ void CheckComplete() const {
+ EXPECT_EQ(0, next_read_epoch_);
+ EXPECT_EQ(0, next_write_epoch_);
+ }
+
+ private:
+ static void SecretCb(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir,
+ PK11SymKey* secret, void* arg) {
+ HandshakeSecretTracker* t = reinterpret_cast<HandshakeSecretTracker*>(arg);
+ t->SecretUpdated(epoch, dir, secret);
+ }
+
+ void SecretUpdated(PRUint16 epoch, SSLSecretDirection dir,
+ PK11SymKey* secret) {
+ if (g_ssl_gtest_verbose) {
+ std::cerr << agent_->role_str() << ": secret callback for " << dir
+ << " epoch " << epoch << std::endl;
+ }
+
+ EXPECT_TRUE(secret);
+ uint16_t* p;
+ if (dir == ssl_secret_read) {
+ p = &next_read_epoch_;
+ } else {
+ ASSERT_EQ(ssl_secret_write, dir);
+ p = &next_write_epoch_;
+ }
+ EXPECT_EQ(*p, epoch);
+ switch (*p) {
+ case 1: // 1 == 0-RTT, next should be handshake.
+ case 2: // 2 == handshake, next should be application data.
+ (*p)++;
+ break;
+
+ case 3: // 3 == application data, there should be no more.
+ // Use 0 as a sentinel value.
+ *p = 0;
+ break;
+
+ default:
+ ADD_FAILURE() << "Unexpected next epoch: " << *p;
+ }
+ }
+
+ std::shared_ptr<TlsAgent> agent_;
+ uint16_t next_read_epoch_;
+ uint16_t next_write_epoch_;
+};
+
+TEST_F(TlsConnectTest, HandshakeSecrets) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ EnsureTlsSetup();
+
+ HandshakeSecretTracker c(client_, 2, 2);
+ HandshakeSecretTracker s(server_, 2, 2);
+
+ Connect();
+ SendReceive();
+
+ c.CheckComplete();
+ s.CheckComplete();
+}
+
+TEST_F(TlsConnectTest, ZeroRttSecrets) {
+ SetupForZeroRtt();
+
+ HandshakeSecretTracker c(client_, 2, 1);
+ HandshakeSecretTracker s(server_, 1, 2);
+
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+
+ c.CheckComplete();
+ s.CheckComplete();
+}
+
+class KeyUpdateTracker {
+ public:
+ KeyUpdateTracker(const std::shared_ptr<TlsAgent>& agent,
+ bool expect_read_secret)
+ : agent_(agent), expect_read_secret_(expect_read_secret), called_(false) {
+ EXPECT_EQ(SECSuccess, SSL_SecretCallback(agent_->ssl_fd(),
+ KeyUpdateTracker::SecretCb, this));
+ }
+
+ void CheckCalled() const { EXPECT_TRUE(called_); }
+
+ private:
+ static void SecretCb(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir,
+ PK11SymKey* secret, void* arg) {
+ KeyUpdateTracker* t = reinterpret_cast<KeyUpdateTracker*>(arg);
+ t->SecretUpdated(epoch, dir, secret);
+ }
+
+ void SecretUpdated(PRUint16 epoch, SSLSecretDirection dir,
+ PK11SymKey* secret) {
+ EXPECT_EQ(4U, epoch);
+ EXPECT_EQ(expect_read_secret_, dir == ssl_secret_read);
+ EXPECT_TRUE(secret);
+ called_ = true;
+ }
+
+ std::shared_ptr<TlsAgent> agent_;
+ bool expect_read_secret_;
+ bool called_;
+};
+
+TEST_F(TlsConnectTest, KeyUpdateSecrets) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ // The update is to the client write secret; the server read secret.
+ KeyUpdateTracker c(client_, false);
+ KeyUpdateTracker s(server_, true);
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(4, 3);
+ c.CheckCalled();
+ s.CheckCalled();
+}
+
+// BadPrSocket is an instance of a PR IO layer that crashes the test if it is
+// ever used for reading or writing. It does that by failing to overwrite any
+// of the DummyIOLayerMethods, which all crash when invoked.
+class BadPrSocket : public DummyIOLayerMethods {
+ public:
+ BadPrSocket(std::shared_ptr<TlsAgent>& agent) : DummyIOLayerMethods() {
+ static PRDescIdentity bad_identity = PR_GetUniqueIdentity("bad NSPR id");
+ fd_ = DummyIOLayerMethods::CreateFD(bad_identity, this);
+
+ // This is terrible, but NSPR doesn't provide an easy way to replace the
+ // bottom layer of an IO stack. Take the DummyPrSocket and replace its
+ // NSPR method vtable with the ones from this object.
+ dummy_layer_ =
+ PR_GetIdentitiesLayer(agent->ssl_fd(), DummyPrSocket::LayerId());
+ EXPECT_TRUE(dummy_layer_);
+ original_methods_ = dummy_layer_->methods;
+ original_secret_ = dummy_layer_->secret;
+ dummy_layer_->methods = fd_->methods;
+ dummy_layer_->secret = reinterpret_cast<PRFilePrivate*>(this);
+ }
+
+ // This will be destroyed before the agent, so we need to restore the state
+ // before we tampered with it.
+ virtual ~BadPrSocket() {
+ dummy_layer_->methods = original_methods_;
+ dummy_layer_->secret = original_secret_;
+ }
+
+ private:
+ ScopedPRFileDesc fd_;
+ PRFileDesc* dummy_layer_;
+ const PRIOMethods* original_methods_;
+ PRFilePrivate* original_secret_;
+};
+
+class StagedRecords {
+ public:
+ StagedRecords(std::shared_ptr<TlsAgent>& agent) : agent_(agent), records_() {
+ EXPECT_EQ(SECSuccess,
+ SSL_RecordLayerWriteCallback(
+ agent_->ssl_fd(), StagedRecords::StageRecordData, this));
+ }
+
+ virtual ~StagedRecords() {
+ // Uninstall so that the callback doesn't fire during cleanup.
+ EXPECT_EQ(SECSuccess,
+ SSL_RecordLayerWriteCallback(agent_->ssl_fd(), nullptr, nullptr));
+ }
+
+ bool empty() const { return records_.empty(); }
+
+ void ForwardAll(std::shared_ptr<TlsAgent>& peer) {
+ EXPECT_NE(agent_, peer) << "can't forward to self";
+ for (auto r : records_) {
+ r.Forward(peer);
+ }
+ records_.clear();
+ }
+
+ // This forwards all saved data and checks the resulting state.
+ void ForwardAll(std::shared_ptr<TlsAgent>& peer,
+ TlsAgent::State expected_state) {
+ ForwardAll(peer);
+ switch (expected_state) {
+ case TlsAgent::STATE_CONNECTED:
+ // The handshake callback should have been called, so check that before
+ // checking that SSL_ForceHandshake succeeds.
+ EXPECT_EQ(expected_state, peer->state());
+ EXPECT_EQ(SECSuccess, SSL_ForceHandshake(peer->ssl_fd()));
+ break;
+
+ case TlsAgent::STATE_CONNECTING:
+ // Check that SSL_ForceHandshake() blocks.
+ EXPECT_EQ(SECFailure, SSL_ForceHandshake(peer->ssl_fd()));
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+ // Update and check the state.
+ peer->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, peer->state());
+ break;
+
+ default:
+ ADD_FAILURE() << "No idea how to handle this state";
+ }
+ }
+
+ void ForwardPartial(std::shared_ptr<TlsAgent>& peer) {
+ if (records_.empty()) {
+ ADD_FAILURE() << "No records to slice";
+ return;
+ }
+ auto& last = records_.back();
+ auto tail = last.SliceTail();
+ ForwardAll(peer, TlsAgent::STATE_CONNECTING);
+ records_.push_back(tail);
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, peer->state());
+ }
+
+ private:
+ // A single record.
+ class StagedRecord {
+ public:
+ StagedRecord(const std::string role, uint16_t epoch, SSLContentType ct,
+ const uint8_t* data, size_t len)
+ : role_(role), epoch_(epoch), content_type_(ct), data_(data, len) {
+ if (g_ssl_gtest_verbose) {
+ std::cerr << role_ << ": staged epoch " << epoch_ << " "
+ << content_type_ << ": " << data_ << std::endl;
+ }
+ }
+
+ // This forwards staged data to the identified agent.
+ void Forward(std::shared_ptr<TlsAgent>& peer) {
+ // Now there should be staged data.
+ EXPECT_FALSE(data_.empty());
+ if (g_ssl_gtest_verbose) {
+ std::cerr << role_ << ": forward epoch " << epoch_ << " " << data_
+ << std::endl;
+ }
+ EXPECT_EQ(SECSuccess,
+ SSL_RecordLayerData(peer->ssl_fd(), epoch_, content_type_,
+ data_.data(),
+ static_cast<unsigned int>(data_.len())));
+ }
+
+ // Slices the tail off this record and returns it.
+ StagedRecord SliceTail() {
+ size_t slice = 1;
+ if (data_.len() <= slice) {
+ ADD_FAILURE() << "record too small to slice in two";
+ slice = 0;
+ }
+ size_t keep = data_.len() - slice;
+ StagedRecord tail(role_, epoch_, content_type_, data_.data() + keep,
+ slice);
+ data_.Truncate(keep);
+ return tail;
+ }
+
+ private:
+ std::string role_;
+ uint16_t epoch_;
+ SSLContentType content_type_;
+ DataBuffer data_;
+ };
+
+ // This is an SSLRecordWriteCallback that stages data.
+ static SECStatus StageRecordData(PRFileDesc* fd, PRUint16 epoch,
+ SSLContentType content_type,
+ const PRUint8* data, unsigned int len,
+ void* arg) {
+ auto stage = reinterpret_cast<StagedRecords*>(arg);
+ stage->records_.push_back(StagedRecord(stage->agent_->role_str(), epoch,
+ content_type, data,
+ static_cast<size_t>(len)));
+ return SECSuccess;
+ }
+
+ std::shared_ptr<TlsAgent>& agent_;
+ std::deque<StagedRecord> records_;
+};
+
+// Attempting to feed application data in before the handshake is complete
+// should be caught.
+static void RefuseApplicationData(std::shared_ptr<TlsAgent>& peer,
+ uint16_t epoch) {
+ static const uint8_t d[] = {1, 2, 3};
+ EXPECT_EQ(SECFailure,
+ SSL_RecordLayerData(peer->ssl_fd(), epoch, ssl_ct_application_data,
+ d, static_cast<unsigned int>(sizeof(d))));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+static void SendForwardReceive(std::shared_ptr<TlsAgent>& sender,
+ StagedRecords& sender_stage,
+ std::shared_ptr<TlsAgent>& receiver) {
+ const size_t count = 10;
+ sender->SendData(count, count);
+ sender_stage.ForwardAll(receiver);
+ receiver->ReadBytes(count);
+}
+
+TEST_P(TlsConnectStream, ReplaceRecordLayer) {
+ StartConnect();
+ client_->SetServerKeyBits(server_->server_key_bits());
+
+ // BadPrSocket installs an IO layer that crashes when the SSL layer attempts
+ // to read or write.
+ BadPrSocket bad_layer_client(client_);
+ BadPrSocket bad_layer_server(server_);
+
+ // StagedRecords installs a handler for unprotected data from the socket, and
+ // captures that data.
+ StagedRecords client_stage(client_);
+ StagedRecords server_stage(server_);
+
+ // Both peers should refuse application data from epoch 0.
+ RefuseApplicationData(client_, 0);
+ RefuseApplicationData(server_, 0);
+
+ // This first call forwards nothing, but it causes the client to handshake,
+ // which starts things off. This stages the ClientHello as a result.
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
+ // This processes the ClientHello and stages the first server flight.
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING);
+
+ // In TLS 1.3, this is 0-RTT; in <TLS 1.3, this is application data.
+ // Neither is acceptable.
+ RefuseApplicationData(client_, 1);
+ RefuseApplicationData(server_, 1);
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // Application data in handshake is never acceptable.
+ RefuseApplicationData(client_, 2);
+ RefuseApplicationData(server_, 2);
+ // Don't accept real data until the handshake is done.
+ RefuseApplicationData(client_, 3);
+ RefuseApplicationData(server_, 3);
+ // Process the server flight and the client is done.
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED);
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
+ } else {
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED);
+ }
+ CheckKeys();
+
+ // Reading and writing application data should work.
+ SendForwardReceive(client_, client_stage, server_);
+ SendForwardReceive(server_, server_stage, client_);
+}
+
+TEST_F(TlsConnectStreamTls13, ReplaceRecordLayerZeroRtt) {
+ SetupForZeroRtt();
+
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ StartConnect();
+ client_->SetServerKeyBits(server_->server_key_bits());
+
+ BadPrSocket bad_layer_client(client_);
+ BadPrSocket bad_layer_server(server_);
+
+ StagedRecords client_stage(client_);
+ StagedRecords server_stage(server_);
+
+ ExpectResumption(RESUME_TICKET);
+
+ // Send ClientHello
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
+
+ // The client can never accept 0-RTT.
+ RefuseApplicationData(client_, 1);
+
+ // Send some 0-RTT data, which get staged in `client_stage`.
+ const char* kMsg = "EarlyData";
+ const PRInt32 kMsgLen = static_cast<PRInt32>(strlen(kMsg));
+ PRInt32 rv = PR_Write(client_->ssl_fd(), kMsg, kMsgLen);
+ EXPECT_EQ(kMsgLen, rv);
+
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING);
+
+ // The server should now have 0-RTT to read.
+ std::vector<uint8_t> buf(kMsgLen);
+ rv = PR_Read(server_->ssl_fd(), buf.data(), kMsgLen);
+ EXPECT_EQ(kMsgLen, rv);
+
+ // The handshake should happily finish.
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED);
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+
+ // Reading and writing application data should work.
+ SendForwardReceive(client_, client_stage, server_);
+ SendForwardReceive(server_, server_stage, client_);
+}
+
+static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) {
+ return SECWouldBlock;
+}
+
+TEST_P(TlsConnectStream, ReplaceRecordLayerAsyncLateAuth) {
+ StartConnect();
+ client_->SetServerKeyBits(server_->server_key_bits());
+
+ BadPrSocket bad_layer_client(client_);
+ BadPrSocket bad_layer_server(server_);
+ StagedRecords client_stage(client_);
+ StagedRecords server_stage(server_);
+
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING);
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
+
+ // Prior to TLS 1.3, the client sends its second flight immediately. But in
+ // TLS 1.3, a client won't send a Finished until it is happy with the server
+ // certificate. So blocking certificate validation causes the client to send
+ // nothing.
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ASSERT_TRUE(client_stage.empty());
+
+ // Client should have stopped reading when it saw the Certificate message,
+ // so it will be reading handshake epoch, and writing cleartext.
+ client_->CheckEpochs(2, 0);
+ // Server should be reading handshake, and writing application data.
+ server_->CheckEpochs(2, 3);
+
+ // Handshake again and the client will read the remainder of the server's
+ // flight, but it will remain blocked.
+ client_->Handshake();
+ ASSERT_TRUE(client_stage.empty());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+ } else {
+ // In prior versions, the client's second flight is always sent.
+ ASSERT_FALSE(client_stage.empty());
+ }
+
+ // Now declare the certificate good.
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+ client_->Handshake();
+ ASSERT_FALSE(client_stage.empty());
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
+ } else {
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED);
+ }
+ CheckKeys();
+
+ // Reading and writing application data should work.
+ SendForwardReceive(client_, client_stage, server_);
+}
+
+TEST_F(TlsConnectStreamTls13, ReplaceRecordLayerAsyncPostHandshake) {
+ StartConnect();
+ client_->SetServerKeyBits(server_->server_key_bits());
+
+ BadPrSocket bad_layer_client(client_);
+ BadPrSocket bad_layer_server(server_);
+ StagedRecords client_stage(client_);
+ StagedRecords server_stage(server_);
+
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING);
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
+
+ ASSERT_TRUE(client_stage.empty());
+ client_->Handshake();
+ ASSERT_TRUE(client_stage.empty());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+
+ // Now declare the certificate good.
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+ client_->Handshake();
+ ASSERT_FALSE(client_stage.empty());
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
+ } else {
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED);
+ }
+ CheckKeys();
+
+ // Reading and writing application data should work.
+ SendForwardReceive(client_, client_stage, server_);
+
+ // Post-handshake messages should work here.
+ EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), nullptr, 0));
+ SendForwardReceive(server_, server_stage, client_);
+}
+
+// This test ensures that data is correctly forwarded when the handshake is
+// resumed after asynchronous server certificate authentication, when
+// SSL_AuthCertificateComplete() is called. The logic for resuming the
+// handshake involves a different code path than the usual one, so this test
+// exercises that code fully.
+TEST_F(TlsConnectStreamTls13, ReplaceRecordLayerAsyncEarlyAuth) {
+ StartConnect();
+ client_->SetServerKeyBits(server_->server_key_bits());
+
+ BadPrSocket bad_layer_client(client_);
+ BadPrSocket bad_layer_server(server_);
+ StagedRecords client_stage(client_);
+ StagedRecords server_stage(server_);
+
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING);
+
+ // Send a partial flight on to the client.
+ // This includes enough to trigger the certificate callback.
+ server_stage.ForwardPartial(client_);
+ EXPECT_TRUE(client_stage.empty());
+
+ // Declare the certificate good.
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+ client_->Handshake();
+ EXPECT_TRUE(client_stage.empty());
+
+ // Send the remainder of the server flight.
+ PRBool pending = PR_FALSE;
+ EXPECT_EQ(SECSuccess,
+ SSLInt_HasPendingHandshakeData(client_->ssl_fd(), &pending));
+ EXPECT_EQ(PR_TRUE, pending);
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED);
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
+ CheckKeys();
+
+ SendForwardReceive(server_, server_stage, client_);
+}
+
+TEST_P(TlsConnectStream, ForwardDataFromWrongEpoch) {
+ const uint8_t data[] = {1};
+ Connect();
+ uint16_t next_epoch;
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ EXPECT_EQ(SECFailure,
+ SSL_RecordLayerData(client_->ssl_fd(), 2, ssl_ct_application_data,
+ data, sizeof(data)));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError())
+ << "Passing data from an old epoch is rejected";
+ next_epoch = 4;
+ } else {
+ // Prior to TLS 1.3, the epoch is only updated once during the handshake.
+ next_epoch = 2;
+ }
+ EXPECT_EQ(SECFailure,
+ SSL_RecordLayerData(client_->ssl_fd(), next_epoch,
+ ssl_ct_application_data, data, sizeof(data)));
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError())
+ << "Passing data from a future epoch blocks";
+}
+
+TEST_F(TlsConnectStreamTls13, ForwardInvalidData) {
+ const uint8_t data[1] = {0};
+
+ EnsureTlsSetup();
+ // Zero-length data.
+ EXPECT_EQ(SECFailure, SSL_RecordLayerData(client_->ssl_fd(), 0,
+ ssl_ct_application_data, data, 0));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ // NULL data.
+ EXPECT_EQ(SECFailure,
+ SSL_RecordLayerData(client_->ssl_fd(), 0, ssl_ct_application_data,
+ nullptr, 1));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(TlsConnectDatagram13, ForwardDataDtls) {
+ EnsureTlsSetup();
+ const uint8_t data[1] = {0};
+ EXPECT_EQ(SECFailure,
+ SSL_RecordLayerData(client_->ssl_fd(), 0, ssl_ct_application_data,
+ data, sizeof(data)));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(TlsConnectStreamTls13, SuppressEndOfEarlyData) {
+ SetupForZeroRtt();
+
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ client_->SetOption(SSL_SUPPRESS_END_OF_EARLY_DATA, true);
+ server_->SetOption(SSL_SUPPRESS_END_OF_EARLY_DATA, true);
+ StartConnect();
+ client_->SetServerKeyBits(server_->server_key_bits());
+
+ BadPrSocket bad_layer_client(client_);
+ BadPrSocket bad_layer_server(server_);
+
+ StagedRecords client_stage(client_);
+ StagedRecords server_stage(server_);
+
+ ExpectResumption(RESUME_TICKET);
+
+ // Send ClientHello
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING);
+
+ // Send some 0-RTT data, which get staged in `client_stage`.
+ const char* kMsg = "ABCDEF";
+ const PRInt32 kMsgLen = static_cast<PRInt32>(strlen(kMsg));
+ PRInt32 rv = PR_Write(client_->ssl_fd(), kMsg, kMsgLen);
+ EXPECT_EQ(kMsgLen, rv);
+
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING);
+
+ // The server should now have 0-RTT to read.
+ std::vector<uint8_t> buf(kMsgLen);
+ rv = PR_Read(server_->ssl_fd(), buf.data(), kMsgLen);
+ EXPECT_EQ(kMsgLen, rv);
+
+ // The handshake should happily finish, without the end of the early data.
+ server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED);
+ client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED);
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+
+ // Reading and writing application data should work.
+ SendForwardReceive(client_, client_stage, server_);
+ SendForwardReceive(server_, server_stage, client_);
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_recordsize_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_recordsize_unittest.cc
new file mode 100644
index 0000000000..8a84db5749
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_recordsize_unittest.cc
@@ -0,0 +1,726 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+// This class tracks the maximum size of record that was sent, both cleartext
+// and plain. It only tracks records that have an outer type of
+// application_data or DTLSCiphertext. In TLS 1.3, this includes handshake
+// messages.
+class TlsRecordMaximum : public TlsRecordFilter {
+ public:
+ TlsRecordMaximum(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a), max_ciphertext_(0), max_plaintext_(0) {}
+
+ size_t max_ciphertext() const { return max_ciphertext_; }
+ size_t max_plaintext() const { return max_plaintext_; }
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& record, size_t* offset,
+ DataBuffer* output) override {
+ std::cerr << "max: " << record << std::endl;
+ // Ignore unprotected packets.
+ if (!header.is_protected()) {
+ return KEEP;
+ }
+
+ max_ciphertext_ = (std::max)(max_ciphertext_, record.len());
+ return TlsRecordFilter::FilterRecord(header, record, offset, output);
+ }
+
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ max_plaintext_ = (std::max)(max_plaintext_, data.len());
+ return KEEP;
+ }
+
+ private:
+ size_t max_ciphertext_;
+ size_t max_plaintext_;
+};
+
+void CheckRecordSizes(const std::shared_ptr<TlsAgent>& agent,
+ const std::shared_ptr<TlsRecordMaximum>& record_max,
+ size_t config) {
+ uint16_t cipher_suite;
+ ASSERT_TRUE(agent->cipher_suite(&cipher_suite));
+
+ size_t expansion;
+ size_t iv;
+ switch (cipher_suite) {
+ case TLS_AES_128_GCM_SHA256:
+ case TLS_AES_256_GCM_SHA384:
+ case TLS_CHACHA20_POLY1305_SHA256:
+ expansion = 16;
+ iv = 0;
+ break;
+
+ case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ expansion = 16;
+ iv = 8;
+ break;
+
+ case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
+ // Expansion is 20 for the MAC. Maximum block padding is 16. Maximum
+ // padding is added when the input plus the MAC is an exact multiple of
+ // the block size.
+ expansion = 20 + 16 - ((config + 20) % 16);
+ iv = 16;
+ break;
+
+ default:
+ ADD_FAILURE() << "No expansion set for ciphersuite "
+ << agent->cipher_suite_name();
+ return;
+ }
+
+ switch (agent->version()) {
+ case SSL_LIBRARY_VERSION_TLS_1_3:
+ EXPECT_EQ(0U, iv) << "No IV for TLS 1.3";
+ // We only have decryption in TLS 1.3.
+ EXPECT_EQ(config - 1, record_max->max_plaintext())
+ << "bad plaintext length for " << agent->role_str();
+ break;
+
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ expansion += iv;
+ break;
+
+ case SSL_LIBRARY_VERSION_TLS_1_0:
+ break;
+
+ default:
+ ADD_FAILURE() << "Unexpected version " << agent->version();
+ return;
+ }
+
+ EXPECT_EQ(config + expansion, record_max->max_ciphertext())
+ << "bad ciphertext length for " << agent->role_str();
+}
+
+TEST_P(TlsConnectGeneric, RecordSizeMaximum) {
+ uint16_t max_record_size =
+ (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) ? 16385 : 16384;
+ size_t send_size = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3)
+ ? max_record_size
+ : max_record_size + 1;
+
+ EnsureTlsSetup();
+ auto client_max = MakeTlsFilter<TlsRecordMaximum>(client_);
+ auto server_max = MakeTlsFilter<TlsRecordMaximum>(server_);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ client_max->EnableDecryption();
+ server_max->EnableDecryption();
+ }
+
+ Connect();
+ client_->SendData(send_size, send_size);
+ server_->SendData(send_size, send_size);
+ server_->ReadBytes(send_size);
+ client_->ReadBytes(send_size);
+
+ CheckRecordSizes(client_, client_max, max_record_size);
+ CheckRecordSizes(server_, server_max, max_record_size);
+}
+
+TEST_P(TlsConnectGeneric, RecordSizeMinimumClient) {
+ EnsureTlsSetup();
+ auto server_max = MakeTlsFilter<TlsRecordMaximum>(server_);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ server_max->EnableDecryption();
+ }
+
+ client_->SetOption(SSL_RECORD_SIZE_LIMIT, 64);
+ Connect();
+ SendReceive(127); // Big enough for one record, allowing for 1+N splitting.
+
+ CheckRecordSizes(server_, server_max, 64);
+}
+
+TEST_P(TlsConnectGeneric, RecordSizeMinimumServer) {
+ EnsureTlsSetup();
+ auto client_max = MakeTlsFilter<TlsRecordMaximum>(client_);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ client_max->EnableDecryption();
+ }
+
+ server_->SetOption(SSL_RECORD_SIZE_LIMIT, 64);
+ Connect();
+ SendReceive(127);
+
+ CheckRecordSizes(client_, client_max, 64);
+}
+
+TEST_P(TlsConnectGeneric, RecordSizeAsymmetric) {
+ EnsureTlsSetup();
+ auto client_max = MakeTlsFilter<TlsRecordMaximum>(client_);
+ auto server_max = MakeTlsFilter<TlsRecordMaximum>(server_);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ client_max->EnableDecryption();
+ server_max->EnableDecryption();
+ }
+
+ client_->SetOption(SSL_RECORD_SIZE_LIMIT, 64);
+ server_->SetOption(SSL_RECORD_SIZE_LIMIT, 100);
+ Connect();
+ SendReceive(127);
+
+ CheckRecordSizes(client_, client_max, 100);
+ CheckRecordSizes(server_, server_max, 64);
+}
+
+// This just modifies the encrypted payload so to include a few extra zeros.
+class TlsRecordExpander : public TlsRecordFilter {
+ public:
+ TlsRecordExpander(const std::shared_ptr<TlsAgent>& a, size_t expansion)
+ : TlsRecordFilter(a), expansion_(expansion) {}
+
+ protected:
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) {
+ if (!header.is_protected()) {
+ // We're targeting application_data records. If the record is
+ // |!is_protected()|, we have two possibilities:
+ if (!decrypting()) {
+ // 1) We're not decrypting, in which this case this is truly an
+ // unencrypted record (Keep).
+ return KEEP;
+ }
+ if (header.content_type() != ssl_ct_application_data) {
+ // 2) We are decrypting, so is_protected() read the internal
+ // content_type. If the internal ct IS NOT application_data, then
+ // it's not our target (Keep).
+ return KEEP;
+ }
+ // Otherwise, the the internal ct IS application_data (Change).
+ }
+
+ changed->Allocate(data.len() + expansion_);
+ changed->Write(0, data.data(), data.len());
+ return CHANGE;
+ }
+
+ private:
+ size_t expansion_;
+};
+
+// Tweak the plaintext of server records so that they exceed the client's limit.
+TEST_F(TlsConnectStreamTls13, RecordSizePlaintextExceed) {
+ EnsureTlsSetup();
+ auto server_expand = MakeTlsFilter<TlsRecordExpander>(server_, 1);
+ server_expand->EnableDecryption();
+
+ client_->SetOption(SSL_RECORD_SIZE_LIMIT, 64);
+ Connect();
+
+ server_->SendData(100);
+
+ client_->ExpectReadWriteError();
+ ExpectAlert(client_, kTlsAlertRecordOverflow);
+ client_->ReadBytes(100);
+ EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, client_->error_code());
+
+ // Consume the alert at the server.
+ server_->Handshake();
+ server_->CheckErrorCode(SSL_ERROR_RECORD_OVERFLOW_ALERT);
+}
+
+// Tweak the ciphertext of server records so that they greatly exceed the limit.
+// This requires a much larger expansion than for plaintext to trigger the
+// guard, which runs before decryption (current allowance is 320 octets,
+// see MAX_EXPANSION in ssl3con.c).
+TEST_F(TlsConnectStreamTls13, RecordSizeCiphertextExceed) {
+ EnsureTlsSetup();
+
+ client_->SetOption(SSL_RECORD_SIZE_LIMIT, 64);
+ Connect();
+
+ auto server_expand = MakeTlsFilter<TlsRecordExpander>(server_, 336);
+ server_->SendData(100);
+
+ client_->ExpectReadWriteError();
+ ExpectAlert(client_, kTlsAlertRecordOverflow);
+ client_->ReadBytes(100);
+ EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, client_->error_code());
+
+ // Consume the alert at the server.
+ server_->Handshake();
+ server_->CheckErrorCode(SSL_ERROR_RECORD_OVERFLOW_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13, ClientHelloF5Padding) {
+ EnsureTlsSetup();
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ScopedPK11SymKey key(
+ PK11_KeyGen(slot.get(), CKM_NSS_CHACHA20_POLY1305, nullptr, 32, nullptr));
+
+ auto filter =
+ MakeTlsFilter<TlsHandshakeRecorder>(client_, kTlsHandshakeClientHello);
+
+ // Add PSK with label long enough to push CH length into [256, 511].
+ std::vector<uint8_t> label(100);
+ EXPECT_EQ(SECSuccess,
+ SSL_AddExternalPsk(client_->ssl_fd(), key.get(), label.data(),
+ label.size(), ssl_hash_sha256));
+ StartConnect();
+ client_->Handshake();
+
+ // Filter removes the 4B handshake header.
+ EXPECT_EQ(508UL, filter->buffer().len());
+}
+
+// This indiscriminately adds padding to application data records.
+class TlsRecordPadder : public TlsRecordFilter {
+ public:
+ TlsRecordPadder(const std::shared_ptr<TlsAgent>& a, size_t padding)
+ : TlsRecordFilter(a), padding_(padding) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& record, size_t* offset,
+ DataBuffer* output) override {
+ if (!header.is_protected()) {
+ return KEEP;
+ }
+
+ uint16_t protection_epoch;
+ uint8_t inner_content_type;
+ DataBuffer plaintext;
+ TlsRecordHeader out_header;
+ if (!Unprotect(header, record, &protection_epoch, &inner_content_type,
+ &plaintext, &out_header)) {
+ return KEEP;
+ }
+
+ if (decrypting() && inner_content_type != ssl_ct_application_data) {
+ return KEEP;
+ }
+
+ DataBuffer ciphertext;
+ bool ok = Protect(spec(protection_epoch), out_header, inner_content_type,
+ plaintext, &ciphertext, &out_header, padding_);
+ EXPECT_TRUE(ok);
+ if (!ok) {
+ return KEEP;
+ }
+ *offset = out_header.Write(output, *offset, ciphertext);
+ return CHANGE;
+ }
+
+ private:
+ size_t padding_;
+};
+
+TEST_F(TlsConnectStreamTls13, RecordSizeExceedPad) {
+ EnsureTlsSetup();
+ auto server_max = std::make_shared<TlsRecordMaximum>(server_);
+ auto server_expand = std::make_shared<TlsRecordPadder>(server_, 1);
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit({server_max, server_expand})));
+ server_expand->EnableDecryption();
+
+ client_->SetOption(SSL_RECORD_SIZE_LIMIT, 64);
+ Connect();
+
+ server_->SendData(100);
+
+ client_->ExpectReadWriteError();
+ ExpectAlert(client_, kTlsAlertRecordOverflow);
+ client_->ReadBytes(100);
+ EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, client_->error_code());
+
+ // Consume the alert at the server.
+ server_->Handshake();
+ server_->CheckErrorCode(SSL_ERROR_RECORD_OVERFLOW_ALERT);
+}
+
+TEST_P(TlsConnectGeneric, RecordSizeBadValues) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECFailure,
+ SSL_OptionSet(client_->ssl_fd(), SSL_RECORD_SIZE_LIMIT, 63));
+ EXPECT_EQ(SECFailure,
+ SSL_OptionSet(client_->ssl_fd(), SSL_RECORD_SIZE_LIMIT, -1));
+ EXPECT_EQ(SECFailure,
+ SSL_OptionSet(server_->ssl_fd(), SSL_RECORD_SIZE_LIMIT, 16386));
+ Connect();
+}
+
+TEST_P(TlsConnectGeneric, RecordSizeGetValues) {
+ EnsureTlsSetup();
+ int v;
+ EXPECT_EQ(SECSuccess,
+ SSL_OptionGet(client_->ssl_fd(), SSL_RECORD_SIZE_LIMIT, &v));
+ EXPECT_EQ(16385, v);
+ client_->SetOption(SSL_RECORD_SIZE_LIMIT, 300);
+ EXPECT_EQ(SECSuccess,
+ SSL_OptionGet(client_->ssl_fd(), SSL_RECORD_SIZE_LIMIT, &v));
+ EXPECT_EQ(300, v);
+ Connect();
+}
+
+// The value of the extension is capped by the maximum version of the client.
+TEST_P(TlsConnectGeneric, RecordSizeCapExtensionClient) {
+ EnsureTlsSetup();
+ client_->SetOption(SSL_RECORD_SIZE_LIMIT, 16385);
+ auto capture =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_record_size_limit_xtn);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ capture->EnableDecryption();
+ }
+ Connect();
+
+ uint64_t val = 0;
+ EXPECT_TRUE(capture->extension().Read(0, 2, &val));
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ EXPECT_EQ(16384U, val) << "Extension should be capped";
+ } else {
+ EXPECT_EQ(16385U, val);
+ }
+}
+
+// The value of the extension is capped by the maximum version of the server.
+TEST_P(TlsConnectGeneric, RecordSizeCapExtensionServer) {
+ EnsureTlsSetup();
+ server_->SetOption(SSL_RECORD_SIZE_LIMIT, 16385);
+ auto capture =
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_record_size_limit_xtn);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ capture->EnableDecryption();
+ }
+ Connect();
+
+ uint64_t val = 0;
+ EXPECT_TRUE(capture->extension().Read(0, 2, &val));
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ EXPECT_EQ(16384U, val) << "Extension should be capped";
+ } else {
+ EXPECT_EQ(16385U, val);
+ }
+}
+
+// Damage the client extension and the handshake fails, but the server
+// doesn't generate a validation error.
+TEST_P(TlsConnectGenericPre13, RecordSizeClientExtensionInvalid) {
+ EnsureTlsSetup();
+ client_->SetOption(SSL_RECORD_SIZE_LIMIT, 1000);
+ static const uint8_t v[] = {0xf4, 0x1f};
+ MakeTlsFilter<TlsExtensionReplacer>(client_, ssl_record_size_limit_xtn,
+ DataBuffer(v, sizeof(v)));
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
+}
+
+// Special handling for TLS 1.3, where the alert isn't read.
+TEST_F(TlsConnectStreamTls13, RecordSizeClientExtensionInvalid) {
+ EnsureTlsSetup();
+ client_->SetOption(SSL_RECORD_SIZE_LIMIT, 1000);
+ static const uint8_t v[] = {0xf4, 0x1f};
+ MakeTlsFilter<TlsExtensionReplacer>(client_, ssl_record_size_limit_xtn,
+ DataBuffer(v, sizeof(v)));
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+}
+
+TEST_P(TlsConnectGeneric, RecordSizeServerExtensionInvalid) {
+ EnsureTlsSetup();
+ server_->SetOption(SSL_RECORD_SIZE_LIMIT, 1000);
+ static const uint8_t v[] = {0xf4, 0x1f};
+ auto replace = MakeTlsFilter<TlsExtensionReplacer>(
+ server_, ssl_record_size_limit_xtn, DataBuffer(v, sizeof(v)));
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ replace->EnableDecryption();
+ }
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+}
+
+TEST_P(TlsConnectGeneric, RecordSizeServerExtensionExtra) {
+ EnsureTlsSetup();
+ server_->SetOption(SSL_RECORD_SIZE_LIMIT, 1000);
+ static const uint8_t v[] = {0x01, 0x00, 0x00};
+ auto replace = MakeTlsFilter<TlsExtensionReplacer>(
+ server_, ssl_record_size_limit_xtn, DataBuffer(v, sizeof(v)));
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ replace->EnableDecryption();
+ }
+ ConnectExpectAlert(client_, kTlsAlertDecodeError);
+}
+
+class RecordSizeDefaultsTest : public ::testing::Test {
+ public:
+ void SetUp() {
+ EXPECT_EQ(SECSuccess,
+ SSL_OptionGetDefault(SSL_RECORD_SIZE_LIMIT, &default_));
+ }
+ void TearDown() {
+ // Make sure to restore the default value at the end.
+ EXPECT_EQ(SECSuccess,
+ SSL_OptionSetDefault(SSL_RECORD_SIZE_LIMIT, default_));
+ }
+
+ private:
+ PRIntn default_ = 0;
+};
+
+TEST_F(RecordSizeDefaultsTest, RecordSizeBadValues) {
+ EXPECT_EQ(SECFailure, SSL_OptionSetDefault(SSL_RECORD_SIZE_LIMIT, 63));
+ EXPECT_EQ(SECFailure, SSL_OptionSetDefault(SSL_RECORD_SIZE_LIMIT, -1));
+ EXPECT_EQ(SECFailure, SSL_OptionSetDefault(SSL_RECORD_SIZE_LIMIT, 16386));
+}
+
+TEST_F(RecordSizeDefaultsTest, RecordSizeGetValue) {
+ int v;
+ EXPECT_EQ(SECSuccess, SSL_OptionGetDefault(SSL_RECORD_SIZE_LIMIT, &v));
+ EXPECT_EQ(16385, v);
+ EXPECT_EQ(SECSuccess, SSL_OptionSetDefault(SSL_RECORD_SIZE_LIMIT, 3000));
+ EXPECT_EQ(SECSuccess, SSL_OptionGetDefault(SSL_RECORD_SIZE_LIMIT, &v));
+ EXPECT_EQ(3000, v);
+}
+
+class TlsCtextResizer : public TlsRecordFilter {
+ public:
+ TlsCtextResizer(const std::shared_ptr<TlsAgent>& a, size_t size)
+ : TlsRecordFilter(a), size_(size) {}
+
+ protected:
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) {
+ // allocate and initialise buffer
+ changed->Allocate(size_);
+
+ // copy record data (partially)
+ changed->Write(0, data.data(),
+ ((data.len() >= size_) ? size_ : data.len()));
+
+ return CHANGE;
+ }
+
+ private:
+ size_t size_;
+};
+
+/* (D)TLS overlong record test for maximum default record size of
+ * 2^14 + (256 (TLS 1.3) OR 2048 (TLS <= 1.2)
+ * [RFC8446, Section 5.2; RFC5246 , Section 6.2.3].
+ * This should fail the first size check in ssl3gthr.c/ssl3_GatherData().
+ * DTLS Record errors are dropped silently. [RFC6347, Section 4.1.2.7]. */
+TEST_P(TlsConnectGeneric, RecordGatherOverlong) {
+ EnsureTlsSetup();
+
+ size_t max_ctext = MAX_FRAGMENT_LENGTH;
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ max_ctext += TLS_1_3_MAX_EXPANSION;
+ } else {
+ max_ctext += TLS_1_2_MAX_EXPANSION;
+ }
+
+ Connect();
+
+ MakeTlsFilter<TlsCtextResizer>(server_, max_ctext + 1);
+ // Dummy record will be overwritten
+ server_->SendData(0xf0);
+
+ /* Drop DTLS Record Errors silently [RFC6347, Section 4.1.2.7]. */
+ if (variant_ == ssl_variant_datagram) {
+ size_t received = client_->received_bytes();
+ client_->ReadBytes(max_ctext + 1);
+ ASSERT_EQ(received, client_->received_bytes());
+ } else {
+ client_->ExpectSendAlert(kTlsAlertRecordOverflow);
+ client_->ReadBytes(max_ctext + 1);
+ server_->ExpectReceiveAlert(kTlsAlertRecordOverflow);
+ server_->Handshake();
+ }
+}
+
+/* (D)TLS overlong record test with recordSizeLimit Extension and plus RFC
+ * specified maximum Expansion: 2^14 + (256 (TLS 1.3) OR 2048 (TLS <= 1.2)
+ * [RFC8446, Section 5.2; RFC5246 , Section 6.2.3].
+ * DTLS Record errors are dropped silently. [RFC6347, Section 4.1.2.7]. */
+TEST_P(TlsConnectGeneric, RecordSizeExtensionOverlong) {
+ EnsureTlsSetup();
+
+ // Set some boundary
+ size_t max_ctext = 1000;
+
+ client_->SetOption(SSL_RECORD_SIZE_LIMIT, max_ctext);
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // The record size limit includes the inner content type byte
+ max_ctext += TLS_1_3_MAX_EXPANSION - 1;
+ } else {
+ max_ctext += TLS_1_2_MAX_EXPANSION;
+ }
+
+ Connect();
+
+ MakeTlsFilter<TlsCtextResizer>(server_, max_ctext + 1);
+ // Dummy record will be overwritten
+ server_->SendData(0xf);
+
+ /* Drop DTLS Record Errors silently [RFC6347, Section 4.1.2.7].
+ * For DTLS 1.0 and 1.2 the package is dropped before the size check because
+ * of the modification. This just tests that no error is thrown as required.
+ */
+ if (variant_ == ssl_variant_datagram) {
+ size_t received = client_->received_bytes();
+ client_->ReadBytes(max_ctext + 1);
+ ASSERT_EQ(received, client_->received_bytes());
+ } else {
+ client_->ExpectSendAlert(kTlsAlertRecordOverflow);
+ client_->ReadBytes(max_ctext + 1);
+ server_->ExpectReceiveAlert(kTlsAlertRecordOverflow);
+ server_->Handshake();
+ }
+}
+
+/* For TLS <= 1.2:
+ * MAX_EXPANSION is the amount by which a record might plausibly be expanded
+ * when protected. It's the worst case estimate, so the sum of block cipher
+ * padding (up to 256 octets), HMAC (48 octets for SHA-384), and IV (16
+ * octets for AES). */
+#define MAX_EXPANSION (256 + 48 + 16)
+
+/* (D)TLS overlong record test for specific ciphersuite expansion.
+ * Testing the smallest illegal record.
+ * This check is performed in ssl3con.c/ssl3_UnprotectRecord() OR
+ * tls13con.c/tls13_UnprotectRecord() and enforces stricter size limitations,
+ * dependent on the implemented cipher suites, than the RFC.
+ * DTLS Record errors are dropped silently. [RFC6347, Section 4.1.2.7]. */
+TEST_P(TlsConnectGeneric, RecordExpansionOverlong) {
+ EnsureTlsSetup();
+
+ // Set some boundary
+ size_t max_ctext = 1000;
+
+ client_->SetOption(SSL_RECORD_SIZE_LIMIT, max_ctext);
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // For TLS1.3 all ciphers expand the cipherext by 16B
+ // The inner content type byte is included in the record size limit
+ max_ctext += 16;
+ } else {
+ // For TLS<=1.2 the max possible expansion in the NSS implementation is 320
+ max_ctext += MAX_EXPANSION;
+ }
+
+ Connect();
+
+ MakeTlsFilter<TlsCtextResizer>(server_, max_ctext + 1);
+ // Dummy record will be overwritten
+ server_->SendData(0xf);
+
+ /* Drop DTLS Record Errors silently [RFC6347, Section 4.1.2.7].
+ * For DTLS 1.0 and 1.2 the package is dropped before the size check because
+ * of the modification. This just tests that no error is thrown as required/
+ * no bytes are received. */
+ if (variant_ == ssl_variant_datagram) {
+ size_t received = client_->received_bytes();
+ client_->ReadBytes(max_ctext + 1);
+ ASSERT_EQ(received, client_->received_bytes());
+ } else {
+ client_->ExpectSendAlert(kTlsAlertRecordOverflow);
+ client_->ReadBytes(max_ctext + 1);
+ server_->ExpectReceiveAlert(kTlsAlertRecordOverflow);
+ server_->Handshake();
+ }
+}
+
+/* (D)TLS longest allowed record default size test. */
+TEST_P(TlsConnectGeneric, RecordSizeDefaultLong) {
+ EnsureTlsSetup();
+ Connect();
+
+ // Maximum allowed plaintext size
+ size_t max = MAX_FRAGMENT_LENGTH;
+
+ /* For TLS 1.0 the first byte of application data is sent in a single record
+ * as explained in the documentation of SSL_CBC_RANDOM_IV in ssl.h.
+ * Because of that we use TlsCTextResizer to send a record of max size.
+ * A bad record mac alert is expected since we modify the record. */
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0 &&
+ variant_ == ssl_variant_stream) {
+ // Set size to maxi plaintext + max allowed expansion
+ MakeTlsFilter<TlsCtextResizer>(server_, max + MAX_EXPANSION);
+ // Dummy record will be overwritten
+ server_->SendData(0xF);
+ // Expect alert
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ // Receive record
+ client_->ReadBytes(max);
+ // Handle alert on server side
+ server_->ExpectReceiveAlert(kTlsAlertBadRecordMac);
+ server_->Handshake();
+ } else { // Everything but TLS 1.0
+ // Send largest legal plaintext as single record
+ // by setting SendData() block size to max.
+ server_->SendData(max, max);
+ // Receive record
+ client_->ReadBytes(max);
+ // Assert that data was received successfully
+ ASSERT_EQ(client_->received_bytes(), max);
+ }
+}
+
+/* (D)TLS longest allowed record size limit extension test. */
+TEST_P(TlsConnectGeneric, RecordSizeLimitLong) {
+ EnsureTlsSetup();
+
+ // Set some boundary
+ size_t max = 1000;
+ client_->SetOption(SSL_RECORD_SIZE_LIMIT, max);
+
+ Connect();
+
+ // For TLS 1.3 the InnerContentType byte is included in the record size limit
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
+ max--;
+ }
+
+ /* For TLS 1.0 the first byte of application data is sent in a single record
+ * as explained in the documentation of SSL_CBC_RANDOM_IV in ssl.h.
+ * Because of that we use TlsCTextResizer to send a record of max size.
+ * A bad record mac alert is expected since we modify the record. */
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0 &&
+ variant_ == ssl_variant_stream) {
+ // Set size to maxi plaintext + max allowed expansion
+ MakeTlsFilter<TlsCtextResizer>(server_, max + MAX_EXPANSION);
+ // Dummy record will be overwritten
+ server_->SendData(0xF);
+ // Expect alert
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ // Receive record
+ client_->ReadBytes(max);
+ // Handle alert on server side
+ server_->ExpectReceiveAlert(kTlsAlertBadRecordMac);
+ server_->Handshake();
+ } else { // Everything but TLS 1.0
+ // Send largest legal plaintext as single record
+ // by setting SendData() block size to max.
+ server_->SendData(max, max);
+ // Receive record
+ client_->ReadBytes(max);
+ // Assert that data was received successfully
+ ASSERT_EQ(client_->received_bytes(), max);
+ }
+}
+
+} // namespace nss_test \ No newline at end of file
diff --git a/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc
new file mode 100644
index 0000000000..3f7074a096
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc
@@ -0,0 +1,235 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+// 1.3 is disabled in the next few tests because we don't
+// presently support resumption in 1.3.
+TEST_P(TlsConnectStreamPre13, RenegotiateClient) {
+ Connect();
+ server_->PrepareForRenegotiate();
+ client_->StartRenegotiate();
+ Handshake();
+ CheckConnected();
+}
+
+TEST_P(TlsConnectStreamPre13, RenegotiateServer) {
+ Connect();
+ client_->PrepareForRenegotiate();
+ server_->StartRenegotiate();
+ Handshake();
+ CheckConnected();
+}
+
+TEST_P(TlsConnectStreamPre13, RenegotiateRandoms) {
+ SSL3Random crand1, crand2, srand1, srand2;
+ Connect();
+ EXPECT_EQ(SECSuccess,
+ SSLInt_GetHandshakeRandoms(client_->ssl_fd(), crand1, srand1));
+
+ // Renegotiate and check that both randoms have changed.
+ client_->PrepareForRenegotiate();
+ server_->StartRenegotiate();
+ Handshake();
+ CheckConnected();
+ EXPECT_EQ(SECSuccess,
+ SSLInt_GetHandshakeRandoms(client_->ssl_fd(), crand2, srand2));
+
+ EXPECT_NE(0, memcmp(crand1, crand2, sizeof(SSL3Random)));
+ EXPECT_NE(0, memcmp(srand1, srand2, sizeof(SSL3Random)));
+}
+
+// The renegotiation options shouldn't cause an error if TLS 1.3 is chosen.
+TEST_F(TlsConnectTest, RenegotiationConfigTls13) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetOption(SSL_ENABLE_RENEGOTIATION, SSL_RENEGOTIATE_UNRESTRICTED);
+ server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE);
+ Connect();
+ SendReceive();
+ CheckKeys();
+}
+
+TEST_P(TlsConnectStream, ConnectTls10AndServerRenegotiateHigher) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ GTEST_SKIP();
+ }
+ // Set the client so it will accept any version from 1.0
+ // to |version_|.
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ // Reset version so that the checks succeed.
+ uint16_t test_version = version_;
+ version_ = SSL_LIBRARY_VERSION_TLS_1_0;
+ Connect();
+
+ // Now renegotiate, with the server being set to do
+ // |version_|.
+ client_->PrepareForRenegotiate();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version);
+ // Reset version and cipher suite so that the preinfo callback
+ // doesn't fail.
+ server_->ResetPreliminaryInfo();
+ server_->StartRenegotiate();
+
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ } else {
+ ExpectAlert(server_, kTlsAlertProtocolVersion);
+ }
+
+ Handshake();
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // In TLS 1.3, the server detects this problem.
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED);
+ } else {
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+ }
+}
+
+TEST_P(TlsConnectStream, ConnectTls10AndClientRenegotiateHigher) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ GTEST_SKIP();
+ }
+ // Set the client so it will accept any version from 1.0
+ // to |version_|.
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ // Reset version so that the checks succeed.
+ uint16_t test_version = version_;
+ version_ = SSL_LIBRARY_VERSION_TLS_1_0;
+ Connect();
+
+ // Now renegotiate, with the server being set to do
+ // |version_|.
+ server_->PrepareForRenegotiate();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version);
+ // Reset version and cipher suite so that the preinfo callback
+ // doesn't fail.
+ server_->ResetPreliminaryInfo();
+ client_->StartRenegotiate();
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ } else {
+ ExpectAlert(server_, kTlsAlertProtocolVersion);
+ }
+ Handshake();
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // In TLS 1.3, the server detects this problem.
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED);
+ } else {
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+ }
+}
+
+TEST_P(TlsConnectStream, ConnectAndServerRenegotiateLower) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ GTEST_SKIP();
+ }
+ Connect();
+
+ // Now renegotiate with the server set to TLS 1.0.
+ client_->PrepareForRenegotiate();
+ server_->PrepareForRenegotiate();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
+ // Reset version and cipher suite so that the preinfo callback
+ // doesn't fail.
+ server_->ResetPreliminaryInfo();
+
+ SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ EXPECT_EQ(SECFailure, rv);
+ return;
+ }
+ ASSERT_EQ(SECSuccess, rv);
+
+ // Now, before handshaking, tweak the server configuration.
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+
+ // The server should catch the own error.
+ ExpectAlert(server_, kTlsAlertProtocolVersion);
+
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+}
+
+TEST_P(TlsConnectStream, ConnectAndServerWontRenegotiateLower) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ GTEST_SKIP();
+ }
+ Connect();
+
+ // Now renegotiate with the server set to TLS 1.0.
+ client_->PrepareForRenegotiate();
+ server_->PrepareForRenegotiate();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ // Reset version and cipher suite so that the preinfo callback
+ // doesn't fail.
+ server_->ResetPreliminaryInfo();
+
+ EXPECT_EQ(SECFailure, SSL_ReHandshake(server_->ssl_fd(), PR_TRUE));
+}
+
+TEST_P(TlsConnectStream, ConnectAndClientWontRenegotiateLower) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ GTEST_SKIP();
+ }
+ Connect();
+
+ // Now renegotiate with the client set to TLS 1.0.
+ client_->PrepareForRenegotiate();
+ server_->PrepareForRenegotiate();
+ server_->ResetPreliminaryInfo();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ // The client will refuse to renegotiate down.
+ EXPECT_EQ(SECFailure, SSL_ReHandshake(client_->ssl_fd(), PR_TRUE));
+}
+
+TEST_F(TlsConnectTest, Tls13RejectsRehandshakeClient) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ SECStatus rv = SSL_ReHandshake(client_->ssl_fd(), PR_TRUE);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError());
+}
+
+TEST_F(TlsConnectTest, Tls13RejectsRehandshakeServer) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError());
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc
new file mode 100644
index 0000000000..2e23fc096a
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc
@@ -0,0 +1,1522 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslexp.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "scoped_ptrs_ssl.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+#include "tls_protect.h"
+
+namespace nss_test {
+
+class TlsServerKeyExchangeEcdhe {
+ public:
+ bool Parse(const DataBuffer& buffer) {
+ TlsParser parser(buffer);
+
+ uint8_t curve_type;
+ if (!parser.Read(&curve_type)) {
+ return false;
+ }
+
+ if (curve_type != 3) { // named_curve
+ return false;
+ }
+
+ uint32_t named_curve;
+ if (!parser.Read(&named_curve, 2)) {
+ return false;
+ }
+
+ return parser.ReadVariable(&public_key_, 1);
+ }
+
+ DataBuffer public_key_;
+};
+
+TEST_P(TlsConnectGenericPre13, ConnectResumed) {
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
+ Connect();
+
+ Reset();
+ ExpectResumption(RESUME_SESSIONID);
+ Connect();
+}
+
+TEST_P(TlsConnectGenericResumption, ConnectClientCacheDisabled) {
+ ConfigureSessionCache(RESUME_NONE, RESUME_SESSIONID);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGenericResumption, ConnectServerCacheDisabled) {
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_NONE);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGenericResumption, ConnectSessionCacheDisabled) {
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGenericResumption, ConnectResumeSupportBoth) {
+ // This prefers tickets.
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGenericResumption, ConnectResumeClientTicketServerBoth) {
+ // This causes no resumption because the client needs the
+ // session cache to resume even with tickets.
+ ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGenericResumption, ConnectResumeClientBothTicketServerTicket) {
+ // This causes a ticket resumption.
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGenericResumption, ConnectResumeClientServerTicketOnly) {
+ // This causes no resumption because the client needs the
+ // session cache to resume even with tickets.
+ ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGenericResumption, ConnectResumeClientBothServerNone) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_NONE);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_NONE);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGenericResumption, ConnectResumeClientNoneServerBoth) {
+ ConfigureSessionCache(RESUME_NONE, RESUME_BOTH);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_NONE, RESUME_BOTH);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGenericPre13, ResumeWithHigherVersionTls13) {
+ uint16_t lower_version = version_;
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ Connect();
+ SendReceive();
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ EnsureTlsSetup();
+ auto psk_ext = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_tls13_pre_shared_key_xtn);
+ auto ticket_ext =
+ std::make_shared<TlsExtensionCapture>(client_, ssl_session_ticket_xtn);
+ client_->SetFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit({psk_ext, ticket_ext})));
+ SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(lower_version, SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(lower_version, SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+
+ // The client shouldn't have sent a PSK, though it will send a ticket.
+ EXPECT_FALSE(psk_ext->captured());
+ EXPECT_TRUE(ticket_ext->captured());
+}
+
+class CaptureSessionId : public TlsHandshakeFilter {
+ public:
+ CaptureSessionId(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(
+ a, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}),
+ sid_() {}
+
+ const DataBuffer& sid() const { return sid_; }
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ // The session_id is in the same place in both Hello messages:
+ size_t offset = 2 + 32; // Version(2) + Random(32)
+ uint32_t len = 0;
+ EXPECT_TRUE(input.Read(offset, 1, &len));
+ offset++;
+ if (input.len() < offset + len) {
+ ADD_FAILURE() << "session_id overflows the Hello message";
+ return KEEP;
+ }
+ sid_.Assign(input.data() + offset, len);
+ return KEEP;
+ }
+
+ private:
+ DataBuffer sid_;
+};
+
+// Attempting to resume from TLS 1.2 when 1.3 is possible should not result in
+// resumption, though it will appear to be TLS 1.3 compatibility mode if the
+// server uses a session ID.
+TEST_P(TlsConnectGenericPre13, ResumeWithHigherVersionTls13SessionId) {
+ uint16_t lower_version = version_;
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
+ auto original_sid = MakeTlsFilter<CaptureSessionId>(server_);
+ Connect();
+ CheckKeys();
+ EXPECT_EQ(32U, original_sid->sid().len());
+
+ // The client should now attempt to resume with the session ID from the last
+ // connection. This looks like compatibility mode, we just want to ensure
+ // that we get TLS 1.3 rather than 1.2 (and no resumption).
+ Reset();
+ auto client_sid = MakeTlsFilter<CaptureSessionId>(client_);
+ auto server_sid = MakeTlsFilter<CaptureSessionId>(server_);
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
+ SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(lower_version, SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(lower_version, SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_NONE);
+
+ Connect();
+ SendReceive();
+
+ EXPECT_EQ(client_sid->sid(), original_sid->sid());
+ if (variant_ == ssl_variant_stream) {
+ EXPECT_EQ(client_sid->sid(), server_sid->sid());
+ } else {
+ // DTLS servers don't echo the session ID.
+ EXPECT_EQ(0U, server_sid->sid().len());
+ }
+}
+
+TEST_P(TlsConnectPre12, ResumeWithHigherVersionTls12) {
+ uint16_t lower_version = version_;
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ Connect();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ EnsureTlsSetup();
+ SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(lower_version, SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(lower_version, SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, ResumeWithLowerVersionFromTls13) {
+ uint16_t original_version = version_;
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ SendReceive();
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ConfigureVersion(original_version);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectPre12, ResumeWithLowerVersionFromTls12) {
+ uint16_t original_version = version_;
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_2);
+ Connect();
+ SendReceive();
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ConfigureVersion(original_version);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicketForget) {
+ // This causes a ticket resumption.
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ClearServerCache();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+// Tickets last two days maximum; this is a time longer than that.
+static const PRTime kLongerThanTicketLifetime =
+ 3LL * 24 * 60 * 60 * PR_USEC_PER_SEC;
+
+TEST_P(TlsConnectGenericResumption, ConnectWithExpiredTicketAtClient) {
+ // This causes a ticket resumption.
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ AdvanceTime(kLongerThanTicketLifetime);
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_NONE);
+
+ // TLS 1.3 uses the pre-shared key extension instead.
+ SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3)
+ ? ssl_tls13_pre_shared_key_xtn
+ : ssl_session_ticket_xtn;
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, xtn);
+ Connect();
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ EXPECT_FALSE(capture->captured());
+ } else {
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(0U, capture->extension().len());
+ }
+}
+
+TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtServer) {
+ // This causes a ticket resumption.
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_NONE);
+
+ SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3)
+ ? ssl_tls13_pre_shared_key_xtn
+ : ssl_session_ticket_xtn;
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, xtn);
+ StartConnect();
+ client_->Handshake();
+ EXPECT_TRUE(capture->captured());
+ EXPECT_LT(0U, capture->extension().len());
+
+ AdvanceTime(kLongerThanTicketLifetime);
+
+ Handshake();
+ CheckConnected();
+}
+
+TEST_P(TlsConnectGeneric, ConnectResumeCorruptTicket) {
+ // This causes a ticket resumption.
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ static const uint8_t kHmacKey1Buf[32] = {0};
+ static const DataBuffer kHmacKey1(kHmacKey1Buf, sizeof(kHmacKey1Buf));
+
+ SECItem key_item = {siBuffer, const_cast<uint8_t*>(kHmacKey1Buf),
+ sizeof(kHmacKey1Buf)};
+
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ PK11SymKey* hmac_key =
+ PK11_ImportSymKey(slot.get(), CKM_SHA256_HMAC, PK11_OriginUnwrap,
+ CKA_SIGN, &key_item, nullptr);
+ ASSERT_NE(nullptr, hmac_key);
+ SSLInt_SetSelfEncryptMacKey(hmac_key);
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ } else {
+ ConnectExpectAlert(server_, illegal_parameter);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+ }
+}
+
+// This callback switches out the "server" cert used on the server with
+// the "client" certificate, which should be the same type.
+static int32_t SwitchCertificates(TlsAgent* agent, const SECItem* srvNameArr,
+ uint32_t srvNameArrSize) {
+ bool ok = agent->ConfigServerCert("client");
+ if (!ok) return SSL_SNI_SEND_ALERT;
+
+ return 0; // first config
+};
+
+TEST_P(TlsConnectGeneric, ServerSNICertSwitch) {
+ Connect();
+ ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert1.get());
+
+ Reset();
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+
+ server_->SetSniCallback(SwitchCertificates);
+
+ Connect();
+ ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert2.get());
+ CheckKeys();
+ EXPECT_FALSE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert));
+}
+
+TEST_P(TlsConnectGeneric, ServerSNICertTypeSwitch) {
+ Reset(TlsAgent::kServerEcdsa256);
+ Connect();
+ ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert1.get());
+
+ Reset();
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+
+ // Because we configure an RSA certificate here, it only adds a second, unused
+ // certificate, which has no effect on what the server uses.
+ server_->SetSniCallback(SwitchCertificates);
+
+ Connect();
+ ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert2.get());
+ CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
+ EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert));
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) {
+ auto filter = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ EnableECDHEServerKeyReuse();
+ Connect();
+ CheckKeys();
+ TlsServerKeyExchangeEcdhe dhe1;
+ EXPECT_TRUE(dhe1.Parse(filter->buffer()));
+
+ // Restart
+ Reset();
+ EnableECDHEServerKeyReuse();
+ auto filter2 = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ Connect();
+ CheckKeys();
+
+ TlsServerKeyExchangeEcdhe dhe2;
+ EXPECT_TRUE(dhe2.Parse(filter2->buffer()));
+
+ // Make sure they are the same.
+ EXPECT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len());
+ EXPECT_TRUE(!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
+ dhe1.public_key_.len()));
+}
+
+// This test parses the ServerKeyExchange, which isn't in 1.3
+TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) {
+ auto filter = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ Connect();
+ CheckKeys();
+ TlsServerKeyExchangeEcdhe dhe1;
+ EXPECT_TRUE(dhe1.Parse(filter->buffer()));
+
+ // Restart
+ Reset();
+ auto filter2 = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ Connect();
+ CheckKeys();
+
+ TlsServerKeyExchangeEcdhe dhe2;
+ EXPECT_TRUE(dhe2.Parse(filter2->buffer()));
+
+ // Make sure they are different.
+ EXPECT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) &&
+ (!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
+ dhe1.public_key_.len())));
+}
+
+// Verify that TLS 1.3 reports an accurate group on resumption.
+TEST_P(TlsConnectTls13, TestTls13ResumeDifferentGroup) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ client_->ConfigNamedGroups(kFFDHEGroups);
+ server_->ConfigNamedGroups(kFFDHEGroups);
+ Connect();
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+}
+
+// Verify that TLS 1.3 server doesn't request certificate in the main
+// handshake, after resumption.
+TEST_P(TlsConnectTls13, TestTls13ResumeNoCertificateRequest) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ ScopedCERTCertificate cert1(SSL_LocalCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert1.get());
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ server_->RequestClientAuth(false);
+ auto cr_capture =
+ MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_certificate_request);
+ cr_capture->EnableDecryption();
+ Connect();
+ SendReceive();
+ EXPECT_EQ(0U, cr_capture->buffer().len()) << "expect nothing captured yet";
+
+ // Sanity check whether the client certificate matches the one
+ // decrypted from ticket.
+ ScopedCERTCertificate cert2(SSL_PeerCertificate(server_->ssl_fd()));
+ ASSERT_NE(nullptr, cert2.get());
+ EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert));
+}
+
+// Here we test that 0.5 RTT is available at the server when resuming, even if
+// configured to request a client certificate. The resumed handshake relies on
+// the authentication from the original handshake, so no certificate is
+// requested this time around. The server can write before the handshake
+// completes because the PSK binder is sufficient authentication for the client.
+TEST_P(TlsConnectTls13, WriteBeforeHandshakeCompleteOnResumption) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ Connect();
+ SendReceive(); // Absorb the session ticket.
+ ScopedCERTCertificate cert1(SSL_LocalCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert1.get());
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ server_->RequestClientAuth(false);
+ StartConnect();
+ client_->Handshake(); // ClientHello
+ server_->Handshake(); // ServerHello
+
+ server_->SendData(10);
+ client_->ReadBytes(10); // Client should emit the Finished as a side-effect.
+ server_->Handshake(); // Server consumes the Finished.
+ CheckConnected();
+
+ // Check whether the client certificate matches the one from the ticket.
+ ScopedCERTCertificate cert2(SSL_PeerCertificate(server_->ssl_fd()));
+ ASSERT_NE(nullptr, cert2.get());
+ EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert));
+}
+
+// We need to enable different cipher suites at different times in the following
+// tests. Those cipher suites need to be suited to the version.
+static uint16_t ChooseOneCipher(uint16_t version) {
+ if (version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ return TLS_AES_128_GCM_SHA256;
+ }
+ return TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA;
+}
+
+static uint16_t ChooseIncompatibleCipher(uint16_t version) {
+ if (version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ return TLS_AES_256_GCM_SHA384;
+ }
+ return TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA;
+}
+
+// Test that we don't resume when we can't negotiate the same cipher. Note that
+// for TLS 1.3, resumption is allowed between compatible ciphers, that is those
+// with the same KDF hash, but we choose an incompatible one here.
+TEST_P(TlsConnectGenericResumption, ResumeClientIncompatibleCipher) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ client_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_NONE);
+ client_->EnableSingleCipher(ChooseIncompatibleCipher(version_));
+ uint16_t ticket_extension;
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ticket_extension = ssl_tls13_pre_shared_key_xtn;
+ } else {
+ ticket_extension = ssl_session_ticket_xtn;
+ }
+ auto ticket_capture =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ticket_extension);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+ EXPECT_EQ(0U, ticket_capture->extension().len());
+}
+
+// Test that we don't resume when we can't negotiate the same cipher.
+TEST_P(TlsConnectGenericResumption, ResumeServerIncompatibleCipher) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ server_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive(); // Absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_NONE);
+ server_->EnableSingleCipher(ChooseIncompatibleCipher(version_));
+ Connect();
+ CheckKeys();
+}
+
+// Test that the client doesn't tolerate the server picking a different cipher
+// suite for resumption.
+TEST_P(TlsConnectStream, ResumptionOverrideCipher) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ server_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ MakeTlsFilter<SelectedCipherSuiteReplacer>(
+ server_, ChooseIncompatibleCipher(version_));
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ } else {
+ ExpectAlert(client_, kTlsAlertHandshakeFailure);
+ }
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // The reason this test is stream only: the server is unable to decrypt
+ // the alert that the client sends, see bug 1304603.
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE);
+ } else {
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT);
+ }
+}
+
+// In TLS 1.3, it is possible to resume with a different cipher if it has the
+// same hash.
+TEST_P(TlsConnectTls13, ResumeClientCompatibleCipher) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ Connect();
+ SendReceive(); // Absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ client_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256);
+ Connect();
+ CheckKeys();
+}
+
+TEST_P(TlsConnectTls13, ResumeServerCompatibleCipher) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ Connect();
+ SendReceive(); // Absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ server_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256);
+ Connect();
+ CheckKeys();
+}
+
+class SelectedVersionReplacer : public TlsHandshakeFilter {
+ public:
+ SelectedVersionReplacer(const std::shared_ptr<TlsAgent>& a, uint16_t version)
+ : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}), version_(version) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ *output = input;
+ output->Write(0, static_cast<uint32_t>(version_), 2);
+ return CHANGE;
+ }
+
+ private:
+ uint16_t version_;
+};
+
+// Test how the client handles the case where the server picks a
+// lower version number on resumption.
+TEST_P(TlsConnectGenericPre13, TestResumptionOverrideVersion) {
+ uint16_t override_version = 0;
+ if (variant_ == ssl_variant_stream) {
+ switch (version_) {
+ case SSL_LIBRARY_VERSION_TLS_1_0:
+ GTEST_SKIP();
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ override_version = SSL_LIBRARY_VERSION_TLS_1_0;
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ override_version = SSL_LIBRARY_VERSION_TLS_1_1;
+ break;
+ default:
+ ASSERT_TRUE(false) << "unknown version";
+ }
+ } else {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_2) {
+ override_version = SSL_LIBRARY_VERSION_DTLS_1_0_WIRE;
+ } else {
+ ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_1, version_);
+ GTEST_SKIP();
+ }
+ }
+
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ // Need to use a cipher that is plausible for the lower version.
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ // Enable the lower version on the client.
+ client_->SetVersionRange(version_ - 1, version_);
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ MakeTlsFilter<SelectedVersionReplacer>(server_, override_version);
+
+ ConnectExpectAlert(client_, kTlsAlertHandshakeFailure);
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT);
+}
+
+// Test that two TLS resumptions work and produce the same ticket.
+// This will change after bug 1257047 is fixed.
+TEST_F(TlsConnectTest, TestTls13ResumptionTwice) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+ uint16_t original_suite;
+ EXPECT_TRUE(client_->cipher_suite(&original_suite));
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_TICKET);
+ auto c1 =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn);
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ // The filter will go away when we reset, so save the captured extension.
+ DataBuffer initialTicket(c1->extension());
+ ASSERT_LT(0U, initialTicket.len());
+
+ ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert1.get());
+
+ Reset();
+ ClearStats();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ auto c2 =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ ASSERT_LT(0U, c2->extension().len());
+
+ ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd()));
+ ASSERT_NE(nullptr, cert2.get());
+
+ // Check that the cipher suite is reported the same on both sides, though in
+ // TLS 1.3 resumption actually negotiates a different cipher suite.
+ uint16_t resumed_suite;
+ EXPECT_TRUE(server_->cipher_suite(&resumed_suite));
+ EXPECT_EQ(original_suite, resumed_suite);
+ EXPECT_TRUE(client_->cipher_suite(&resumed_suite));
+ EXPECT_EQ(original_suite, resumed_suite);
+
+ ASSERT_NE(initialTicket, c2->extension());
+}
+
+// Check that resumption works after receiving two NST messages.
+TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNST) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+
+ // Clear the session ticket keys to invalidate the old ticket.
+ ClearServerCache();
+ EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0));
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+
+ // Resume the connection.
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+}
+
+// Check that the value captured in a NewSessionTicket message matches the value
+// captured from a pre_shared_key extension.
+void NstTicketMatchesPskIdentity(const DataBuffer& nst, const DataBuffer& psk) {
+ uint32_t len;
+
+ size_t offset = 4 + 4; // Skip ticket_lifetime and ticket_age_add.
+ ASSERT_TRUE(nst.Read(offset, 1, &len));
+ offset += 1 + len; // Skip ticket_nonce.
+
+ ASSERT_TRUE(nst.Read(offset, 2, &len));
+ offset += 2; // Skip the ticket length.
+ ASSERT_LE(offset + len, nst.len());
+ DataBuffer nst_ticket(nst.data() + offset, static_cast<size_t>(len));
+
+ offset = 2; // Skip the identities length.
+ ASSERT_TRUE(psk.Read(offset, 2, &len));
+ offset += 2; // Skip the identity length.
+ ASSERT_LE(offset + len, psk.len());
+ DataBuffer psk_ticket(psk.data() + offset, static_cast<size_t>(len));
+
+ EXPECT_EQ(nst_ticket, psk_ticket);
+}
+
+TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNSTWithToken) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ auto nst_capture =
+ MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_new_session_ticket);
+ nst_capture->EnableDecryption();
+ Connect();
+
+ // Clear the session ticket keys to invalidate the old ticket.
+ ClearServerCache();
+ nst_capture->Reset();
+ uint8_t token[] = {0x20, 0x20, 0xff, 0x00};
+ EXPECT_EQ(SECSuccess,
+ SSL_SendSessionTicket(server_->ssl_fd(), token, sizeof(token)));
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+ EXPECT_LT(0U, nst_capture->buffer().len());
+
+ // Resume the connection.
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_TICKET);
+
+ auto psk_capture =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn);
+ Connect();
+ SendReceive();
+
+ NstTicketMatchesPskIdentity(nst_capture->buffer(), psk_capture->extension());
+}
+
+// Disable SSL_ENABLE_SESSION_TICKETS but ensure that tickets can still be sent
+// by invoking SSL_SendSessionTicket directly (and that the ticket is usable).
+TEST_F(TlsConnectTest, SendSessionTicketWithTicketsDisabled) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ server_->SetOption(SSL_ENABLE_SESSION_TICKETS, PR_FALSE);
+
+ auto nst_capture =
+ MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_new_session_ticket);
+ nst_capture->EnableDecryption();
+ Connect();
+
+ EXPECT_EQ(0U, nst_capture->buffer().len()) << "expect nothing captured yet";
+
+ EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0));
+ EXPECT_LT(0U, nst_capture->buffer().len()) << "should capture now";
+
+ SendReceive(); // Ensure that the client reads the ticket.
+
+ // Resume the connection.
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_TICKET);
+
+ auto psk_capture =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn);
+ Connect();
+ SendReceive();
+
+ NstTicketMatchesPskIdentity(nst_capture->buffer(), psk_capture->extension());
+}
+
+// Successfully send a session ticket after resuming and then use it.
+TEST_F(TlsConnectTest, SendTicketAfterResumption) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+
+ // Resume the connection.
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_TICKET);
+
+ // We need to capture just one ticket, so
+ // disable automatic sending of tickets at the server.
+ // ConfigureSessionCache enables this option, so revert that.
+ server_->SetOption(SSL_ENABLE_SESSION_TICKETS, PR_FALSE);
+ auto nst_capture =
+ MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_new_session_ticket);
+ nst_capture->EnableDecryption();
+ Connect();
+
+ ClearServerCache();
+ EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0));
+ SendReceive();
+
+ // Reset stats so that the counters for resumptions match up.
+ ClearStats();
+ // Resume again and ensure that we get the same ticket.
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_TICKET);
+
+ auto psk_capture =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn);
+ Connect();
+ SendReceive();
+
+ NstTicketMatchesPskIdentity(nst_capture->buffer(), psk_capture->extension());
+}
+
+// Test calling SSL_SendSessionTicket in inappropriate conditions.
+TEST_F(TlsConnectTest, SendSessionTicketInappropriate) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_2);
+
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(client_->ssl_fd(), NULL, 0))
+ << "clients can't send tickets";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ StartConnect();
+
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0))
+ << "no ticket before the handshake has started";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ Handshake();
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0))
+ << "no special tickets in TLS 1.2";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(TlsConnectTest, SendSessionTicketMassiveToken) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ // It should be safe to set length with a NULL token because the length should
+ // be checked before reading token.
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0x1ffff))
+ << "this is clearly too big";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ static const uint8_t big_token[0xffff] = {1};
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), big_token,
+ sizeof(big_token)))
+ << "this is too big, but that's not immediately obvious";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(TlsConnectDatagram13, SendSessionTicketDtls) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0))
+ << "no extra tickets in DTLS until we have Ack support";
+ EXPECT_EQ(SSL_ERROR_FEATURE_NOT_SUPPORTED_FOR_VERSION, PORT_GetError());
+}
+
+TEST_F(TlsConnectStreamTls13, ExternalResumptionUseSecondTicket) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ struct ResumptionTicketState {
+ std::vector<uint8_t> ticket;
+ size_t invoked = 0;
+ } ticket_state;
+ auto cb = [](PRFileDesc* fd, const PRUint8* ticket, unsigned int ticket_len,
+ void* arg) -> SECStatus {
+ auto state = reinterpret_cast<ResumptionTicketState*>(arg);
+ state->ticket.assign(ticket, ticket + ticket_len);
+ state->invoked++;
+ return SECSuccess;
+ };
+ EXPECT_EQ(SECSuccess, SSL_SetResumptionTokenCallback(client_->ssl_fd(), cb,
+ &ticket_state));
+
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), nullptr, 0));
+ SendReceive();
+ EXPECT_EQ(2U, ticket_state.invoked);
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ client_->SetResumptionToken(ticket_state.ticket);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+}
+
+TEST_F(TlsConnectTest, TestTls13ResumptionDowngrade) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+
+ // Try resuming the connection. This will fail resuming the 1.3 session
+ // from before, but will successfully establish a 1.2 connection.
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ Connect();
+
+ // Renegotiate to ensure we don't carryover any state
+ // from the 1.3 resumption attempt.
+ client_->SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->PrepareForRenegotiate();
+ server_->StartRenegotiate();
+ Handshake();
+
+ SendReceive();
+ CheckKeys();
+}
+
+TEST_F(TlsConnectTest, TestTls13ResumptionForcedDowngrade) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+
+ // Try resuming the connection.
+ Reset();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ // Enable the lower version on the client.
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+
+ // Add filters that set downgrade SH.version to 1.2 and the cipher suite
+ // to one that works with 1.2, so that we don't run into early sanity checks.
+ // We will eventually fail the (sid.version == SH.version) check.
+ std::vector<std::shared_ptr<PacketFilter>> filters;
+ filters.push_back(std::make_shared<SelectedCipherSuiteReplacer>(
+ server_, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256));
+ filters.push_back(std::make_shared<SelectedVersionReplacer>(
+ server_, SSL_LIBRARY_VERSION_TLS_1_2));
+
+ // Drop a bunch of extensions so that we get past the SH processing. The
+ // version extension says TLS 1.3, which is counter to our goal, the others
+ // are not permitted in TLS 1.2 handshakes.
+ filters.push_back(std::make_shared<TlsExtensionDropper>(
+ server_, ssl_tls13_supported_versions_xtn));
+ filters.push_back(
+ std::make_shared<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn));
+ filters.push_back(std::make_shared<TlsExtensionDropper>(
+ server_, ssl_tls13_pre_shared_key_xtn));
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(filters));
+
+ // The client here generates an unexpected_message alert when it receives an
+ // encrypted handshake message from the server (EncryptedExtension). The
+ // client expects to receive an unencrypted TLS 1.2 Certificate message.
+ // The server can't decrypt the alert.
+ client_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); // Server can't read
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE);
+}
+
+TEST_P(TlsConnectGenericResumption, ReConnectTicket) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ server_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ // Resume
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ // Only the client knows this.
+ CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519,
+ ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256);
+}
+
+TEST_P(TlsConnectGenericPre13, ReConnectCache) {
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
+ server_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ // Resume
+ Reset();
+ ExpectResumption(RESUME_SESSIONID);
+ Connect();
+ CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519,
+ ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256);
+}
+
+TEST_P(TlsConnectGenericResumption, ReConnectAgainTicket) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ server_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ // Resume
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ // Only the client knows this.
+ CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519,
+ ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256);
+ // Resume connection again
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET, 2);
+ Connect();
+ // Only the client knows this.
+ CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519,
+ ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256);
+}
+
+void CheckGetInfoResult(PRTime now, uint32_t alpnSize, uint32_t earlyDataSize,
+ ScopedCERTCertificate& cert,
+ ScopedSSLResumptionTokenInfo& token) {
+ ASSERT_TRUE(cert);
+ ASSERT_TRUE(token->peerCert);
+
+ // Check that the server cert is the correct one.
+ ASSERT_EQ(cert->derCert.len, token->peerCert->derCert.len);
+ EXPECT_EQ(0, memcmp(cert->derCert.data, token->peerCert->derCert.data,
+ cert->derCert.len));
+
+ ASSERT_EQ(alpnSize, token->alpnSelectionLen);
+ EXPECT_EQ(0, memcmp("a", token->alpnSelection, token->alpnSelectionLen));
+
+ ASSERT_EQ(earlyDataSize, token->maxEarlyDataSize);
+
+ ASSERT_LT(now, token->expirationTime);
+}
+
+// The client should generate a new, randomized session_id
+// when resuming using an external token.
+TEST_P(TlsConnectGenericResumptionToken, CheckSessionId) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ auto original_sid = MakeTlsFilter<CaptureSessionId>(client_);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+
+ StartConnect();
+ ASSERT_TRUE(client_->MaybeSetResumptionToken());
+ auto resumed_sid = MakeTlsFilter<CaptureSessionId>(client_);
+
+ Handshake();
+ CheckConnected();
+ SendReceive();
+
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ EXPECT_NE(resumed_sid->sid(), original_sid->sid());
+ EXPECT_EQ(32U, resumed_sid->sid().len());
+ } else {
+ EXPECT_EQ(0U, resumed_sid->sid().len());
+ }
+}
+
+TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfo) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+
+ StartConnect();
+ ASSERT_TRUE(client_->MaybeSetResumptionToken());
+
+ // Get resumption token infos
+ SSLResumptionTokenInfo tokenInfo = {0};
+ ScopedSSLResumptionTokenInfo token(&tokenInfo);
+ client_->GetTokenInfo(token);
+ ScopedCERTCertificate cert(
+ PK11_FindCertFromNickname(server_->name().c_str(), nullptr));
+ ASSERT_NE(nullptr, cert.get());
+
+ CheckGetInfoResult(now(), 0, 0, cert, token);
+
+ Handshake();
+ CheckConnected();
+
+ SendReceive();
+}
+
+TEST_P(TlsConnectGenericResumptionToken, RefuseExpiredTicketClient) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ Connect();
+ SendReceive();
+
+ // Move the clock to the expiration time of the ticket.
+ SSLResumptionTokenInfo tokenInfo = {0};
+ ScopedSSLResumptionTokenInfo token(&tokenInfo);
+ client_->GetTokenInfo(token);
+ AdvanceTime(token->expirationTime - now());
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+
+ StartConnect();
+ ASSERT_EQ(SECFailure,
+ SSL_SetResumptionToken(client_->ssl_fd(),
+ client_->GetResumptionToken().data(),
+ client_->GetResumptionToken().size()));
+ EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError());
+}
+
+TEST_P(TlsConnectGenericResumptionToken, RefuseExpiredTicketServer) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_NONE);
+
+ // Start the handshake and send the ClientHello.
+ StartConnect();
+ ASSERT_EQ(SECSuccess,
+ SSL_SetResumptionToken(client_->ssl_fd(),
+ client_->GetResumptionToken().data(),
+ client_->GetResumptionToken().size()));
+ client_->Handshake();
+
+ // Move the clock to the expiration time of the ticket.
+ SSLResumptionTokenInfo tokenInfo = {0};
+ ScopedSSLResumptionTokenInfo token(&tokenInfo);
+ client_->GetTokenInfo(token);
+ AdvanceTime(token->expirationTime - now());
+
+ Handshake();
+ CheckConnected();
+}
+
+TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfoAlpn) {
+ EnableAlpn();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ Connect();
+ CheckAlpn("a");
+ SendReceive();
+
+ Reset();
+ EnableAlpn();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+
+ StartConnect();
+ ASSERT_TRUE(client_->MaybeSetResumptionToken());
+
+ // Get resumption token infos
+ SSLResumptionTokenInfo tokenInfo = {0};
+ ScopedSSLResumptionTokenInfo token(&tokenInfo);
+ client_->GetTokenInfo(token);
+ ScopedCERTCertificate cert(
+ PK11_FindCertFromNickname(server_->name().c_str(), nullptr));
+ ASSERT_NE(nullptr, cert.get());
+
+ CheckGetInfoResult(now(), 1, 0, cert, token);
+
+ Handshake();
+ CheckConnected();
+ CheckAlpn("a");
+
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13ResumptionToken, ConnectResumeGetInfoZeroRtt) {
+ EnableAlpn();
+ RolloverAntiReplay();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ server_->Set0RttEnabled(true);
+ Connect();
+ CheckAlpn("a");
+ SendReceive();
+
+ Reset();
+ EnableAlpn();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+
+ StartConnect();
+ server_->Set0RttEnabled(true);
+ client_->Set0RttEnabled(true);
+ ASSERT_TRUE(client_->MaybeSetResumptionToken());
+
+ // Get resumption token infos
+ SSLResumptionTokenInfo tokenInfo = {0};
+ ScopedSSLResumptionTokenInfo token(&tokenInfo);
+ client_->GetTokenInfo(token);
+ ScopedCERTCertificate cert(
+ PK11_FindCertFromNickname(server_->name().c_str(), nullptr));
+ ASSERT_NE(nullptr, cert.get());
+ CheckGetInfoResult(now(), 1, 1024, cert, token);
+
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ CheckAlpn("a");
+
+ SendReceive();
+}
+
+// Resumption on sessions with client authentication only works with internal
+// caching.
+TEST_P(TlsConnectGenericResumption, ConnectResumeClientAuth) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ Connect();
+ SendReceive();
+ EXPECT_FALSE(client_->resumption_callback_called());
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ if (use_external_cache()) {
+ ExpectResumption(RESUME_NONE);
+ } else {
+ ExpectResumption(RESUME_TICKET);
+ }
+ Connect();
+ SendReceive();
+}
+
+// Check that resumption is blocked if the server requires client auth.
+TEST_P(TlsConnectGenericResumption, ClientAuthRequiredOnResumption) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ server_->RequestClientAuth(false);
+ Connect();
+ SendReceive();
+
+ Reset();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+// Check that resumption is blocked if the server requires client auth and
+// the client fails to provide a certificate.
+TEST_P(TlsConnectGenericResumption, ClientAuthRequiredOnResumptionNoCert) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ server_->RequestClientAuth(false);
+ Connect();
+ SendReceive();
+
+ Reset();
+ server_->RequestClientAuth(true);
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ // Drive handshake manually because TLS 1.3 needs it.
+ StartConnect();
+ client_->Handshake(); // CH
+ server_->Handshake(); // SH.. (no resumption)
+ client_->Handshake(); // ...
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // In TLS 1.3, the client thinks that everything is OK here.
+ ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ ExpectAlert(server_, kTlsAlertCertificateRequired);
+ server_->Handshake(); // Alert
+ client_->Handshake(); // Receive Alert
+ client_->CheckErrorCode(SSL_ERROR_RX_CERTIFICATE_REQUIRED_ALERT);
+ } else {
+ ExpectAlert(server_, kTlsAlertBadCertificate);
+ server_->Handshake(); // Alert
+ client_->Handshake(); // Receive Alert
+ client_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT);
+ }
+ server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE);
+}
+
+TEST_F(TlsConnectStreamTls13, ExternalTokenAfterHrr) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+ server_->ConfigNamedGroups(groups);
+
+ StartConnect();
+ ASSERT_TRUE(client_->MaybeSetResumptionToken());
+
+ client_->Handshake(); // Send ClientHello.
+ server_->Handshake(); // Process ClientHello, send HelloRetryRequest.
+
+ auto& token = client_->GetResumptionToken();
+ SECStatus rv =
+ SSL_SetResumptionToken(client_->ssl_fd(), token.data(), token.size());
+ ASSERT_EQ(SECFailure, rv);
+ ASSERT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_F(TlsConnectStreamTls13, ExternalTokenWithPeerId) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ EXPECT_EQ(SECSuccess, SSL_SetSockPeerID(client_->ssl_fd(), "testPeerId"));
+ std::vector<uint8_t> ticket_state;
+ auto cb = [](PRFileDesc* fd, const PRUint8* ticket, unsigned int ticket_len,
+ void* arg) -> SECStatus {
+ EXPECT_NE(0U, ticket_len);
+ EXPECT_NE(nullptr, ticket);
+ auto ticket_state_ = reinterpret_cast<std::vector<uint8_t>*>(arg);
+ ticket_state_->assign(ticket, ticket + ticket_len);
+ return SECSuccess;
+ };
+ EXPECT_EQ(SECSuccess, SSL_SetResumptionTokenCallback(client_->ssl_fd(), cb,
+ &ticket_state));
+
+ Connect();
+ SendReceive();
+ EXPECT_NE(0U, ticket_state.size());
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ EXPECT_EQ(SECSuccess, SSL_SetSockPeerID(client_->ssl_fd(), "testPeerId"));
+ client_->SetResumptionToken(ticket_state);
+ ASSERT_TRUE(client_->MaybeSetResumptionToken());
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
new file mode 100644
index 0000000000..606e731033
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
@@ -0,0 +1,246 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "sslerr.h"
+
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+/*
+ * The tests in this file test that the TLS state machine is robust against
+ * attacks that alter the order of handshake messages.
+ *
+ * See <https://www.smacktls.com/smack.pdf> for a description of the problems
+ * that this sort of attack can enable.
+ */
+namespace nss_test {
+
+class TlsHandshakeSkipFilter : public TlsRecordFilter {
+ public:
+ // A TLS record filter that skips handshake messages of the identified type.
+ TlsHandshakeSkipFilter(const std::shared_ptr<TlsAgent>& a,
+ uint8_t handshake_type)
+ : TlsRecordFilter(a), handshake_type_(handshake_type), skipped_(false) {}
+
+ protected:
+ // Takes a record; if it is a handshake record, it removes the first handshake
+ // message that is of handshake_type_ type.
+ virtual PacketFilter::Action FilterRecord(
+ const TlsRecordHeader& record_header, const DataBuffer& input,
+ DataBuffer* output) {
+ if (record_header.content_type() != ssl_ct_handshake) {
+ return KEEP;
+ }
+
+ size_t output_offset = 0U;
+ output->Allocate(input.len());
+
+ TlsParser parser(input);
+ while (parser.remaining()) {
+ size_t start = parser.consumed();
+ TlsHandshakeFilter::HandshakeHeader header;
+ DataBuffer ignored;
+ bool complete = false;
+ if (!header.Parse(&parser, record_header, DataBuffer(), &ignored,
+ &complete)) {
+ ADD_FAILURE() << "Error parsing handshake header";
+ return KEEP;
+ }
+ if (!complete) {
+ ADD_FAILURE() << "Don't want to deal with fragmented input";
+ return KEEP;
+ }
+
+ if (skipped_ || header.handshake_type() != handshake_type_) {
+ size_t entire_length = parser.consumed() - start;
+ output->Write(output_offset, input.data() + start, entire_length);
+ // DTLS sequence numbers need to be rewritten
+ if (skipped_ && header.is_dtls()) {
+ output->data()[start + 5] -= 1;
+ }
+ output_offset += entire_length;
+ } else {
+ std::cerr << "Dropping handshake: "
+ << static_cast<unsigned>(handshake_type_) << std::endl;
+ // We only need to report that the output contains changed data if we
+ // drop a handshake message. But once we've skipped one message, we
+ // have to modify all subsequent handshake messages so that they include
+ // the correct DTLS sequence numbers.
+ skipped_ = true;
+ }
+ }
+ output->Truncate(output_offset);
+ return skipped_ ? CHANGE : KEEP;
+ }
+
+ private:
+ // The type of handshake message to drop.
+ uint8_t handshake_type_;
+ // Whether this filter has ever skipped a handshake message. Track this so
+ // that sequence numbers on DTLS handshake messages can be rewritten in
+ // subsequent calls.
+ bool skipped_;
+};
+
+class TlsSkipTest : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ protected:
+ TlsSkipTest()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
+ EnsureTlsSetup();
+ }
+
+ void ServerSkipTest(std::shared_ptr<PacketFilter> filter,
+ uint8_t alert = kTlsAlertUnexpectedMessage) {
+ server_->SetFilter(filter);
+ ConnectExpectAlert(client_, alert);
+ }
+};
+
+class Tls13SkipTest : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
+ protected:
+ Tls13SkipTest()
+ : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
+
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
+ EnsureTlsSetup();
+ }
+
+ void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ filter->EnableDecryption();
+ server_->SetFilter(filter);
+ ExpectAlert(client_, kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ client_->CheckErrorCode(error);
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ }
+
+ void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ filter->EnableDecryption();
+ client_->SetFilter(filter);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFailOneSide(TlsAgent::SERVER);
+
+ server_->CheckErrorCode(error);
+ ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+
+ client_->Handshake(); // Make sure to consume the alert the server sends.
+ }
+};
+
+TEST_P(TlsSkipTest, SkipCertificateRsa) {
+ EnableOnlyStaticRsaCiphers();
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+TEST_P(TlsSkipTest, SkipCertificateDhe) {
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
+}
+
+TEST_P(TlsSkipTest, SkipCertificateEcdhe) {
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
+}
+
+TEST_P(TlsSkipTest, SkipCertificateEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
+}
+
+TEST_P(TlsSkipTest, SkipServerKeyExchange) {
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
+ auto chain = std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit{std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate),
+ std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange)});
+ ServerSkipTest(chain);
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ auto chain = std::make_shared<ChainedPacketFilter>();
+ chain->Add(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
+ chain->Add(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(chain);
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+TEST_P(Tls13SkipTest, SkipEncryptedExtensions) {
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeEncryptedExtensions),
+ SSL_ERROR_RX_UNEXPECTED_CERTIFICATE);
+}
+
+TEST_P(Tls13SkipTest, SkipServerCertificate) {
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+}
+
+TEST_P(Tls13SkipTest, SkipServerCertificateVerify) {
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
+}
+
+TEST_P(Tls13SkipTest, SkipClientCertificate) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
+ ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ client_, kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+}
+
+TEST_P(Tls13SkipTest, SkipClientCertificateVerify) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
+ ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ client_, kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ SkipTls10, TlsSkipTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10));
+INSTANTIATE_TEST_SUITE_P(SkipVariants, TlsSkipTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV11V12));
+INSTANTIATE_TEST_SUITE_P(Skip13Variants, Tls13SkipTest,
+ TlsConnectTestBase::kTlsVariantsAll);
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
new file mode 100644
index 0000000000..abddaa5b61
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
@@ -0,0 +1,139 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+#include "rsa8193.h"
+
+namespace nss_test {
+
+const uint8_t kBogusClientKeyExchange[] = {
+ 0x01, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+};
+
+TEST_P(TlsConnectGenericPre13, ConnectStaticRSA) {
+ EnableOnlyStaticRsaCiphers();
+ Connect();
+ CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none);
+}
+
+// Test that a totally bogus EPMS is handled correctly.
+// This test is stream so we can catch the bad_record_mac alert.
+TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) {
+ EnableOnlyStaticRsaCiphers();
+ MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>(
+ client_, kTlsHandshakeClientKeyExchange,
+ DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
+ ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
+}
+
+// Test that a PMS with a bogus version number is handled correctly.
+// This test is stream so we can catch the bad_record_mac alert.
+TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) {
+ EnableOnlyStaticRsaCiphers();
+ MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_);
+ ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
+}
+
+// Test that a PMS with a bogus version number is ignored when
+// rollback detection is disabled. This is a positive control for
+// ConnectStaticRSABogusPMSVersionDetect.
+TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) {
+ EnableOnlyStaticRsaCiphers();
+ MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_);
+ server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE);
+ Connect();
+}
+
+// This test is stream so we can catch the bad_record_mac alert.
+TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusCKE) {
+ EnableOnlyStaticRsaCiphers();
+ EnableExtendedMasterSecret();
+ MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>(
+ client_, kTlsHandshakeClientKeyExchange,
+ DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
+ ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
+}
+
+// This test is stream so we can catch the bad_record_mac alert.
+TEST_P(TlsConnectStreamPre13,
+ ConnectExtendedMasterSecretStaticRSABogusPMSVersionDetect) {
+ EnableOnlyStaticRsaCiphers();
+ EnableExtendedMasterSecret();
+ MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_);
+ ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
+}
+
+TEST_P(TlsConnectStreamPre13,
+ ConnectExtendedMasterSecretStaticRSABogusPMSVersionIgnore) {
+ EnableOnlyStaticRsaCiphers();
+ EnableExtendedMasterSecret();
+ MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_);
+ server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE);
+ Connect();
+}
+
+// Replace the server certificate with one that uses 8193-bit RSA.
+class TooLargeRSACertFilter : public TlsHandshakeFilter {
+ public:
+ TooLargeRSACertFilter(const std::shared_ptr<TlsAgent> &server)
+ : TlsHandshakeFilter(server, {kTlsHandshakeCertificate}) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
+ const DataBuffer &input,
+ DataBuffer *output) {
+ const uint32_t cert_len = sizeof(rsa8193);
+ const uint32_t outer_len = cert_len + 3;
+ size_t offset = 0;
+ offset = output->Write(offset, outer_len, 3);
+ offset = output->Write(offset, cert_len, 3);
+ offset = output->Write(offset, rsa8193, cert_len);
+
+ return CHANGE;
+ }
+};
+
+TEST_P(TlsConnectGenericPre13, TooLargeRSAKeyInCert) {
+ EnableOnlyStaticRsaCiphers();
+ MakeTlsFilter<TooLargeRSACertFilter>(server_);
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_CLIENT_KEY_EXCHANGE_FAILURE);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectGeneric, ServerAuthBiggestRsa) {
+ Reset(TlsAgent::kRsa8192);
+ Connect();
+ CheckKeys();
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
new file mode 100644
index 0000000000..2421470a4f
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
@@ -0,0 +1,573 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <memory>
+#include <vector>
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+class Tls13CompatTest : public TlsConnectStreamTls13 {
+ protected:
+ void EnableCompatMode() {
+ client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE);
+ }
+
+ void InstallFilters() {
+ EnsureTlsSetup();
+ client_recorders_.Install(client_);
+ server_recorders_.Install(server_);
+ }
+
+ void CheckRecordVersions() {
+ ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_0,
+ client_recorders_.records_->record(0).header.version());
+ CheckRecordsAreTls12("client", client_recorders_.records_, 1);
+ CheckRecordsAreTls12("server", server_recorders_.records_, 0);
+ }
+
+ void CheckHelloVersions() {
+ uint32_t ver;
+ ASSERT_TRUE(server_recorders_.hello_->buffer().Read(0, 2, &ver));
+ ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, static_cast<uint16_t>(ver));
+ ASSERT_TRUE(client_recorders_.hello_->buffer().Read(0, 2, &ver));
+ ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, static_cast<uint16_t>(ver));
+ }
+
+ void CheckForCCS(bool expected_client, bool expected_server) {
+ client_recorders_.CheckForCCS(expected_client);
+ server_recorders_.CheckForCCS(expected_server);
+ }
+
+ void CheckForRegularHandshake() {
+ CheckRecordVersions();
+ CheckHelloVersions();
+ EXPECT_EQ(0U, client_recorders_.session_id_length());
+ EXPECT_EQ(0U, server_recorders_.session_id_length());
+ CheckForCCS(false, false);
+ }
+
+ void CheckForCompatHandshake() {
+ CheckRecordVersions();
+ CheckHelloVersions();
+ EXPECT_EQ(32U, client_recorders_.session_id_length());
+ EXPECT_EQ(32U, server_recorders_.session_id_length());
+ CheckForCCS(true, true);
+ }
+
+ private:
+ struct Recorders {
+ Recorders() : records_(nullptr), hello_(nullptr) {}
+
+ uint8_t session_id_length() const {
+ // session_id is always after version (2) and random (32).
+ uint32_t len = 0;
+ EXPECT_TRUE(hello_->buffer().Read(2 + 32, 1, &len));
+ return static_cast<uint8_t>(len);
+ }
+
+ void CheckForCCS(bool expected) const {
+ EXPECT_LT(0U, records_->count());
+ for (size_t i = 0; i < records_->count(); ++i) {
+ // Only the second record can be a CCS.
+ bool expected_match = expected && (i == 1);
+ EXPECT_EQ(expected_match,
+ ssl_ct_change_cipher_spec ==
+ records_->record(i).header.content_type());
+ }
+ }
+
+ void Install(std::shared_ptr<TlsAgent>& agent) {
+ if (records_ && records_->agent() == agent) {
+ // Avoid replacing the filters if they are already installed on this
+ // agent. This ensures that InstallFilters() can be used after
+ // MakeNewServer() without losing state on the client filters.
+ return;
+ }
+ records_.reset(new TlsRecordRecorder(agent));
+ hello_.reset(new TlsHandshakeRecorder(
+ agent, std::set<uint8_t>(
+ {kTlsHandshakeClientHello, kTlsHandshakeServerHello})));
+ agent->SetFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit({records_, hello_})));
+ }
+
+ std::shared_ptr<TlsRecordRecorder> records_;
+ std::shared_ptr<TlsHandshakeRecorder> hello_;
+ };
+
+ void CheckRecordsAreTls12(const std::string& agent,
+ const std::shared_ptr<TlsRecordRecorder>& records,
+ size_t start) {
+ EXPECT_LE(start, records->count());
+ for (size_t i = start; i < records->count(); ++i) {
+ EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_2,
+ records->record(i).header.version())
+ << agent << ": record " << i << " has wrong version";
+ }
+ }
+
+ Recorders client_recorders_;
+ Recorders server_recorders_;
+};
+
+TEST_F(Tls13CompatTest, Disabled) {
+ InstallFilters();
+ Connect();
+ CheckForRegularHandshake();
+}
+
+TEST_F(Tls13CompatTest, Enabled) {
+ EnableCompatMode();
+ InstallFilters();
+ Connect();
+ CheckForCompatHandshake();
+}
+
+TEST_F(Tls13CompatTest, EnabledZeroRtt) {
+ SetupForZeroRtt();
+ EnableCompatMode();
+ InstallFilters();
+
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ CheckForCCS(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+
+ CheckForCompatHandshake();
+}
+
+TEST_F(Tls13CompatTest, EnabledHrr) {
+ EnableCompatMode();
+ InstallFilters();
+
+ // Force a HelloRetryRequest. The server sends CCS immediately.
+ server_->ConfigNamedGroups({ssl_grp_ec_secp384r1});
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ CheckForCCS(false, true);
+
+ Handshake();
+ CheckConnected();
+ CheckForCompatHandshake();
+}
+
+TEST_F(Tls13CompatTest, EnabledStatelessHrr) {
+ EnableCompatMode();
+ InstallFilters();
+
+ // Force a HelloRetryRequest
+ server_->ConfigNamedGroups({ssl_grp_ec_secp384r1});
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+
+ // The server should send CCS before HRR.
+ CheckForCCS(false, true);
+
+ // A new server should complete the handshake, and not send CCS.
+ MakeNewServer();
+ InstallFilters();
+ server_->ConfigNamedGroups({ssl_grp_ec_secp384r1});
+
+ Handshake();
+ CheckConnected();
+ CheckRecordVersions();
+ CheckHelloVersions();
+ CheckForCCS(true, false);
+}
+
+TEST_F(Tls13CompatTest, EnabledHrrZeroRtt) {
+ SetupForZeroRtt();
+ EnableCompatMode();
+ InstallFilters();
+ server_->ConfigNamedGroups({ssl_grp_ec_secp384r1});
+
+ // With 0-RTT, the client sends CCS immediately. With HRR, the server sends
+ // CCS immediately too.
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false);
+ CheckForCCS(true, true);
+
+ Handshake();
+ ExpectEarlyDataAccepted(false);
+ CheckConnected();
+ CheckForCompatHandshake();
+}
+
+TEST_F(Tls13CompatTest, EnabledAcceptedEch) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ EnableCompatMode();
+ InstallFilters();
+ Connect();
+ CheckForCompatHandshake();
+}
+
+TEST_F(Tls13CompatTest, EnabledRejectedEch) {
+ EnsureTlsSetup();
+ // Configure ECH on the client only, and expect CCS.
+ SetupEch(client_, server_, HpkeDhKemX25519Sha256, false, true, false);
+ EnableCompatMode();
+ InstallFilters();
+ ExpectAlert(client_, kTlsAlertEchRequired);
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITHOUT_ECH);
+ CheckForCompatHandshake();
+ // Reset expectations for the TlsAgent dtor.
+ server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning);
+}
+
+class TlsSessionIDEchoFilter : public TlsHandshakeFilter {
+ public:
+ TlsSessionIDEchoFilter(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(
+ a, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ TlsParser parser(input);
+
+ // Skip version + random.
+ EXPECT_TRUE(parser.Skip(2 + 32));
+
+ // Capture CH.legacy_session_id.
+ if (header.handshake_type() == kTlsHandshakeClientHello) {
+ EXPECT_TRUE(parser.ReadVariable(&sid_, 1));
+ return KEEP;
+ }
+
+ // Check that server sends one too.
+ uint32_t sid_len = 0;
+ EXPECT_TRUE(parser.Read(&sid_len, 1));
+ EXPECT_EQ(sid_len, sid_.len());
+
+ // Echo the one we captured.
+ *output = input;
+ output->Write(parser.consumed(), sid_.data(), sid_.len());
+
+ return CHANGE;
+ }
+
+ private:
+ DataBuffer sid_;
+};
+
+TEST_F(TlsConnectTest, EchoTLS13CompatibilitySessionID) {
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
+
+ client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE);
+
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+
+ server_->SetFilter(MakeTlsFilter<TlsSessionIDEchoFilter>(client_));
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+class TlsSessionIDInjectFilter : public TlsHandshakeFilter {
+ public:
+ TlsSessionIDInjectFilter(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ TlsParser parser(input);
+
+ // Skip version + random.
+ EXPECT_TRUE(parser.Skip(2 + 32));
+
+ *output = input;
+
+ // Inject a Session ID.
+ const uint8_t fake_sid[SSL3_SESSIONID_BYTES] = {0xff};
+ output->Write(parser.consumed(), sizeof(fake_sid), 1);
+ output->Splice(fake_sid, sizeof(fake_sid), parser.consumed() + 1, 0);
+
+ return CHANGE;
+ }
+};
+
+TEST_F(TlsConnectTest, TLS13NonCompatModeSessionID) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ MakeTlsFilter<TlsSessionIDInjectFilter>(server_);
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE);
+}
+
+static const uint8_t kCannedCcs[] = {
+ ssl_ct_change_cipher_spec,
+ SSL_LIBRARY_VERSION_TLS_1_2 >> 8,
+ SSL_LIBRARY_VERSION_TLS_1_2 & 0xff,
+ 0,
+ 1, // length
+ 1 // change_cipher_spec_choice
+};
+
+// A ChangeCipherSpec is ignored by a server because we have to tolerate it for
+// compatibility mode. That doesn't mean that we have to tolerate it
+// unconditionally. If we negotiate 1.3, we expect to see a cookie extension.
+TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello13) {
+ EnsureTlsSetup();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ // Client sends CCS before starting the handshake.
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+// A ChangeCipherSpec is ignored by a server because we have to tolerate it for
+// compatibility mode. That doesn't mean that we have to tolerate it
+// unconditionally. If we negotiate 1.3, we expect to see a cookie extension.
+TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHelloTwice) {
+ EnsureTlsSetup();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ // Client sends CCS before starting the handshake.
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+// The server accepts a ChangeCipherSpec even if the client advertises
+// an empty session ID.
+TEST_F(TlsConnectStreamTls13, ChangeCipherSpecAfterClientHelloEmptySid) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); // Send CCS
+
+ Handshake();
+ CheckConnected();
+}
+
+// The server rejects multiple ChangeCipherSpec even if the client
+// indicates compatibility mode with non-empty session ID.
+TEST_F(Tls13CompatTest, ChangeCipherSpecAfterClientHelloTwice) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ EnableCompatMode();
+
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+ // Send CCS twice in a row
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ server_->Handshake(); // Consume ClientHello and CCS.
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CHANGE_CIPHER);
+}
+
+// The client accepts a ChangeCipherSpec even if it advertises an empty
+// session ID.
+TEST_F(TlsConnectStreamTls13, ChangeCipherSpecAfterServerHelloEmptySid) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ // To replace Finished with a CCS below
+ auto filter = MakeTlsFilter<TlsHandshakeDropper>(server_);
+ filter->SetHandshakeTypes({kTlsHandshakeFinished});
+ filter->EnableDecryption();
+
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Consume ClientHello, and
+ // send ServerHello..CertificateVerify
+ // Send CCS
+ server_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+
+ // No alert is sent from the client. As Finished is dropped, we
+ // can't use Handshake() and CheckConnected().
+ client_->Handshake();
+}
+
+// The client rejects multiple ChangeCipherSpec in a row even if the
+// client indicates compatibility mode with non-empty session ID.
+TEST_F(Tls13CompatTest, ChangeCipherSpecAfterServerHelloTwice) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ EnableCompatMode();
+
+ // To replace Finished with a CCS below
+ auto filter = MakeTlsFilter<TlsHandshakeDropper>(server_);
+ filter->SetHandshakeTypes({kTlsHandshakeFinished});
+ filter->EnableDecryption();
+
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Consume ClientHello, and
+ // send ServerHello..CertificateVerify
+ // the ServerHello is followed by CCS
+ // Send another CCS
+ server_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ client_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ client_->Handshake(); // Consume ClientHello and CCS
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CHANGE_CIPHER);
+}
+
+// If we negotiate 1.2, we abort.
+TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello12) {
+ EnsureTlsSetup();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ // Client sends CCS before starting the handshake.
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13, ChangeCipherSpecAfterFinished13) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ SendReceive(10);
+ // Client sends CCS after the handshake.
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ server_->ExpectReadWriteError();
+ server_->ReadBytes();
+ EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code());
+}
+
+TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) {
+ EnsureTlsSetup();
+ client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE);
+ auto client_records = MakeTlsFilter<TlsRecordRecorder>(client_);
+ auto server_records = MakeTlsFilter<TlsRecordRecorder>(server_);
+ Connect();
+
+ ASSERT_EQ(2U, client_records->count()); // CH, Fin
+ EXPECT_EQ(ssl_ct_handshake, client_records->record(0).header.content_type());
+ EXPECT_EQ(kCtDtlsCiphertext,
+ (client_records->record(1).header.content_type() &
+ kCtDtlsCiphertextMask));
+
+ ASSERT_EQ(6U, server_records->count()); // SH, EE, CT, CV, Fin, Ack
+ EXPECT_EQ(ssl_ct_handshake, server_records->record(0).header.content_type());
+ for (size_t i = 1; i < server_records->count(); ++i) {
+ EXPECT_EQ(kCtDtlsCiphertext,
+ (server_records->record(i).header.content_type() &
+ kCtDtlsCiphertextMask));
+ }
+}
+
+class AddSessionIdFilter : public TlsHandshakeFilter {
+ public:
+ AddSessionIdFilter(const std::shared_ptr<TlsAgent>& client)
+ : TlsHandshakeFilter(client, {ssl_hs_client_hello}) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ uint32_t session_id_len = 0;
+ EXPECT_TRUE(input.Read(2 + 32, 1, &session_id_len));
+ EXPECT_EQ(0U, session_id_len);
+ uint8_t session_id[33] = {32}; // 32 for length, the rest zero.
+ *output = input;
+ output->Splice(session_id, sizeof(session_id), 34, 1);
+ return CHANGE;
+ }
+};
+
+// Adding a session ID to a DTLS ClientHello should not trigger compatibility
+// mode. It should be ignored instead.
+TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) {
+ EnsureTlsSetup();
+ auto client_records = std::make_shared<TlsRecordRecorder>(client_);
+ client_->SetFilter(
+ std::make_shared<ChainedPacketFilter>(ChainedPacketFilterInit(
+ {client_records, std::make_shared<AddSessionIdFilter>(client_)})));
+ auto server_hello =
+ std::make_shared<TlsHandshakeRecorder>(server_, kTlsHandshakeServerHello);
+ auto server_records = std::make_shared<TlsRecordRecorder>(server_);
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit({server_records, server_hello})));
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ // The client will consume the ServerHello, but discard everything else
+ // because it doesn't decrypt. And don't wait around for the client to ACK.
+ client_->Handshake();
+
+ ASSERT_EQ(1U, client_records->count());
+ EXPECT_EQ(ssl_ct_handshake, client_records->record(0).header.content_type());
+
+ ASSERT_EQ(5U, server_records->count()); // SH, EE, CT, CV, Fin
+ EXPECT_EQ(ssl_ct_handshake, server_records->record(0).header.content_type());
+ for (size_t i = 1; i < server_records->count(); ++i) {
+ EXPECT_EQ(kCtDtlsCiphertext,
+ (server_records->record(i).header.content_type() &
+ kCtDtlsCiphertextMask));
+ }
+
+ uint32_t session_id_len = 0;
+ EXPECT_TRUE(server_hello->buffer().Read(2 + 32, 1, &session_id_len));
+ EXPECT_EQ(0U, session_id_len);
+}
+
+TEST_F(Tls13CompatTest, ConnectWith12ThenAttemptToResume13CompatMode) {
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_2);
+ Connect();
+
+ Reset();
+ ExpectResumption(RESUME_NONE);
+ version_ = SSL_LIBRARY_VERSION_TLS_1_3;
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ EnableCompatMode();
+ Connect();
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
new file mode 100644
index 0000000000..9aa6542d66
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
@@ -0,0 +1,414 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "pk11pub.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+
+namespace nss_test {
+
+// Replaces the client hello with an SSLv2 version once.
+class SSLv2ClientHelloFilter : public PacketFilter {
+ public:
+ SSLv2ClientHelloFilter(const std::shared_ptr<TlsAgent>& client,
+ uint16_t version)
+ : replaced_(false),
+ client_(client),
+ version_(version),
+ pad_len_(0),
+ reported_pad_len_(0),
+ client_random_len_(16),
+ ciphers_(0),
+ send_escape_(false) {}
+
+ void SetVersion(uint16_t version) { version_ = version; }
+
+ void SetCipherSuites(const std::vector<uint16_t>& ciphers) {
+ ciphers_ = ciphers;
+ }
+
+ // Set a padding length and announce it correctly.
+ void SetPadding(uint8_t pad_len) { SetPadding(pad_len, pad_len); }
+
+ // Set a padding length and allow to lie about its length.
+ void SetPadding(uint8_t pad_len, uint8_t reported_pad_len) {
+ pad_len_ = pad_len;
+ reported_pad_len_ = reported_pad_len;
+ }
+
+ void SetClientRandomLength(uint16_t client_random_len) {
+ client_random_len_ = client_random_len;
+ }
+
+ void SetSendEscape(bool send_escape) { send_escape_ = send_escape; }
+
+ protected:
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) {
+ if (replaced_) {
+ return KEEP;
+ }
+
+ // Replace only the very first packet.
+ replaced_ = true;
+
+ // The SSLv2 client hello size.
+ size_t packet_len = SSL_HL_CLIENT_HELLO_HBYTES + (ciphers_.size() * 3) +
+ client_random_len_ + pad_len_;
+
+ size_t idx = 0;
+ *output = input;
+ output->Allocate(packet_len);
+ output->Truncate(packet_len);
+
+ // Write record length.
+ if (pad_len_ > 0) {
+ size_t masked_len = 0x3fff & packet_len;
+ if (send_escape_) {
+ masked_len |= 0x4000;
+ }
+
+ idx = output->Write(idx, masked_len, 2);
+ idx = output->Write(idx, reported_pad_len_, 1);
+ } else {
+ PR_ASSERT(!send_escape_);
+ idx = output->Write(idx, 0x8000 | packet_len, 2);
+ }
+
+ // Remember header length.
+ size_t hdr_len = idx;
+
+ // Write client hello.
+ idx = output->Write(idx, SSL_MT_CLIENT_HELLO, 1);
+ idx = output->Write(idx, version_, 2);
+
+ // Cipher list length.
+ idx = output->Write(idx, (ciphers_.size() * 3), 2);
+
+ // Session ID length.
+ idx = output->Write(idx, static_cast<uint32_t>(0), 2);
+
+ // ClientRandom length.
+ idx = output->Write(idx, client_random_len_, 2);
+
+ // Cipher suites.
+ for (auto cipher : ciphers_) {
+ idx = output->Write(idx, static_cast<uint32_t>(cipher), 3);
+ }
+
+ // Challenge.
+ std::vector<uint8_t> challenge(client_random_len_);
+ PK11_GenerateRandom(challenge.data(), challenge.size());
+ idx = output->Write(idx, challenge.data(), challenge.size());
+
+ // Add padding if any.
+ if (pad_len_ > 0) {
+ std::vector<uint8_t> pad(pad_len_);
+ idx = output->Write(idx, pad.data(), pad.size());
+ }
+
+ // Update the client random so that the handshake succeeds.
+ SECStatus rv = SSLInt_UpdateSSLv2ClientRandom(
+ client_.lock()->ssl_fd(), challenge.data(), challenge.size(),
+ output->data() + hdr_len, output->len() - hdr_len);
+ EXPECT_EQ(SECSuccess, rv);
+
+ return CHANGE;
+ }
+
+ private:
+ bool replaced_;
+ std::weak_ptr<TlsAgent> client_;
+ uint16_t version_;
+ uint8_t pad_len_;
+ uint8_t reported_pad_len_;
+ uint16_t client_random_len_;
+ std::vector<uint16_t> ciphers_;
+ bool send_escape_;
+};
+
+class SSLv2ClientHelloTestF : public TlsConnectTestBase {
+ public:
+ SSLv2ClientHelloTestF()
+ : TlsConnectTestBase(ssl_variant_stream, 0), filter_(nullptr) {}
+
+ SSLv2ClientHelloTestF(SSLProtocolVariant variant, uint16_t version)
+ : TlsConnectTestBase(variant, version), filter_(nullptr) {}
+
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
+ filter_ = MakeTlsFilter<SSLv2ClientHelloFilter>(client_, version_);
+ server_->SetOption(SSL_ENABLE_V2_COMPATIBLE_HELLO, PR_TRUE);
+ }
+
+ void SetExpectedVersion(uint16_t version) {
+ TlsConnectTestBase::SetExpectedVersion(version);
+ filter_->SetVersion(version);
+ }
+
+ void SetAvailableCipherSuite(uint16_t cipher) {
+ filter_->SetCipherSuites(std::vector<uint16_t>(1, cipher));
+ }
+
+ void SetAvailableCipherSuites(const std::vector<uint16_t>& ciphers) {
+ filter_->SetCipherSuites(ciphers);
+ }
+
+ void SetPadding(uint8_t pad_len) { filter_->SetPadding(pad_len); }
+
+ void SetPadding(uint8_t pad_len, uint8_t reported_pad_len) {
+ filter_->SetPadding(pad_len, reported_pad_len);
+ }
+
+ void SetClientRandomLength(uint16_t client_random_len) {
+ filter_->SetClientRandomLength(client_random_len);
+ }
+
+ void SetSendEscape(bool send_escape) { filter_->SetSendEscape(send_escape); }
+
+ private:
+ std::shared_ptr<SSLv2ClientHelloFilter> filter_;
+};
+
+// Parameterized version of SSLv2ClientHelloTestF we can
+// use with TEST_P to test multiple TLS versions easily.
+class SSLv2ClientHelloTest : public SSLv2ClientHelloTestF,
+ public ::testing::WithParamInterface<uint16_t> {
+ public:
+ SSLv2ClientHelloTest()
+ : SSLv2ClientHelloTestF(ssl_variant_stream, GetParam()) {}
+};
+
+// Test negotiating TLS 1.0 - 1.2.
+TEST_P(SSLv2ClientHelloTest, Connect) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+ Connect();
+}
+
+TEST_P(SSLv2ClientHelloTest, ConnectDisabled) {
+ server_->SetOption(SSL_ENABLE_V2_COMPATIBLE_HELLO, PR_FALSE);
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+
+ StartConnect();
+ client_->Handshake(); // Send the modified ClientHello.
+ server_->Handshake(); // Read some.
+ // The problem here is that the v2 ClientHello puts the version where the v3
+ // ClientHello puts a version number. So the version number (0x0301+) appears
+ // to be a length and server blocks waiting for that much data.
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // This is usually what happens with v2-compatible: the server hangs.
+ // But to be certain, feed in more data to see if an error comes out.
+ uint8_t zeros[SSL_LIBRARY_VERSION_TLS_1_2] = {0};
+ client_->SendDirect(DataBuffer(zeros, sizeof(zeros)));
+ ExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ server_->Handshake();
+ client_->Handshake();
+}
+
+// Sending a v2 ClientHello after a no-op v3 record must fail.
+TEST_P(SSLv2ClientHelloTest, ConnectAfterEmptyV3Record) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x16, 1); // handshake
+ idx = buffer.Write(idx, 0x0301, 2); // record_version
+ (void)buffer.Write(idx, 0U, 2); // length=0
+
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+ EnsureTlsSetup();
+ client_->SendDirect(buffer);
+
+ // Need padding so the connection doesn't just time out. With a v2
+ // ClientHello parsed as a v3 record we will use the record version
+ // as the record length.
+ SetPadding(255);
+
+ ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, server_->error_code());
+}
+
+// Test negotiating TLS 1.3.
+TEST_F(SSLv2ClientHelloTestF, Connect13) {
+ EnsureTlsSetup();
+ SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ std::vector<uint16_t> cipher_suites = {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256};
+ SetAvailableCipherSuites(cipher_suites);
+
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
+}
+
+// Test negotiating an EC suite.
+TEST_P(SSLv2ClientHelloTest, NegotiateECSuite) {
+ SetAvailableCipherSuite(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ Connect();
+}
+
+// Test negotiating TLS 1.0 - 1.2 with a padded client hello.
+TEST_P(SSLv2ClientHelloTest, AddPadding) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+ SetPadding(255);
+ Connect();
+}
+
+// Test that sending a security escape fails the handshake.
+TEST_P(SSLv2ClientHelloTest, SendSecurityEscape) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+
+ // Send a security escape.
+ SetSendEscape(true);
+
+ // Set a big padding so that the server fails instead of timing out.
+ SetPadding(255);
+
+ ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
+}
+
+// Invalid SSLv2 client hello padding must fail the handshake.
+TEST_P(SSLv2ClientHelloTest, AddErroneousPadding) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+
+ // Append 5 bytes of padding but say it's only 4.
+ SetPadding(5, 4);
+
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
+}
+
+// Invalid SSLv2 client hello padding must fail the handshake.
+TEST_P(SSLv2ClientHelloTest, AddErroneousPadding2) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+
+ // Append 5 bytes of padding but say it's 6.
+ SetPadding(5, 6);
+
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
+}
+
+// Wrong amount of bytes for the ClientRandom must fail the handshake.
+TEST_P(SSLv2ClientHelloTest, SmallClientRandom) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+
+ // Send a ClientRandom that's too small.
+ SetClientRandomLength(15);
+
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
+}
+
+// Test sending the maximum accepted number of ClientRandom bytes.
+TEST_P(SSLv2ClientHelloTest, MaxClientRandom) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+ SetClientRandomLength(32);
+ Connect();
+}
+
+// Wrong amount of bytes for the ClientRandom must fail the handshake.
+TEST_P(SSLv2ClientHelloTest, BigClientRandom) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+
+ // Send a ClientRandom that's too big.
+ SetClientRandomLength(33);
+
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
+}
+
+// Connection must fail if we require safe renegotiation but the client doesn't
+// include TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites.
+TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiation) {
+ server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE);
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ EXPECT_EQ(SSL_ERROR_UNSAFE_NEGOTIATION, server_->error_code());
+}
+
+// Connection must succeed when requiring safe renegotiation and the client
+// includes TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites.
+TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiationWithSCSV) {
+ server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE);
+ std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS_EMPTY_RENEGOTIATION_INFO_SCSV};
+ SetAvailableCipherSuites(cipher_suites);
+ Connect();
+}
+
+TEST_P(SSLv2ClientHelloTest, CheckServerRandom) {
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+
+ static const size_t random_len = 32;
+ uint8_t srandom1[random_len];
+ uint8_t z[random_len] = {0};
+
+ auto sh = MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_server_hello);
+ Connect();
+ ASSERT_TRUE(sh->buffer().len() > (random_len + 2));
+ memcpy(srandom1, sh->buffer().data() + 2, random_len);
+ EXPECT_NE(0, memcmp(srandom1, z, random_len));
+
+ Reset();
+ sh = MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_server_hello);
+ Connect();
+ ASSERT_TRUE(sh->buffer().len() > (random_len + 2));
+ const uint8_t* srandom2 = sh->buffer().data() + 2;
+
+ EXPECT_NE(0, memcmp(srandom2, z, random_len));
+ EXPECT_NE(0, memcmp(srandom1, srandom2, random_len));
+}
+
+// Connect to the server with TLS 1.1, signalling that this is a fallback from
+// a higher version. As the server doesn't support anything higher than TLS 1.1
+// it must accept the connection.
+TEST_F(SSLv2ClientHelloTestF, FallbackSCSV) {
+ EnsureTlsSetup();
+ SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_1);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_1);
+
+ std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS_FALLBACK_SCSV};
+ SetAvailableCipherSuites(cipher_suites);
+ Connect();
+}
+
+// Connect to the server with TLS 1.1, signalling that this is a fallback from
+// a higher version. As the server supports TLS 1.2 though it must reject the
+// connection due to a possible downgrade attack.
+TEST_F(SSLv2ClientHelloTestF, InappropriateFallbackSCSV) {
+ SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_1);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_1);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+
+ std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS_FALLBACK_SCSV};
+ SetAvailableCipherSuites(cipher_suites);
+
+ ConnectExpectAlert(server_, kTlsAlertInappropriateFallback);
+ EXPECT_EQ(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT, server_->error_code());
+}
+
+INSTANTIATE_TEST_SUITE_P(VersionsStream10Pre13, SSLv2ClientHelloTest,
+ TlsConnectTestBase::kTlsV10);
+INSTANTIATE_TEST_SUITE_P(VersionsStreamPre13, SSLv2ClientHelloTest,
+ TlsConnectTestBase::kTlsV11V12);
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc
new file mode 100644
index 0000000000..079d865a8f
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc
@@ -0,0 +1,456 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "ssl3prot.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectStream, ServerNegotiateTls10) {
+ uint16_t minver, maxver;
+ client_->GetVersionRange(&minver, &maxver);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, maxver);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ Connect();
+}
+
+TEST_P(TlsConnectGeneric, ServerNegotiateTls11) {
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_1) GTEST_SKIP();
+
+ uint16_t minver, maxver;
+ client_->GetVersionRange(&minver, &maxver);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, maxver);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_1);
+ Connect();
+}
+
+TEST_P(TlsConnectGeneric, ServerNegotiateTls12) {
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_2) GTEST_SKIP();
+
+ uint16_t minver, maxver;
+ client_->GetVersionRange(&minver, &maxver);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, maxver);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ Connect();
+}
+
+// Test the ServerRandom version hack from
+// [draft-ietf-tls-tls13-11 Section 6.3.1.1].
+// The first three tests test for active tampering. The next
+// two validate that we can also detect fallback using the
+// SSL_SetDowngradeCheckVersion() API.
+TEST_F(TlsConnectTest, TestDowngradeDetectionToTls11) {
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_TRUE);
+ MakeTlsFilter<TlsMessageVersionSetter>(client_, kTlsHandshakeClientHello,
+ SSL_LIBRARY_VERSION_TLS_1_1);
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Attempt to negotiate the bogus DTLS 1.1 version.
+TEST_F(DtlsConnectTest, TestDtlsVersion11) {
+ MakeTlsFilter<TlsMessageVersionSetter>(client_, kTlsHandshakeClientHello,
+ ((~0x0101) & 0xffff));
+ ConnectExpectAlert(server_, kTlsAlertProtocolVersion);
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+}
+
+TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) {
+ client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_TRUE);
+ MakeTlsFilter<TlsExtensionDropper>(client_, ssl_tls13_supported_versions_xtn);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Disabling downgrade checks will be caught when the Finished MAC check fails.
+TEST_F(TlsConnectTest, TestDisableDowngradeDetection) {
+ client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_FALSE);
+ MakeTlsFilter<TlsExtensionDropper>(client_, ssl_tls13_supported_versions_xtn);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+}
+
+typedef std::tuple<SSLProtocolVariant,
+ uint16_t, // client version
+ uint16_t> // server version
+ TlsDowngradeProfile;
+
+class TlsDowngradeTest
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<TlsDowngradeProfile> {
+ public:
+ TlsDowngradeTest()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())),
+ c_ver(std::get<1>(GetParam())),
+ s_ver(std::get<2>(GetParam())) {}
+
+ protected:
+ const uint16_t c_ver;
+ const uint16_t s_ver;
+};
+
+TEST_P(TlsDowngradeTest, TlsDowngradeSentinelTest) {
+ static const uint8_t tls12_downgrade_random[] = {0x44, 0x4F, 0x57, 0x4E,
+ 0x47, 0x52, 0x44, 0x01};
+ static const uint8_t tls1_downgrade_random[] = {0x44, 0x4F, 0x57, 0x4E,
+ 0x47, 0x52, 0x44, 0x00};
+ static const size_t kRandomLen = 32;
+
+ if (c_ver > s_ver) {
+ GTEST_SKIP();
+ }
+
+ client_->SetVersionRange(c_ver, c_ver);
+ server_->SetVersionRange(c_ver, s_ver);
+
+ auto sh = MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_server_hello);
+ Connect();
+ ASSERT_TRUE(sh->buffer().len() > (kRandomLen + 2));
+
+ const uint8_t* downgrade_sentinel =
+ sh->buffer().data() + 2 + kRandomLen - sizeof(tls1_downgrade_random);
+ if (c_ver < s_ver) {
+ if (c_ver == SSL_LIBRARY_VERSION_TLS_1_2) {
+ EXPECT_EQ(0, memcmp(downgrade_sentinel, tls12_downgrade_random,
+ sizeof(tls12_downgrade_random)));
+ } else {
+ EXPECT_EQ(0, memcmp(downgrade_sentinel, tls1_downgrade_random,
+ sizeof(tls1_downgrade_random)));
+ }
+ } else {
+ EXPECT_NE(0, memcmp(downgrade_sentinel, tls12_downgrade_random,
+ sizeof(tls12_downgrade_random)));
+ EXPECT_NE(0, memcmp(downgrade_sentinel, tls1_downgrade_random,
+ sizeof(tls1_downgrade_random)));
+ }
+}
+
+// TLS 1.1 clients do not check the random values, so we should
+// instead get a handshake failure alert from the server.
+TEST_F(TlsConnectTest, TestDowngradeDetectionToTls10) {
+ // Setting the option here has no effect.
+ client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_TRUE);
+ MakeTlsFilter<TlsMessageVersionSetter>(client_, kTlsHandshakeClientHello,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_1);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+}
+
+TEST_F(TlsConnectTest, TestFallbackFromTls12) {
+ client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_TRUE);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_1);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->SetDowngradeCheckVersion(SSL_LIBRARY_VERSION_TLS_1_2);
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+static SECStatus AllowFalseStart(PRFileDesc* fd, void* arg,
+ PRBool* can_false_start) {
+ bool* false_start_attempted = reinterpret_cast<bool*>(arg);
+ *false_start_attempted = true;
+ *can_false_start = PR_TRUE;
+ return SECSuccess;
+}
+
+// If we disable the downgrade check, the sentinel is still generated, and we
+// disable false start instead.
+TEST_F(TlsConnectTest, DisableFalseStartOnFallback) {
+ // Don't call client_->EnableFalseStart(), because that sets the client up for
+ // success, and we want false start to fail.
+ client_->SetOption(SSL_ENABLE_FALSE_START, PR_TRUE);
+ bool false_start_attempted = false;
+ EXPECT_EQ(SECSuccess,
+ SSL_SetCanFalseStartCallback(client_->ssl_fd(), AllowFalseStart,
+ &false_start_attempted));
+
+ client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_FALSE);
+ client_->SetDowngradeCheckVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_FALSE(false_start_attempted);
+}
+
+TEST_F(TlsConnectTest, TestFallbackFromTls13) {
+ client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_TRUE);
+ client_->SetDowngradeCheckVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectGeneric, TestFallbackSCSVVersionMatch) {
+ client_->SetOption(SSL_ENABLE_FALLBACK_SCSV, PR_TRUE);
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, TestFallbackSCSVVersionMismatch) {
+ client_->SetOption(SSL_ENABLE_FALLBACK_SCSV, PR_TRUE);
+ server_->SetVersionRange(version_, version_ + 1);
+ ConnectExpectAlert(server_, kTlsAlertInappropriateFallback);
+ client_->CheckErrorCode(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT);
+}
+
+// The TLS v1.3 spec section C.4 states that 'Implementations MUST NOT send or
+// accept any records with a version less than { 3, 0 }'. Thus we will not
+// allow version ranges including both SSL v3 and TLS v1.3.
+TEST_F(TlsConnectTest, DisallowSSLv3HelloWithTLSv13Enabled) {
+ SECStatus rv;
+ SSLVersionRange vrange = {SSL_LIBRARY_VERSION_3_0,
+ SSL_LIBRARY_VERSION_TLS_1_3};
+
+ EnsureTlsSetup();
+ rv = SSL_VersionRangeSet(client_->ssl_fd(), &vrange);
+ EXPECT_EQ(SECFailure, rv);
+
+ rv = SSL_VersionRangeSet(server_->ssl_fd(), &vrange);
+ EXPECT_EQ(SECFailure, rv);
+}
+
+TEST_P(TlsConnectGeneric, AlertBeforeServerHello) {
+ EnsureTlsSetup();
+ client_->ExpectReceiveAlert(kTlsAlertUnrecognizedName, kTlsAlertWarning);
+ StartConnect();
+ client_->Handshake(); // Send ClientHello.
+ static const uint8_t kWarningAlert[] = {kTlsAlertWarning,
+ kTlsAlertUnrecognizedName};
+ DataBuffer alert;
+ TlsAgentTestBase::MakeRecord(variant_, ssl_ct_alert,
+ SSL_LIBRARY_VERSION_TLS_1_0, kWarningAlert,
+ PR_ARRAY_SIZE(kWarningAlert), &alert);
+ client_->adapter()->PacketReceived(alert);
+ Handshake();
+ CheckConnected();
+}
+
+class Tls13NoSupportedVersions : public TlsConnectStreamTls12 {
+ protected:
+ void Run(uint16_t overwritten_client_version, uint16_t max_server_version) {
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, max_server_version);
+ MakeTlsFilter<TlsMessageVersionSetter>(client_, kTlsHandshakeClientHello,
+ overwritten_client_version);
+ auto capture =
+ MakeTlsFilter<TlsHandshakeRecorder>(server_, kTlsHandshakeServerHello);
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+ const DataBuffer& server_hello = capture->buffer();
+ ASSERT_GT(server_hello.len(), 2U);
+ uint32_t ver;
+ ASSERT_TRUE(server_hello.Read(0, 2, &ver));
+ ASSERT_EQ(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), ver);
+ }
+};
+
+// If we offer a 1.3 ClientHello w/o supported_versions, the server should
+// negotiate 1.2.
+TEST_F(Tls13NoSupportedVersions,
+ Tls13ClientHelloWithoutSupportedVersionsServer12) {
+ Run(SSL_LIBRARY_VERSION_TLS_1_3, SSL_LIBRARY_VERSION_TLS_1_2);
+}
+
+TEST_F(Tls13NoSupportedVersions,
+ Tls13ClientHelloWithoutSupportedVersionsServer13) {
+ Run(SSL_LIBRARY_VERSION_TLS_1_3, SSL_LIBRARY_VERSION_TLS_1_3);
+}
+
+TEST_F(Tls13NoSupportedVersions,
+ Tls14ClientHelloWithoutSupportedVersionsServer13) {
+ Run(SSL_LIBRARY_VERSION_TLS_1_3 + 1, SSL_LIBRARY_VERSION_TLS_1_3);
+}
+
+// Offer 1.3 but with ClientHello.legacy_version == TLS 1.4. This
+// causes a bad MAC error when we read EncryptedExtensions.
+TEST_F(TlsConnectStreamTls13, Tls14ClientHelloWithSupportedVersions) {
+ MakeTlsFilter<TlsMessageVersionSetter>(client_, kTlsHandshakeClientHello,
+ SSL_LIBRARY_VERSION_TLS_1_3 + 1);
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ server_, ssl_tls13_supported_versions_xtn);
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+
+ ASSERT_EQ(2U, capture->extension().len());
+ uint32_t version = 0;
+ ASSERT_TRUE(capture->extension().Read(0, 2, &version));
+ // This way we don't need to change with new draft version.
+ ASSERT_LT(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), version);
+}
+
+// Offer 1.3 but with Server/ClientHello.legacy_version == SSL 3.0. This
+// causes a protocol version alert. See RFC 8446 Appendix D.5.
+TEST_F(TlsConnectStreamTls13, Ssl30ClientHelloWithSupportedVersions) {
+ MakeTlsFilter<TlsMessageVersionSetter>(client_, kTlsHandshakeClientHello,
+ SSL_LIBRARY_VERSION_3_0);
+ ConnectExpectAlert(server_, kTlsAlertProtocolVersion);
+}
+
+TEST_F(TlsConnectStreamTls13, Ssl30ServerHelloWithSupportedVersions) {
+ MakeTlsFilter<TlsMessageVersionSetter>(server_, kTlsHandshakeServerHello,
+ SSL_LIBRARY_VERSION_3_0);
+ StartConnect();
+ client_->ExpectSendAlert(kTlsAlertProtocolVersion);
+ /* Since the handshake is not finished the client will send an unencrypted
+ * alert. The server is expected to close the connection with a unexpected
+ * message alert. */
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ Handshake();
+}
+
+// Verify the client sends only DTLS versions in supported_versions
+TEST_F(DtlsConnectTest, DtlsSupportedVersionsEncoding) {
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_supported_versions_xtn);
+ Connect();
+
+ ASSERT_EQ(7U, capture->extension().len());
+ uint32_t version = 0;
+ ASSERT_TRUE(capture->extension().Read(1, 2, &version));
+ EXPECT_EQ(0x7f00 | DTLS_1_3_DRAFT_VERSION, static_cast<int>(version));
+ ASSERT_TRUE(capture->extension().Read(3, 2, &version));
+ EXPECT_EQ(SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, static_cast<int>(version));
+ ASSERT_TRUE(capture->extension().Read(5, 2, &version));
+ EXPECT_EQ(SSL_LIBRARY_VERSION_DTLS_1_0_WIRE, static_cast<int>(version));
+}
+
+// Verify the DTLS 1.3 supported_versions interop workaround.
+TEST_F(DtlsConnectTest, Dtls13VersionWorkaround) {
+ static const uint16_t kExpectVersionsWorkaround[] = {
+ 0x7f00 | DTLS_1_3_DRAFT_VERSION, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE,
+ SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_DTLS_1_0_WIRE,
+ SSL_LIBRARY_VERSION_TLS_1_1};
+ const int min_ver = SSL_LIBRARY_VERSION_TLS_1_1,
+ max_ver = SSL_LIBRARY_VERSION_TLS_1_3;
+
+ // Toggle the workaround, then verify both encodings are present.
+ EnsureTlsSetup();
+ SSL_SetDtls13VersionWorkaround(client_->ssl_fd(), PR_TRUE);
+ SSL_SetDtls13VersionWorkaround(client_->ssl_fd(), PR_FALSE);
+ SSL_SetDtls13VersionWorkaround(client_->ssl_fd(), PR_TRUE);
+ client_->SetVersionRange(min_ver, max_ver);
+ server_->SetVersionRange(min_ver, max_ver);
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_supported_versions_xtn);
+ Connect();
+
+ uint32_t version = 0;
+ size_t off = 1;
+ ASSERT_EQ(1 + sizeof(kExpectVersionsWorkaround), capture->extension().len());
+ for (unsigned int i = 0; i < PR_ARRAY_SIZE(kExpectVersionsWorkaround); i++) {
+ ASSERT_TRUE(capture->extension().Read(off, 2, &version));
+ EXPECT_EQ(kExpectVersionsWorkaround[i], static_cast<uint16_t>(version));
+ off += 2;
+ }
+}
+
+// Verify the client sends only TLS versions in supported_versions
+TEST_F(TlsConnectTest, TlsSupportedVersionsEncoding) {
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_supported_versions_xtn);
+ Connect();
+
+ ASSERT_EQ(9U, capture->extension().len());
+ uint32_t version = 0;
+ ASSERT_TRUE(capture->extension().Read(1, 2, &version));
+ EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_3, static_cast<int>(version));
+ ASSERT_TRUE(capture->extension().Read(3, 2, &version));
+ EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, static_cast<int>(version));
+ ASSERT_TRUE(capture->extension().Read(5, 2, &version));
+ EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_1, static_cast<int>(version));
+ ASSERT_TRUE(capture->extension().Read(7, 2, &version));
+ EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_0, static_cast<int>(version));
+}
+
+/* Test that on reception of unsupported ClientHello.legacy_version the TLS 1.3
+ * server sends the correct alert.
+ *
+ * If the "supported_versions" extension is absent and the server only supports
+ * versions greater than ClientHello.legacy_version, the server MUST abort the
+ * handshake with a "protocol_version" alert [RFC8446, Appendix D.2]. */
+TEST_P(TlsConnectGenericPre13, ClientHelloUnsupportedTlsVersion) {
+ StartConnect();
+
+ if (variant_ == ssl_variant_stream) {
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ } else {
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_DTLS_1_3,
+ SSL_LIBRARY_VERSION_DTLS_1_3);
+ }
+
+ // Try to handshake
+ client_->Handshake();
+ // Expect protocol version alert
+ server_->ExpectSendAlert(kTlsAlertProtocolVersion);
+ server_->Handshake();
+ // Digest alert at peer
+ client_->ExpectReceiveAlert(kTlsAlertProtocolVersion);
+ client_->ReadBytes();
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ TlsDowngradeSentinelTest, TlsDowngradeTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll,
+ TlsConnectTestBase::kTlsV12Plus));
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
new file mode 100644
index 0000000000..91d8080377
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
@@ -0,0 +1,385 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "nss.h"
+#include "secerr.h"
+#include "ssl.h"
+#include "ssl3prot.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+#include <iostream>
+
+namespace nss_test {
+
+std::string GetSSLVersionString(uint16_t v) {
+ switch (v) {
+ case SSL_LIBRARY_VERSION_3_0:
+ return "ssl3";
+ case SSL_LIBRARY_VERSION_TLS_1_0:
+ return "tls1.0";
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ return "tls1.1";
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ return "tls1.2";
+ case SSL_LIBRARY_VERSION_TLS_1_3:
+ return "tls1.3";
+ case SSL_LIBRARY_VERSION_NONE:
+ return "NONE";
+ }
+ if (v < SSL_LIBRARY_VERSION_3_0) {
+ return "undefined-too-low";
+ }
+ return "undefined-too-high";
+}
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const SSLVersionRange& vr) {
+ return stream << GetSSLVersionString(vr.min) << ","
+ << GetSSLVersionString(vr.max);
+}
+
+class VersionRangeWithLabel {
+ public:
+ VersionRangeWithLabel(const std::string& txt, const SSLVersionRange& vr)
+ : label_(txt), vr_(vr) {}
+ VersionRangeWithLabel(const std::string& txt, uint16_t start, uint16_t end)
+ : label_(txt) {
+ vr_.min = start;
+ vr_.max = end;
+ }
+ VersionRangeWithLabel(const std::string& label) : label_(label) {
+ vr_.min = vr_.max = SSL_LIBRARY_VERSION_NONE;
+ }
+
+ void WriteStream(std::ostream& stream) const {
+ stream << " " << label_ << ": " << vr_;
+ }
+
+ uint16_t min() const { return vr_.min; }
+ uint16_t max() const { return vr_.max; }
+ SSLVersionRange range() const { return vr_; }
+
+ private:
+ std::string label_;
+ SSLVersionRange vr_;
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const VersionRangeWithLabel& vrwl) {
+ vrwl.WriteStream(stream);
+ return stream;
+}
+
+typedef std::tuple<SSLProtocolVariant, // variant
+ uint16_t, // policy min
+ uint16_t, // policy max
+ uint16_t, // input min
+ uint16_t> // input max
+ PolicyVersionRangeInput;
+
+class TestPolicyVersionRange
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<PolicyVersionRangeInput> {
+ public:
+ TestPolicyVersionRange()
+ : TlsConnectTestBase(std::get<0>(GetParam()), 0),
+ variant_(std::get<0>(GetParam())),
+ policy_("policy", std::get<1>(GetParam()), std::get<2>(GetParam())),
+ input_("input", std::get<3>(GetParam()), std::get<4>(GetParam())),
+ library_("supported-by-library",
+ ((variant_ == ssl_variant_stream)
+ ? SSL_LIBRARY_VERSION_MIN_SUPPORTED_STREAM
+ : SSL_LIBRARY_VERSION_MIN_SUPPORTED_DATAGRAM),
+ SSL_LIBRARY_VERSION_MAX_SUPPORTED) {
+ TlsConnectTestBase::SkipVersionChecks();
+ }
+
+ void SetPolicy(const SSLVersionRange& policy) {
+ NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, 0);
+
+ SECStatus rv;
+ rv = NSS_OptionSet(NSS_TLS_VERSION_MIN_POLICY, policy.min);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_TLS_VERSION_MAX_POLICY, policy.max);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_DTLS_VERSION_MIN_POLICY, policy.min);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_DTLS_VERSION_MAX_POLICY, policy.max);
+ ASSERT_EQ(SECSuccess, rv);
+ }
+
+ void CreateDummySocket(std::shared_ptr<DummyPrSocket>* dummy_socket,
+ ScopedPRFileDesc* ssl_fd) {
+ (*dummy_socket).reset(new DummyPrSocket("dummy", variant_));
+ *ssl_fd = (*dummy_socket)->CreateFD();
+ if (variant_ == ssl_variant_stream) {
+ SSL_ImportFD(nullptr, ssl_fd->get());
+ } else {
+ DTLS_ImportFD(nullptr, ssl_fd->get());
+ }
+ }
+
+ bool GetOverlap(const SSLVersionRange& r1, const SSLVersionRange& r2,
+ SSLVersionRange* overlap) {
+ if (r1.min == SSL_LIBRARY_VERSION_NONE ||
+ r1.max == SSL_LIBRARY_VERSION_NONE ||
+ r2.min == SSL_LIBRARY_VERSION_NONE ||
+ r2.max == SSL_LIBRARY_VERSION_NONE) {
+ return false;
+ }
+
+ SSLVersionRange temp;
+ temp.min = PR_MAX(r1.min, r2.min);
+ temp.max = PR_MIN(r1.max, r2.max);
+
+ if (temp.min > temp.max) {
+ return false;
+ }
+
+ *overlap = temp;
+ return true;
+ }
+
+ bool IsValidInputForVersionRangeSet(SSLVersionRange* expectedEffectiveRange) {
+ if (input_.min() <= SSL_LIBRARY_VERSION_3_0 &&
+ input_.max() >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // This is always invalid input, independent of policy
+ return false;
+ }
+
+ if (input_.min() < library_.min() || input_.max() > library_.max() ||
+ input_.min() > input_.max()) {
+ // Asking for unsupported ranges is invalid input for VersionRangeSet
+ // APIs, regardless of overlap.
+ return false;
+ }
+
+ SSLVersionRange overlap_with_library;
+ if (!GetOverlap(input_.range(), library_.range(), &overlap_with_library)) {
+ return false;
+ }
+
+ SSLVersionRange overlap_with_library_and_policy;
+ if (!GetOverlap(overlap_with_library, policy_.range(),
+ &overlap_with_library_and_policy)) {
+ return false;
+ }
+
+ RemoveConflictingVersions(variant_, &overlap_with_library_and_policy);
+ *expectedEffectiveRange = overlap_with_library_and_policy;
+ return true;
+ }
+
+ void RemoveConflictingVersions(SSLProtocolVariant variant,
+ SSLVersionRange* r) {
+ ASSERT_TRUE(r != nullptr);
+ if (r->max >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ r->min < SSL_LIBRARY_VERSION_TLS_1_0) {
+ r->min = SSL_LIBRARY_VERSION_TLS_1_0;
+ }
+ }
+
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
+ SetPolicy(policy_.range());
+ }
+
+ void TearDown() override {
+ TlsConnectTestBase::TearDown();
+ saved_version_policy_.RestoreOriginalPolicy();
+ }
+
+ protected:
+ class VersionPolicy {
+ public:
+ VersionPolicy() { SaveOriginalPolicy(); }
+
+ void RestoreOriginalPolicy() {
+ SECStatus rv;
+ rv = NSS_OptionSet(NSS_TLS_VERSION_MIN_POLICY, saved_min_tls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_TLS_VERSION_MAX_POLICY, saved_max_tls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_DTLS_VERSION_MIN_POLICY, saved_min_dtls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_DTLS_VERSION_MAX_POLICY, saved_max_dtls_);
+ ASSERT_EQ(SECSuccess, rv);
+ }
+
+ private:
+ void SaveOriginalPolicy() {
+ SECStatus rv;
+ rv = NSS_OptionGet(NSS_TLS_VERSION_MIN_POLICY, &saved_min_tls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionGet(NSS_TLS_VERSION_MAX_POLICY, &saved_max_tls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionGet(NSS_DTLS_VERSION_MIN_POLICY, &saved_min_dtls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionGet(NSS_DTLS_VERSION_MAX_POLICY, &saved_max_dtls_);
+ ASSERT_EQ(SECSuccess, rv);
+ }
+
+ int32_t saved_min_tls_;
+ int32_t saved_max_tls_;
+ int32_t saved_min_dtls_;
+ int32_t saved_max_dtls_;
+ };
+
+ VersionPolicy saved_version_policy_;
+
+ SSLProtocolVariant variant_;
+ const VersionRangeWithLabel policy_;
+ const VersionRangeWithLabel input_;
+ const VersionRangeWithLabel library_;
+};
+
+static const uint16_t kExpandedVersionsArr[] = {
+ /* clang-format off */
+ SSL_LIBRARY_VERSION_3_0 - 1,
+ SSL_LIBRARY_VERSION_3_0,
+ SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2,
+#ifndef NSS_DISABLE_TLS_1_3
+ SSL_LIBRARY_VERSION_TLS_1_3,
+#endif
+ SSL_LIBRARY_VERSION_MAX_SUPPORTED + 1
+ /* clang-format on */
+};
+static ::testing::internal::ParamGenerator<uint16_t> kExpandedVersions =
+ ::testing::ValuesIn(kExpandedVersionsArr);
+
+TEST_P(TestPolicyVersionRange, TestAllTLSVersionsAndPolicyCombinations) {
+ ASSERT_TRUE(variant_ == ssl_variant_stream ||
+ variant_ == ssl_variant_datagram)
+ << "testing unsupported ssl variant";
+
+ std::cerr << "testing: " << variant_ << policy_ << input_ << library_
+ << std::endl;
+
+ SSLVersionRange supported_range;
+ SECStatus rv = SSL_VersionRangeGetSupported(variant_, &supported_range);
+ VersionRangeWithLabel supported("SSL_VersionRangeGetSupported",
+ supported_range);
+
+ std::cerr << supported << std::endl;
+
+ std::shared_ptr<DummyPrSocket> dummy_socket;
+ ScopedPRFileDesc ssl_fd;
+ CreateDummySocket(&dummy_socket, &ssl_fd);
+
+ SECStatus rv_socket;
+ SSLVersionRange overlap_policy_and_lib;
+ if (!GetOverlap(policy_.range(), library_.range(), &overlap_policy_and_lib)) {
+ EXPECT_EQ(SECFailure, rv)
+ << "expected SSL_VersionRangeGetSupported to fail with invalid policy";
+
+ SSLVersionRange enabled_range;
+ rv = SSL_VersionRangeGetDefault(variant_, &enabled_range);
+ EXPECT_EQ(SECFailure, rv)
+ << "expected SSL_VersionRangeGetDefault to fail with invalid policy";
+
+ SSLVersionRange enabled_range_on_socket;
+ rv_socket = SSL_VersionRangeGet(ssl_fd.get(), &enabled_range_on_socket);
+ EXPECT_EQ(SECFailure, rv_socket)
+ << "expected SSL_VersionRangeGet to fail with invalid policy";
+
+ ConnectExpectFail();
+ return;
+ }
+
+ EXPECT_EQ(SECSuccess, rv)
+ << "expected SSL_VersionRangeGetSupported to succeed with valid policy";
+
+ EXPECT_TRUE(supported_range.min != SSL_LIBRARY_VERSION_NONE &&
+ supported_range.max != SSL_LIBRARY_VERSION_NONE)
+ << "expected SSL_VersionRangeGetSupported to return real values with "
+ "valid policy";
+
+ RemoveConflictingVersions(variant_, &overlap_policy_and_lib);
+ VersionRangeWithLabel overlap_info("overlap", overlap_policy_and_lib);
+
+ EXPECT_TRUE(supported_range == overlap_policy_and_lib)
+ << "expected range from GetSupported to be identical with calculated "
+ "overlap "
+ << overlap_info;
+
+ // We don't know which versions are "enabled by default" by the library,
+ // therefore we don't know if there's overlap between the default
+ // and the policy, and therefore, we don't if TLS connections should
+ // be successful or fail in this combination.
+ // Therefore we don't test if we can connect, without having configured a
+ // version range explicitly.
+
+ // Now start testing with supplied input.
+
+ SSLVersionRange expected_effective_range;
+ bool is_valid_input =
+ IsValidInputForVersionRangeSet(&expected_effective_range);
+
+ SSLVersionRange temp_input = input_.range();
+ rv = SSL_VersionRangeSetDefault(variant_, &temp_input);
+ rv_socket = SSL_VersionRangeSet(ssl_fd.get(), &temp_input);
+
+ if (!is_valid_input) {
+ EXPECT_EQ(SECFailure, rv)
+ << "expected failure return from SSL_VersionRangeSetDefault";
+
+ EXPECT_EQ(SECFailure, rv_socket)
+ << "expected failure return from SSL_VersionRangeSet";
+ return;
+ }
+
+ EXPECT_EQ(SECSuccess, rv)
+ << "expected successful return from SSL_VersionRangeSetDefault";
+
+ EXPECT_EQ(SECSuccess, rv_socket)
+ << "expected successful return from SSL_VersionRangeSet";
+
+ SSLVersionRange effective;
+ SSLVersionRange effective_socket;
+
+ rv = SSL_VersionRangeGetDefault(variant_, &effective);
+ EXPECT_EQ(SECSuccess, rv)
+ << "expected successful return from SSL_VersionRangeGetDefault";
+
+ rv_socket = SSL_VersionRangeGet(ssl_fd.get(), &effective_socket);
+ EXPECT_EQ(SECSuccess, rv_socket)
+ << "expected successful return from SSL_VersionRangeGet";
+
+ VersionRangeWithLabel expected_info("expectation", expected_effective_range);
+ VersionRangeWithLabel effective_info("effectively-enabled", effective);
+
+ EXPECT_TRUE(expected_effective_range == effective)
+ << "range returned by SSL_VersionRangeGetDefault doesn't match "
+ "expectation: "
+ << expected_info << effective_info;
+
+ EXPECT_TRUE(expected_effective_range == effective_socket)
+ << "range returned by SSL_VersionRangeGet doesn't match "
+ "expectation: "
+ << expected_info << effective_info;
+
+ // Because we found overlap between policy and supported versions,
+ // and because we have used SetDefault to enable at least one version,
+ // it should be possible to execute an SSL/TLS connection.
+ Connect();
+}
+
+INSTANTIATE_TEST_SUITE_P(TLSVersionRanges, TestPolicyVersionRange,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ kExpandedVersions,
+ kExpandedVersions,
+ kExpandedVersions,
+ kExpandedVersions));
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/test_io.cc b/security/nss/gtests/ssl_gtest/test_io.cc
new file mode 100644
index 0000000000..e4651a2352
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/test_io.cc
@@ -0,0 +1,278 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "test_io.h"
+
+#include <algorithm>
+#include <cassert>
+#include <iostream>
+#include <memory>
+
+#include "prerror.h"
+#include "prlog.h"
+#include "prthread.h"
+
+extern bool g_ssl_gtest_verbose;
+
+namespace nss_test {
+
+#define LOG(a) std::cerr << name_ << ": " << a << std::endl
+#define LOGV(a) \
+ do { \
+ if (g_ssl_gtest_verbose) LOG(a); \
+ } while (false)
+
+PRDescIdentity DummyPrSocket::LayerId() {
+ static PRDescIdentity id = PR_GetUniqueIdentity("dummysocket");
+ return id;
+}
+
+ScopedPRFileDesc DummyPrSocket::CreateFD() {
+ return DummyIOLayerMethods::CreateFD(DummyPrSocket::LayerId(), this);
+}
+
+void DummyPrSocket::Reset() {
+ auto p = peer_.lock();
+ peer_.reset();
+ if (p) {
+ p->peer_.reset();
+ p->Reset();
+ }
+ while (!input_.empty()) {
+ input_.pop();
+ }
+ filter_ = nullptr;
+ write_error_ = 0;
+}
+
+void DummyPrSocket::PacketReceived(const DataBuffer &packet) {
+ input_.push(Packet(packet));
+}
+
+int32_t DummyPrSocket::Read(PRFileDesc *f, void *data, int32_t len) {
+ PR_ASSERT(variant_ == ssl_variant_stream);
+ if (variant_ != ssl_variant_stream) {
+ PR_SetError(PR_INVALID_METHOD_ERROR, 0);
+ return -1;
+ }
+
+ auto dst = peer_.lock();
+ if (!dst) {
+ PR_SetError(PR_NOT_CONNECTED_ERROR, 0);
+ return -1;
+ }
+
+ if (input_.empty()) {
+ LOGV("Read --> wouldblock " << len);
+ PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
+ return -1;
+ }
+
+ auto &front = input_.front();
+ size_t to_read =
+ std::min(static_cast<size_t>(len), front.len() - front.offset());
+ memcpy(data, static_cast<const void *>(front.data() + front.offset()),
+ to_read);
+ front.Advance(to_read);
+
+ if (!front.remaining()) {
+ input_.pop();
+ }
+
+ return static_cast<int32_t>(to_read);
+}
+
+int32_t DummyPrSocket::Recv(PRFileDesc *f, void *buf, int32_t buflen,
+ int32_t flags, PRIntervalTime to) {
+ PR_ASSERT(flags == 0);
+ if (flags != 0) {
+ PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
+ return -1;
+ }
+
+ if (variant() != ssl_variant_datagram) {
+ return Read(f, buf, buflen);
+ }
+
+ auto dst = peer_.lock();
+ if (!dst) {
+ PR_SetError(PR_NOT_CONNECTED_ERROR, 0);
+ return -1;
+ }
+
+ if (input_.empty()) {
+ PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
+ return -1;
+ }
+
+ auto &front = input_.front();
+ if (static_cast<size_t>(buflen) < front.len()) {
+ PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0);
+ return -1;
+ }
+
+ size_t count = front.len();
+ memcpy(buf, front.data(), count);
+
+ input_.pop();
+ return static_cast<int32_t>(count);
+}
+
+int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) {
+ if (write_error_) {
+ PR_SetError(write_error_, 0);
+ return -1;
+ }
+
+ auto dst = peer_.lock();
+ if (!dst) {
+ PR_SetError(PR_NOT_CONNECTED_ERROR, 0);
+ return -1;
+ }
+
+ DataBuffer packet(static_cast<const uint8_t *>(buf),
+ static_cast<size_t>(length));
+ DataBuffer filtered;
+ PacketFilter::Action action = PacketFilter::KEEP;
+ if (filter_) {
+ LOGV("Original packet: " << packet);
+ action = filter_->Process(packet, &filtered);
+ }
+ switch (action) {
+ case PacketFilter::CHANGE:
+ LOG("Filtered packet: " << filtered);
+ dst->PacketReceived(filtered);
+ break;
+ case PacketFilter::DROP:
+ LOG("Drop packet");
+ break;
+ case PacketFilter::KEEP:
+ dst->PacketReceived(packet);
+ break;
+ }
+ // libssl can't handle it if this reports something other than the length
+ // of what was passed in (or less, but we're not doing partial writes).
+ return static_cast<int32_t>(packet.len());
+}
+
+Poller *Poller::instance;
+
+Poller *Poller::Instance() {
+ if (!instance) instance = new Poller();
+
+ return instance;
+}
+
+void Poller::Shutdown() {
+ delete instance;
+ instance = nullptr;
+}
+
+void Poller::Wait(Event event, std::shared_ptr<DummyPrSocket> &adapter,
+ PollTarget *target, PollCallback cb) {
+ assert(event < TIMER_EVENT);
+ if (event >= TIMER_EVENT) return;
+
+ std::unique_ptr<Waiter> waiter;
+ auto it = waiters_.find(adapter);
+ if (it == waiters_.end()) {
+ waiter.reset(new Waiter(adapter));
+ } else {
+ waiter = std::move(it->second);
+ }
+
+ waiter->targets_[event] = target;
+ waiter->callbacks_[event] = cb;
+ waiters_[adapter] = std::move(waiter);
+}
+
+void Poller::Cancel(Event event, std::shared_ptr<DummyPrSocket> &adapter) {
+ auto it = waiters_.find(adapter);
+ if (it == waiters_.end()) {
+ return;
+ }
+
+ auto &waiter = it->second;
+ waiter->targets_[event] = nullptr;
+ waiter->callbacks_[event] = nullptr;
+
+ // Clean up if there are no callbacks.
+ for (size_t i = 0; i < TIMER_EVENT; ++i) {
+ if (waiter->callbacks_[i]) return;
+ }
+
+ waiters_.erase(adapter);
+}
+
+void Poller::SetTimer(uint32_t timer_ms, PollTarget *target, PollCallback cb,
+ std::shared_ptr<Timer> *timer) {
+ auto t = std::make_shared<Timer>(PR_Now() + timer_ms * 1000, target, cb);
+ timers_.push(t);
+ if (timer) *timer = t;
+}
+
+bool Poller::Poll() {
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "Poll() waiters = " << waiters_.size()
+ << " timers = " << timers_.size() << std::endl;
+ }
+ PRIntervalTime timeout = PR_INTERVAL_NO_TIMEOUT;
+ PRTime now = PR_Now();
+ bool fired = false;
+
+ // Figure out the timer for the select.
+ if (!timers_.empty()) {
+ auto first_timer = timers_.top();
+ if (now >= first_timer->deadline_) {
+ // Timer expired.
+ timeout = PR_INTERVAL_NO_WAIT;
+ } else {
+ timeout =
+ PR_MillisecondsToInterval((first_timer->deadline_ - now) / 1000);
+ }
+ }
+
+ for (auto it = waiters_.begin(); it != waiters_.end(); ++it) {
+ auto &waiter = it->second;
+
+ if (waiter->callbacks_[READABLE_EVENT]) {
+ if (waiter->io_->readable()) {
+ PollCallback callback = waiter->callbacks_[READABLE_EVENT];
+ PollTarget *target = waiter->targets_[READABLE_EVENT];
+ waiter->callbacks_[READABLE_EVENT] = nullptr;
+ waiter->targets_[READABLE_EVENT] = nullptr;
+ callback(target, READABLE_EVENT);
+ fired = true;
+ }
+ }
+ }
+
+ if (fired) timeout = PR_INTERVAL_NO_WAIT;
+
+ // Can't wait forever and also have nothing readable now.
+ if (timeout == PR_INTERVAL_NO_TIMEOUT) return false;
+
+ // Sleep.
+ if (timeout != PR_INTERVAL_NO_WAIT) {
+ PR_Sleep(timeout);
+ }
+
+ // Now process anything that timed out.
+ now = PR_Now();
+ while (!timers_.empty()) {
+ if (now < timers_.top()->deadline_) break;
+
+ auto timer = timers_.top();
+ timers_.pop();
+ if (timer->callback_) {
+ timer->callback_(timer->target_, TIMER_EVENT);
+ }
+ }
+
+ return true;
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/test_io.h b/security/nss/gtests/ssl_gtest/test_io.h
new file mode 100644
index 0000000000..e262fb123e
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/test_io.h
@@ -0,0 +1,187 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef test_io_h_
+#define test_io_h_
+
+#include <string.h>
+#include <map>
+#include <memory>
+#include <ostream>
+#include <queue>
+#include <string>
+
+#include "databuffer.h"
+#include "dummy_io.h"
+#include "prio.h"
+#include "nss_scoped_ptrs.h"
+#include "sslt.h"
+
+namespace nss_test {
+
+class DataBuffer;
+class DummyPrSocket; // Fwd decl.
+
+// Allow us to inspect a packet before it is written.
+class PacketFilter {
+ public:
+ enum Action {
+ KEEP, // keep the original packet unmodified
+ CHANGE, // change the packet to a different value
+ DROP // drop the packet
+ };
+ explicit PacketFilter(bool on = true) : enabled_(on) {}
+ virtual ~PacketFilter() {}
+
+ bool enabled() const { return enabled_; }
+
+ virtual Action Process(const DataBuffer& input, DataBuffer* output) {
+ if (!enabled_) {
+ return KEEP;
+ }
+ return Filter(input, output);
+ }
+ void Enable() { enabled_ = true; }
+ void Disable() { enabled_ = false; }
+
+ // The packet filter takes input and has the option of mutating it.
+ //
+ // A filter that modifies the data places the modified data in *output and
+ // returns CHANGE. A filter that does not modify data returns LEAVE, in which
+ // case the value in *output is ignored. A Filter can return DROP, in which
+ // case the packet is dropped (and *output is ignored).
+ virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0;
+
+ private:
+ bool enabled_;
+};
+
+class DummyPrSocket : public DummyIOLayerMethods {
+ public:
+ DummyPrSocket(const std::string& name, SSLProtocolVariant var)
+ : name_(name),
+ variant_(var),
+ peer_(),
+ input_(),
+ filter_(nullptr),
+ write_error_(0) {}
+ virtual ~DummyPrSocket() {}
+
+ static PRDescIdentity LayerId();
+
+ // Create a file descriptor that will reference this object. The fd must not
+ // live longer than this adapter; call PR_Close() before.
+ ScopedPRFileDesc CreateFD();
+
+ std::weak_ptr<DummyPrSocket>& peer() { return peer_; }
+ void SetPeer(const std::shared_ptr<DummyPrSocket>& p) { peer_ = p; }
+ void SetPacketFilter(const std::shared_ptr<PacketFilter>& filter) {
+ filter_ = filter;
+ }
+ // Drops peer, packet filter and any outstanding packets.
+ void Reset();
+
+ void PacketReceived(const DataBuffer& data);
+ int32_t Read(PRFileDesc* f, void* data, int32_t len) override;
+ int32_t Recv(PRFileDesc* f, void* buf, int32_t buflen, int32_t flags,
+ PRIntervalTime to) override;
+ int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override;
+ void SetWriteError(PRErrorCode code) { write_error_ = code; }
+
+ SSLProtocolVariant variant() const { return variant_; }
+ bool readable() const { return !input_.empty(); }
+
+ private:
+ class Packet : public DataBuffer {
+ public:
+ Packet(const DataBuffer& buf) : DataBuffer(buf), offset_(0) {}
+
+ void Advance(size_t delta) {
+ PR_ASSERT(offset_ + delta <= len());
+ offset_ = std::min(len(), offset_ + delta);
+ }
+
+ size_t offset() const { return offset_; }
+ size_t remaining() const { return len() - offset_; }
+
+ private:
+ size_t offset_;
+ };
+
+ const std::string name_;
+ SSLProtocolVariant variant_;
+ std::weak_ptr<DummyPrSocket> peer_;
+ std::queue<Packet> input_;
+ std::shared_ptr<PacketFilter> filter_;
+ PRErrorCode write_error_;
+};
+
+// Marker interface.
+class PollTarget {};
+
+enum Event { READABLE_EVENT, TIMER_EVENT /* Must be last */ };
+
+typedef void (*PollCallback)(PollTarget*, Event);
+
+class Poller {
+ public:
+ static Poller* Instance(); // Get a singleton.
+ static void Shutdown(); // Shut it down.
+
+ class Timer {
+ public:
+ Timer(PRTime deadline, PollTarget* target, PollCallback callback)
+ : deadline_(deadline), target_(target), callback_(callback) {}
+ void Cancel() { callback_ = nullptr; }
+
+ PRTime deadline_;
+ PollTarget* target_;
+ PollCallback callback_;
+ };
+
+ void Wait(Event event, std::shared_ptr<DummyPrSocket>& adapter,
+ PollTarget* target, PollCallback cb);
+ void Cancel(Event event, std::shared_ptr<DummyPrSocket>& adapter);
+ void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb,
+ std::shared_ptr<Timer>* handle);
+ bool Poll();
+
+ private:
+ Poller() : waiters_(), timers_() {}
+ ~Poller() {}
+
+ class Waiter {
+ public:
+ Waiter(std::shared_ptr<DummyPrSocket> io) : io_(io) {
+ memset(&targets_[0], 0, sizeof(targets_));
+ memset(&callbacks_[0], 0, sizeof(callbacks_));
+ }
+
+ void WaitFor(Event event, PollCallback callback);
+
+ std::shared_ptr<DummyPrSocket> io_;
+ PollTarget* targets_[TIMER_EVENT];
+ PollCallback callbacks_[TIMER_EVENT];
+ };
+
+ class TimerComparator {
+ public:
+ bool operator()(const std::shared_ptr<Timer> lhs,
+ const std::shared_ptr<Timer> rhs) {
+ return lhs->deadline_ > rhs->deadline_;
+ }
+ };
+
+ static Poller* instance;
+ std::map<std::shared_ptr<DummyPrSocket>, std::unique_ptr<Waiter>> waiters_;
+ std::priority_queue<std::shared_ptr<Timer>,
+ std::vector<std::shared_ptr<Timer>>, TimerComparator>
+ timers_;
+};
+
+} // namespace nss_test
+
+#endif
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc
new file mode 100644
index 0000000000..8ec2f40f75
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_agent.cc
@@ -0,0 +1,1432 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_agent.h"
+#include "databuffer.h"
+#include "keyhi.h"
+#include "pk11func.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslexp.h"
+#include "sslproto.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+// This is an internal header, used to get DTLS_1_3_DRAFT_VERSION.
+#include "ssl3prot.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+
+extern std::string g_working_dir_path;
+
+namespace nss_test {
+
+const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"};
+
+const std::string TlsAgent::kClient = "client"; // both sign and encrypt
+const std::string TlsAgent::kRsa2048 = "rsa2048"; // bigger
+const std::string TlsAgent::kRsa8192 = "rsa8192"; // biggest allowed
+const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt
+const std::string TlsAgent::kServerRsaSign = "rsa_sign";
+const std::string TlsAgent::kServerRsaPss = "rsa_pss";
+const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt";
+const std::string TlsAgent::kServerEcdsa256 = "ecdsa256";
+const std::string TlsAgent::kServerEcdsa384 = "ecdsa384";
+const std::string TlsAgent::kServerEcdsa521 = "ecdsa521";
+const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa";
+const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa";
+const std::string TlsAgent::kServerDsa = "dsa";
+const std::string TlsAgent::kDelegatorEcdsa256 = "delegator_ecdsa256";
+const std::string TlsAgent::kDelegatorRsae2048 = "delegator_rsae2048";
+const std::string TlsAgent::kDelegatorRsaPss2048 = "delegator_rsa_pss2048";
+
+static const uint8_t kCannedTls13ServerHello[] = {
+ 0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3,
+ 0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b,
+ 0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76,
+ 0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24,
+ 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03,
+ 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9,
+ 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08,
+ 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04};
+
+TlsAgent::TlsAgent(const std::string& nm, Role rl, SSLProtocolVariant var)
+ : name_(nm),
+ variant_(var),
+ role_(rl),
+ server_key_bits_(0),
+ adapter_(new DummyPrSocket(role_str(), var)),
+ ssl_fd_(nullptr),
+ state_(STATE_INIT),
+ timer_handle_(nullptr),
+ falsestart_enabled_(false),
+ expected_version_(0),
+ expected_cipher_suite_(0),
+ expect_client_auth_(false),
+ expect_ech_(false),
+ expect_psk_(ssl_psk_none),
+ can_falsestart_hook_called_(false),
+ sni_hook_called_(false),
+ auth_certificate_hook_called_(false),
+ expected_received_alert_(kTlsAlertCloseNotify),
+ expected_received_alert_level_(kTlsAlertWarning),
+ expected_sent_alert_(kTlsAlertCloseNotify),
+ expected_sent_alert_level_(kTlsAlertWarning),
+ handshake_callback_called_(false),
+ resumption_callback_called_(false),
+ error_code_(0),
+ send_ctr_(0),
+ recv_ctr_(0),
+ expect_readwrite_error_(false),
+ handshake_callback_(),
+ auth_certificate_callback_(),
+ sni_callback_(),
+ skip_version_checks_(false),
+ resumption_token_(),
+ policy_() {
+ memset(&info_, 0, sizeof(info_));
+ memset(&csinfo_, 0, sizeof(csinfo_));
+ SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+TlsAgent::~TlsAgent() {
+ if (timer_handle_) {
+ timer_handle_->Cancel();
+ }
+
+ if (adapter_) {
+ Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
+ }
+
+ // Add failures manually, if any, so we don't throw in a destructor.
+ if (expected_received_alert_ != kTlsAlertCloseNotify ||
+ expected_received_alert_level_ != kTlsAlertWarning) {
+ ADD_FAILURE() << "Wrong expected_received_alert status: " << role_str();
+ }
+ if (expected_sent_alert_ != kTlsAlertCloseNotify ||
+ expected_sent_alert_level_ != kTlsAlertWarning) {
+ ADD_FAILURE() << "Wrong expected_sent_alert status: " << role_str();
+ }
+}
+
+void TlsAgent::SetState(State s) {
+ if (state_ == s) return;
+
+ LOG("Changing state from " << state_ << " to " << s);
+ state_ = s;
+}
+
+/*static*/ bool TlsAgent::LoadCertificate(const std::string& name,
+ ScopedCERTCertificate* cert,
+ ScopedSECKEYPrivateKey* priv) {
+ cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr));
+ EXPECT_NE(nullptr, cert);
+ if (!cert) return false;
+ EXPECT_NE(nullptr, cert->get());
+ if (!cert->get()) return false;
+
+ priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr));
+ EXPECT_NE(nullptr, priv);
+ if (!priv) return false;
+ EXPECT_NE(nullptr, priv->get());
+ if (!priv->get()) return false;
+
+ return true;
+}
+
+// Loads a key pair from the certificate identified by |id|.
+/*static*/ bool TlsAgent::LoadKeyPairFromCert(const std::string& name,
+ ScopedSECKEYPublicKey* pub,
+ ScopedSECKEYPrivateKey* priv) {
+ ScopedCERTCertificate cert;
+ if (!TlsAgent::LoadCertificate(name, &cert, priv)) {
+ return false;
+ }
+
+ pub->reset(SECKEY_ExtractPublicKey(&cert->subjectPublicKeyInfo));
+ if (!pub->get()) {
+ return false;
+ }
+
+ return true;
+}
+
+void TlsAgent::DelegateCredential(const std::string& name,
+ const ScopedSECKEYPublicKey& dc_pub,
+ SSLSignatureScheme dc_cert_verify_alg,
+ PRUint32 dc_valid_for, PRTime now,
+ SECItem* dc) {
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey cert_priv;
+ EXPECT_TRUE(TlsAgent::LoadCertificate(name, &cert, &cert_priv))
+ << "Could not load delegate certificate: " << name
+ << "; test db corrupt?";
+
+ EXPECT_EQ(SECSuccess,
+ SSL_DelegateCredential(cert.get(), cert_priv.get(), dc_pub.get(),
+ dc_cert_verify_alg, dc_valid_for, now, dc));
+}
+
+void TlsAgent::EnableDelegatedCredentials() {
+ ASSERT_TRUE(EnsureTlsSetup());
+ SetOption(SSL_ENABLE_DELEGATED_CREDENTIALS, PR_TRUE);
+}
+
+void TlsAgent::AddDelegatedCredential(const std::string& dc_name,
+ SSLSignatureScheme dc_cert_verify_alg,
+ PRUint32 dc_valid_for, PRTime now) {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(dc_name, &pub, &priv));
+
+ StackSECItem dc;
+ TlsAgent::DelegateCredential(name_, pub, dc_cert_verify_alg, dc_valid_for,
+ now, &dc);
+
+ SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr,
+ nullptr, &dc, priv.get()};
+ EXPECT_TRUE(ConfigServerCert(name_, true, &extra_data));
+}
+
+bool TlsAgent::ConfigServerCert(const std::string& id, bool updateKeyBits,
+ const SSLExtraServerCertData* serverCertData) {
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ if (!TlsAgent::LoadCertificate(id, &cert, &priv)) {
+ return false;
+ }
+
+ if (updateKeyBits) {
+ ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get()));
+ EXPECT_NE(nullptr, pub.get());
+ if (!pub.get()) return false;
+ server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get());
+ }
+
+ SECStatus rv =
+ SSL_ConfigSecureServer(ssl_fd(), nullptr, nullptr, ssl_kea_null);
+ EXPECT_EQ(SECFailure, rv);
+ rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData,
+ serverCertData ? sizeof(*serverCertData) : 0);
+ return rv == SECSuccess;
+}
+
+bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
+ // Don't set up twice
+ if (ssl_fd_) return true;
+ NssManagePolicy policyManage(policy_, option_);
+
+ ScopedPRFileDesc dummy_fd(adapter_->CreateFD());
+ EXPECT_NE(nullptr, dummy_fd);
+ if (!dummy_fd) {
+ return false;
+ }
+ if (adapter_->variant() == ssl_variant_stream) {
+ ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get()));
+ } else {
+ ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get()));
+ }
+
+ EXPECT_NE(nullptr, ssl_fd_);
+ if (!ssl_fd_) {
+ return false;
+ }
+ dummy_fd.release(); // Now subsumed by ssl_fd_.
+
+ SECStatus rv;
+ if (!skip_version_checks_) {
+ rv = SSL_VersionRangeSet(ssl_fd(), &vrange_);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+ }
+
+ ScopedCERTCertList anchors(CERT_NewCertList());
+ rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get());
+ if (rv != SECSuccess) return false;
+
+ if (role_ == SERVER) {
+ EXPECT_TRUE(ConfigServerCert(name_, true));
+
+ rv = SSL_SNISocketConfigHook(ssl_fd(), SniHook, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ rv = SSL_SetMaxEarlyDataSize(ssl_fd(), 1024);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+ } else {
+ rv = SSL_SetURL(ssl_fd(), "server");
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+ }
+
+ rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ // All these tests depend on having this disabled to start with.
+ SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_FALSE);
+
+ return true;
+}
+
+bool TlsAgent::MaybeSetResumptionToken() {
+ if (!resumption_token_.empty()) {
+ LOG("setting external resumption token");
+ SECStatus rv = SSL_SetResumptionToken(ssl_fd(), resumption_token_.data(),
+ resumption_token_.size());
+
+ // rv is SECFailure with error set to SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR
+ // if the resumption token was bad (expired/malformed/etc.).
+ if (expect_psk_ == ssl_psk_resume) {
+ // Only in case we expect resumption this has to be successful. We might
+ // not expect resumption due to some reason but the token is totally fine.
+ EXPECT_EQ(SECSuccess, rv);
+ }
+ if (rv != SECSuccess) {
+ EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError());
+ resumption_token_.clear();
+ EXPECT_FALSE(expect_psk_ == ssl_psk_resume);
+ if (expect_psk_ == ssl_psk_resume) return false;
+ }
+ }
+
+ return true;
+}
+
+void TlsAgent::SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx) {
+ EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(ssl_fd(), ctx.get()));
+}
+
+// Defaults to a Sync callback returning success
+void TlsAgent::SetupClientAuth(ClientAuthCallbackType callbackType,
+ bool callbackSuccess) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ ASSERT_EQ(CLIENT, role_);
+
+ client_auth_callback_type_ = callbackType;
+ client_auth_callback_success_ = callbackSuccess;
+
+ if (callbackType == ClientAuthCallbackType::kNone && !callbackSuccess) {
+ // Don't set a callback for this case.
+ return;
+ }
+ EXPECT_EQ(SECSuccess,
+ SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook,
+ reinterpret_cast<void*>(this)));
+}
+
+void CheckCertReqAgainstDefaultCAs(const CERTDistNames* caNames) {
+ ScopedCERTDistNames expected(CERT_GetSSLCACerts(nullptr));
+
+ ASSERT_EQ(expected->nnames, caNames->nnames);
+
+ for (size_t i = 0; i < static_cast<size_t>(expected->nnames); ++i) {
+ EXPECT_EQ(SECEqual,
+ SECITEM_CompareItem(&(expected->names[i]), &(caNames->names[i])));
+ }
+}
+
+// Complete processing of Client Certificate Selection
+// A No-op if the agent is using synchronous client cert selection.
+// Otherwise, calls SSL_ClientCertCallbackComplete.
+// kAsyncDelay triggers a call to SSL_ForceHandshake prior to completion to
+// ensure that the socket is correctly blocked.
+void TlsAgent::ClientAuthCallbackComplete() {
+ ASSERT_EQ(CLIENT, role_);
+
+ if (client_auth_callback_type_ != ClientAuthCallbackType::kAsyncDelay &&
+ client_auth_callback_type_ != ClientAuthCallbackType::kAsyncImmediate) {
+ return;
+ }
+ client_auth_callback_fired_++;
+ EXPECT_TRUE(client_auth_callback_awaiting_);
+
+ std::cerr << "client: calling SSL_ClientCertCallbackComplete with status "
+ << (client_auth_callback_success_ ? "success" : "failed")
+ << std::endl;
+
+ client_auth_callback_awaiting_ = false;
+
+ if (client_auth_callback_type_ == ClientAuthCallbackType::kAsyncDelay) {
+ std::cerr
+ << "Running Handshake prior to running SSL_ClientCertCallbackComplete"
+ << std::endl;
+ SECStatus rv = SSL_ForceHandshake(ssl_fd());
+ EXPECT_EQ(rv, SECFailure);
+ EXPECT_EQ(PORT_GetError(), PR_WOULD_BLOCK_ERROR);
+ }
+
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ if (client_auth_callback_success_) {
+ ASSERT_TRUE(TlsAgent::LoadCertificate(name(), &cert, &priv));
+ EXPECT_EQ(SECSuccess,
+ SSL_ClientCertCallbackComplete(ssl_fd(), SECSuccess,
+ priv.release(), cert.release()));
+ } else {
+ EXPECT_EQ(SECSuccess, SSL_ClientCertCallbackComplete(ssl_fd(), SECFailure,
+ nullptr, nullptr));
+ }
+}
+
+SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd,
+ CERTDistNames* caNames,
+ CERTCertificate** clientCert,
+ SECKEYPrivateKey** clientKey) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(self);
+ EXPECT_EQ(CLIENT, agent->role_);
+ agent->client_auth_callback_fired_++;
+
+ switch (agent->client_auth_callback_type_) {
+ case ClientAuthCallbackType::kAsyncDelay:
+ case ClientAuthCallbackType::kAsyncImmediate:
+ std::cerr << "Waiting for complete call" << std::endl;
+ agent->client_auth_callback_awaiting_ = true;
+ return SECWouldBlock;
+ case ClientAuthCallbackType::kSync:
+ case ClientAuthCallbackType::kNone:
+ // Handle the sync case. None && Success is treated as Sync and Success.
+ if (!agent->client_auth_callback_success_) {
+ return SECFailure;
+ }
+ ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd()));
+ EXPECT_TRUE(peerCert) << "Client should be able to see the server cert";
+
+ // See bug 1573945
+ // CheckCertReqAgainstDefaultCAs(caNames);
+
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) {
+ return SECFailure;
+ }
+
+ *clientCert = cert.release();
+ *clientKey = priv.release();
+ return SECSuccess;
+ }
+ /* This is unreachable, but some old compilers can't tell that. */
+ PORT_Assert(0);
+ PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
+ return SECFailure;
+}
+
+// Increments by 1 for each callback
+bool TlsAgent::CheckClientAuthCallbacksCompleted(uint8_t expected) {
+ EXPECT_EQ(CLIENT, role_);
+ return expected == client_auth_callback_fired_;
+}
+
+bool TlsAgent::GetPeerChainLength(size_t* count) {
+ CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd());
+ if (!chain) return false;
+ *count = 0;
+
+ for (PRCList* cursor = PR_NEXT_LINK(&chain->list); cursor != &chain->list;
+ cursor = PR_NEXT_LINK(cursor)) {
+ CERTCertListNode* node = (CERTCertListNode*)cursor;
+ std::cerr << node->cert->subjectName << std::endl;
+ ++(*count);
+ }
+
+ CERT_DestroyCertList(chain);
+
+ return true;
+}
+
+void TlsAgent::CheckCipherSuite(uint16_t suite) {
+ EXPECT_EQ(csinfo_.cipherSuite, suite);
+}
+
+void TlsAgent::RequestClientAuth(bool requireAuth) {
+ ASSERT_EQ(SERVER, role_);
+
+ SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE);
+ SetOption(SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE);
+
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook(
+ ssl_fd(), &TlsAgent::ClientAuthenticated, this));
+ expect_client_auth_ = true;
+}
+
+void TlsAgent::StartConnect(PRFileDesc* model) {
+ EXPECT_TRUE(EnsureTlsSetup(model));
+
+ SECStatus rv;
+ rv = SSL_ResetHandshake(ssl_fd(), role_ == SERVER ? PR_TRUE : PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+ SetState(STATE_CONNECTING);
+}
+
+void TlsAgent::DisableAllCiphers() {
+ for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
+ SECStatus rv =
+ SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+}
+
+// Not actually all groups, just the onece that we are actually willing
+// to use.
+const std::vector<SSLNamedGroup> kAllDHEGroups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1, ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072,
+ ssl_grp_ffdhe_4096, ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192};
+
+const std::vector<SSLNamedGroup> kECDHEGroups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+
+const std::vector<SSLNamedGroup> kFFDHEGroups = {
+ ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096,
+ ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192};
+
+// Defined because the big DHE groups are ridiculously slow.
+const std::vector<SSLNamedGroup> kFasterDHEGroups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072};
+
+void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
+ SSLCipherSuiteInfo csinfo;
+
+ SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
+ sizeof(csinfo));
+ ASSERT_EQ(SECSuccess, rv);
+ EXPECT_EQ(sizeof(csinfo), csinfo.length);
+
+ if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) {
+ rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+ }
+}
+
+void TlsAgent::EnableGroupsByKeyExchange(SSLKEAType kea) {
+ switch (kea) {
+ case ssl_kea_dh:
+ ConfigNamedGroups(kFFDHEGroups);
+ break;
+ case ssl_kea_ecdh:
+ ConfigNamedGroups(kECDHEGroups);
+ break;
+ default:
+ break;
+ }
+}
+
+void TlsAgent::EnableGroupsByAuthType(SSLAuthType authType) {
+ if (authType == ssl_auth_ecdh_rsa || authType == ssl_auth_ecdh_ecdsa ||
+ authType == ssl_auth_ecdsa || authType == ssl_auth_tls13_any) {
+ ConfigNamedGroups(kECDHEGroups);
+ }
+}
+
+void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
+ SSLCipherSuiteInfo csinfo;
+
+ SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
+ sizeof(csinfo));
+ ASSERT_EQ(SECSuccess, rv);
+
+ if ((csinfo.authType == authType) ||
+ (csinfo.keaType == ssl_kea_tls13_any)) {
+ rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+ }
+}
+
+void TlsAgent::EnableSingleCipher(uint16_t cipher) {
+ DisableAllCiphers();
+ SECStatus rv = SSL_CipherPrefSet(ssl_fd(), cipher, PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ SECStatus rv = SSL_NamedGroupConfig(ssl_fd(), &groups[0], groups.size());
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::Set0RttEnabled(bool en) {
+ SetOption(SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
+}
+
+void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
+ vrange_.min = minver;
+ vrange_.max = maxver;
+
+ if (ssl_fd()) {
+ SECStatus rv = SSL_VersionRangeSet(ssl_fd(), &vrange_);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+}
+
+SECStatus ResumptionTokenCallback(PRFileDesc* fd,
+ const PRUint8* resumptionToken,
+ unsigned int len, void* ctx) {
+ EXPECT_NE(nullptr, resumptionToken);
+ if (!resumptionToken) {
+ return SECFailure;
+ }
+
+ std::vector<uint8_t> new_token(resumptionToken, resumptionToken + len);
+ reinterpret_cast<TlsAgent*>(ctx)->SetResumptionToken(new_token);
+ reinterpret_cast<TlsAgent*>(ctx)->SetResumptionCallbackCalled();
+ return SECSuccess;
+}
+
+void TlsAgent::SetResumptionTokenCallback() {
+ EXPECT_TRUE(EnsureTlsSetup());
+ SECStatus rv =
+ SSL_SetResumptionTokenCallback(ssl_fd(), ResumptionTokenCallback, this);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) {
+ *minver = vrange_.min;
+ *maxver = vrange_.max;
+}
+
+void TlsAgent::SetExpectedVersion(uint16_t ver) { expected_version_ = ver; }
+
+void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; }
+
+void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; }
+
+void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; }
+
+void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes,
+ size_t count) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ EXPECT_LE(count, SSL_SignatureMaxCount());
+ EXPECT_EQ(SECSuccess,
+ SSL_SignatureSchemePrefSet(ssl_fd(), schemes,
+ static_cast<unsigned int>(count)));
+ EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 0))
+ << "setting no schemes should fail and do nothing";
+
+ std::vector<SSLSignatureScheme> configuredSchemes(count);
+ unsigned int configuredCount;
+ EXPECT_EQ(SECFailure,
+ SSL_SignatureSchemePrefGet(ssl_fd(), nullptr, &configuredCount, 1))
+ << "get schemes, schemes is nullptr";
+ EXPECT_EQ(SECFailure,
+ SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0],
+ &configuredCount, 0))
+ << "get schemes, too little space";
+ EXPECT_EQ(SECFailure,
+ SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], nullptr,
+ configuredSchemes.size()))
+ << "get schemes, countOut is nullptr";
+
+ EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet(
+ ssl_fd(), &configuredSchemes[0], &configuredCount,
+ configuredSchemes.size()));
+ // SignatureSchemePrefSet drops unsupported algorithms silently, so the
+ // number that are configured might be fewer.
+ EXPECT_LE(configuredCount, count);
+ unsigned int i = 0;
+ for (unsigned int j = 0; j < count && i < configuredCount; ++j) {
+ if (i < configuredCount && schemes[j] == configuredSchemes[i]) {
+ ++i;
+ }
+ }
+ EXPECT_EQ(i, configuredCount) << "schemes in use were all set";
+}
+
+void TlsAgent::CheckKEA(SSLKEAType kea, SSLNamedGroup kea_group,
+ size_t kea_size) const {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+ EXPECT_EQ(kea, info_.keaType);
+ if (kea_size == 0) {
+ switch (kea_group) {
+ case ssl_grp_ec_curve25519:
+ kea_size = 255;
+ break;
+ case ssl_grp_ec_secp256r1:
+ kea_size = 256;
+ break;
+ case ssl_grp_ec_secp384r1:
+ kea_size = 384;
+ break;
+ case ssl_grp_ffdhe_2048:
+ kea_size = 2048;
+ break;
+ case ssl_grp_ffdhe_3072:
+ kea_size = 3072;
+ break;
+ case ssl_grp_ffdhe_custom:
+ break;
+ default:
+ if (kea == ssl_kea_rsa) {
+ kea_size = server_key_bits_;
+ } else {
+ EXPECT_TRUE(false) << "need to update group sizes";
+ }
+ }
+ }
+ if (kea_group != ssl_grp_ffdhe_custom) {
+ EXPECT_EQ(kea_size, info_.keaKeyBits);
+ EXPECT_EQ(kea_group, info_.keaGroup);
+ }
+}
+
+void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const {
+ if (kea_group != ssl_grp_ffdhe_custom) {
+ EXPECT_EQ(kea_group, info_.originalKeaGroup);
+ }
+}
+
+void TlsAgent::CheckAuthType(SSLAuthType auth,
+ SSLSignatureScheme sig_scheme) const {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+ EXPECT_EQ(auth, info_.authType);
+ if (auth != ssl_auth_psk) {
+ EXPECT_EQ(server_key_bits_, info_.authKeyBits);
+ }
+ if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) {
+ switch (auth) {
+ case ssl_auth_rsa_sign:
+ sig_scheme = ssl_sig_rsa_pkcs1_sha1md5;
+ break;
+ case ssl_auth_ecdsa:
+ sig_scheme = ssl_sig_ecdsa_sha1;
+ break;
+ default:
+ break;
+ }
+ }
+ EXPECT_EQ(sig_scheme, info_.signatureScheme);
+
+ if (info_.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ return;
+ }
+
+ // Check authAlgorithm, which is the old value for authType. This is a second
+ // switch statement because default label is different.
+ switch (auth) {
+ case ssl_auth_rsa_sign:
+ case ssl_auth_rsa_pss:
+ EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
+ << "authAlgorithm for RSA is always decrypt";
+ break;
+ case ssl_auth_ecdh_rsa:
+ EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
+ << "authAlgorithm for ECDH_RSA is RSA decrypt (i.e., wrong)";
+ break;
+ case ssl_auth_ecdh_ecdsa:
+ EXPECT_EQ(ssl_auth_ecdsa, csinfo_.authAlgorithm)
+ << "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)";
+ break;
+ default:
+ EXPECT_EQ(auth, csinfo_.authAlgorithm)
+ << "authAlgorithm is (usually) the same as authType";
+ break;
+ }
+}
+
+void TlsAgent::EnableFalseStart() {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ falsestart_enabled_ = true;
+ EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback(
+ ssl_fd(), CanFalseStartCallback, this));
+ SetOption(SSL_ENABLE_FALSE_START, PR_TRUE);
+}
+
+void TlsAgent::ExpectEch(bool expected) { expect_ech_ = expected; }
+
+void TlsAgent::ExpectPsk(SSLPskType psk) { expect_psk_ = psk; }
+
+void TlsAgent::ExpectResumption() { expect_psk_ = ssl_psk_resume; }
+
+void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len));
+}
+
+void TlsAgent::AddPsk(const ScopedPK11SymKey& psk, std::string label,
+ SSLHashType hash, uint16_t zeroRttSuite) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ EXPECT_EQ(SECSuccess, SSL_AddExternalPsk0Rtt(
+ ssl_fd(), psk.get(),
+ reinterpret_cast<const uint8_t*>(label.data()),
+ label.length(), hash, zeroRttSuite, 1000));
+}
+
+void TlsAgent::RemovePsk(std::string label) {
+ EXPECT_EQ(SECSuccess,
+ SSL_RemoveExternalPsk(
+ ssl_fd(), reinterpret_cast<const uint8_t*>(label.data()),
+ label.length()));
+}
+
+void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
+ const std::string& expected) const {
+ SSLNextProtoState alpn_state;
+ char chosen[10];
+ unsigned int chosen_len;
+ SECStatus rv = SSL_GetNextProto(ssl_fd(), &alpn_state,
+ reinterpret_cast<unsigned char*>(chosen),
+ &chosen_len, sizeof(chosen));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(expected_state, alpn_state);
+ if (alpn_state == SSL_NEXT_PROTO_NO_SUPPORT) {
+ EXPECT_EQ("", expected);
+ } else {
+ EXPECT_NE("", expected);
+ EXPECT_EQ(expected, std::string(chosen, chosen_len));
+ }
+}
+
+void TlsAgent::CheckEpochs(uint16_t expected_read,
+ uint16_t expected_write) const {
+ uint16_t read_epoch = 0;
+ uint16_t write_epoch = 0;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetCurrentEpoch(ssl_fd(), &read_epoch, &write_epoch));
+ EXPECT_EQ(expected_read, read_epoch) << role_str() << " read epoch";
+ EXPECT_EQ(expected_write, write_epoch) << role_str() << " write epoch";
+}
+
+void TlsAgent::EnableSrtp() {
+ EXPECT_TRUE(EnsureTlsSetup());
+ const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80,
+ SRTP_AES128_CM_HMAC_SHA1_32};
+ EXPECT_EQ(SECSuccess,
+ SSL_SetSRTPCiphers(ssl_fd(), ciphers, PR_ARRAY_SIZE(ciphers)));
+}
+
+void TlsAgent::CheckSrtp() const {
+ uint16_t actual;
+ EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd(), &actual));
+ EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
+}
+
+void TlsAgent::CheckErrorCode(int32_t expected) const {
+ EXPECT_EQ(STATE_ERROR, state_);
+ EXPECT_EQ(expected, error_code_)
+ << "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
+ << PORT_ErrorToName(expected) << std::endl;
+}
+
+static uint8_t GetExpectedAlertLevel(uint8_t alert) {
+ if (alert == kTlsAlertCloseNotify) {
+ return kTlsAlertWarning;
+ }
+ return kTlsAlertFatal;
+}
+
+void TlsAgent::ExpectReceiveAlert(uint8_t alert, uint8_t level) {
+ expected_received_alert_ = alert;
+ if (level == 0) {
+ expected_received_alert_level_ = GetExpectedAlertLevel(alert);
+ } else {
+ expected_received_alert_level_ = level;
+ }
+}
+
+void TlsAgent::ExpectSendAlert(uint8_t alert, uint8_t level) {
+ expected_sent_alert_ = alert;
+ if (level == 0) {
+ expected_sent_alert_level_ = GetExpectedAlertLevel(alert);
+ } else {
+ expected_sent_alert_level_ = level;
+ }
+}
+
+void TlsAgent::CheckAlert(bool sent, const SSLAlert* alert) {
+ LOG(((alert->level == kTlsAlertWarning) ? "Warning" : "Fatal")
+ << " alert " << (sent ? "sent" : "received") << ": "
+ << static_cast<int>(alert->description));
+
+ auto& expected = sent ? expected_sent_alert_ : expected_received_alert_;
+ auto& expected_level =
+ sent ? expected_sent_alert_level_ : expected_received_alert_level_;
+ /* Silently pass close_notify in case the test has already ended. */
+ if (expected == kTlsAlertCloseNotify && expected_level == kTlsAlertWarning &&
+ alert->description == expected && alert->level == expected_level) {
+ return;
+ }
+
+ EXPECT_EQ(expected, alert->description);
+ EXPECT_EQ(expected_level, alert->level);
+ expected = kTlsAlertCloseNotify;
+ expected_level = kTlsAlertWarning;
+}
+
+void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const {
+ ASSERT_EQ(0, error_code_);
+ WAIT_(error_code_ != 0, delay);
+ EXPECT_EQ(expected, error_code_)
+ << "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
+ << PORT_ErrorToName(expected) << std::endl;
+}
+
+void TlsAgent::CheckPreliminaryInfo() {
+ SSLPreliminaryChannelInfo preinfo;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetPreliminaryChannelInfo(ssl_fd(), &preinfo, sizeof(preinfo)));
+ EXPECT_EQ(sizeof(preinfo), preinfo.length);
+ EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version);
+
+ // A version of 0 is invalid and indicates no expectation. This value is
+ // initialized to 0 so that tests that don't explicitly set an expected
+ // version can negotiate a version.
+ if (!expected_version_) {
+ expected_version_ = preinfo.protocolVersion;
+ }
+ EXPECT_EQ(expected_version_, preinfo.protocolVersion);
+
+ // As with the version; 0 is the null cipher suite (and also invalid).
+ if (!expected_cipher_suite_) {
+ expected_cipher_suite_ = preinfo.cipherSuite;
+ }
+ EXPECT_EQ(expected_cipher_suite_, preinfo.cipherSuite);
+}
+
+// Check that all the expected callbacks have been called.
+void TlsAgent::CheckCallbacks() const {
+ // If false start happens, the handshake is reported as being complete at the
+ // point that false start happens.
+ if (expect_psk_ == ssl_psk_resume || !falsestart_enabled_) {
+ EXPECT_TRUE(handshake_callback_called_);
+ }
+
+ // These callbacks shouldn't fire if we are resuming, except on TLS 1.3.
+ if (role_ == SERVER) {
+ PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn);
+ EXPECT_EQ(((expect_psk_ != ssl_psk_resume && have_sni) ||
+ expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3),
+ sni_hook_called_);
+ } else {
+ EXPECT_EQ(expect_psk_ == ssl_psk_none, auth_certificate_hook_called_);
+ // Note that this isn't unconditionally called, even with false start on.
+ // But the callback is only skipped if a cipher that is ridiculously weak
+ // (80 bits) is chosen. Don't test that: plan to remove bad ciphers.
+ EXPECT_EQ(falsestart_enabled_ && expect_psk_ != ssl_psk_resume,
+ can_falsestart_hook_called_);
+ }
+}
+
+void TlsAgent::ResetPreliminaryInfo() {
+ expected_version_ = 0;
+ expected_cipher_suite_ = 0;
+}
+
+void TlsAgent::UpdatePreliminaryChannelInfo() {
+ SECStatus rv =
+ SSL_GetPreliminaryChannelInfo(ssl_fd(), &pre_info_, sizeof(pre_info_));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(sizeof(pre_info_), pre_info_.length);
+}
+
+void TlsAgent::ValidateCipherSpecs() {
+ PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd());
+ // We use one ciphersuite in each direction.
+ PRInt32 expected = 2;
+ if (variant_ == ssl_variant_datagram) {
+ // For DTLS 1.3, the client retains the cipher spec for early data and the
+ // handshake so that it can retransmit EndOfEarlyData and its final flight.
+ // It also retains the handshake read cipher spec so that it can read ACKs
+ // from the server. The server retains the handshake read cipher spec so it
+ // can read the client's retransmitted Finished.
+ if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ if (role_ == CLIENT) {
+ expected = info_.earlyDataAccepted ? 5 : 4;
+ } else {
+ expected = 3;
+ }
+ } else {
+ // For DTLS 1.1 and 1.2, the last endpoint to send maintains a cipher spec
+ // until the holddown timer runs down.
+ if (expect_psk_ == ssl_psk_resume) {
+ if (role_ == CLIENT) {
+ expected = 3;
+ }
+ } else {
+ if (role_ == SERVER) {
+ expected = 3;
+ }
+ }
+ }
+ }
+ // This function will be run before the handshake completes if false start is
+ // enabled. In that case, the client will still be reading cleartext, but
+ // will have a spec prepared for reading ciphertext. With DTLS, the client
+ // will also have a spec retained for retransmission of handshake messages.
+ if (role_ == CLIENT && falsestart_enabled_ && !handshake_callback_called_) {
+ EXPECT_GT(SSL_LIBRARY_VERSION_TLS_1_3, expected_version_);
+ expected = (variant_ == ssl_variant_datagram) ? 4 : 3;
+ }
+ EXPECT_EQ(expected, cipherSpecs);
+ if (expected != cipherSpecs) {
+ SSLInt_PrintCipherSpecs(role_str().c_str(), ssl_fd());
+ }
+}
+
+void TlsAgent::Connected() {
+ if (state_ == STATE_CONNECTED) {
+ return;
+ }
+
+ LOG("Handshake success");
+ CheckPreliminaryInfo();
+ CheckCallbacks();
+
+ SECStatus rv = SSL_GetChannelInfo(ssl_fd(), &info_, sizeof(info_));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(sizeof(info_), info_.length);
+
+ EXPECT_EQ(expect_psk_ == ssl_psk_resume, info_.resumed == PR_TRUE);
+ EXPECT_EQ(expect_psk_, info_.pskType);
+ EXPECT_EQ(expect_ech_, info_.echAccepted);
+
+ // Preliminary values are exposed through callbacks during the handshake.
+ // If either expected values were set or the callbacks were called, check
+ // that the final values are correct.
+ UpdatePreliminaryChannelInfo();
+ EXPECT_EQ(expected_version_, info_.protocolVersion);
+ EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite);
+
+ rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(sizeof(csinfo_), csinfo_.length);
+
+ ValidateCipherSpecs();
+
+ SetState(STATE_CONNECTED);
+}
+
+void TlsAgent::CheckClientAuthCompleted(uint8_t handshakes) {
+ EXPECT_FALSE(client_auth_callback_awaiting_);
+ switch (client_auth_callback_type_) {
+ case ClientAuthCallbackType::kNone:
+ if (!client_auth_callback_success_) {
+ EXPECT_TRUE(CheckClientAuthCallbacksCompleted(0));
+ break;
+ }
+ case ClientAuthCallbackType::kSync:
+ EXPECT_TRUE(CheckClientAuthCallbacksCompleted(handshakes));
+ break;
+ case ClientAuthCallbackType::kAsyncDelay:
+ case ClientAuthCallbackType::kAsyncImmediate:
+ EXPECT_TRUE(CheckClientAuthCallbacksCompleted(2 * handshakes));
+ break;
+ }
+}
+
+void TlsAgent::EnableExtendedMasterSecret() {
+ SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
+}
+
+void TlsAgent::CheckExtendedMasterSecret(bool expected) {
+ if (version() >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ expected = PR_TRUE;
+ }
+ ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE)
+ << "unexpected extended master secret state for " << name_;
+}
+
+void TlsAgent::CheckEarlyDataAccepted(bool expected) {
+ if (version() < SSL_LIBRARY_VERSION_TLS_1_3) {
+ expected = false;
+ }
+ ASSERT_EQ(expected, info_.earlyDataAccepted != PR_FALSE)
+ << "unexpected early data state for " << name_;
+}
+
+void TlsAgent::CheckSecretsDestroyed() {
+ ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd()));
+}
+
+void TlsAgent::SetDowngradeCheckVersion(uint16_t ver) {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), ver);
+ ASSERT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::Handshake() {
+ LOGV("Handshake");
+ SECStatus rv = SSL_ForceHandshake(ssl_fd());
+ if (client_auth_callback_awaiting_) {
+ ClientAuthCallbackComplete();
+ rv = SSL_ForceHandshake(ssl_fd());
+ }
+ if (rv == SECSuccess) {
+ Connected();
+ Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+ &TlsAgent::ReadableCallback);
+ return;
+ }
+
+ int32_t err = PR_GetError();
+ if (err == PR_WOULD_BLOCK_ERROR) {
+ LOGV("Would have blocked");
+ if (variant_ == ssl_variant_datagram) {
+ if (timer_handle_) {
+ timer_handle_->Cancel();
+ timer_handle_ = nullptr;
+ }
+
+ PRIntervalTime timeout;
+ rv = DTLS_GetHandshakeTimeout(ssl_fd(), &timeout);
+ if (rv == SECSuccess) {
+ Poller::Instance()->SetTimer(
+ timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_);
+ }
+ }
+ Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+ &TlsAgent::ReadableCallback);
+ return;
+ }
+
+ LOG("Handshake failed with error " << PORT_ErrorToName(err) << ": "
+ << PORT_ErrorToString(err));
+ error_code_ = err;
+ SetState(STATE_ERROR);
+}
+
+void TlsAgent::PrepareForRenegotiate() {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+
+ SetState(STATE_CONNECTING);
+}
+
+void TlsAgent::StartRenegotiate() {
+ PrepareForRenegotiate();
+
+ SECStatus rv = SSL_ReHandshake(ssl_fd(), PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SendDirect(const DataBuffer& buf) {
+ LOG("Send Direct " << buf);
+ auto peer = adapter_->peer().lock();
+ if (peer) {
+ peer->PacketReceived(buf);
+ } else {
+ LOG("Send Direct peer absent");
+ }
+}
+
+void TlsAgent::SendRecordDirect(const TlsRecord& record) {
+ DataBuffer buf;
+
+ auto rv = record.header.Write(&buf, 0, record.buffer);
+ EXPECT_EQ(record.header.header_length() + record.buffer.len(), rv);
+ SendDirect(buf);
+}
+
+static bool ErrorIsFatal(PRErrorCode code) {
+ return code != PR_WOULD_BLOCK_ERROR && code != SSL_ERROR_RX_SHORT_DTLS_READ;
+}
+
+void TlsAgent::SendData(size_t bytes, size_t blocksize) {
+ uint8_t block[16385]; // One larger than the maximum record size.
+
+ ASSERT_LE(blocksize, sizeof(block));
+
+ while (bytes) {
+ size_t tosend = std::min(blocksize, bytes);
+
+ for (size_t i = 0; i < tosend; ++i) {
+ block[i] = 0xff & send_ctr_;
+ ++send_ctr_;
+ }
+
+ SendBuffer(DataBuffer(block, tosend));
+ bytes -= tosend;
+ }
+}
+
+void TlsAgent::SendBuffer(const DataBuffer& buf) {
+ LOGV("Writing " << buf.len() << " bytes");
+ int32_t rv = PR_Write(ssl_fd(), buf.data(), buf.len());
+ if (expect_readwrite_error_) {
+ EXPECT_GT(0, rv);
+ EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_);
+ error_code_ = PR_GetError();
+ expect_readwrite_error_ = false;
+ } else {
+ ASSERT_EQ(buf.len(), static_cast<size_t>(rv));
+ }
+}
+
+bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
+ uint64_t seq, uint8_t ct,
+ const DataBuffer& buf) {
+ // Ensure that we are doing TLS 1.3.
+ EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3);
+ if (variant_ != ssl_variant_datagram) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ LOGV("Encrypting " << buf.len() << " bytes");
+ uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
+ kCtDtlsCiphertextLengthPresent;
+ TlsRecordHeader header(variant_, expected_version_, dtls13_ct, seq);
+ TlsRecordHeader out_header(header);
+ DataBuffer padded = buf;
+ padded.Write(padded.len(), ct, 1);
+ DataBuffer ciphertext;
+ if (!spec->Protect(header, padded, &ciphertext, &out_header)) {
+ return false;
+ }
+
+ DataBuffer record;
+ auto rv = out_header.Write(&record, 0, ciphertext);
+ EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv);
+ SendDirect(record);
+ return true;
+}
+
+void TlsAgent::ReadBytes(size_t amount) {
+ uint8_t block[16384];
+
+ size_t remaining = amount;
+ while (remaining > 0) {
+ int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block)));
+ LOGV("ReadBytes " << rv);
+
+ if (rv > 0) {
+ size_t count = static_cast<size_t>(rv);
+ for (size_t i = 0; i < count; ++i) {
+ ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
+ recv_ctr_++;
+ }
+ remaining -= rv;
+ } else {
+ PRErrorCode err = 0;
+ if (rv < 0) {
+ err = PR_GetError();
+ LOG("Read error " << PORT_ErrorToName(err) << ": "
+ << PORT_ErrorToString(err));
+ if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) {
+ error_code_ = err;
+ expect_readwrite_error_ = false;
+ }
+ }
+ if (err != 0 && ErrorIsFatal(err)) {
+ // If we hit a fatal error, we're done.
+ remaining = 0;
+ }
+ break;
+ }
+ }
+
+ // If closed, then don't bother waiting around.
+ if (remaining) {
+ LOGV("Re-arming");
+ Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+ &TlsAgent::ReadableCallback);
+ }
+}
+
+void TlsAgent::ResetSentBytes(size_t bytes) { send_ctr_ = bytes; }
+
+void TlsAgent::SetOption(int32_t option, int value) {
+ ASSERT_TRUE(EnsureTlsSetup());
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), option, value));
+}
+
+void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
+ SetOption(SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE);
+ SetOption(SSL_ENABLE_SESSION_TICKETS,
+ mode & RESUME_TICKET ? PR_TRUE : PR_FALSE);
+}
+
+void TlsAgent::EnableECDHEServerKeyReuse() {
+ ASSERT_EQ(TlsAgent::SERVER, role_);
+ SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_TRUE);
+}
+
+static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"};
+::testing::internal::ParamGenerator<std::string>
+ TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr);
+
+void TlsAgentTestBase::SetUp() {
+ SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
+}
+
+void TlsAgentTestBase::TearDown() {
+ agent_ = nullptr;
+ SSL_ClearSessionCache();
+ SSL_ShutdownServerSessionIDCache();
+}
+
+void TlsAgentTestBase::Reset(const std::string& server_name) {
+ agent_.reset(
+ new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name,
+ role_, variant_));
+ if (version_) {
+ agent_->SetVersionRange(version_, version_);
+ }
+ agent_->adapter()->SetPeer(sink_adapter_);
+ agent_->StartConnect();
+}
+
+void TlsAgentTestBase::EnsureInit() {
+ if (!agent_) {
+ Reset();
+ }
+ const std::vector<SSLNamedGroup> groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ffdhe_2048};
+ agent_->ConfigNamedGroups(groups);
+}
+
+void TlsAgentTestBase::ExpectAlert(uint8_t alert) {
+ EnsureInit();
+ agent_->ExpectSendAlert(alert);
+}
+
+void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
+ TlsAgent::State expected_state,
+ int32_t error_code) {
+ std::cerr << "Process message: " << buffer << std::endl;
+ EnsureInit();
+ agent_->adapter()->PacketReceived(buffer);
+ agent_->Handshake();
+
+ ASSERT_EQ(expected_state, agent_->state());
+
+ if (expected_state == TlsAgent::STATE_ERROR) {
+ ASSERT_EQ(error_code, agent_->error_code());
+ }
+}
+
+void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type,
+ uint16_t version, const uint8_t* buf,
+ size_t len, DataBuffer* out,
+ uint64_t sequence_number) {
+ // Fixup the content type for DTLSCiphertext
+ if (variant == ssl_variant_datagram &&
+ version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ type == ssl_ct_application_data) {
+ type = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
+ kCtDtlsCiphertextLengthPresent;
+ }
+
+ size_t index = 0;
+ if (variant == ssl_variant_stream) {
+ index = out->Write(index, type, 1);
+ index = out->Write(index, version, 2);
+ } else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ (type & kCtDtlsCiphertextMask) == kCtDtlsCiphertext) {
+ uint32_t epoch = (sequence_number >> 48) & 0x3;
+ index = out->Write(index, type | epoch, 1);
+ uint32_t seqno = sequence_number & ((1ULL << 16) - 1);
+ index = out->Write(index, seqno, 2);
+ } else {
+ index = out->Write(index, type, 1);
+ index = out->Write(index, TlsVersionToDtlsVersion(version), 2);
+ index = out->Write(index, sequence_number >> 32, 4);
+ index = out->Write(index, sequence_number & PR_UINT32_MAX, 4);
+ }
+ index = out->Write(index, len, 2);
+ out->Write(index, buf, len);
+}
+
+void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version,
+ const uint8_t* buf, size_t len,
+ DataBuffer* out, uint64_t seq_num) const {
+ MakeRecord(variant_, type, version, buf, len, out, seq_num);
+}
+
+void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type,
+ const uint8_t* data, size_t hs_len,
+ DataBuffer* out,
+ uint64_t seq_num) const {
+ return MakeHandshakeMessageFragment(hs_type, data, hs_len, out, seq_num, 0,
+ 0);
+}
+
+void TlsAgentTestBase::MakeHandshakeMessageFragment(
+ uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out,
+ uint64_t seq_num, uint32_t fragment_offset,
+ uint32_t fragment_length) const {
+ size_t index = 0;
+ if (!fragment_length) fragment_length = hs_len;
+ index = out->Write(index, hs_type, 1); // Handshake record type.
+ index = out->Write(index, hs_len, 3); // Handshake length
+ if (variant_ == ssl_variant_datagram) {
+ index = out->Write(index, seq_num, 2);
+ index = out->Write(index, fragment_offset, 3);
+ index = out->Write(index, fragment_length, 3);
+ }
+ if (data) {
+ index = out->Write(index, data, fragment_length);
+ } else {
+ for (size_t i = 0; i < fragment_length; ++i) {
+ index = out->Write(index, 1, 1);
+ }
+ }
+}
+
+void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type,
+ size_t hs_len,
+ DataBuffer* out) {
+ size_t index = 0;
+ index = out->Write(index, ssl_ct_handshake, 1); // Content Type
+ index = out->Write(index, 3, 1); // Version high
+ index = out->Write(index, 1, 1); // Version low
+ index = out->Write(index, 4 + hs_len, 2); // Length
+
+ index = out->Write(index, hs_type, 1); // Handshake record type.
+ index = out->Write(index, hs_len, 3); // Handshake length
+ for (size_t i = 0; i < hs_len; ++i) {
+ index = out->Write(index, 1, 1);
+ }
+}
+
+DataBuffer TlsAgentTestBase::MakeCannedTls13ServerHello() {
+ DataBuffer sh(kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello));
+ if (variant_ == ssl_variant_datagram) {
+ sh.Write(0, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 2);
+ // The version should be at the end.
+ uint32_t v;
+ EXPECT_TRUE(sh.Read(sh.len() - 2, 2, &v));
+ EXPECT_EQ(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_3), v);
+ sh.Write(sh.len() - 2, 0x7f00 | DTLS_1_3_DRAFT_VERSION, 2);
+ }
+ return sh;
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h
new file mode 100644
index 0000000000..35375e0c11
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_agent.h
@@ -0,0 +1,588 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_agent_h_
+#define tls_agent_h_
+
+#include "prio.h"
+#include "ssl.h"
+#include "sslproto.h"
+
+#include <functional>
+#include <iostream>
+
+#include "nss_policy.h"
+#include "test_io.h"
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+#include "nss_scoped_ptrs.h"
+#include "scoped_ptrs_ssl.h"
+
+extern bool g_ssl_gtest_verbose;
+
+namespace nss_test {
+
+#define LOG(msg) std::cerr << role_str() << ": " << msg << std::endl
+#define LOGV(msg) \
+ do { \
+ if (g_ssl_gtest_verbose) LOG(msg); \
+ } while (false)
+
+enum SessionResumptionMode {
+ RESUME_NONE = 0,
+ RESUME_SESSIONID = 1,
+ RESUME_TICKET = 2,
+ RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
+};
+
+enum class ClientAuthCallbackType {
+ kAsyncImmediate,
+ kAsyncDelay,
+ kSync,
+ kNone,
+};
+
+class PacketFilter;
+class TlsAgent;
+class TlsCipherSpec;
+struct TlsRecord;
+
+const extern std::vector<SSLNamedGroup> kAllDHEGroups;
+const extern std::vector<SSLNamedGroup> kECDHEGroups;
+const extern std::vector<SSLNamedGroup> kFFDHEGroups;
+const extern std::vector<SSLNamedGroup> kFasterDHEGroups;
+
+// These functions are called from callbacks. They use bare pointers because
+// TlsAgent sets up the callback and it doesn't know who owns it.
+typedef std::function<SECStatus(TlsAgent* agent, bool checksig, bool isServer)>
+ AuthCertificateCallbackFunction;
+
+typedef std::function<void(TlsAgent* agent)> HandshakeCallbackFunction;
+
+typedef std::function<int32_t(TlsAgent* agent, const SECItem* srvNameArr,
+ PRUint32 srvNameArrSize)>
+ SniCallbackFunction;
+
+class TlsAgent : public PollTarget {
+ public:
+ enum Role { CLIENT, SERVER };
+ enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED, STATE_ERROR };
+
+ static const std::string kClient; // the client key is sign only
+ static const std::string kRsa2048; // bigger sign and encrypt for either
+ static const std::string kRsa8192; // biggest sign and encrypt for either
+ static const std::string kServerRsa; // both sign and encrypt
+ static const std::string kServerRsaSign;
+ static const std::string kServerRsaPss;
+ static const std::string kServerRsaDecrypt;
+ static const std::string kServerEcdsa256;
+ static const std::string kServerEcdsa384;
+ static const std::string kServerEcdsa521;
+ static const std::string kServerEcdhEcdsa;
+ static const std::string kServerEcdhRsa;
+ static const std::string kServerDsa;
+ static const std::string kDelegatorEcdsa256; // draft-ietf-tls-subcerts
+ static const std::string kDelegatorRsae2048; // draft-ietf-tls-subcerts
+ static const std::string kDelegatorRsaPss2048; // draft-ietf-tls-subcerts
+
+ TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant);
+ virtual ~TlsAgent();
+
+ void SetPeer(std::shared_ptr<TlsAgent>& peer) {
+ adapter_->SetPeer(peer->adapter_);
+ }
+
+ void SetFilter(std::shared_ptr<PacketFilter> filter) {
+ adapter_->SetPacketFilter(filter);
+ }
+ void ClearFilter() { adapter_->SetPacketFilter(nullptr); }
+
+ void StartConnect(PRFileDesc* model = nullptr);
+ void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
+ size_t kea_size = 0) const;
+ void CheckOriginalKEA(SSLNamedGroup kea_group) const;
+ void CheckAuthType(SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme) const;
+
+ void DisableAllCiphers();
+ void EnableCiphersByAuthType(SSLAuthType authType);
+ void EnableCiphersByKeyExchange(SSLKEAType kea);
+ void EnableGroupsByKeyExchange(SSLKEAType kea);
+ void EnableGroupsByAuthType(SSLAuthType authType);
+ void EnableSingleCipher(uint16_t cipher);
+
+ void Handshake();
+ // Marks the internal state as CONNECTING in anticipation of renegotiation.
+ void PrepareForRenegotiate();
+ // Prepares for renegotiation, then actually triggers it.
+ void StartRenegotiate();
+ void SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx);
+
+ static bool LoadCertificate(const std::string& name,
+ ScopedCERTCertificate* cert,
+ ScopedSECKEYPrivateKey* priv);
+ static bool LoadKeyPairFromCert(const std::string& name,
+ ScopedSECKEYPublicKey* pub,
+ ScopedSECKEYPrivateKey* priv);
+
+ // Delegated credentials.
+ //
+ // Generate a delegated credential and sign it using the certificate
+ // associated with |name|.
+ static void DelegateCredential(const std::string& name,
+ const ScopedSECKEYPublicKey& dcPub,
+ SSLSignatureScheme dcCertVerifyAlg,
+ PRUint32 dcValidFor, PRTime now, SECItem* dc);
+ // Indicate support for the delegated credentials extension.
+ void EnableDelegatedCredentials();
+ // Generate and configure a delegated credential to use in the handshake with
+ // clients that support this extension..
+ void AddDelegatedCredential(const std::string& dc_name,
+ SSLSignatureScheme dcCertVerifyAlg,
+ PRUint32 dcValidFor, PRTime now);
+ void UpdatePreliminaryChannelInfo();
+
+ bool ConfigServerCert(const std::string& name, bool updateKeyBits = false,
+ const SSLExtraServerCertData* serverCertData = nullptr);
+ bool ConfigServerCertWithChain(const std::string& name);
+ bool EnsureTlsSetup(PRFileDesc* modelSocket = nullptr);
+
+ void SetupClientAuth(
+ ClientAuthCallbackType callbackType = ClientAuthCallbackType::kSync,
+ bool callbackSuccess = true);
+ void RequestClientAuth(bool requireAuth);
+ void ClientAuthCallbackComplete();
+ bool CheckClientAuthCallbacksCompleted(uint8_t expected);
+ void CheckClientAuthCompleted(uint8_t handshakes = 1);
+ void SetOption(int32_t option, int value);
+ void ConfigureSessionCache(SessionResumptionMode mode);
+ void Set0RttEnabled(bool en);
+ void SetFallbackSCSVEnabled(bool en);
+ void SetVersionRange(uint16_t minver, uint16_t maxver);
+ void GetVersionRange(uint16_t* minver, uint16_t* maxver);
+ void CheckPreliminaryInfo();
+ void ResetPreliminaryInfo();
+ void SetExpectedVersion(uint16_t version);
+ void SetServerKeyBits(uint16_t bits);
+ void ExpectReadWriteError();
+ void EnableFalseStart();
+ void ExpectEch(bool expected = true);
+ bool GetEchExpected() const { return expect_ech_; }
+ void ExpectPsk(SSLPskType psk = ssl_psk_external);
+ void ExpectResumption();
+ void SkipVersionChecks();
+ void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count);
+ void EnableAlpn(const uint8_t* val, size_t len);
+ void CheckAlpn(SSLNextProtoState expected_state,
+ const std::string& expected = "") const;
+ void EnableSrtp();
+ void CheckSrtp() const;
+ void CheckEpochs(uint16_t expected_read, uint16_t expected_write) const;
+ void CheckErrorCode(int32_t expected) const;
+ void WaitForErrorCode(int32_t expected, uint32_t delay) const;
+ // Send data on the socket, encrypting it.
+ void SendData(size_t bytes, size_t blocksize = 1024);
+ void SendBuffer(const DataBuffer& buf);
+ bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
+ uint64_t seq, uint8_t ct, const DataBuffer& buf);
+ // Send data directly to the underlying socket, skipping the TLS layer.
+ void SendDirect(const DataBuffer& buf);
+ void SendRecordDirect(const TlsRecord& record);
+ void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash,
+ uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL);
+ void RemovePsk(std::string label);
+ void ReadBytes(size_t max = 16384U);
+ void ResetSentBytes(size_t bytes = 0); // Hack to test drops.
+ void EnableExtendedMasterSecret();
+ void CheckExtendedMasterSecret(bool expected);
+ void CheckEarlyDataAccepted(bool expected);
+ void CheckEchAccepted(bool expected);
+ void SetDowngradeCheckVersion(uint16_t version);
+ void CheckSecretsDestroyed();
+ void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
+ void EnableECDHEServerKeyReuse();
+ bool GetPeerChainLength(size_t* count);
+ void CheckCipherSuite(uint16_t cipher_suite);
+ void SetResumptionTokenCallback();
+ bool MaybeSetResumptionToken();
+ void SetResumptionToken(const std::vector<uint8_t>& resumption_token) {
+ resumption_token_ = resumption_token;
+ }
+ const std::vector<uint8_t>& GetResumptionToken() const {
+ return resumption_token_;
+ }
+ void GetTokenInfo(ScopedSSLResumptionTokenInfo& token) {
+ SECStatus rv = SSL_GetResumptionTokenInfo(
+ resumption_token_.data(), resumption_token_.size(), token.get(),
+ sizeof(SSLResumptionTokenInfo));
+ ASSERT_EQ(SECSuccess, rv);
+ }
+ void SetResumptionCallbackCalled() { resumption_callback_called_ = true; }
+ bool resumption_callback_called() const {
+ return resumption_callback_called_;
+ }
+
+ const std::string& name() const { return name_; }
+
+ Role role() const { return role_; }
+ std::string role_str() const { return role_ == SERVER ? "server" : "client"; }
+
+ SSLProtocolVariant variant() const { return variant_; }
+
+ State state() const { return state_; }
+
+ const CERTCertificate* peer_cert() const {
+ return SSL_PeerCertificate(ssl_fd_.get());
+ }
+
+ const char* state_str() const { return state_str(state()); }
+
+ static const char* state_str(State state) { return states[state]; }
+
+ NssManagedFileDesc ssl_fd() const {
+ return NssManagedFileDesc(ssl_fd_.get(), policy_, option_);
+ }
+ std::shared_ptr<DummyPrSocket>& adapter() { return adapter_; }
+
+ const SSLChannelInfo& info() const {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+ return info_;
+ }
+
+ const SSLPreliminaryChannelInfo& pre_info() const { return pre_info_; }
+
+ bool is_compressed() const {
+ return info().compressionMethod != ssl_compression_null;
+ }
+ uint16_t server_key_bits() const { return server_key_bits_; }
+ uint16_t min_version() const { return vrange_.min; }
+ uint16_t max_version() const { return vrange_.max; }
+ uint16_t version() const { return info().protocolVersion; }
+
+ bool cipher_suite(uint16_t* suite) const {
+ if (state_ != STATE_CONNECTED) return false;
+
+ *suite = info_.cipherSuite;
+ return true;
+ }
+
+ void expected_cipher_suite(uint16_t suite) { expected_cipher_suite_ = suite; }
+
+ std::string cipher_suite_name() const {
+ if (state_ != STATE_CONNECTED) return "UNKNOWN";
+
+ return csinfo_.cipherSuiteName;
+ }
+
+ std::vector<uint8_t> session_id() const {
+ return std::vector<uint8_t>(info_.sessionID,
+ info_.sessionID + info_.sessionIDLength);
+ }
+
+ bool auth_type(SSLAuthType* a) const {
+ if (state_ != STATE_CONNECTED) return false;
+
+ *a = info_.authType;
+ return true;
+ }
+
+ bool kea_type(SSLKEAType* k) const {
+ if (state_ != STATE_CONNECTED) return false;
+
+ *k = info_.keaType;
+ return true;
+ }
+
+ size_t received_bytes() const { return recv_ctr_; }
+ PRErrorCode error_code() const { return error_code_; }
+
+ bool can_falsestart_hook_called() const {
+ return can_falsestart_hook_called_;
+ }
+
+ void SetHandshakeCallback(HandshakeCallbackFunction handshake_callback) {
+ handshake_callback_ = handshake_callback;
+ }
+
+ void SetAuthCertificateCallback(
+ AuthCertificateCallbackFunction auth_certificate_callback) {
+ auth_certificate_callback_ = auth_certificate_callback;
+ }
+
+ void SetSniCallback(SniCallbackFunction sni_callback) {
+ sni_callback_ = sni_callback;
+ }
+
+ void ExpectReceiveAlert(uint8_t alert, uint8_t level = 0);
+ void ExpectSendAlert(uint8_t alert, uint8_t level = 0);
+
+ std::string alpn_value_to_use_ = "";
+ // set the given policy before this agent runs
+ void SetPolicy(SECOidTag oid, PRUint32 set, PRUint32 clear) {
+ policy_ = NssPolicy(oid, set, clear);
+ }
+ void SetNssOption(PRInt32 id, PRInt32 value) {
+ option_ = NssOption(id, value);
+ }
+
+ private:
+ const static char* states[];
+
+ void SetState(State state);
+ void ValidateCipherSpecs();
+
+ // Dummy auth certificate hook.
+ static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
+ PRBool checksig, PRBool isServer) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ agent->CheckPreliminaryInfo();
+ agent->auth_certificate_hook_called_ = true;
+ if (agent->auth_certificate_callback_) {
+ return agent->auth_certificate_callback_(agent, checksig ? true : false,
+ isServer ? true : false);
+ }
+ return SECSuccess;
+ }
+
+ // Client auth certificate hook.
+ static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd,
+ PRBool checksig, PRBool isServer) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ EXPECT_TRUE(agent->expect_client_auth_);
+ EXPECT_EQ(PR_TRUE, isServer);
+ if (agent->auth_certificate_callback_) {
+ return agent->auth_certificate_callback_(agent, checksig ? true : false,
+ isServer ? true : false);
+ }
+ return SECSuccess;
+ }
+
+ static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd,
+ CERTDistNames* caNames,
+ CERTCertificate** cert,
+ SECKEYPrivateKey** privKey);
+
+ static void ReadableCallback(PollTarget* self, Event event) {
+ TlsAgent* agent = static_cast<TlsAgent*>(self);
+ if (event == TIMER_EVENT) {
+ agent->timer_handle_ = nullptr;
+ }
+ agent->ReadableCallback_int();
+ }
+
+ void ReadableCallback_int() {
+ LOGV("Readable");
+ switch (state_) {
+ case STATE_CONNECTING:
+ Handshake();
+ break;
+ case STATE_CONNECTED:
+ ReadBytes();
+ break;
+ default:
+ break;
+ }
+ }
+
+ static PRInt32 SniHook(PRFileDesc* fd, const SECItem* srvNameArr,
+ PRUint32 srvNameArrSize, void* arg) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ agent->CheckPreliminaryInfo();
+ agent->sni_hook_called_ = true;
+ EXPECT_EQ(1UL, srvNameArrSize);
+ if (agent->sni_callback_) {
+ return agent->sni_callback_(agent, srvNameArr, srvNameArrSize);
+ }
+ return 0; // First configuration.
+ }
+
+ static SECStatus CanFalseStartCallback(PRFileDesc* fd, void* arg,
+ PRBool* canFalseStart) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ agent->CheckPreliminaryInfo();
+ EXPECT_TRUE(agent->falsestart_enabled_);
+ EXPECT_FALSE(agent->can_falsestart_hook_called_);
+ agent->can_falsestart_hook_called_ = true;
+ *canFalseStart = true;
+ return SECSuccess;
+ }
+
+ void CheckAlert(bool sent, const SSLAlert* alert);
+
+ static void AlertReceivedCallback(const PRFileDesc* fd, void* arg,
+ const SSLAlert* alert) {
+ reinterpret_cast<TlsAgent*>(arg)->CheckAlert(false, alert);
+ }
+
+ static void AlertSentCallback(const PRFileDesc* fd, void* arg,
+ const SSLAlert* alert) {
+ reinterpret_cast<TlsAgent*>(arg)->CheckAlert(true, alert);
+ }
+
+ static void HandshakeCallback(PRFileDesc* fd, void* arg) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ agent->handshake_callback_called_ = true;
+ agent->Connected();
+ if (agent->handshake_callback_) {
+ agent->handshake_callback_(agent);
+ }
+ }
+
+ void DisableLameGroups();
+ void ConfigStrongECGroups(bool en);
+ void ConfigAllDHGroups(bool en);
+ void CheckCallbacks() const;
+ void Connected();
+
+ const std::string name_;
+ SSLProtocolVariant variant_;
+ Role role_;
+ uint16_t server_key_bits_;
+ std::shared_ptr<DummyPrSocket> adapter_;
+ ScopedPRFileDesc ssl_fd_;
+ State state_;
+ std::shared_ptr<Poller::Timer> timer_handle_;
+ bool falsestart_enabled_;
+ uint16_t expected_version_;
+ uint16_t expected_cipher_suite_;
+ bool expect_client_auth_;
+ bool expect_ech_;
+ SSLPskType expect_psk_;
+ bool can_falsestart_hook_called_;
+ bool sni_hook_called_;
+ bool auth_certificate_hook_called_;
+ uint8_t expected_received_alert_;
+ uint8_t expected_received_alert_level_;
+ uint8_t expected_sent_alert_;
+ uint8_t expected_sent_alert_level_;
+ bool handshake_callback_called_;
+ bool resumption_callback_called_;
+ SSLChannelInfo info_;
+ SSLPreliminaryChannelInfo pre_info_;
+ SSLCipherSuiteInfo csinfo_;
+ SSLVersionRange vrange_;
+ PRErrorCode error_code_;
+ size_t send_ctr_;
+ size_t recv_ctr_;
+ bool expect_readwrite_error_;
+ HandshakeCallbackFunction handshake_callback_;
+ AuthCertificateCallbackFunction auth_certificate_callback_;
+ SniCallbackFunction sni_callback_;
+ bool skip_version_checks_;
+ std::vector<uint8_t> resumption_token_;
+ NssPolicy policy_;
+ NssOption option_;
+ ClientAuthCallbackType client_auth_callback_type_ =
+ ClientAuthCallbackType::kNone;
+ bool client_auth_callback_success_ = false;
+ uint8_t client_auth_callback_fired_ = 0;
+ bool client_auth_callback_awaiting_ = false;
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const TlsAgent::State& state) {
+ return stream << TlsAgent::state_str(state);
+}
+
+class TlsAgentTestBase : public ::testing::Test {
+ public:
+ static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll;
+
+ TlsAgentTestBase(TlsAgent::Role role, SSLProtocolVariant variant,
+ uint16_t version = 0)
+ : agent_(nullptr),
+ role_(role),
+ variant_(variant),
+ version_(version),
+ sink_adapter_(new DummyPrSocket("sink", variant)) {}
+ virtual ~TlsAgentTestBase() {}
+
+ void SetUp();
+ void TearDown();
+
+ void ExpectAlert(uint8_t alert);
+
+ static void MakeRecord(SSLProtocolVariant variant, uint8_t type,
+ uint16_t version, const uint8_t* buf, size_t len,
+ DataBuffer* out, uint64_t seq_num = 0);
+ void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf,
+ size_t len, DataBuffer* out, uint64_t seq_num = 0) const;
+ void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len,
+ DataBuffer* out, uint64_t seq_num = 0) const;
+ void MakeHandshakeMessageFragment(uint8_t hs_type, const uint8_t* data,
+ size_t hs_len, DataBuffer* out,
+ uint64_t seq_num, uint32_t fragment_offset,
+ uint32_t fragment_length) const;
+ DataBuffer MakeCannedTls13ServerHello();
+ static void MakeTrivialHandshakeRecord(uint8_t hs_type, size_t hs_len,
+ DataBuffer* out);
+ static inline TlsAgent::Role ToRole(const std::string& str) {
+ return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER;
+ }
+
+ void Init(const std::string& server_name = TlsAgent::kServerRsa);
+ void Reset(const std::string& server_name = TlsAgent::kServerRsa);
+
+ protected:
+ void EnsureInit();
+ void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state,
+ int32_t error_code = 0);
+
+ std::shared_ptr<TlsAgent> agent_;
+ TlsAgent::Role role_;
+ SSLProtocolVariant variant_;
+ uint16_t version_;
+ // This adapter is here just to accept packets from this agent.
+ std::shared_ptr<DummyPrSocket> sink_adapter_;
+};
+
+class TlsAgentTest
+ : public TlsAgentTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<std::string, SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsAgentTest()
+ : TlsAgentTestBase(ToRole(std::get<0>(GetParam())),
+ std::get<1>(GetParam()), std::get<2>(GetParam())) {}
+};
+
+class TlsAgentTestClient : public TlsAgentTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsAgentTestClient()
+ : TlsAgentTestBase(TlsAgent::CLIENT, std::get<0>(GetParam()),
+ std::get<1>(GetParam())) {}
+};
+
+class TlsAgentTestClient13 : public TlsAgentTestClient {};
+
+class TlsAgentStreamTestClient : public TlsAgentTestBase {
+ public:
+ TlsAgentStreamTestClient()
+ : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_stream) {}
+};
+
+class TlsAgentStreamTestServer : public TlsAgentTestBase {
+ public:
+ TlsAgentStreamTestServer()
+ : TlsAgentTestBase(TlsAgent::SERVER, ssl_variant_stream) {}
+};
+
+class TlsAgentDgramTestClient : public TlsAgentTestBase {
+ public:
+ TlsAgentDgramTestClient()
+ : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_datagram) {}
+};
+
+inline bool operator==(const SSLVersionRange& vr1, const SSLVersionRange& vr2) {
+ return vr1.min == vr2.min && vr1.max == vr2.max;
+}
+
+} // namespace nss_test
+
+#endif
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc
new file mode 100644
index 0000000000..fd10e34a79
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_connect.cc
@@ -0,0 +1,1065 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_connect.h"
+#include "sslexp.h"
+extern "C" {
+#include "libssl_internals.h"
+}
+
+#include <iostream>
+
+#include "databuffer.h"
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "sslproto.h"
+
+extern std::string g_working_dir_path;
+
+namespace nss_test {
+
+static const SSLProtocolVariant kTlsVariantsStreamArr[] = {ssl_variant_stream};
+::testing::internal::ParamGenerator<SSLProtocolVariant>
+ TlsConnectTestBase::kTlsVariantsStream =
+ ::testing::ValuesIn(kTlsVariantsStreamArr);
+static const SSLProtocolVariant kTlsVariantsDatagramArr[] = {
+ ssl_variant_datagram};
+::testing::internal::ParamGenerator<SSLProtocolVariant>
+ TlsConnectTestBase::kTlsVariantsDatagram =
+ ::testing::ValuesIn(kTlsVariantsDatagramArr);
+static const SSLProtocolVariant kTlsVariantsAllArr[] = {ssl_variant_stream,
+ ssl_variant_datagram};
+::testing::internal::ParamGenerator<SSLProtocolVariant>
+ TlsConnectTestBase::kTlsVariantsAll =
+ ::testing::ValuesIn(kTlsVariantsAllArr);
+
+static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10 =
+ ::testing::ValuesIn(kTlsV10Arr);
+static const uint16_t kTlsV11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11 =
+ ::testing::ValuesIn(kTlsV11Arr);
+static const uint16_t kTlsV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_2};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12 =
+ ::testing::ValuesIn(kTlsV12Arr);
+static const uint16_t kTlsV10V11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_1};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10V11 =
+ ::testing::ValuesIn(kTlsV10V11Arr);
+static const uint16_t kTlsV10ToV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10ToV12 =
+ ::testing::ValuesIn(kTlsV10ToV12Arr);
+static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11V12 =
+ ::testing::ValuesIn(kTlsV11V12Arr);
+
+static const uint16_t kTlsV11PlusArr[] = {
+#ifndef NSS_DISABLE_TLS_1_3
+ SSL_LIBRARY_VERSION_TLS_1_3,
+#endif
+ SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11Plus =
+ ::testing::ValuesIn(kTlsV11PlusArr);
+static const uint16_t kTlsV12PlusArr[] = {
+#ifndef NSS_DISABLE_TLS_1_3
+ SSL_LIBRARY_VERSION_TLS_1_3,
+#endif
+ SSL_LIBRARY_VERSION_TLS_1_2};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12Plus =
+ ::testing::ValuesIn(kTlsV12PlusArr);
+static const uint16_t kTlsV13Arr[] = {SSL_LIBRARY_VERSION_TLS_1_3};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV13 =
+ ::testing::ValuesIn(kTlsV13Arr);
+static const uint16_t kTlsVAllArr[] = {
+#ifndef NSS_DISABLE_TLS_1_3
+ SSL_LIBRARY_VERSION_TLS_1_3,
+#endif
+ SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_0};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsVAll =
+ ::testing::ValuesIn(kTlsVAllArr);
+
+std::string VersionString(uint16_t version) {
+ switch (version) {
+ case 0:
+ return "(no version)";
+ case SSL_LIBRARY_VERSION_3_0:
+ return "1.0";
+ case SSL_LIBRARY_VERSION_TLS_1_0:
+ return "1.0";
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ return "1.1";
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ return "1.2";
+ case SSL_LIBRARY_VERSION_TLS_1_3:
+ return "1.3";
+ default:
+ std::cerr << "Invalid version: " << version << std::endl;
+ EXPECT_TRUE(false);
+ return "";
+ }
+}
+
+// The default anti-replay window for tests. Tests that rely on a different
+// value call ResetAntiReplay directly.
+static PRTime kAntiReplayWindow = 100 * PR_USEC_PER_SEC;
+
+TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant,
+ uint16_t version)
+ : variant_(variant),
+ client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)),
+ server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)),
+ client_model_(nullptr),
+ server_model_(nullptr),
+ version_(version),
+ expected_resumption_mode_(RESUME_NONE),
+ expected_resumptions_(0),
+ session_ids_(),
+ expect_extended_master_secret_(false),
+ expect_early_data_accepted_(false),
+ skip_version_checks_(false) {
+ std::string v;
+ if (variant_ == ssl_variant_datagram &&
+ version_ == SSL_LIBRARY_VERSION_TLS_1_1) {
+ v = "1.0";
+ } else {
+ v = VersionString(version_);
+ }
+ std::cerr << "Version: " << variant_ << " " << v << std::endl;
+}
+
+TlsConnectTestBase::~TlsConnectTestBase() {}
+
+// Check the group of each of the supported groups
+void TlsConnectTestBase::CheckGroups(
+ const DataBuffer& groups, std::function<void(SSLNamedGroup)> check_group) {
+ DuplicateGroupChecker group_set;
+ uint32_t tmp = 0;
+ EXPECT_TRUE(groups.Read(0, 2, &tmp));
+ EXPECT_EQ(groups.len() - 2, static_cast<size_t>(tmp));
+ for (size_t i = 2; i < groups.len(); i += 2) {
+ EXPECT_TRUE(groups.Read(i, 2, &tmp));
+ SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp);
+ group_set.AddAndCheckGroup(group);
+ check_group(group);
+ }
+}
+
+// Check the group of each of the shares
+void TlsConnectTestBase::CheckShares(
+ const DataBuffer& shares, std::function<void(SSLNamedGroup)> check_group) {
+ DuplicateGroupChecker group_set;
+ uint32_t tmp = 0;
+ EXPECT_TRUE(shares.Read(0, 2, &tmp));
+ EXPECT_EQ(shares.len() - 2, static_cast<size_t>(tmp));
+ size_t i;
+ for (i = 2; i < shares.len(); i += 4 + tmp) {
+ ASSERT_TRUE(shares.Read(i, 2, &tmp));
+ SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp);
+ group_set.AddAndCheckGroup(group);
+ check_group(group);
+ ASSERT_TRUE(shares.Read(i + 2, 2, &tmp));
+ }
+ EXPECT_EQ(shares.len(), i);
+}
+
+void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch,
+ uint16_t server_epoch) const {
+ client_->CheckEpochs(server_epoch, client_epoch);
+ server_->CheckEpochs(client_epoch, server_epoch);
+}
+
+void TlsConnectTestBase::ClearStats() {
+ // Clear statistics.
+ SSL3Statistics* stats = SSL_GetStatistics();
+ memset(stats, 0, sizeof(*stats));
+}
+
+void TlsConnectTestBase::ClearServerCache() {
+ SSL_ShutdownServerSessionIDCache();
+ SSLInt_ClearSelfEncryptKey();
+ SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
+}
+
+void TlsConnectTestBase::SaveAlgorithmPolicy() {
+ saved_policies_.clear();
+ for (auto it = algorithms_.begin(); it != algorithms_.end(); ++it) {
+ uint32_t policy;
+ SECStatus rv = NSS_GetAlgorithmPolicy(*it, &policy);
+ ASSERT_EQ(SECSuccess, rv);
+ saved_policies_.push_back(std::make_tuple(*it, policy));
+ }
+ saved_options_.clear();
+ for (auto it : options_) {
+ int32_t option;
+ SECStatus rv = NSS_OptionGet(it, &option);
+ ASSERT_EQ(SECSuccess, rv);
+ saved_options_.push_back(std::make_tuple(it, option));
+ }
+}
+
+void TlsConnectTestBase::RestoreAlgorithmPolicy() {
+ for (auto it = saved_policies_.begin(); it != saved_policies_.end(); ++it) {
+ auto algorithm = std::get<0>(*it);
+ auto policy = std::get<1>(*it);
+ SECStatus rv = NSS_SetAlgorithmPolicy(
+ algorithm, policy, NSS_USE_POLICY_IN_SSL | NSS_USE_ALG_IN_SSL_KX);
+ ASSERT_EQ(SECSuccess, rv);
+ }
+ for (auto it = saved_options_.begin(); it != saved_options_.end(); ++it) {
+ auto option_id = std::get<0>(*it);
+ auto option = std::get<1>(*it);
+ SECStatus rv = NSS_OptionSet(option_id, option);
+ ASSERT_EQ(SECSuccess, rv);
+ }
+}
+
+PRTime TlsConnectTestBase::TimeFunc(void* arg) {
+ return *reinterpret_cast<PRTime*>(arg);
+}
+
+void TlsConnectTestBase::SetUp() {
+ SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
+ SSLInt_ClearSelfEncryptKey();
+ now_ = PR_Now();
+ ResetAntiReplay(kAntiReplayWindow);
+ ClearStats();
+ SaveAlgorithmPolicy();
+ Init();
+}
+
+void TlsConnectTestBase::TearDown() {
+ client_ = nullptr;
+ server_ = nullptr;
+
+ SSL_ClearSessionCache();
+ SSLInt_ClearSelfEncryptKey();
+ SSL_ShutdownServerSessionIDCache();
+ RestoreAlgorithmPolicy();
+}
+
+void TlsConnectTestBase::Init() {
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+
+ if (version_) {
+ ConfigureVersion(version_);
+ }
+}
+
+void TlsConnectTestBase::ResetAntiReplay(PRTime window) {
+ SSLAntiReplayContext* p_anti_replay = nullptr;
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateAntiReplayContext(now_, window, 1, 3, &p_anti_replay));
+ EXPECT_NE(nullptr, p_anti_replay);
+ anti_replay_.reset(p_anti_replay);
+}
+
+ScopedSECItem TlsConnectTestBase::MakeEcKeyParams(SSLNamedGroup group) {
+ auto groupDef = ssl_LookupNamedGroup(group);
+ EXPECT_NE(nullptr, groupDef);
+
+ auto oidData = SECOID_FindOIDByTag(groupDef->oidTag);
+ EXPECT_NE(nullptr, oidData);
+ ScopedSECItem params(
+ SECITEM_AllocItem(nullptr, nullptr, (2 + oidData->oid.len)));
+ EXPECT_TRUE(!!params);
+ params->data[0] = SEC_ASN1_OBJECT_ID;
+ params->data[1] = oidData->oid.len;
+ memcpy(params->data + 2, oidData->oid.data, oidData->oid.len);
+ return params;
+}
+
+void TlsConnectTestBase::GenerateEchConfig(
+ HpkeKemId kem_id, const std::vector<HpkeSymmetricSuite>& cipher_suites,
+ const std::string& public_name, uint16_t max_name_len, DataBuffer& record,
+ ScopedSECKEYPublicKey& pubKey, ScopedSECKEYPrivateKey& privKey) {
+ bool gen_keys = !pubKey && !privKey;
+
+ SECKEYPublicKey* pub = nullptr;
+ SECKEYPrivateKey* priv = nullptr;
+
+ if (gen_keys) {
+ ScopedSECItem ecParams = MakeEcKeyParams(ssl_grp_ec_curve25519);
+ priv = SECKEY_CreateECPrivateKey(ecParams.get(), &pub, nullptr);
+ } else {
+ priv = privKey.get();
+ pub = pubKey.get();
+ }
+ ASSERT_NE(nullptr, priv);
+ PRUint8 encoded[1024];
+ unsigned int encoded_len = 0;
+ SECStatus rv = SSL_EncodeEchConfigId(
+ 77, public_name.c_str(), max_name_len, kem_id, pub, cipher_suites.data(),
+ cipher_suites.size(), encoded, &encoded_len, sizeof(encoded));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_GT(encoded_len, 0U);
+
+ if (gen_keys) {
+ pubKey.reset(pub);
+ privKey.reset(priv);
+ }
+ record.Truncate(0);
+ record.Write(0, encoded, encoded_len);
+}
+
+void TlsConnectTestBase::SetupEch(std::shared_ptr<TlsAgent>& client,
+ std::shared_ptr<TlsAgent>& server,
+ HpkeKemId kem_id, bool expect_ech,
+ bool set_client_config,
+ bool set_server_config, int max_name_len) {
+ EXPECT_TRUE(set_server_config || set_client_config);
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer record;
+ static const std::vector<HpkeSymmetricSuite> kDefaultSuites = {
+ {HpkeKdfHkdfSha256, HpkeAeadChaCha20Poly1305},
+ {HpkeKdfHkdfSha256, HpkeAeadAes128Gcm}};
+
+ GenerateEchConfig(kem_id, kDefaultSuites, "public.name", max_name_len, record,
+ pub, priv);
+ ASSERT_NE(0U, record.len());
+ SECStatus rv;
+ if (set_server_config) {
+ rv = SSL_SetServerEchConfigs(server->ssl_fd(), pub.get(), priv.get(),
+ record.data(), record.len());
+ ASSERT_EQ(SECSuccess, rv);
+ }
+ if (set_client_config) {
+ rv = SSL_SetClientEchConfigs(client->ssl_fd(), record.data(), record.len());
+ ASSERT_EQ(SECSuccess, rv);
+ }
+
+ /* Filter expect_ech, which typically defaults to true. Parameterized tests
+ * running DTLS or TLS < 1.3 should expect only a non-ECH result. */
+ bool expect = expect_ech && variant_ != ssl_variant_datagram &&
+ version_ >= SSL_LIBRARY_VERSION_TLS_1_3 && set_client_config &&
+ set_server_config;
+ client->ExpectEch(expect);
+ server->ExpectEch(expect);
+}
+
+void TlsConnectTestBase::Reset() {
+ // Take a copy of the names because they are about to disappear.
+ std::string server_name = server_->name();
+ std::string client_name = client_->name();
+ Reset(server_name, client_name);
+}
+
+void TlsConnectTestBase::Reset(const std::string& server_name,
+ const std::string& client_name) {
+ auto token = client_->GetResumptionToken();
+ client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_));
+ client_->SetResumptionToken(token);
+ server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_));
+ if (skip_version_checks_) {
+ client_->SkipVersionChecks();
+ server_->SkipVersionChecks();
+ }
+
+ std::cerr << "Reset server:" << server_name << ", client:" << client_name
+ << std::endl;
+ Init();
+}
+
+void TlsConnectTestBase::MakeNewServer() {
+ auto replacement = std::make_shared<TlsAgent>(
+ server_->name(), TlsAgent::SERVER, server_->variant());
+ server_ = replacement;
+ if (version_) {
+ server_->SetVersionRange(version_, version_);
+ }
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+ server_->StartConnect();
+}
+
+void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected,
+ uint8_t num_resumptions) {
+ expected_resumption_mode_ = expected;
+ if (expected != RESUME_NONE) {
+ client_->ExpectResumption();
+ server_->ExpectResumption();
+ expected_resumptions_ = num_resumptions;
+ }
+ EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE);
+}
+
+void TlsConnectTestBase::EnsureTlsSetup() {
+ EXPECT_TRUE(server_->EnsureTlsSetup(
+ server_model_ ? server_model_->ssl_fd().get() : nullptr));
+ EXPECT_TRUE(client_->EnsureTlsSetup(
+ client_model_ ? client_model_->ssl_fd().get() : nullptr));
+ server_->SetAntiReplayContext(anti_replay_);
+ EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(client_->ssl_fd(),
+ TlsConnectTestBase::TimeFunc, &now_));
+ EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(server_->ssl_fd(),
+ TlsConnectTestBase::TimeFunc, &now_));
+}
+
+void TlsConnectTestBase::Handshake() {
+ client_->SetServerKeyBits(server_->server_key_bits());
+ client_->Handshake();
+ server_->Handshake();
+
+ ASSERT_TRUE_WAIT((client_->state() != TlsAgent::STATE_CONNECTING) &&
+ (server_->state() != TlsAgent::STATE_CONNECTING),
+ 5000);
+}
+
+void TlsConnectTestBase::EnableExtendedMasterSecret() {
+ client_->EnableExtendedMasterSecret();
+ server_->EnableExtendedMasterSecret();
+ ExpectExtendedMasterSecret(true);
+}
+
+void TlsConnectTestBase::Connect() {
+ StartConnect();
+ client_->MaybeSetResumptionToken();
+ Handshake();
+ CheckConnected();
+}
+
+void TlsConnectTestBase::StartConnect() {
+ EnsureTlsSetup();
+ server_->StartConnect();
+ client_->StartConnect();
+}
+
+void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
+ EnsureTlsSetup();
+ client_->EnableSingleCipher(cipher_suite);
+
+ Connect();
+ SendReceive();
+
+ // Check that we used the right cipher suite.
+ uint16_t actual;
+ EXPECT_TRUE(client_->cipher_suite(&actual));
+ EXPECT_EQ(cipher_suite, actual);
+ EXPECT_TRUE(server_->cipher_suite(&actual));
+ EXPECT_EQ(cipher_suite, actual);
+}
+
+void TlsConnectTestBase::CheckConnected() {
+ // Have the client read handshake twice to make sure we get the
+ // NST and the ACK.
+ if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ variant_ == ssl_variant_datagram) {
+ client_->Handshake();
+ client_->Handshake();
+ auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd());
+ // Verify that we dropped the client's retransmission cipher suites.
+ EXPECT_EQ(2, suites) << "Client has the wrong number of suites";
+ if (suites != 2) {
+ SSLInt_PrintCipherSpecs("client", client_->ssl_fd());
+ }
+ }
+ EXPECT_EQ(client_->version(), server_->version());
+ if (!skip_version_checks_) {
+ // Check the version is as expected
+ EXPECT_EQ(std::min(client_->max_version(), server_->max_version()),
+ client_->version());
+ }
+
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+
+ uint16_t cipher_suite1, cipher_suite2;
+ ASSERT_TRUE(client_->cipher_suite(&cipher_suite1));
+ ASSERT_TRUE(server_->cipher_suite(&cipher_suite2));
+ EXPECT_EQ(cipher_suite1, cipher_suite2);
+
+ std::cerr << "Connected with version " << client_->version()
+ << " cipher suite " << client_->cipher_suite_name() << std::endl;
+
+ if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) {
+ // Check and store session ids.
+ std::vector<uint8_t> sid_c1 = client_->session_id();
+ EXPECT_EQ(32U, sid_c1.size());
+ std::vector<uint8_t> sid_s1 = server_->session_id();
+ EXPECT_EQ(32U, sid_s1.size());
+ EXPECT_EQ(sid_c1, sid_s1);
+ session_ids_.push_back(sid_c1);
+ }
+
+ CheckExtendedMasterSecret();
+ CheckEarlyDataAccepted();
+ CheckResumption(expected_resumption_mode_);
+ client_->CheckSecretsDestroyed();
+ server_->CheckSecretsDestroyed();
+}
+
+void TlsConnectTestBase::CheckEarlyDataLimit(
+ const std::shared_ptr<TlsAgent>& agent, size_t expected_size) {
+ SSLPreliminaryChannelInfo preinfo;
+ SECStatus rv =
+ SSL_GetPreliminaryChannelInfo(agent->ssl_fd(), &preinfo, sizeof(preinfo));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(expected_size, static_cast<size_t>(preinfo.maxEarlyDataSize));
+}
+
+void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
+ SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme) const {
+ if (kea_group != ssl_grp_none) {
+ client_->CheckKEA(kea_type, kea_group);
+ server_->CheckKEA(kea_type, kea_group);
+ }
+ server_->CheckAuthType(auth_type, sig_scheme);
+ client_->CheckAuthType(auth_type, sig_scheme);
+}
+
+void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type,
+ SSLAuthType auth_type) const {
+ SSLNamedGroup group;
+ switch (kea_type) {
+ case ssl_kea_ecdh:
+ group = ssl_grp_ec_curve25519;
+ break;
+ case ssl_kea_dh:
+ group = ssl_grp_ffdhe_2048;
+ break;
+ case ssl_kea_rsa:
+ group = ssl_grp_none;
+ break;
+ default:
+ EXPECT_TRUE(false) << "unexpected KEA";
+ group = ssl_grp_none;
+ break;
+ }
+
+ SSLSignatureScheme scheme;
+ switch (auth_type) {
+ case ssl_auth_rsa_decrypt:
+ scheme = ssl_sig_none;
+ break;
+ case ssl_auth_rsa_sign:
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) {
+ scheme = ssl_sig_rsa_pss_rsae_sha256;
+ } else {
+ scheme = ssl_sig_rsa_pkcs1_sha256;
+ }
+ break;
+ case ssl_auth_rsa_pss:
+ scheme = ssl_sig_rsa_pss_rsae_sha256;
+ break;
+ case ssl_auth_ecdsa:
+ scheme = ssl_sig_ecdsa_secp256r1_sha256;
+ break;
+ case ssl_auth_dsa:
+ scheme = ssl_sig_dsa_sha1;
+ break;
+ default:
+ EXPECT_TRUE(false) << "unexpected auth type";
+ scheme = static_cast<SSLSignatureScheme>(0x0100);
+ break;
+ }
+ CheckKeys(kea_type, group, auth_type, scheme);
+}
+
+void TlsConnectTestBase::CheckKeys() const {
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+}
+
+void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type,
+ SSLNamedGroup kea_group,
+ SSLNamedGroup original_kea_group,
+ SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme) {
+ CheckKeys(kea_type, kea_group, auth_type, sig_scheme);
+ EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE);
+ client_->CheckOriginalKEA(original_kea_group);
+ server_->CheckOriginalKEA(original_kea_group);
+}
+
+void TlsConnectTestBase::ConnectExpectFail() {
+ StartConnect();
+ Handshake();
+ ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state());
+}
+
+void TlsConnectTestBase::ExpectAlert(std::shared_ptr<TlsAgent>& sender,
+ uint8_t alert) {
+ EnsureTlsSetup();
+ auto receiver = (sender == client_) ? server_ : client_;
+ sender->ExpectSendAlert(alert);
+ receiver->ExpectReceiveAlert(alert);
+}
+
+void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender,
+ uint8_t alert) {
+ ExpectAlert(sender, alert);
+ ConnectExpectFail();
+}
+
+void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) {
+ StartConnect();
+ client_->SetServerKeyBits(server_->server_key_bits());
+ client_->Handshake();
+ server_->Handshake();
+
+ auto failing_agent = server_;
+ if (failing_side == TlsAgent::CLIENT) {
+ failing_agent = client_;
+ }
+ ASSERT_TRUE_WAIT(failing_agent->state() == TlsAgent::STATE_ERROR, 5000);
+}
+
+void TlsConnectTestBase::ConfigureVersion(uint16_t version) {
+ version_ = version;
+ client_->SetVersionRange(version, version);
+ server_->SetVersionRange(version, version);
+}
+
+void TlsConnectTestBase::SetExpectedVersion(uint16_t version) {
+ client_->SetExpectedVersion(version);
+ server_->SetExpectedVersion(version);
+}
+
+void TlsConnectTestBase::AddPsk(const ScopedPK11SymKey& psk, std::string label,
+ SSLHashType hash, uint16_t zeroRttSuite) {
+ client_->AddPsk(psk, label, hash, zeroRttSuite);
+ server_->AddPsk(psk, label, hash, zeroRttSuite);
+ client_->ExpectPsk();
+ server_->ExpectPsk();
+}
+
+void TlsConnectTestBase::DisableAllCiphers() {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ server_->DisableAllCiphers();
+}
+
+void TlsConnectTestBase::EnableOnlyStaticRsaCiphers() {
+ DisableAllCiphers();
+
+ client_->EnableCiphersByKeyExchange(ssl_kea_rsa);
+ server_->EnableCiphersByKeyExchange(ssl_kea_rsa);
+}
+
+void TlsConnectTestBase::EnableOnlyDheCiphers() {
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_dh);
+ server_->EnableCiphersByKeyExchange(ssl_kea_dh);
+ } else {
+ client_->ConfigNamedGroups(kFFDHEGroups);
+ server_->ConfigNamedGroups(kFFDHEGroups);
+ }
+}
+
+void TlsConnectTestBase::EnableSomeEcdhCiphers() {
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ client_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa);
+ client_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa);
+ server_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa);
+ server_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa);
+ } else {
+ client_->ConfigNamedGroups(kECDHEGroups);
+ server_->ConfigNamedGroups(kECDHEGroups);
+ }
+}
+
+void TlsConnectTestBase::ConfigureSelfEncrypt() {
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey privKey;
+ ASSERT_TRUE(
+ TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey));
+
+ ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get()));
+ ASSERT_TRUE(pubKey);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get()));
+}
+
+void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client,
+ SessionResumptionMode server) {
+ client_->ConfigureSessionCache(client);
+ server_->ConfigureSessionCache(server);
+ if ((server & RESUME_TICKET) != 0) {
+ ConfigureSelfEncrypt();
+ }
+}
+
+void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
+ EXPECT_NE(RESUME_BOTH, expected);
+
+ int resume_count = expected ? expected_resumptions_ : 0;
+ int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0;
+
+ // Note: hch == server counter; hsh == client counter.
+ SSL3Statistics* stats = SSL_GetStatistics();
+ EXPECT_EQ(resume_count, stats->hch_sid_cache_hits);
+ EXPECT_EQ(resume_count, stats->hsh_sid_cache_hits);
+
+ EXPECT_EQ(stateless_count, stats->hch_sid_stateless_resumes);
+ EXPECT_EQ(stateless_count, stats->hsh_sid_stateless_resumes);
+
+ if (expected != RESUME_NONE) {
+ if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 &&
+ client_->GetResumptionToken().size() == 0) {
+ // Check that the last two session ids match.
+ ASSERT_EQ(1U + expected_resumptions_, session_ids_.size());
+ EXPECT_EQ(session_ids_[session_ids_.size() - 1],
+ session_ids_[session_ids_.size() - 2]);
+ } else {
+ // We've either chosen TLS 1.3 or are using an external resumption token,
+ // both of which only use tickets.
+ EXPECT_TRUE(expected & RESUME_TICKET);
+ }
+ }
+}
+
+static SECStatus NextProtoCallbackServer(void* arg, PRFileDesc* fd,
+ const unsigned char* protos,
+ unsigned int protos_len,
+ unsigned char* protoOut,
+ unsigned int* protoOutLen,
+ unsigned int protoMaxLen) {
+ EXPECT_EQ(protoMaxLen, 255U);
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ // Check that agent->alpn_value_to_use_ is in protos.
+ if (protos_len < 1) {
+ return SECFailure;
+ }
+ for (size_t i = 0; i < protos_len;) {
+ size_t l = protos[i];
+ EXPECT_LT(i + l, protos_len);
+ if (i + l >= protos_len) {
+ return SECFailure;
+ }
+ std::string protos_s(reinterpret_cast<const char*>(protos + i + 1), l);
+ if (protos_s == agent->alpn_value_to_use_) {
+ size_t s_len = agent->alpn_value_to_use_.size();
+ EXPECT_LE(s_len, 255U);
+ memcpy(protoOut, &agent->alpn_value_to_use_[0], s_len);
+ *protoOutLen = s_len;
+ return SECSuccess;
+ }
+ i += l + 1;
+ }
+ return SECFailure;
+}
+
+void TlsConnectTestBase::EnableAlpn() {
+ client_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
+ server_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
+}
+
+void TlsConnectTestBase::EnableAlpnWithCallback(
+ const std::vector<uint8_t>& client_vals, std::string server_choice) {
+ EnsureTlsSetup();
+ server_->alpn_value_to_use_ = server_choice;
+ EXPECT_EQ(SECSuccess,
+ SSL_SetNextProtoNego(client_->ssl_fd(), client_vals.data(),
+ client_vals.size()));
+ SECStatus rv = SSL_SetNextProtoCallback(
+ server_->ssl_fd(), NextProtoCallbackServer, server_.get());
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsConnectTestBase::EnableAlpn(const std::vector<uint8_t>& vals) {
+ client_->EnableAlpn(vals.data(), vals.size());
+ server_->EnableAlpn(vals.data(), vals.size());
+}
+
+void TlsConnectTestBase::EnsureModelSockets() {
+ // Make sure models agents are available.
+ if (!client_model_) {
+ ASSERT_EQ(server_model_, nullptr);
+ client_model_.reset(
+ new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_));
+ server_model_.reset(
+ new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_));
+ if (skip_version_checks_) {
+ client_model_->SkipVersionChecks();
+ server_model_->SkipVersionChecks();
+ }
+ }
+}
+
+void TlsConnectTestBase::CheckAlpn(const std::string& val) {
+ client_->CheckAlpn(SSL_NEXT_PROTO_SELECTED, val);
+ server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, val);
+}
+
+void TlsConnectTestBase::EnableSrtp() {
+ client_->EnableSrtp();
+ server_->EnableSrtp();
+}
+
+void TlsConnectTestBase::CheckSrtp() const {
+ client_->CheckSrtp();
+ server_->CheckSrtp();
+}
+
+void TlsConnectTestBase::SendReceive(size_t total) {
+ ASSERT_GT(total, client_->received_bytes());
+ ASSERT_GT(total, server_->received_bytes());
+ client_->SendData(total - server_->received_bytes());
+ server_->SendData(total - client_->received_bytes());
+ Receive(total); // Receive() is cumulative
+}
+
+// Do a first connection so we can do 0-RTT on the second one.
+void TlsConnectTestBase::SetupForZeroRtt() {
+ // Force rollover of the anti-replay window.
+ // If we don't do this, then all 0-RTT attempts will be rejected.
+ RolloverAntiReplay();
+
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT.
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ StartConnect();
+}
+
+// Do a first connection so we can do resumption
+void TlsConnectTestBase::SetupForResume() {
+ EnsureTlsSetup();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+}
+
+void TlsConnectTestBase::ZeroRttSendReceive(
+ bool expect_writable, bool expect_readable,
+ std::function<bool()> post_clienthello_check) {
+ const char* k0RttData = "ABCDEF";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+
+ client_->Handshake(); // Send ClientHello.
+ if (post_clienthello_check) {
+ if (!post_clienthello_check()) return;
+ }
+ PRInt32 rv =
+ PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write.
+ if (expect_writable) {
+ EXPECT_EQ(k0RttDataLen, rv);
+ } else {
+ EXPECT_EQ(SECFailure, rv);
+ }
+ server_->Handshake(); // Consume ClientHello
+
+ std::vector<uint8_t> buf(k0RttDataLen);
+ rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read
+ if (expect_readable) {
+ std::cerr << "0-RTT read " << rv << " bytes\n";
+ EXPECT_EQ(k0RttDataLen, rv);
+ } else {
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError())
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+ }
+
+ // Do a second read. This should fail.
+ rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+}
+
+void TlsConnectTestBase::Receive(size_t amount) {
+ WAIT_(client_->received_bytes() == amount &&
+ server_->received_bytes() == amount,
+ 2000);
+ ASSERT_EQ(amount, client_->received_bytes());
+ ASSERT_EQ(amount, server_->received_bytes());
+}
+
+void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) {
+ expect_extended_master_secret_ = expected;
+}
+
+void TlsConnectTestBase::CheckExtendedMasterSecret() {
+ client_->CheckExtendedMasterSecret(expect_extended_master_secret_);
+ server_->CheckExtendedMasterSecret(expect_extended_master_secret_);
+}
+
+void TlsConnectTestBase::ExpectEarlyDataAccepted(bool expected) {
+ expect_early_data_accepted_ = expected;
+}
+
+void TlsConnectTestBase::CheckEarlyDataAccepted() {
+ client_->CheckEarlyDataAccepted(expect_early_data_accepted_);
+ server_->CheckEarlyDataAccepted(expect_early_data_accepted_);
+}
+
+void TlsConnectTestBase::EnableECDHEServerKeyReuse() {
+ server_->EnableECDHEServerKeyReuse();
+}
+
+void TlsConnectTestBase::SkipVersionChecks() {
+ skip_version_checks_ = true;
+ client_->SkipVersionChecks();
+ server_->SkipVersionChecks();
+}
+
+// Shift the DTLS timers, to the minimum time necessary to let the next timer
+// run on either client or server. This allows tests to skip waiting without
+// having timers run out of order.
+void TlsConnectTestBase::ShiftDtlsTimers() {
+ PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT;
+ PRIntervalTime time;
+ SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time);
+ if (rv == SECSuccess) {
+ time_shift = time;
+ }
+ rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time);
+ if (rv == SECSuccess &&
+ (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) {
+ time_shift = time;
+ }
+
+ if (time_shift != PR_INTERVAL_NO_TIMEOUT) {
+ AdvanceTime(PR_IntervalToMicroseconds(time_shift));
+ EXPECT_EQ(SECSuccess,
+ SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift));
+ EXPECT_EQ(SECSuccess,
+ SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift));
+ }
+}
+
+void TlsConnectTestBase::AdvanceTime(PRTime time_shift) { now_ += time_shift; }
+
+// Advance time by a full anti-replay window.
+void TlsConnectTestBase::RolloverAntiReplay() {
+ AdvanceTime(kAntiReplayWindow);
+}
+
+TlsConnectGeneric::TlsConnectGeneric()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+TlsConnectPre12::TlsConnectPre12()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+TlsConnectTls12::TlsConnectTls12()
+ : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_2) {}
+
+TlsConnectTls12Plus::TlsConnectTls12Plus()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+TlsConnectTls13::TlsConnectTls13()
+ : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
+
+TlsConnectGenericResumption::TlsConnectGenericResumption()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())),
+ external_cache_(std::get<2>(GetParam())) {}
+
+TlsConnectTls13ResumptionToken::TlsConnectTls13ResumptionToken()
+ : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
+
+TlsConnectGenericResumptionToken::TlsConnectGenericResumptionToken()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+void TlsKeyExchangeTest::EnsureKeyShareSetup() {
+ EnsureTlsSetup();
+ groups_capture_ =
+ std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
+ shares_capture_ =
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
+ shares_capture2_ = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_tls13_key_share_xtn, true);
+ std::vector<std::shared_ptr<PacketFilter>> captures = {
+ groups_capture_, shares_capture_, shares_capture2_};
+ client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures));
+ capture_hrr_ = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeHelloRetryRequest);
+}
+
+void TlsKeyExchangeTest::ConfigNamedGroups(
+ const std::vector<SSLNamedGroup>& groups) {
+ client_->ConfigNamedGroups(groups);
+ server_->ConfigNamedGroups(groups);
+}
+
+std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
+ const std::shared_ptr<TlsExtensionCapture>& capture) {
+ EXPECT_TRUE(capture->captured());
+ const DataBuffer& ext = capture->extension();
+
+ uint32_t tmp = 0;
+ EXPECT_TRUE(ext.Read(0, 2, &tmp));
+ EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
+ EXPECT_TRUE(ext.len() % 2 == 0);
+
+ std::vector<SSLNamedGroup> groups;
+ for (size_t i = 1; i < ext.len() / 2; i += 1) {
+ EXPECT_TRUE(ext.Read(2 * i, 2, &tmp));
+ groups.push_back(static_cast<SSLNamedGroup>(tmp));
+ }
+ return groups;
+}
+
+std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
+ const std::shared_ptr<TlsExtensionCapture>& capture) {
+ EXPECT_TRUE(capture->captured());
+ const DataBuffer& ext = capture->extension();
+
+ uint32_t tmp = 0;
+ EXPECT_TRUE(ext.Read(0, 2, &tmp));
+ EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
+
+ std::vector<SSLNamedGroup> shares;
+ size_t i = 2;
+ while (i < ext.len()) {
+ EXPECT_TRUE(ext.Read(i, 2, &tmp));
+ shares.push_back(static_cast<SSLNamedGroup>(tmp));
+ EXPECT_TRUE(ext.Read(i + 2, 2, &tmp));
+ i += 4 + tmp;
+ }
+ EXPECT_EQ(ext.len(), i);
+ return shares;
+}
+
+void TlsKeyExchangeTest::CheckKEXDetails(
+ const std::vector<SSLNamedGroup>& expected_groups,
+ const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) {
+ std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_);
+ EXPECT_EQ(expected_groups, groups);
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ASSERT_LT(0U, expected_shares.size());
+ std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_);
+ EXPECT_EQ(expected_shares, shares);
+ } else {
+ EXPECT_FALSE(shares_capture_->captured());
+ }
+
+ EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0);
+}
+
+void TlsKeyExchangeTest::CheckKEXDetails(
+ const std::vector<SSLNamedGroup>& expected_groups,
+ const std::vector<SSLNamedGroup>& expected_shares) {
+ CheckKEXDetails(expected_groups, expected_shares, false);
+}
+
+void TlsKeyExchangeTest::CheckKEXDetails(
+ const std::vector<SSLNamedGroup>& expected_groups,
+ const std::vector<SSLNamedGroup>& expected_shares,
+ SSLNamedGroup expected_share2) {
+ CheckKEXDetails(expected_groups, expected_shares, true);
+
+ for (auto it : expected_shares) {
+ EXPECT_NE(expected_share2, it);
+ }
+ std::vector<SSLNamedGroup> expected_shares2 = {expected_share2};
+ EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_));
+}
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.h b/security/nss/gtests/ssl_gtest/tls_connect.h
new file mode 100644
index 0000000000..6a4795f83e
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_connect.h
@@ -0,0 +1,390 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_connect_h_
+#define tls_connect_h_
+
+#include <tuple>
+
+#include "sslproto.h"
+#include "sslt.h"
+#include "nss.h"
+
+#include "tls_agent.h"
+#include "tls_filter.h"
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+
+namespace nss_test {
+
+extern std::string VersionString(uint16_t version);
+
+// A generic TLS connection test base.
+class TlsConnectTestBase : public ::testing::Test {
+ public:
+ static ::testing::internal::ParamGenerator<SSLProtocolVariant>
+ kTlsVariantsStream;
+ static ::testing::internal::ParamGenerator<SSLProtocolVariant>
+ kTlsVariantsDatagram;
+ static ::testing::internal::ParamGenerator<SSLProtocolVariant>
+ kTlsVariantsAll;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV10;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV11;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV12;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV10V11;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV11V12;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV10ToV12;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV13;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV11Plus;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsVAll;
+
+ TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version);
+ virtual ~TlsConnectTestBase();
+
+ virtual void SetUp();
+ virtual void TearDown();
+
+ PRTime now() const { return now_; }
+
+ // Initialize client and server.
+ void Init();
+ // Clear the statistics.
+ void ClearStats();
+ // Clear the server session cache.
+ void ClearServerCache();
+ // Make sure TLS is configured for a connection.
+ virtual void EnsureTlsSetup();
+ // Reset and keep the same certificate names
+ void Reset();
+ // Reset, and update the certificate names on both peers
+ void Reset(const std::string& server_name,
+ const std::string& client_name = "client");
+ // Replace the server.
+ void MakeNewServer();
+
+ // Set up
+ void StartConnect();
+ // Run the handshake.
+ void Handshake();
+ // Connect and check that it works.
+ void Connect();
+ // Check that the connection was successfully established.
+ void CheckConnected();
+ // Connect and expect it to fail.
+ void ConnectExpectFail();
+ void ExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
+ void ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
+ void ConnectExpectFailOneSide(TlsAgent::Role failingSide);
+ void ConnectWithCipherSuite(uint16_t cipher_suite);
+ void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent,
+ size_t expected_size);
+ // Check that the keys used in the handshake match expectations.
+ void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
+ SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const;
+ // This version guesses some of the values.
+ void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const;
+ // This version assumes defaults.
+ void CheckKeys() const;
+ // Check that keys on resumed sessions.
+ void CheckKeysResumption(SSLKEAType kea_type, SSLNamedGroup kea_group,
+ SSLNamedGroup original_kea_group,
+ SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme);
+ void CheckGroups(const DataBuffer& groups,
+ std::function<void(SSLNamedGroup)> check_group);
+ void CheckShares(const DataBuffer& shares,
+ std::function<void(SSLNamedGroup)> check_group);
+ void CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const;
+
+ void ConfigureVersion(uint16_t version);
+ void SetExpectedVersion(uint16_t version);
+ // Expect resumption of a particular type.
+ void ExpectResumption(SessionResumptionMode expected,
+ uint8_t num_resumed = 1);
+ void DisableAllCiphers();
+ void EnableOnlyStaticRsaCiphers();
+ void EnableOnlyDheCiphers();
+ void EnableSomeEcdhCiphers();
+ void EnableExtendedMasterSecret();
+ void ConfigureSelfEncrypt();
+ void ConfigureSessionCache(SessionResumptionMode client,
+ SessionResumptionMode server);
+ void EnableAlpn();
+ void EnableAlpnWithCallback(const std::vector<uint8_t>& client,
+ std::string server_choice);
+ void EnableAlpn(const std::vector<uint8_t>& vals);
+ void EnsureModelSockets();
+ void CheckAlpn(const std::string& val);
+ void EnableSrtp();
+ void CheckSrtp() const;
+ void SendReceive(size_t total = 50);
+ void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash,
+ uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL);
+ void RemovePsk(std::string label);
+ void SetupForZeroRtt();
+ void SetupForResume();
+ void ZeroRttSendReceive(
+ bool expect_writable, bool expect_readable,
+ std::function<bool()> post_clienthello_check = nullptr);
+ void Receive(size_t amount);
+ void ExpectExtendedMasterSecret(bool expected);
+ void ExpectEarlyDataAccepted(bool expected);
+ void EnableECDHEServerKeyReuse();
+ void SkipVersionChecks();
+
+ // Move the DTLS timers for both endpoints to pop the next timer.
+ void ShiftDtlsTimers();
+ void AdvanceTime(PRTime time_shift);
+
+ void ResetAntiReplay(PRTime window);
+ void RolloverAntiReplay();
+
+ void SaveAlgorithmPolicy();
+ void RestoreAlgorithmPolicy();
+
+ static ScopedSECItem MakeEcKeyParams(SSLNamedGroup group);
+ static void GenerateEchConfig(
+ HpkeKemId kem_id, const std::vector<HpkeSymmetricSuite>& cipher_suites,
+ const std::string& public_name, uint16_t max_name_len, DataBuffer& record,
+ ScopedSECKEYPublicKey& pubKey, ScopedSECKEYPrivateKey& privKey);
+ void SetupEch(std::shared_ptr<TlsAgent>& client,
+ std::shared_ptr<TlsAgent>& server,
+ HpkeKemId kem_id = HpkeDhKemX25519Sha256,
+ bool expect_ech = true, bool set_client_config = true,
+ bool set_server_config = true, int maxConfigSize = 100);
+
+ protected:
+ SSLProtocolVariant variant_;
+ std::shared_ptr<TlsAgent> client_;
+ std::shared_ptr<TlsAgent> server_;
+ std::unique_ptr<TlsAgent> client_model_;
+ std::unique_ptr<TlsAgent> server_model_;
+ uint16_t version_;
+ SessionResumptionMode expected_resumption_mode_;
+ uint8_t expected_resumptions_;
+ std::vector<std::vector<uint8_t>> session_ids_;
+ ScopedSSLAntiReplayContext anti_replay_;
+
+ // A simple value of "a", "b". Note that the preferred value of "a" is placed
+ // at the end, because the NSS API follows the now defunct NPN specification,
+ // which places the preferred (and default) entry at the end of the list.
+ // NSS will move this final entry to the front when used with ALPN.
+ const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61};
+
+ // A list of algorithm IDs whose policies need to be preserved
+ // around test cases. In particular, DSA is checked in
+ // ssl_extension_unittest.cc.
+ const std::vector<SECOidTag> algorithms_ = {SEC_OID_APPLY_SSL_POLICY,
+ SEC_OID_ANSIX9_DSA_SIGNATURE,
+ SEC_OID_CURVE25519, SEC_OID_SHA1};
+ std::vector<std::tuple<SECOidTag, uint32_t>> saved_policies_;
+ const std::vector<PRInt32> options_ = {
+ NSS_RSA_MIN_KEY_SIZE, NSS_DH_MIN_KEY_SIZE, NSS_DSA_MIN_KEY_SIZE,
+ NSS_TLS_VERSION_MIN_POLICY, NSS_TLS_VERSION_MAX_POLICY};
+ std::vector<std::tuple<PRInt32, uint32_t>> saved_options_;
+
+ private:
+ void CheckResumption(SessionResumptionMode expected);
+ void CheckExtendedMasterSecret();
+ void CheckEarlyDataAccepted();
+ static PRTime TimeFunc(void* arg);
+
+ bool expect_extended_master_secret_;
+ bool expect_early_data_accepted_;
+ bool skip_version_checks_;
+ PRTime now_;
+
+ // Track groups and make sure that there are no duplicates.
+ class DuplicateGroupChecker {
+ public:
+ void AddAndCheckGroup(SSLNamedGroup group) {
+ EXPECT_EQ(groups_.end(), groups_.find(group))
+ << "Group " << group << " should not be duplicated";
+ groups_.insert(group);
+ }
+
+ private:
+ std::set<SSLNamedGroup> groups_;
+ };
+};
+
+// A non-parametrized TLS test base.
+class TlsConnectTest : public TlsConnectTestBase {
+ public:
+ TlsConnectTest() : TlsConnectTestBase(ssl_variant_stream, 0) {}
+};
+
+// A non-parametrized DTLS-only test base.
+class DtlsConnectTest : public TlsConnectTestBase {
+ public:
+ DtlsConnectTest() : TlsConnectTestBase(ssl_variant_datagram, 0) {}
+};
+
+// A TLS-only test base.
+class TlsConnectStream : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<uint16_t> {
+ public:
+ TlsConnectStream() : TlsConnectTestBase(ssl_variant_stream, GetParam()) {}
+};
+
+// A TLS-only test base for tests before 1.3
+class TlsConnectStreamPre13 : public TlsConnectStream {};
+
+// A DTLS-only test base.
+class TlsConnectDatagram : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<uint16_t> {
+ public:
+ TlsConnectDatagram() : TlsConnectTestBase(ssl_variant_datagram, GetParam()) {}
+};
+
+// A generic test class that can be either stream or datagram and a single
+// version of TLS. This is configured in ssl_loopback_unittest.cc.
+class TlsConnectGeneric : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsConnectGeneric();
+};
+
+class TlsConnectGenericResumption
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t, bool>> {
+ private:
+ bool external_cache_;
+
+ public:
+ TlsConnectGenericResumption();
+
+ virtual void EnsureTlsSetup() {
+ TlsConnectTestBase::EnsureTlsSetup();
+ // Enable external resumption token cache.
+ if (external_cache_) {
+ client_->SetResumptionTokenCallback();
+ }
+ }
+
+ bool use_external_cache() const { return external_cache_; }
+};
+
+class TlsConnectTls13ResumptionToken
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
+ public:
+ TlsConnectTls13ResumptionToken();
+
+ virtual void EnsureTlsSetup() {
+ TlsConnectTestBase::EnsureTlsSetup();
+ client_->SetResumptionTokenCallback();
+ }
+};
+
+class TlsConnectGenericResumptionToken
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsConnectGenericResumptionToken();
+
+ virtual void EnsureTlsSetup() {
+ TlsConnectTestBase::EnsureTlsSetup();
+ client_->SetResumptionTokenCallback();
+ }
+};
+
+// A Pre TLS 1.2 generic test.
+class TlsConnectPre12 : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsConnectPre12();
+};
+
+// A TLS 1.2 only generic test.
+class TlsConnectTls12
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
+ public:
+ TlsConnectTls12();
+};
+
+// A TLS 1.2 only stream test.
+class TlsConnectStreamTls12 : public TlsConnectTestBase {
+ public:
+ TlsConnectStreamTls12()
+ : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_2) {}
+};
+
+// A TLS 1.2+ generic test.
+class TlsConnectTls12Plus : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsConnectTls12Plus();
+};
+
+// A TLS 1.3 only generic test.
+class TlsConnectTls13
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
+ public:
+ TlsConnectTls13();
+};
+
+// A TLS 1.3 only stream test.
+class TlsConnectStreamTls13 : public TlsConnectTestBase {
+ public:
+ TlsConnectStreamTls13()
+ : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {}
+};
+
+class TlsConnectDatagram13 : public TlsConnectTestBase {
+ public:
+ TlsConnectDatagram13()
+ : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {}
+};
+
+class TlsConnectDatagramPre13 : public TlsConnectDatagram {
+ public:
+ TlsConnectDatagramPre13() {}
+};
+
+// A variant that is used only with Pre13.
+class TlsConnectGenericPre13 : public TlsConnectGeneric {};
+
+class TlsKeyExchangeTest : public TlsConnectGeneric {
+ protected:
+ std::shared_ptr<TlsExtensionCapture> groups_capture_;
+ std::shared_ptr<TlsExtensionCapture> shares_capture_;
+ std::shared_ptr<TlsExtensionCapture> shares_capture2_;
+ std::shared_ptr<TlsHandshakeRecorder> capture_hrr_;
+
+ void EnsureKeyShareSetup();
+ void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
+ std::vector<SSLNamedGroup> GetGroupDetails(
+ const std::shared_ptr<TlsExtensionCapture>& capture);
+ std::vector<SSLNamedGroup> GetShareDetails(
+ const std::shared_ptr<TlsExtensionCapture>& capture);
+ void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
+ const std::vector<SSLNamedGroup>& expectedShares);
+ void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
+ const std::vector<SSLNamedGroup>& expectedShares,
+ SSLNamedGroup expectedShare2);
+
+ private:
+ void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
+ const std::vector<SSLNamedGroup>& expectedShares,
+ bool expect_hrr);
+};
+
+class TlsKeyExchangeTest13 : public TlsKeyExchangeTest {};
+class TlsKeyExchangeTestPre13 : public TlsKeyExchangeTest {};
+
+} // namespace nss_test
+
+#endif
diff --git a/security/nss/gtests/ssl_gtest/tls_ech_unittest.cc b/security/nss/gtests/ssl_gtest/tls_ech_unittest.cc
new file mode 100644
index 0000000000..1e70a6ee59
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_ech_unittest.cc
@@ -0,0 +1,2913 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+
+#include "gtest_utils.h"
+#include "pk11pub.h"
+#include "tls_agent.h"
+#include "tls_connect.h"
+#include "util.h"
+#include "tls13ech.h"
+
+namespace nss_test {
+
+class TlsAgentEchTest : public TlsAgentTestClient13 {
+ protected:
+ void InstallEchConfig(const DataBuffer& echconfig, PRErrorCode err = 0) {
+ SECStatus rv = SSL_SetClientEchConfigs(agent_->ssl_fd(), echconfig.data(),
+ echconfig.len());
+ if (err == 0) {
+ ASSERT_EQ(SECSuccess, rv);
+ } else {
+ ASSERT_EQ(SECFailure, rv);
+ ASSERT_EQ(err, PORT_GetError());
+ }
+ }
+};
+
+#include "cpputil.h" // Unused function error if included without HPKE.
+
+static std::string kPublicName("public.name");
+
+static const std::vector<HpkeSymmetricSuite> kDefaultSuites = {
+ {HpkeKdfHkdfSha256, HpkeAeadChaCha20Poly1305},
+ {HpkeKdfHkdfSha256, HpkeAeadAes128Gcm}};
+static const std::vector<HpkeSymmetricSuite> kSuiteChaCha = {
+ {HpkeKdfHkdfSha256, HpkeAeadChaCha20Poly1305}};
+static const std::vector<HpkeSymmetricSuite> kSuiteAes = {
+ {HpkeKdfHkdfSha256, HpkeAeadAes128Gcm}};
+std::vector<HpkeSymmetricSuite> kBogusSuite = {
+ {static_cast<HpkeKdfId>(0xfefe), static_cast<HpkeAeadId>(0xfefe)}};
+static const std::vector<HpkeSymmetricSuite> kUnknownFirstSuite = {
+ {static_cast<HpkeKdfId>(0xfefe), static_cast<HpkeAeadId>(0xfefe)},
+ {HpkeKdfHkdfSha256, HpkeAeadAes128Gcm}};
+
+class TlsConnectStreamTls13Ech : public TlsConnectTestBase {
+ public:
+ TlsConnectStreamTls13Ech()
+ : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {}
+
+ void ReplayChWithMalformedInner(const std::string& ch, uint8_t server_alert,
+ uint32_t server_code, uint32_t client_code) {
+ std::vector<uint8_t> ch_vec = hex_string_to_bytes(ch);
+ DataBuffer ch_buf;
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ EnsureTlsSetup();
+ ImportFixedEchKeypair(pub, priv);
+ SetMutualEchConfigs(pub, priv);
+
+ TlsAgentTestBase::MakeRecord(variant_, ssl_ct_handshake,
+ SSL_LIBRARY_VERSION_TLS_1_3, ch_vec.data(),
+ ch_vec.size(), &ch_buf, 0);
+ StartConnect();
+ client_->SendDirect(ch_buf);
+ ExpectAlert(server_, server_alert);
+ server_->Handshake();
+ server_->CheckErrorCode(server_code);
+ client_->ExpectReceiveAlert(server_alert, kTlsAlertFatal);
+ client_->Handshake();
+ client_->CheckErrorCode(client_code);
+ }
+
+ // Setup Client/Server with mismatched AEADs
+ void SetupForEchRetry() {
+ ScopedSECKEYPublicKey server_pub;
+ ScopedSECKEYPrivateKey server_priv;
+ ScopedSECKEYPublicKey client_pub;
+ ScopedSECKEYPrivateKey client_priv;
+ DataBuffer server_rec;
+ DataBuffer client_rec;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteChaCha,
+ kPublicName, 100, server_rec,
+ server_pub, server_priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(),
+ server_priv.get(), server_rec.data(),
+ server_rec.len()));
+
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes,
+ kPublicName, 100, client_rec,
+ client_pub, client_priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), client_rec.data(),
+ client_rec.len()));
+ }
+
+ // Parse a captured SNI extension and validate the contained name.
+ void CheckSniExtension(const DataBuffer& data,
+ const std::string expected_name) {
+ TlsParser parser(data.data(), data.len());
+ uint32_t tmp;
+ ASSERT_TRUE(parser.Read(&tmp, 2));
+ ASSERT_EQ(parser.remaining(), tmp);
+ ASSERT_TRUE(parser.Read(&tmp, 1));
+ ASSERT_EQ(0U, tmp); /* sni_nametype_hostname */
+ DataBuffer name;
+ ASSERT_TRUE(parser.ReadVariable(&name, 2));
+ ASSERT_EQ(0U, parser.remaining());
+ // Manual comparison to silence coverity false-positives.
+ ASSERT_EQ(name.len(), kPublicName.length());
+ ASSERT_EQ(0,
+ memcmp(kPublicName.c_str(), name.data(), kPublicName.length()));
+ }
+
+ void DoEchRetry(const ScopedSECKEYPublicKey& server_pub,
+ const ScopedSECKEYPrivateKey& server_priv,
+ const DataBuffer& server_rec) {
+ StackSECItem retry_configs;
+ ASSERT_EQ(SECSuccess,
+ SSL_GetEchRetryConfigs(client_->ssl_fd(), &retry_configs));
+ ASSERT_NE(0U, retry_configs.len);
+
+ // Reset expectations for the TlsAgent dtor.
+ server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning);
+ Reset();
+ EnsureTlsSetup();
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(),
+ server_priv.get(), server_rec.data(),
+ server_rec.len()));
+ ASSERT_EQ(SECSuccess,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), retry_configs.data,
+ retry_configs.len));
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+ }
+
+ void ImportFixedEchKeypair(ScopedSECKEYPublicKey& pub,
+ ScopedSECKEYPrivateKey& priv) {
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ if (!slot) {
+ ADD_FAILURE() << "No slot";
+ return;
+ }
+ std::vector<uint8_t> pkcs8_r = hex_string_to_bytes(kFixedServerKey);
+ SECItem pkcs8_r_item = {siBuffer, toUcharPtr(pkcs8_r.data()),
+ static_cast<unsigned int>(pkcs8_r.size())};
+
+ SECKEYPrivateKey* tmp_priv = nullptr;
+ ASSERT_EQ(SECSuccess, PK11_ImportDERPrivateKeyInfoAndReturnKey(
+ slot.get(), &pkcs8_r_item, nullptr, nullptr,
+ false, false, KU_ALL, &tmp_priv, nullptr));
+ priv.reset(tmp_priv);
+ SECKEYPublicKey* tmp_pub = SECKEY_ConvertToPublicKey(tmp_priv);
+ pub.reset(tmp_pub);
+ ASSERT_NE(nullptr, tmp_pub);
+ }
+
+ void SetMutualEchConfigs(ScopedSECKEYPublicKey& pub,
+ ScopedSECKEYPrivateKey& priv) {
+ DataBuffer echconfig;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub,
+ priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+ ASSERT_EQ(SECSuccess,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(),
+ echconfig.len()));
+ }
+
+ void ValidatePublicNames(const std::vector<std::string>& names,
+ SECStatus expected) {
+ static const std::vector<HpkeSymmetricSuite> kSuites = {
+ {HpkeKdfHkdfSha256, HpkeAeadAes128Gcm}};
+
+ ScopedSECItem ecParams = MakeEcKeyParams(ssl_grp_ec_curve25519);
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ SECKEYPublicKey* pub_p = nullptr;
+ SECKEYPrivateKey* priv_p =
+ SECKEY_CreateECPrivateKey(ecParams.get(), &pub_p, nullptr);
+ pub.reset(pub_p);
+ priv.reset(priv_p);
+ ASSERT_TRUE(!!pub);
+ ASSERT_TRUE(!!priv);
+
+ EnsureTlsSetup();
+
+ DataBuffer cfg;
+ SECStatus rv;
+ for (auto name : names) {
+ if (g_ssl_gtest_verbose) {
+ std::cout << ((expected == SECFailure) ? "in" : "")
+ << "valid public_name: " << name << std::endl;
+ }
+ GenerateEchConfig(HpkeDhKemX25519Sha256, kSuites, name, 100, cfg, pub,
+ priv);
+
+ rv = SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ cfg.data(), cfg.len());
+ EXPECT_EQ(expected, rv);
+
+ rv = SSL_SetClientEchConfigs(client_->ssl_fd(), cfg.data(), cfg.len());
+ EXPECT_EQ(expected, rv);
+ }
+ }
+
+ private:
+ // Testing certan invalid CHInner configurations is tricky, particularly
+ // since the CHOuter forms AAD and isn't available in filters. Instead of
+ // generating these inputs on the fly, use a fixed server keypair so that
+ // the input can be generated once (e.g. via a debugger) and replayed in
+ // each invocation of the test.
+ std::string kFixedServerKey =
+ "3067020100301406072a8648ce3d020106092b06010401da470f01044c304a"
+ "02010104205a8aa0d2476b28521588e0c704b14db82cdd4970d340d293a957"
+ "6deaee9ec1c7a1230321008756e2580c07c1d2ffcb662f5fadc6d6ff13da85"
+ "abd7adfecf984aaa102c1269";
+};
+
+static void CheckCertVerifyPublicName(TlsAgent* agent) {
+ agent->UpdatePreliminaryChannelInfo();
+ EXPECT_NE(0U, (agent->pre_info().valuesSet & ssl_preinfo_ech));
+ EXPECT_EQ(agent->GetEchExpected(), agent->pre_info().echAccepted);
+
+ // Check that echPublicName is only exposed in the rejection
+ // case, so that the application can use it for CertVerfiy.
+ if (agent->GetEchExpected()) {
+ EXPECT_EQ(nullptr, agent->pre_info().echPublicName);
+ } else {
+ EXPECT_NE(nullptr, agent->pre_info().echPublicName);
+ if (agent->pre_info().echPublicName) {
+ EXPECT_EQ(0,
+ strcmp(kPublicName.c_str(), agent->pre_info().echPublicName));
+ }
+ }
+}
+
+static SECStatus AuthCompleteFail(TlsAgent* agent, PRBool, PRBool) {
+ CheckCertVerifyPublicName(agent);
+ return SECFailure;
+}
+
+// Given two EchConfigList structures, e.g. from GenerateEchConfig, construct
+// a single list containing all entries.
+static DataBuffer MakeEchConfigList(DataBuffer config1, DataBuffer config2) {
+ DataBuffer sizedConfigListBuffer;
+
+ sizedConfigListBuffer.Write(2, config1.data() + 2, config1.len() - 2);
+ sizedConfigListBuffer.Write(sizedConfigListBuffer.len(), config2.data() + 2,
+ config2.len() - 2);
+ sizedConfigListBuffer.Write(0, sizedConfigListBuffer.len() - 2, 2);
+
+ PR_ASSERT(sizedConfigListBuffer.len() == config1.len() + config2.len() - 2);
+ return sizedConfigListBuffer;
+}
+
+TEST_P(TlsAgentEchTest, EchConfigsSupportedYesNo) {
+ if (variant_ == ssl_variant_datagram) {
+ GTEST_SKIP();
+ }
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ // ECHConfig 2 cipher_suites are unsupported.
+ DataBuffer config1;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes,
+ kPublicName, 100, config1, pub, priv);
+ DataBuffer config2;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kBogusSuite,
+ kPublicName, 100, config2, pub, priv);
+ EnsureInit();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(),
+ PR_FALSE)); // Don't GREASE
+
+ DataBuffer sizedConfigListBuffer = MakeEchConfigList(config1, config2);
+ InstallEchConfig(sizedConfigListBuffer, 0);
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ agent_, ssl_tls13_encrypted_client_hello_xtn);
+ agent_->Handshake();
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state());
+ ASSERT_TRUE(filter->captured());
+}
+
+TEST_P(TlsAgentEchTest, EchConfigsSupportedNoYes) {
+ if (variant_ == ssl_variant_datagram) {
+ GTEST_SKIP();
+ }
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer config2;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes,
+ kPublicName, 100, config2, pub, priv);
+ DataBuffer config1;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kBogusSuite,
+ kPublicName, 100, config1, pub, priv);
+ // ECHConfig 1 cipher_suites are unsupported.
+ DataBuffer sizedConfigListBuffer = MakeEchConfigList(config1, config2);
+ EnsureInit();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(),
+ PR_FALSE)); // Don't GREASE
+ InstallEchConfig(sizedConfigListBuffer, 0);
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ agent_, ssl_tls13_encrypted_client_hello_xtn);
+ agent_->Handshake();
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state());
+ ASSERT_TRUE(filter->captured());
+}
+
+TEST_P(TlsAgentEchTest, EchConfigsSupportedNoNo) {
+ if (variant_ == ssl_variant_datagram) {
+ GTEST_SKIP();
+ }
+
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer config2;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kBogusSuite,
+ kPublicName, 100, config2, pub, priv);
+ DataBuffer config1;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kBogusSuite,
+ kPublicName, 100, config1, pub, priv);
+ // ECHConfig 1 and 2 cipher_suites are unsupported.
+ DataBuffer sizedConfigListBuffer = MakeEchConfigList(config1, config2);
+ EnsureInit();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(),
+ PR_FALSE)); // Don't GREASE
+ InstallEchConfig(sizedConfigListBuffer, SEC_ERROR_INVALID_ARGS);
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ agent_, ssl_tls13_encrypted_client_hello_xtn);
+ agent_->Handshake();
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state());
+ ASSERT_FALSE(filter->captured());
+}
+
+TEST_P(TlsAgentEchTest, ShortEchConfig) {
+ EnsureInit();
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub, priv);
+ echconfig.Truncate(echconfig.len() - 1);
+ InstallEchConfig(echconfig, SEC_ERROR_BAD_DATA);
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(),
+ PR_FALSE)); // Don't GREASE
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ agent_, ssl_tls13_encrypted_client_hello_xtn);
+ agent_->Handshake();
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state());
+ ASSERT_FALSE(filter->captured());
+}
+
+TEST_P(TlsAgentEchTest, LongEchConfig) {
+ EnsureInit();
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub, priv);
+ echconfig.Write(echconfig.len(), 1, 1); // Append one byte
+ InstallEchConfig(echconfig, SEC_ERROR_BAD_DATA);
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(),
+ PR_FALSE)); // Don't GREASE
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ agent_, ssl_tls13_encrypted_client_hello_xtn);
+ agent_->Handshake();
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state());
+ ASSERT_FALSE(filter->captured());
+}
+
+TEST_P(TlsAgentEchTest, UnsupportedEchConfigVersion) {
+ EnsureInit();
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ static const uint8_t bad_version[] = {0xff, 0xff};
+ DataBuffer bad_ver_buf(bad_version, sizeof(bad_version));
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub, priv);
+ echconfig.Splice(bad_ver_buf, 2, 2);
+ InstallEchConfig(echconfig, SEC_ERROR_INVALID_ARGS);
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(),
+ PR_FALSE)); // Don't GREASE
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ agent_, ssl_tls13_encrypted_client_hello_xtn);
+ agent_->Handshake();
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state());
+ ASSERT_FALSE(filter->captured());
+}
+
+TEST_P(TlsAgentEchTest, UnsupportedHpkeKem) {
+ EnsureInit();
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ // SSL_EncodeEchConfigId encodes without validation.
+ TlsConnectTestBase::GenerateEchConfig(static_cast<HpkeKemId>(0xff),
+ kDefaultSuites, kPublicName, 100,
+ echconfig, pub, priv);
+ InstallEchConfig(echconfig, SEC_ERROR_INVALID_ARGS);
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(),
+ PR_FALSE)); // Don't GREASE
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ agent_, ssl_tls13_encrypted_client_hello_xtn);
+ agent_->Handshake();
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state());
+ ASSERT_FALSE(filter->captured());
+}
+
+TEST_P(TlsAgentEchTest, EchRejectIgnoreAllUnknownSuites) {
+ EnsureInit();
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kBogusSuite,
+ kPublicName, 100, echconfig, pub, priv);
+ InstallEchConfig(echconfig, SEC_ERROR_INVALID_ARGS);
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(),
+ PR_FALSE)); // Don't GREASE
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ agent_, ssl_tls13_encrypted_client_hello_xtn);
+ agent_->Handshake();
+ ASSERT_FALSE(filter->captured());
+}
+
+TEST_P(TlsAgentEchTest, EchConfigRejectEmptyPublicName) {
+ EnsureInit();
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kBogusSuite, "",
+ 100, echconfig, pub, priv);
+ InstallEchConfig(echconfig, SSL_ERROR_RX_MALFORMED_ECH_CONFIG);
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(),
+ PR_FALSE)); // Don't GREASE
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ agent_, ssl_tls13_encrypted_client_hello_xtn);
+ agent_->Handshake();
+ ASSERT_FALSE(filter->captured());
+}
+
+TEST_F(TlsConnectStreamTls13, EchAcceptIgnoreSingleUnknownSuite) {
+ EnsureTlsSetup();
+ DataBuffer echconfig;
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256,
+ kUnknownFirstSuite, kPublicName, 100,
+ echconfig, pub, priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(),
+ echconfig.len()));
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+}
+
+TEST_P(TlsAgentEchTest, ApiInvalidArgs) {
+ EnsureInit();
+ // SetClient
+ EXPECT_EQ(SECFailure, SSL_SetClientEchConfigs(agent_->ssl_fd(), nullptr, 1));
+
+ EXPECT_EQ(SECFailure,
+ SSL_SetClientEchConfigs(agent_->ssl_fd(),
+ reinterpret_cast<const uint8_t*>(1), 0));
+
+ // SetServer
+ EXPECT_EQ(SECFailure,
+ SSL_SetServerEchConfigs(agent_->ssl_fd(), nullptr,
+ reinterpret_cast<SECKEYPrivateKey*>(1),
+ reinterpret_cast<const uint8_t*>(1), 1));
+ EXPECT_EQ(SECFailure,
+ SSL_SetServerEchConfigs(
+ agent_->ssl_fd(), reinterpret_cast<SECKEYPublicKey*>(1),
+ nullptr, reinterpret_cast<const uint8_t*>(1), 1));
+ EXPECT_EQ(SECFailure,
+ SSL_SetServerEchConfigs(
+ agent_->ssl_fd(), reinterpret_cast<SECKEYPublicKey*>(1),
+ reinterpret_cast<SECKEYPrivateKey*>(1), nullptr, 1));
+ EXPECT_EQ(SECFailure,
+ SSL_SetServerEchConfigs(agent_->ssl_fd(),
+ reinterpret_cast<SECKEYPublicKey*>(1),
+ reinterpret_cast<SECKEYPrivateKey*>(1),
+ reinterpret_cast<const uint8_t*>(1), 0));
+
+ // GetRetries
+ EXPECT_EQ(SECFailure, SSL_GetEchRetryConfigs(agent_->ssl_fd(), nullptr));
+
+ // EncodeEchConfigId
+ EXPECT_EQ(SECFailure,
+ SSL_EncodeEchConfigId(0, nullptr, 1, static_cast<HpkeKemId>(1),
+ reinterpret_cast<SECKEYPublicKey*>(1),
+ reinterpret_cast<HpkeSymmetricSuite*>(1), 1,
+ reinterpret_cast<uint8_t*>(1),
+ reinterpret_cast<unsigned int*>(1), 1));
+
+ EXPECT_EQ(SECFailure,
+ SSL_EncodeEchConfigId(0, "name", 1, static_cast<HpkeKemId>(1),
+ reinterpret_cast<SECKEYPublicKey*>(1),
+ nullptr, 1, reinterpret_cast<uint8_t*>(1),
+ reinterpret_cast<unsigned int*>(1), 1));
+ EXPECT_EQ(SECFailure,
+ SSL_EncodeEchConfigId(0, "name", 1, static_cast<HpkeKemId>(1),
+ reinterpret_cast<SECKEYPublicKey*>(1),
+ reinterpret_cast<HpkeSymmetricSuite*>(1), 0,
+ reinterpret_cast<uint8_t*>(1),
+ reinterpret_cast<unsigned int*>(1), 1));
+
+ EXPECT_EQ(SECFailure, SSL_EncodeEchConfigId(
+ 0, "name", 1, static_cast<HpkeKemId>(1), nullptr,
+ reinterpret_cast<HpkeSymmetricSuite*>(1), 1,
+ reinterpret_cast<uint8_t*>(1),
+ reinterpret_cast<unsigned int*>(1), 1));
+
+ EXPECT_EQ(SECFailure,
+ SSL_EncodeEchConfigId(0, nullptr, 0, static_cast<HpkeKemId>(1),
+ reinterpret_cast<SECKEYPublicKey*>(1),
+ reinterpret_cast<HpkeSymmetricSuite*>(1), 1,
+ reinterpret_cast<uint8_t*>(1),
+ reinterpret_cast<unsigned int*>(1), 1));
+
+ EXPECT_EQ(SECFailure, SSL_EncodeEchConfigId(
+ 0, "name", 1, static_cast<HpkeKemId>(1),
+ reinterpret_cast<SECKEYPublicKey*>(1),
+ reinterpret_cast<HpkeSymmetricSuite*>(1), 1,
+ nullptr, reinterpret_cast<unsigned int*>(1), 1));
+
+ EXPECT_EQ(SECFailure,
+ SSL_EncodeEchConfigId(0, "name", 1, static_cast<HpkeKemId>(1),
+ reinterpret_cast<SECKEYPublicKey*>(1),
+ reinterpret_cast<HpkeSymmetricSuite*>(1), 1,
+ reinterpret_cast<uint8_t*>(1), nullptr, 1));
+}
+
+TEST_P(TlsAgentEchTest, NoEarlyRetryConfigs) {
+ EnsureInit();
+ StackSECItem retry_configs;
+ EXPECT_EQ(SECFailure,
+ SSL_GetEchRetryConfigs(agent_->ssl_fd(), &retry_configs));
+ EXPECT_EQ(SSL_ERROR_HANDSHAKE_NOT_COMPLETED, PORT_GetError());
+
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub, priv);
+ InstallEchConfig(echconfig, 0);
+
+ EXPECT_EQ(SECFailure,
+ SSL_GetEchRetryConfigs(agent_->ssl_fd(), &retry_configs));
+ EXPECT_EQ(SSL_ERROR_HANDSHAKE_NOT_COMPLETED, PORT_GetError());
+}
+
+TEST_P(TlsAgentEchTest, NoSniSoNoEch) {
+ EnsureInit();
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub, priv);
+ SSL_SetURL(agent_->ssl_fd(), "");
+ InstallEchConfig(echconfig, 0);
+ SSL_SetURL(agent_->ssl_fd(), "");
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(),
+ PR_FALSE)); // Don't GREASE
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ agent_, ssl_tls13_encrypted_client_hello_xtn);
+ agent_->Handshake();
+ ASSERT_FALSE(filter->captured());
+}
+
+TEST_P(TlsAgentEchTest, NoEchConfigSoNoEch) {
+ EnsureInit();
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(),
+ PR_FALSE)); // Don't GREASE
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ agent_, ssl_tls13_encrypted_client_hello_xtn);
+ agent_->Handshake();
+ ASSERT_FALSE(filter->captured());
+}
+
+TEST_P(TlsAgentEchTest, EchConfigDuplicateExtensions) {
+ EnsureInit();
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub, priv);
+
+ static const uint8_t duped_xtn[] = {0x00, 0x08, 0x00, 0x01, 0x00,
+ 0x00, 0x00, 0x01, 0x00, 0x00};
+ DataBuffer buf(duped_xtn, sizeof(duped_xtn));
+ echconfig.Truncate(echconfig.len() - 2);
+ echconfig.Append(buf);
+ uint32_t len;
+ ASSERT_TRUE(echconfig.Read(0, 2, &len));
+ len += buf.len() - 2;
+ DataBuffer new_len;
+ ASSERT_TRUE(new_len.Write(0, len, 2));
+ echconfig.Splice(new_len, 0, 2);
+ new_len.Truncate(0);
+
+ ASSERT_TRUE(echconfig.Read(4, 2, &len));
+ len += buf.len() - 2;
+ ASSERT_TRUE(new_len.Write(0, len, 2));
+ echconfig.Splice(new_len, 4, 2);
+
+ InstallEchConfig(echconfig, SEC_ERROR_EXTENSION_VALUE_INVALID);
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(agent_->ssl_fd(),
+ PR_FALSE)); // Don't GREASE
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ agent_, ssl_tls13_encrypted_client_hello_xtn);
+ agent_->Handshake();
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, agent_->state());
+ ASSERT_FALSE(filter->captured());
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchFixedConfig) {
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ EnsureTlsSetup();
+ ImportFixedEchKeypair(pub, priv);
+ SetMutualEchConfigs(pub, priv);
+
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+}
+
+// The next set of tests all use a fixed server key and a pre-built ClientHello.
+// This ClientHelo can be constructed using the above EchFixedConfig test,
+// modifying tls13_ConstructInnerExtensionsFromOuter as indicated. For this
+// small number of tests, these fixed values are easier to construct than
+// constructing ClientHello in the test that can be successfully decrypted.
+
+// Test an encoded ClientHelloInner containing an extra extensionType
+// in outer_extensions, for which there is no corresponding (uncompressed)
+// extension in ClientHelloOuter.
+TEST_F(TlsConnectStreamTls13Ech, EchOuterExtensionsReferencesMissing) {
+ // Construct this by prepending 0xabcd to ssl_tls13_outer_extensions_xtn.
+ std::string ch =
+ "010001fc030390901d039ca83262d9115a5f98f43ddb2553241a8de5c46d9f118c4c29c2"
+ "64bc000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff"
+ "01000100000a00140012001d00170018001901000101010201030104003300260024001d"
+ "00206df5f908d1c02320e246694c765d5ec1c0f7d7aef2b1b00b17c36331623d332d002b"
+ "0003020304000d0018001604030503060302030804080508060401050106010201002d00"
+ "020101001c00024001fe0d00f900000100034d00209a4f67b0744d1fba23aa4bacfadb2a"
+ "c706562dae04d80a83ae668a6f2dd6ef2700cfab1671182341df246d66c3aca873e8c714"
+ "bc2b1c3b576653609533c486df0bdcf63ab4e4e7d0b67fadf4e3504eec96f72e6778b15d"
+ "69c9a9594a041348a7130f67a1a7cac796a0e6d6fca505438355278a9a8fd55e44218441"
+ "9927a1e084ac7d7adeb2f0c19faafba430876bf0cdf4d195b2d06428b3de13120f65748a"
+ "468f8997a2c3bf1dd7f3996a0f2c70dea6c88149df182b3c3b78a8da8bb709a9ed9d77c6"
+ "5dc09accdfeb66c90db26b99a35052a8cbaf7bb9307a1e17d90a7aa9f768f5f446559d08"
+ "69bccc83eda9d2b347a00015004200000000000000000000000000000000000000000000"
+ "000000000000000000000000000000000000000000000000000000000000000000000000"
+ "0000000000000000";
+ ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter,
+ SSL_ERROR_RX_MALFORMED_ECH_EXTENSION,
+ SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchOuterExtensionsInsideInner) {
+ // Construct this by appending ssl_tls13_outer_extensions_xtn to the
+ // references in
+ // ssl_tls13_outer_extensions_xtn.
+ std::string ch =
+ "010001fc03035e2268bc7133079cd33eb088253393e561d80c5ee6f9a238aff022e1e10d"
+ "4c82000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff"
+ "01000100000a00140012001d00170018001901000101010201030104003300260024001d"
+ "00200e071fd982854d50236ed0e4e7981460840f03d03fd84b44c409fe486203b252002b"
+ "0003020304000d0018001604030503060302030804080508060401050106010201002d00"
+ "020101001c00024001fe0d00f900000100034d002099a032502ea4fd3c85b858ae1c59df"
+ "6a374f3698ed6bca188cf75c432c78cf5a00cf28dde32de7ade40abb16d550c1eec3dad4"
+ "a03c85efb95ec605837deae92a419285116e5cb8223ea53cff2b605e66f28e96d37e9b4e"
+ "3035fb1cfa125fa053d6770091b5731c9fb03e872a82991dfdd24ad8399fcc76db7fadba"
+ "029e064beb02c1282684a93e777bcefbca3dd143dfc225d2e65c80dbf3819ebda288e32c"
+ "3a1f8a27bb3aa9480dee2a4307073da3e15ee03dba386223d9399ad796af80c646f85406"
+ "282c34fd9406d25752087f08140e1be834e8a149f0bebfc2b3db16ccba83c37051e2e75d"
+ "e8a4e999ef385c74c96d0015004200000000000000000000000000000000000000000000"
+ "000000000000000000000000000000000000000000000000000000000000000000000000"
+ "0000000000000000";
+ ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter,
+ SSL_ERROR_RX_MALFORMED_ECH_EXTENSION,
+ SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchOuterExtensionsDuplicateReference) {
+ // Construct this by altering tls13_ConstructInnerExtensionsFromOuter to have
+ // each extension inserted in the reference list twice and running the
+ // EchFixedConfig test.
+ std::string ch =
+ "010001fc0303d8717df80286fcd8b4242ed846995c6473e290678231046bb1bfc7848460"
+ "b122000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff"
+ "01000100000a00140012001d00170018001901000101010201030104003300260024001d"
+ "00206f21d5fdf7bf81943939a03656c1195ad347cec453bf7a16d0773ffef481d22f002b"
+ "0003020304000d0018001604030503060302030804080508060401050106010201002d00"
+ "020101001c00024001fe0d011900000100034d002027eb9b641ba8ffc3a4028d00d1f5bd"
+ "e190736b1ea5a79513dee0a551cc6fe55200efc2ed6bf501f100896eb91221ce512c20c3"
+ "c5c110e7be6a5d340854ff5ac0175312631b021fd5a5c9841549989f415c4041a4b384b1"
+ "dba1d6b4182cc48904f993a15eab6bf7787b267ca65acef51c019508e0c9b382086a71d8"
+ "517cf19644d66d396efc066a4d37916d67b0e5fe08d52dd94d068dd85b9a245aaffac4ff"
+ "66d9a5221fd5805473bb7584eb7f218357c00aff890d2f2edf1c092c648c888b5cba1ca6"
+ "26817fda7765fcedfbc418b90b1841d878ed443593cafb61fa8fb708c53977615b45f545"
+ "2a8236cab3ec121cdc91a2de6a79437cae9d09e781339fddcac005ce62fd65d50e33faa2"
+ "2366955a0374001500220000000000000000000000000000000000000000000000000000"
+ "0000000000000000";
+ ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter,
+ SSL_ERROR_RX_MALFORMED_ECH_EXTENSION,
+ SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchOuterExtensionsOutOfOrder) {
+ // Construct this by altering tls13_ConstructInnerExtensionsFromOuter to leave
+ // a gap at the start and insert a 'late' extension there.
+ std::string ch =
+ "010001fc0303fabff6caf4d8b1eb1db5945c96badefec4b33188b97121e6a206e82b74bd"
+ "a855000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff"
+ "01000100000a00140012001d00170018001901000101010201030104003300260024001d"
+ "00208fe970fc0c908f0c51734f18467e640df1d45a6ace2948b5c4bf73ee52ab3160002b"
+ "0003020304000d0018001604030503060302030804080508060401050106010201002d00"
+ "020101001c00024001fe0d00f900000100034d00203339239f8925c3f9b89f4ced17c3b3"
+ "1c649299d7e10b3cdbc115de2a57d90d2200cf006e62866516380e8a16763bee5b2a75a8"
+ "74e8698c459f474d0e952c2fd3300bef1decd6f259b8ac2912684ef69b7a7be2520fbf15"
+ "5e0c3f88998789976ca1fbcaa40616fc513e3353540db091da76ca98007532974550d3da"
+ "aaddb799baf60adbc5800df30e187251427fe9de707d18a270352ee44f6eb37f0d8c72a1"
+ "5f9ffb5dd4bbb6045473c8d99b7a5c2c8cc59027f346cbe6ef240d5cf1919f58a998d427"
+ "0f8c882d03d22ec4df4079e15a639452ea4c24023f6bcad89566ce6a32b1dad6ddf6b436"
+ "3e6759bd48bed1b30a840015004200000000000000000000000000000000000000000000"
+ "000000000000000000000000000000000000000000000000000000000000000000000000"
+ "0000000000000000";
+ ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter,
+ SSL_ERROR_RX_MALFORMED_ECH_EXTENSION,
+ SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Drop supported_versions from CHInner, make sure we don't negotiate 1.2+ECH.
+TEST_F(TlsConnectStreamTls13Ech, EchVersion12Inner) {
+ // Construct this by removing ssl_tls13_supported_versions_xtn entirely.
+ std::string ch =
+ "010001fc030338e9ebcde2b87ef779c4d9a9b9870aef3978130b254fbf168a92644c97c1"
+ "c5cb000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff"
+ "01000100000a00140012001d00170018001901000101010201030104003300260024001d"
+ "002081b3ea444fd631f9264e01276bcc1a6771aed3b5a8a396446467d1c820e52b25002b"
+ "0003020304000d0018001604030503060302030804080508060401050106010201002d00"
+ "020101001c00024001fe0d00f900000100034d00205864042b43f4d4d544558fbcba410f"
+ "ebfb78ddfc5528672a7f7d9e70abc3eb6300cf6ff3271da628139bddc4a58ee92db26170"
+ "7310dee54d88c8a96a8d998b8608d5f10260b7e201e5dc8cafa13917a3fdfdf399082959"
+ "8adf3c291decf640f696e64c4e22bafb81565587c50dd829ccad68bd00babeaba7d8a7a5"
+ "400ad3200dbae674c549953ca6d3298ed751a9bc215a33be444fe908bf1c6f374cc139f9"
+ "98339f58b8fd3510a670e4102e3f7de21586ebd70c3fb1df8bb6b9e5dbc0db147dbac6d0"
+ "72dfc6cdf17ecee5c019c311b37ef9f5ceabb7edbdf87a4a04041c4d8b512a16517c5380"
+ "e8d4f6e3b2412b4a6c030015004200000000000000000000000000000000000000000000"
+ "000000000000000000000000000000000000000000000000000000000000000000000000"
+ "0000000000000000";
+ ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter,
+ SSL_ERROR_UNSUPPORTED_VERSION,
+ SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Use CHInner supported_versions to negotiate 1.2.
+TEST_F(TlsConnectStreamTls13Ech, EchVersion12InnerSupportedVersions) {
+ // Construct this by changing ssl_tls13_supported_versions_xtn to write
+ // TLS 1.2 instead of TLS 1.3.
+ std::string ch =
+ "010001fc0303f7146bdc88c399feb49c62b796db2f8b1330e25292a889edf7c65231d0be"
+ "b95f000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff"
+ "01000100000a00140012001d00170018001901000101010201030104003300260024001d"
+ "0020d31f8eb204efba49dbdbf40bb046b1e0b90fa3f034260d60f351d4b15e614e7f002b"
+ "0003020304000d0018001604030503060302030804080508060401050106010201002d00"
+ "020101001c00024001fe0d00f900000100034d0020eaa25e92721e65fd405577bf2fd322"
+ "857e60f8766a595929fc404c9a01ef441200cf04992c693fbc8eac87726b336a11abc411"
+ "541ceff50d533d4cf4d6e1078479acb5446675b652f22d6db04daf0c3640ec2429ba4f51"
+ "99c00daa43e9a7d85bd6733041feeca0b38ee6ca07042c7e67d40cd3e236499f3f9d92ab"
+ "e4642e483c75d77c247b0228bc773c09551d15845c35663afd1805c5b3adb136ffa6d94f"
+ "b7cbfe93d5d33c894b2a6437ad9a2278d5863ed20db652a6084c9e95a8dfaf821d0b474a"
+ "7efc2839f110edb4a73376ecab629b26b1eea63304899c49a07157fbbee67c786686cb04"
+ "a53666a74e1e003aefc70015004200000000000000000000000000000000000000000000"
+ "000000000000000000000000000000000000000000000000000000000000000000000000"
+ "0000000000000000";
+ ReplayChWithMalformedInner(ch, kTlsAlertProtocolVersion,
+ SSL_ERROR_UNSUPPORTED_VERSION,
+ SSL_ERROR_PROTOCOL_VERSION_ALERT);
+}
+
+// Replay a CH for which CHInner lacks the required ech xtn of inner type
+TEST_F(TlsConnectStreamTls13Ech, EchInnerMissing) {
+ // Construct by omitting the ech inner extension
+ std::string ch =
+ "010001fc0303fa9cd9cf5b77bb4083f69a1d169d44b356faea0d6a0aee6d50412de6fef7"
+ "8d22000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff"
+ "01000100000a00140012001d00170018001901000101010201030104003300260024001d"
+ "0020c329f1dde4d51b50f68c21053b545290b250af527b2832d3acf2c6af9b8b8d5c002b"
+ "0003020304000d0018001604030503060302030804080508060401050106010201002d00"
+ "020101001c00024001fe0d00f900000100034d00207e2a0397b7d2776ae468057d630243"
+ "b01388cf80680b074323adf4091aba7b4c00cff4b649fb5b3a0719c1e085c7006a95eaad"
+ "32375b717a42d009c075e6246342fdc1e847c528495f90378ff5b4912da5190f7e8bfa1c"
+ "c9744b50e9e469cd7cd12bcb5f6534b7d617459d2efa4d796ad244567c49f1d22feb08a5"
+ "8e8ebdce059c28883dd69ca401e189f3ef438c3f0bf3d377e6727a1f6abf3a8a8cc149ee"
+ "60a1aa5ba4a50e99d2519216762558e9613a238bd630b5822f549575d9402f8da066aaef"
+ "2e0e6a7a04583b041925e0ef4575107c4436f9af26e561c0ab733cd88bee6a20e6414128"
+ "ea0ba1c73612bb62c1e90015004200000000000000000000000000000000000000000000"
+ "000000000000000000000000000000000000000000000000000000000000000000000000"
+ "0000000000000000";
+ ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter,
+ SSL_ERROR_MISSING_ECH_EXTENSION,
+ SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchInnerWrongSize) {
+ // Construct by including ech inner with wrong size
+ std::string ch =
+ "010001fc03035f8410dab9e49b0833d13390f3fe0b3c6321d842961c9cc46b59a0b5b8e1"
+ "4e0b000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff"
+ "01000100000a00140012001d00170018001901000101010201030104003300260024001d"
+ "0020526a56087d685e574accb0e87d6781bc553612479e56460fe6a497fa1cd74e2e002b"
+ "0003020304000d0018001604030503060302030804080508060401050106010201002d00"
+ "020101001c00024001fe0d00f900000100034d00200d096bf6ac0c3bcb79d70677da0e0d"
+ "249b40bc5ba6b8727654619fe6567d0b0700cfd13e136d2d041e3cd993b252386d97e98d"
+ "c972d29d28e0281a210fa56156b95e4371a6610a0b3e65f1b842875fb456de9b9c0e03f8"
+ "aa4d1055057ac3e20e5fa45b837ccbb06ef3856c71f1f63e91b60bfb5f3415f26e9a0d3c"
+ "4d404d5d5aaa6dca8d57cf2e6b4aaf399fa7271b0c1eedbfdd85fbc9711b0446eb9c9535"
+ "a74f3e5a71e2e22dc8d89980f96233ec9b80fbe4f295ff7903bade407fc544c8d76df4fb"
+ "ce4b8d79cea0ff7e0b0736ecbeaf5a146a4f81a930e788ae144cf2219e90dc3594165a7e"
+ "2a0b64f6189a87a348840015004200000000000000000000000000000000000000000000"
+ "000000000000000000000000000000000000000000000000000000000000000000000000"
+ "0000000000000000";
+ ReplayChWithMalformedInner(ch, kTlsAlertDecodeError,
+ SSL_ERROR_RX_MALFORMED_ESNI_EXTENSION,
+ SSL_ERROR_DECODE_ERROR_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13Ech, InnerWithEchAndEchIsInner) {
+ // Construct by appending an empty ssl_tls13_encrypted_client_hello_xtn of
+ // type outer to
+ // CHInner.
+ std::string ch =
+ "010001fc0303527df5a8dbcf390c184c5274295283fdba78d05784170d8f3cb8c7d84747"
+ "afb5000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff"
+ "01000100000a00140012001d00170018001901000101010201030104003300260024001d"
+ "002099461dcfcdc7804a0f34bf3ca49ac39776a7ef4d8edd30fab3599ff59b09f826002b"
+ "0003020304000d0018001604030503060302030804080508060401050106010201002d00"
+ "020101001c00024001fe0d00f900000100034d00201da1341e8ba21ff90e025d2438d4e5"
+ "b4e8b376befc57cf8c9afb484e6f051b2f00cff747491b810705e5cc8d8a1302468000d9"
+ "8660d659d8382a6fc23ca1a582def728eabb363771328035565048213b1d725b20f757be"
+ "63d6956cd861aa9d33adcc913de2443695f70e130af96fd2b078dd662478a29bd17a4479"
+ "715c949b5fc118456d0243c9d1819cecd0f5fbd1c78dadd6fcd09abe41ca97a00c97efb3"
+ "894c9d4bab60dcd150b55608f6260723a08e112e39e6a43f645f85a08085054f27f269bc"
+ "1acb9ff5007b04eaef3414767666472e4e24c2a2953f5dc68aeb5207d556f1b872a810b6"
+ "686cf83a09db8b474df70015004200000000000000000000000000000000000000000000"
+ "000000000000000000000000000000000000000000000000000000000000000000000000"
+ "0000000000000000";
+ ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter,
+ SSL_ERROR_RX_UNEXPECTED_EXTENSION,
+ SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13, EchWithInnerExtNotSplit) {
+ static uint8_t type_val[1] = {1};
+ DataBuffer type_buffer(type_val, sizeof(type_val));
+
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_FALSE));
+ MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello,
+ ssl_tls13_encrypted_client_hello_xtn,
+ type_buffer);
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_EXTENSION);
+}
+
+/* Parameters
+ * Length of SNI for first connection
+ * Length of SNI for second connection
+ * Use GREASE for first connection?
+ * Use GREASE for second connection?
+ * For both connections, SNI length to pad to.
+ */
+class EchCHPaddingTest : public TlsConnectStreamTls13,
+ public testing::WithParamInterface<
+ std::tuple<int, int, bool, bool, int>> {};
+
+TEST_P(EchCHPaddingTest, EchChPaddingEqual) {
+ auto parameters = GetParam();
+ std::string name_str1 = std::string(std::get<0>(parameters), 'a');
+ std::string name_str2 = std::string(std::get<1>(parameters), 'a');
+ const char* name1 = name_str1.c_str();
+ const char* name2 = name_str2.c_str();
+ bool grease_mode1 = std::get<2>(parameters);
+ bool grease_mode2 = std::get<3>(parameters);
+ uint8_t max_name_len = std::get<4>(parameters);
+
+ // Connection 1
+ EnsureTlsSetup();
+ SSL_SetURL(client_->ssl_fd(), name1);
+ if (grease_mode1) {
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess,
+ SSL_SetTls13GreaseEchSize(client_->ssl_fd(), max_name_len));
+ client_->ExpectEch(false);
+ server_->ExpectEch(false);
+ } else {
+ SetupEch(client_, server_, HpkeDhKemX25519Sha256, true, true, true,
+ max_name_len);
+ }
+ auto filter1 = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ Connect();
+ size_t echXtnLen1 = filter1->extension().len();
+
+ Reset();
+
+ // Connection 2
+ EnsureTlsSetup();
+ SSL_SetURL(client_->ssl_fd(), name2);
+ if (grease_mode2) {
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess,
+ SSL_SetTls13GreaseEchSize(client_->ssl_fd(), max_name_len));
+ client_->ExpectEch(false);
+ server_->ExpectEch(false);
+ } else {
+ SetupEch(client_, server_, HpkeDhKemX25519Sha256, true, true, true,
+ max_name_len);
+ }
+ auto filter2 = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ Connect();
+ size_t echXtnLen2 = filter2->extension().len();
+
+ // We always expect an ECH extension.
+ ASSERT_TRUE(echXtnLen2 > 0 && echXtnLen1 > 0);
+ // We expect the ECH extension to round to the same multiple of 32.
+ // Note: It will not be 0 % 32 because we pad the Payload, but have a number
+ // of extra bytes from the rest of the ECH extension (e.g. ciphersuite)
+ ASSERT_EQ(echXtnLen1 % 32, echXtnLen2 % 32);
+ // Both connections should have the same size after padding.
+ if (name_str1.size() <= max_name_len && name_str2.size() <= max_name_len) {
+ ASSERT_EQ(echXtnLen1, echXtnLen2);
+ }
+}
+
+#define ECH_PADDING_TEST_INSTANTIATE(name, values) \
+ INSTANTIATE_TEST_SUITE_P(name, EchCHPaddingTest, \
+ testing::Combine(values, values, testing::Bool(), \
+ testing::Bool(), values))
+
+const int kExtremalSNILengths[] = {1, 128, 255};
+const int kNormalSNILengths[] = {17, 24, 100};
+const int kLongSNILengths[] = {90, 167, 214};
+
+/* Each invocation with N lengths, results in 4N^3 test cases, so we test
+ * 3 lots of (4*3^3) rather than all permutations. */
+ECH_PADDING_TEST_INSTANTIATE(extremal, testing::ValuesIn(kExtremalSNILengths));
+ECH_PADDING_TEST_INSTANTIATE(normal, testing::ValuesIn(kNormalSNILengths));
+ECH_PADDING_TEST_INSTANTIATE(lengthy, testing::ValuesIn(kLongSNILengths));
+
+// Check the server rejects ClientHellos with bad padding
+TEST_F(TlsConnectStreamTls13Ech, EchChPaddingChecked) {
+ // Generate this string by changing the padding in
+ // tls13_GenPaddingClientHelloInner
+ std::string ch =
+ "010001fc03037473367a6eb6773391081b403908fc0c0026aac706889c59ca694d0c1188"
+ "c4b3000006130113031302010001cd00000010000e00000b7075626c69632e6e616d65ff"
+ "01000100000a00140012001d00170018001901000101010201030104003300260024001d"
+ "0020f7d8ad5fea0165e115e984e11c43f1d8f255bd8f772b893432d8d7721e91785a002b"
+ "0003020304000d0018001604030503060302030804080508060401050106010201002d00"
+ "020101001c00024001fe0d00f900000100034d00207e0ad8e83f8a9c89e1ae4fd65b8091"
+ "01e496bbb5f29ce20b299ce58937e2563300cff471a787585e15ae5aff5e4fee7ec988ba"
+ "72f8a95db41e793568b0301d553251f0826dc0c3ff658e4e029ef840ae86fa80af4b11b5"
+ "3a33fab99887bf8df18bc87abbb1f578f7964848d91a2023cbe7609fcc31bd721865009c"
+ "ad68c09e438d677f7c56af76e62c168bdb373bb88962471dacc4ddf654e435cd903f6555"
+ "4c9a93ffd2541cd7bce520e7215d15495184b781ca8c138cedd573fbdef1d40e5de82c33"
+ "5c9c43370102ecb0b66dd27efc719a9a54589b6e6b599b1b0146e121eae0ab5b2070c12f"
+ "4f4f2b099808294a459f0015004200000000000000000000000000000000000000000000"
+ "000000000000000000000000000000000000000000000000000000000000000000000000"
+ "0000000000000000";
+ ReplayChWithMalformedInner(ch, kTlsAlertIllegalParameter,
+ SSL_ERROR_RX_MALFORMED_ECH_EXTENSION,
+ SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchConfigList) {
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ EnsureTlsSetup();
+
+ DataBuffer config1;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes,
+ kPublicName, 100, config1, pub, priv);
+ DataBuffer config2;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes,
+ kPublicName, 100, config2, pub, priv);
+ DataBuffer configList = MakeEchConfigList(config1, config2);
+ SECStatus rv =
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ configList.data(), configList.len());
+ printf("%u", rv);
+ ASSERT_EQ(rv, SECSuccess);
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchConfigsTrialDecrypt) {
+ // Apply two ECHConfigs on the server. They are identical with the exception
+ // of the public key: the first ECHConfig contains a public key for which we
+ // lack the private value. Use an SSLInt function to zero all the config_ids
+ // (client and server), then confirm that trial decryption works.
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ EnsureTlsSetup();
+ ImportFixedEchKeypair(pub, priv);
+ ScopedSECKEYPublicKey pub2;
+ ScopedSECKEYPrivateKey priv2;
+ DataBuffer config2;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes,
+ kPublicName, 100, config2, pub, priv);
+ DataBuffer config1;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes,
+ kPublicName, 100, config1, pub2, priv2);
+ // Zero the config id for both, only public key differs.
+ config2.Write(7, (uint32_t)0, 1);
+ config1.Write(7, (uint32_t)0, 1);
+ // Server only knows private key for conf2
+ DataBuffer configList = MakeEchConfigList(config1, config2);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ configList.data(), configList.len()));
+ ASSERT_EQ(SECSuccess, SSL_SetClientEchConfigs(client_->ssl_fd(),
+ config2.data(), config2.len()));
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchAcceptBasic) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ auto c_filter_sni =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_server_name_xtn);
+ Connect();
+ ASSERT_TRUE(c_filter_sni->captured());
+ CheckSniExtension(c_filter_sni->extension(), kPublicName);
+}
+
+TEST_F(TlsConnectStreamTls13, EchAcceptWithResume) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ ExpectResumption(RESUME_TICKET);
+ auto filter =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn);
+ StartConnect();
+ Handshake();
+ CheckConnected();
+ // Make sure that the PSK extension is only in CHInner.
+ ASSERT_TRUE(filter->captured());
+}
+
+TEST_F(TlsConnectStreamTls13, EchAcceptWithExternalPsk) {
+ static const std::string kPskId = "testing123";
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_TRUE(!!slot);
+ ScopedPK11SymKey key(
+ PK11_KeyGen(slot.get(), CKM_HKDF_KEY_GEN, nullptr, 16, nullptr));
+ ASSERT_TRUE(!!key);
+ AddPsk(key, kPskId, ssl_hash_sha256);
+
+ // Not permitted in outer.
+ auto filter =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn);
+ StartConnect();
+ Handshake();
+ CheckConnected();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+ // The PSK extension is present in CHOuter.
+ ASSERT_TRUE(filter->captured());
+
+ // But the PSK in CHOuter is completely different.
+ // (Failure/collision chance means kPskId needs to be longish.)
+ uint32_t v = 0;
+ ASSERT_TRUE(filter->extension().Read(0, 2, &v));
+ ASSERT_EQ(v, kPskId.size() + 2 + 4) << "check size of identities";
+ ASSERT_TRUE(filter->extension().Read(2, 2, &v));
+ ASSERT_EQ(v, kPskId.size()) << "check size of identity";
+ bool different = false;
+ for (size_t i = 0; i < kPskId.size(); ++i) {
+ ASSERT_TRUE(filter->extension().Read(i + 4, 1, &v));
+ different |= v != static_cast<uint8_t>(kPskId[i]);
+ }
+ ASSERT_TRUE(different);
+}
+
+// If an earlier version is negotiated, False Start must be disabled.
+TEST_F(TlsConnectStreamTls13, EchDowngradeNoFalseStart) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_, HpkeDhKemX25519Sha256, false, true, false);
+ MakeTlsFilter<TlsExtensionDropper>(client_,
+ ssl_tls13_encrypted_client_hello_xtn);
+ client_->EnableFalseStart();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ client_->Handshake();
+ EXPECT_FALSE(client_->can_falsestart_hook_called());
+
+ // Make sure the write is blocked.
+ client_->ExpectReadWriteError();
+ client_->SendData(10);
+}
+
+SSLHelloRetryRequestAction RetryEchHello(PRBool firstHello,
+ const PRUint8* clientToken,
+ unsigned int clientTokenLen,
+ PRUint8* appToken,
+ unsigned int* appTokenLen,
+ unsigned int appTokenMax, void* arg) {
+ auto* called = reinterpret_cast<size_t*>(arg);
+ ++*called;
+
+ EXPECT_EQ(0U, clientTokenLen);
+ return firstHello ? ssl_hello_retry_request : ssl_hello_retry_accept;
+}
+
+// Generate HRR on CH1 Inner
+TEST_F(TlsConnectStreamTls13, EchAcceptWithHrr) {
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub, priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+ ASSERT_EQ(SECSuccess,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(),
+ echconfig.len()));
+ client_->ExpectEch();
+ server_->ExpectEch();
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+
+ auto server_hrr_ech_xtn = MakeTlsFilter<TlsExtensionCapture>(
+ server_, ssl_tls13_encrypted_client_hello_xtn);
+ // Start the handshake.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ MakeNewServer();
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Handshake();
+ ASSERT_TRUE(server_hrr_ech_xtn->captured());
+ EXPECT_EQ(1U, cb_called);
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchGreaseSize) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE));
+
+ auto greased_ext = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ Connect();
+ ASSERT_TRUE(greased_ext->captured());
+
+ Reset();
+ EnsureTlsSetup();
+
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ ImportFixedEchKeypair(pub, priv);
+ SetMutualEchConfigs(pub, priv);
+
+ auto real_ext = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+
+ ASSERT_TRUE(real_ext->captured());
+ ASSERT_EQ(real_ext->extension().len(), greased_ext->extension().len());
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchGreaseClientDisable) {
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ EnsureTlsSetup();
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub, priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_FALSE));
+
+ auto c_filter_esni = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+
+ Connect();
+ ASSERT_TRUE(!c_filter_esni->captured());
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchHrrGreaseServerDisable) {
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(server_->ssl_fd(), PR_FALSE));
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+
+ auto server_hrr_ech_xtn = MakeTlsFilter<TlsExtensionCapture>(
+ server_, ssl_tls13_encrypted_client_hello_xtn);
+ // Start the handshake.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ MakeNewServer();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(server_->ssl_fd(), PR_FALSE));
+ Handshake();
+ ASSERT_TRUE(!server_hrr_ech_xtn->captured());
+ EXPECT_EQ(1U, cb_called);
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchGreaseSizePsk) {
+ // Original connection without ECH
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ Connect();
+ SendReceive();
+
+ // Resumption with only GREASE
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE));
+
+ auto greased_ext = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ Connect();
+ SendReceive();
+ ASSERT_TRUE(greased_ext->captured());
+
+ // Finally, resume with ECH enabled
+ // ECH state does not determine whether resumption succeeds
+ // or is attempted, so this should work fine.
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET, 2);
+
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ ImportFixedEchKeypair(pub, priv);
+ SetMutualEchConfigs(pub, priv);
+
+ auto real_ext = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+ ASSERT_TRUE(real_ext->captured());
+
+ ASSERT_EQ(real_ext->extension().len(), greased_ext->extension().len());
+}
+
+// Send GREASE ECH in CH1. CH2 must send exactly the same GREASE ECH contents.
+TEST_F(TlsConnectStreamTls13, GreaseEchHrrMatches) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(),
+ PR_TRUE)); // GREASE
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+
+ // Start the handshake.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake(); // Send CH1
+ EXPECT_TRUE(capture->captured());
+ DataBuffer ch1_grease = capture->extension();
+
+ server_->Handshake();
+ MakeNewServer();
+ capture = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+
+ EXPECT_FALSE(capture->captured());
+ client_->Handshake(); // Send CH2
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(ch1_grease, capture->extension());
+
+ EXPECT_EQ(1U, cb_called);
+ server_->StartConnect();
+ Handshake();
+ CheckConnected();
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchRejectMisizedEchXtn) {
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(server_->ssl_fd(), PR_TRUE));
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+ auto server_hrr_ext_xtn_fake = MakeTlsFilter<TlsExtensionResizer>(
+ server_, ssl_tls13_encrypted_client_hello_xtn, 34);
+ // Start the handshake.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ // Process the hello retry.
+ server_->ExpectReceiveAlert(kTlsAlertDecodeError, kTlsAlertFatal);
+ client_->ExpectSendAlert(kTlsAlertDecodeError);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_ECH_EXTENSION);
+ server_->CheckErrorCode(SSL_ERROR_DECODE_ERROR_ALERT);
+ EXPECT_EQ(1U, cb_called);
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchRejectDroppedEchXtn) {
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub, priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+ ASSERT_EQ(SECSuccess,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(),
+ echconfig.len()));
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+ auto server_hrr_ext_xtn_fake = MakeTlsFilter<TlsExtensionDropper>(
+ server_, ssl_tls13_encrypted_client_hello_xtn);
+ // Start the handshake.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ MakeNewServer();
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+ // Process the hello retry.
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ EXPECT_EQ(1U, cb_called);
+}
+
+// Generate an HRR on CHInner. Mangle the Hrr Xtn causing client to reject ECH
+// which then causes a MAC mismatch.
+TEST_F(TlsConnectStreamTls13Ech, EchRejectMangledHrrXtn) {
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub, priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+ ASSERT_EQ(SECSuccess,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(),
+ echconfig.len()));
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+ auto server_hrr_ech_xtn = MakeTlsFilter<TlsExtensionDamager>(
+ server_, ssl_tls13_encrypted_client_hello_xtn, 4);
+ // Start the handshake.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ MakeNewServer();
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+ client_->ExpectEch(false);
+ server_->ExpectEch(false);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ EXPECT_EQ(1U, cb_called);
+}
+
+// First capture an ECH CH Xtn.
+// Start new connection, inject ECH CH Xtn.
+// Server will respond with ECH HRR Xtn.
+// Check Client correctly panics.
+TEST_F(TlsConnectStreamTls13Ech, EchClientRejectSpuriousHrrXtn) {
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub, priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE));
+ client_->ExpectEch(false);
+ server_->ExpectEch(false);
+ auto client_ech_xtn_capture = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ Connect();
+ ASSERT_TRUE(client_ech_xtn_capture->captured());
+
+ // Now configure client without ECH. Server with ECH.
+ Reset();
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_FALSE));
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+ client_->ExpectEch(false);
+ server_->ExpectEch(false);
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+
+ // Inject CH ECH Xtn into CH.
+ DataBuffer buff = DataBuffer(client_ech_xtn_capture->extension());
+ auto client_ech_xtn = MakeTlsFilter<TlsExtensionAppender>(
+ client_, kTlsHandshakeClientHello, ssl_tls13_encrypted_client_hello_xtn,
+ buff);
+
+ // Connect and check we see the HRR extension and alert.
+ auto server_hrr_ech_xtn = MakeTlsFilter<TlsExtensionCapture>(
+ server_, ssl_tls13_encrypted_client_hello_xtn);
+ server_hrr_ech_xtn->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+
+ ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension);
+
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_EXTENSION);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_EXTENSION_ALERT);
+ ASSERT_TRUE(server_hrr_ech_xtn->captured());
+}
+
+// Fail to decrypt CH2. Unlike CH1, this generates an alert.
+TEST_F(TlsConnectStreamTls13, EchFailDecryptCH2) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ EXPECT_EQ(1U, cb_called);
+ // Stop the callback from being called in future handshakes.
+ EXPECT_EQ(SECSuccess,
+ SSL_HelloRetryRequestCallback(server_->ssl_fd(), nullptr, nullptr));
+
+ MakeTlsFilter<TlsExtensionDamager>(client_,
+ ssl_tls13_encrypted_client_hello_xtn, 80);
+ ExpectAlert(server_, kTlsAlertDecryptError);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_ECH_EXTENSION);
+}
+
+// Change the ECH advertisement between CH1 and CH2. Use GREASE for simplicity.
+TEST_F(TlsConnectStreamTls13, EchHrrChangeCh2OfferingYN) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+
+ // Start the handshake, send GREASE ECH.
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(),
+ PR_TRUE)); // GREASE
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ MakeNewServer();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(),
+ PR_FALSE)); // Don't GREASE
+ ExpectAlert(server_, kTlsAlertMissingExtension);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ EXPECT_EQ(1U, cb_called);
+}
+
+TEST_F(TlsConnectStreamTls13, EchHrrChangeCh2OfferingNY) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+
+ MakeTlsFilter<TlsExtensionDropper>(client_,
+ ssl_tls13_encrypted_client_hello_xtn);
+ // Start the handshake.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ MakeNewServer();
+ client_->ClearFilter(); // Let the second ECH offering through.
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ EXPECT_EQ(1U, cb_called);
+}
+
+// Change the ECHCipherSuite between CH1 and CH2. Expect alert.
+TEST_F(TlsConnectStreamTls13, EchHrrChangeCipherSuite) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+ // Start the handshake and trigger HRR.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ MakeNewServer();
+
+ // Damage the first byte of the ciphersuite (offset 1)
+ MakeTlsFilter<TlsExtensionDamager>(client_,
+ ssl_tls13_encrypted_client_hello_xtn, 1);
+
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ EXPECT_EQ(1U, cb_called);
+}
+
+// Configure an external PSK. Generate an HRR off CH1Inner (which contains
+// the PSK extension). Use the same PSK in CH2 and connect.
+TEST_F(TlsConnectStreamTls13, EchAcceptWithHrrAndPsk) {
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub, priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+ ASSERT_EQ(SECSuccess,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(),
+ echconfig.len()));
+ client_->ExpectEch();
+ server_->ExpectEch();
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+
+ static const uint8_t key_buf[16] = {0};
+ SECItem key_item = {siBuffer, const_cast<uint8_t*>(&key_buf[0]),
+ sizeof(key_buf)};
+ const char* label = "foo";
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_TRUE(!!slot);
+ ScopedPK11SymKey key(PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN,
+ PK11_OriginUnwrap, CKA_DERIVE,
+ &key_item, nullptr));
+ ASSERT_TRUE(!!key);
+ AddPsk(key, std::string(label), ssl_hash_sha256);
+
+ // Start the handshake.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ MakeNewServer();
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+ client_->ExpectEch();
+ server_->ExpectEch();
+ EXPECT_EQ(SECSuccess,
+ SSL_AddExternalPsk0Rtt(server_->ssl_fd(), key.get(),
+ reinterpret_cast<const uint8_t*>(label),
+ strlen(label), ssl_hash_sha256, 0, 1000));
+ server_->ExpectPsk();
+ Handshake();
+ EXPECT_EQ(1U, cb_called);
+ CheckConnected();
+ SendReceive();
+}
+
+// Generate an HRR on CHOuter. Reject ECH on the second CH.
+TEST_F(TlsConnectStreamTls13Ech, EchRejectWithHrr) {
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ SetupForEchRetry();
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(server_->ssl_fd(), PR_TRUE));
+ auto server_hrr_ech_xtn = MakeTlsFilter<TlsExtensionCapture>(
+ server_, ssl_tls13_encrypted_client_hello_xtn);
+ // Start the handshake.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ MakeNewServer();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(server_->ssl_fd(), PR_TRUE));
+ client_->ExpectEch(false);
+ server_->ExpectEch(false);
+ ExpectAlert(client_, kTlsAlertEchRequired);
+ Handshake();
+ ASSERT_TRUE(server_hrr_ech_xtn->captured());
+ client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITHOUT_ECH);
+ server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal);
+ server_->Handshake();
+ EXPECT_EQ(1U, cb_called);
+}
+
+// Server can't change its mind on ECH after HRR. We change the confirmation
+// value and the server panics accordingly.
+TEST_F(TlsConnectStreamTls13Ech, EchHrrServerYN) {
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub, priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+ ASSERT_EQ(SECSuccess,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(),
+ echconfig.len()));
+ client_->ExpectEch();
+ server_->ExpectEch();
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+
+ auto server_hrr_ech_xtn = MakeTlsFilter<TlsExtensionCapture>(
+ server_, ssl_tls13_encrypted_client_hello_xtn);
+ // Start the handshake.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ MakeNewServer();
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), pub.get(), priv.get(),
+ echconfig.data(), echconfig.len()));
+ client_->ExpectEch();
+ server_->ExpectEch();
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ auto server_random_damager = MakeTlsFilter<ServerHelloRandomChanger>(server_);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ ASSERT_TRUE(server_hrr_ech_xtn->captured());
+ EXPECT_EQ(1U, cb_called);
+}
+
+// Client sends GREASE'd ECH Xtn, server reponds with HRR in GREASE mode
+// Check HRR responses are present and differ.
+TEST_F(TlsConnectStreamTls13Ech, EchHrrServerGreaseChanges) {
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(server_->ssl_fd(), PR_TRUE));
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+
+ auto server_hrr_ech_xtn_1 = MakeTlsFilter<TlsExtensionCapture>(
+ server_, ssl_tls13_encrypted_client_hello_xtn);
+ // Start the handshake.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ ASSERT_TRUE(server_hrr_ech_xtn_1->captured());
+ EXPECT_EQ(1U, cb_called);
+
+ /* Run the connection again */
+ Reset();
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(server_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_TRUE));
+ cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+
+ auto server_hrr_ech_xtn_2 = MakeTlsFilter<TlsExtensionCapture>(
+ server_, ssl_tls13_encrypted_client_hello_xtn);
+ // Start the handshake.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ ASSERT_TRUE(server_hrr_ech_xtn_2->captured());
+ EXPECT_EQ(1U, cb_called);
+
+ ASSERT_TRUE(server_hrr_ech_xtn_1->extension().len() ==
+ server_hrr_ech_xtn_2->extension().len());
+ ASSERT_TRUE(memcmp(server_hrr_ech_xtn_1->extension().data(),
+ server_hrr_ech_xtn_2->extension().data(),
+ server_hrr_ech_xtn_1->extension().len()));
+}
+
+// Reject ECH on CH1 and CH2. PSKs are no longer allowed
+// in CHOuter, but we can still make sure the handshake succeeds.
+// This prompts an ech_required alert when the handshake completes.
+TEST_F(TlsConnectStreamTls13, EchRejectWithHrrAndPsk) {
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, echconfig, pub, priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(),
+ echconfig.len()));
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryEchHello, &cb_called));
+
+ // Add a PSK to both endpoints.
+ static const uint8_t key_buf[16] = {0};
+ SECItem key_item = {siBuffer, const_cast<uint8_t*>(&key_buf[0]),
+ sizeof(key_buf)};
+ const char* label = "foo";
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_TRUE(!!slot);
+ ScopedPK11SymKey key(PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN,
+ PK11_OriginUnwrap, CKA_DERIVE,
+ &key_item, nullptr));
+ ASSERT_TRUE(!!key);
+ AddPsk(key, std::string(label), ssl_hash_sha256);
+ client_->ExpectPsk(ssl_psk_none);
+
+ // Start the handshake.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ MakeNewServer();
+ client_->ExpectEch(false);
+ server_->ExpectEch(false);
+ EXPECT_EQ(SECSuccess,
+ SSL_AddExternalPsk0Rtt(server_->ssl_fd(), key.get(),
+ reinterpret_cast<const uint8_t*>(label),
+ strlen(label), ssl_hash_sha256, 0, 1000));
+ // Don't call ExpectPsk
+ ExpectAlert(client_, kTlsAlertEchRequired);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITHOUT_ECH);
+ server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal);
+ server_->Handshake();
+ EXPECT_EQ(1U, cb_called);
+}
+
+// ECH (both connections), resumption rejected.
+TEST_F(TlsConnectStreamTls13, EchRejectResume) {
+ EnsureTlsSetup();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ SetupEch(client_, server_);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ClearServerCache(); // Invalidate the ticket
+ ConfigureSessionCache(RESUME_BOTH, RESUME_NONE);
+ ExpectResumption(RESUME_NONE);
+ SetupEch(client_, server_);
+ Connect();
+ SendReceive();
+}
+
+// ECH (both connections) + 0-RTT
+TEST_F(TlsConnectStreamTls13, EchZeroRttBoth) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ SetupEch(client_, server_);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+// ECH (first connection only) + 0-RTT
+TEST_F(TlsConnectStreamTls13, EchZeroRttFirst) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+// ECH (second connection only) + 0-RTT
+TEST_F(TlsConnectStreamTls13, EchZeroRttSecond) {
+ EnsureTlsSetup();
+ SetupForZeroRtt(); // Get a ticket
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ SetupEch(client_, server_);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+// ECH (first connection only, reject on second) + 0-RTT
+TEST_F(TlsConnectStreamTls13, EchZeroRttRejectSecond) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+
+ // Setup ECH only on the client.
+ SetupEch(client_, server_, HpkeDhKemX25519Sha256, false, true, false);
+
+ ExpectResumption(RESUME_NONE);
+ ExpectAlert(client_, kTlsAlertEchRequired);
+ ZeroRttSendReceive(true, false);
+ server_->Handshake();
+ client_->Handshake();
+ client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITHOUT_ECH);
+
+ ExpectEarlyDataAccepted(false);
+ server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal);
+ server_->Handshake();
+ // Reset expectations for the TlsAgent dtor.
+ server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning);
+}
+
+// Test a critical extension in ECHConfig
+TEST_F(TlsConnectStreamTls13, EchRejectUnknownCriticalExtension) {
+ EnsureTlsSetup();
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ DataBuffer echconfig;
+ DataBuffer crit_rec;
+ DataBuffer len_buf;
+ uint64_t tmp;
+
+ static const uint8_t crit_extensions[] = {0x00, 0x04, 0xff, 0xff, 0x00, 0x00};
+ static const uint8_t extensions[] = {0x00, 0x04, 0x7f, 0xff, 0x00, 0x00};
+ DataBuffer crit_exts(crit_extensions, sizeof(crit_extensions));
+ DataBuffer non_crit_exts(extensions, sizeof(extensions));
+
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteChaCha,
+ kPublicName, 100, echconfig, pub, priv);
+ echconfig.Truncate(echconfig.len() - 2); // Eat the empty extensions.
+ crit_rec.Assign(echconfig);
+ ASSERT_TRUE(crit_rec.Read(0, 2, &tmp));
+ len_buf.Write(0, tmp + crit_exts.len() - 2, 2); // two bytes of length
+ crit_rec.Splice(len_buf, 0, 2);
+ len_buf.Truncate(0);
+
+ ASSERT_TRUE(crit_rec.Read(4, 2, &tmp));
+ len_buf.Write(0, tmp + crit_exts.len() - 2, 2); // two bytes of length
+ crit_rec.Append(crit_exts);
+ crit_rec.Splice(len_buf, 4, 2);
+ len_buf.Truncate(0);
+
+ ASSERT_TRUE(echconfig.Read(0, 2, &tmp));
+ len_buf.Write(0, tmp + non_crit_exts.len() - 2, 2);
+ echconfig.Append(non_crit_exts);
+ echconfig.Splice(len_buf, 0, 2);
+ ASSERT_TRUE(echconfig.Read(4, 2, &tmp));
+ len_buf.Write(0, tmp + non_crit_exts.len() - 2, 2);
+ echconfig.Splice(len_buf, 4, 2);
+
+ /* Expect that retry configs containing unsupported mandatory extensions can
+ * not be set and lead to SEC_ERROR_INVALID_ARGS. */
+ EXPECT_EQ(SECFailure,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), crit_rec.data(),
+ crit_rec.len()));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(),
+ PR_FALSE)); // Don't GREASE
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ StartConnect();
+ client_->Handshake();
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+ ASSERT_FALSE(filter->captured());
+
+ // Now try a variant with non-critical extensions, it should work.
+ Reset();
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), echconfig.data(),
+ echconfig.len()));
+ filter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ StartConnect();
+ client_->Handshake();
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+ ASSERT_TRUE(filter->captured());
+}
+
+// Secure disable without ECH
+TEST_F(TlsConnectStreamTls13, EchRejectAuthCertSuccessNoRetries) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_, HpkeDhKemX25519Sha256, false, true, false);
+ ExpectAlert(client_, kTlsAlertEchRequired);
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITHOUT_ECH);
+ server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal);
+ server_->Handshake();
+ // Reset expectations for the TlsAgent dtor.
+ server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning);
+}
+
+// When authenticating to the public name, the client MUST NOT
+// send a certificate in response to a certificate request.
+TEST_F(TlsConnectStreamTls13, EchRejectSuppressClientCert) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_, HpkeDhKemX25519Sha256, false, true, false);
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ auto cert_capture =
+ MakeTlsFilter<TlsHandshakeRecorder>(client_, kTlsHandshakeCertificate);
+ cert_capture->EnableDecryption();
+
+ StartConnect();
+ client_->ExpectSendAlert(kTlsAlertEchRequired);
+ server_->ExpectSendAlert(kTlsAlertCertificateRequired);
+ ConnectExpectFail();
+
+ static const uint8_t empty_cert[4] = {0};
+ EXPECT_EQ(DataBuffer(empty_cert, sizeof(empty_cert)), cert_capture->buffer());
+}
+
+// Secure disable with incompatible ECHConfig
+TEST_F(TlsConnectStreamTls13, EchRejectAuthCertSuccessIncompatibleRetries) {
+ EnsureTlsSetup();
+ ScopedSECKEYPublicKey server_pub;
+ ScopedSECKEYPrivateKey server_priv;
+ ScopedSECKEYPublicKey client_pub;
+ ScopedSECKEYPrivateKey client_priv;
+ DataBuffer server_rec;
+ DataBuffer client_rec;
+
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteChaCha,
+ kPublicName, 100, server_rec,
+ server_pub, server_priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(),
+ server_priv.get(), server_rec.data(),
+ server_rec.len()));
+
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes,
+ kPublicName, 100, client_rec,
+ client_pub, client_priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), client_rec.data(),
+ client_rec.len()));
+
+ // Change the first ECHConfig version to one we don't understand.
+ server_rec.Write(2, 0xfefe, 2);
+ // Skip the ECHConfigs length, the server sender will re-encode.
+ ASSERT_EQ(SECSuccess, SSLInt_SetRawEchConfigForRetry(server_->ssl_fd(),
+ &server_rec.data()[2],
+ server_rec.len() - 2));
+
+ ExpectAlert(client_, kTlsAlertEchRequired);
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITHOUT_ECH);
+ server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal);
+ server_->Handshake();
+ // Reset expectations for the TlsAgent dtor.
+ server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning);
+}
+
+// Check that an otherwise-accepted ECH fails expectedly
+// with a bad certificate.
+TEST_F(TlsConnectStreamTls13, EchRejectAuthCertFail) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ client_->SetAuthCertificateCallback(AuthCompleteFail);
+ ConnectExpectAlert(client_, kTlsAlertBadCertificate);
+ client_->CheckErrorCode(SSL_ERROR_BAD_CERTIFICATE);
+ server_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT);
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchShortClientEncryptedCH) {
+ EnsureTlsSetup();
+ SetupForEchRetry();
+ auto filter = MakeTlsFilter<TlsExtensionResizer>(
+ client_, ssl_tls13_encrypted_client_hello_xtn, 1);
+ ConnectExpectAlert(server_, kTlsAlertDecodeError);
+ client_->CheckErrorCode(SSL_ERROR_DECODE_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_ECH_EXTENSION);
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchLongClientEncryptedCH) {
+ EnsureTlsSetup();
+ SetupForEchRetry();
+ auto filter = MakeTlsFilter<TlsExtensionResizer>(
+ client_, ssl_tls13_encrypted_client_hello_xtn, 1000);
+ ConnectExpectAlert(server_, kTlsAlertDecodeError);
+ client_->CheckErrorCode(SSL_ERROR_DECODE_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_ECH_EXTENSION);
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchShortServerEncryptedCH) {
+ EnsureTlsSetup();
+ SetupForEchRetry();
+ auto filter = MakeTlsFilter<TlsExtensionResizer>(
+ server_, ssl_tls13_encrypted_client_hello_xtn, 1);
+ filter->EnableDecryption();
+ ConnectExpectAlert(client_, kTlsAlertDecodeError);
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_ECH_CONFIG);
+ server_->CheckErrorCode(SSL_ERROR_DECODE_ERROR_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchLongServerEncryptedCH) {
+ EnsureTlsSetup();
+ SetupForEchRetry();
+ auto filter = MakeTlsFilter<TlsExtensionResizer>(
+ server_, ssl_tls13_encrypted_client_hello_xtn, 1000);
+ filter->EnableDecryption();
+ ConnectExpectAlert(client_, kTlsAlertDecodeError);
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_ECH_CONFIG);
+ server_->CheckErrorCode(SSL_ERROR_DECODE_ERROR_ALERT);
+}
+
+// Check that if authCertificate fails, retry_configs
+// are not available to the application.
+TEST_F(TlsConnectStreamTls13Ech, EchInsecureFallbackNoRetries) {
+ EnsureTlsSetup();
+ StackSECItem retry_configs;
+ SetupForEchRetry();
+
+ // Use the filter to make sure retry_configs are sent.
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ server_, ssl_tls13_encrypted_client_hello_xtn);
+ filter->EnableDecryption();
+
+ client_->SetAuthCertificateCallback(AuthCompleteFail);
+ ConnectExpectAlert(client_, kTlsAlertBadCertificate);
+ client_->CheckErrorCode(SSL_ERROR_BAD_CERTIFICATE);
+ server_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT);
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ EXPECT_EQ(SECFailure,
+ SSL_GetEchRetryConfigs(client_->ssl_fd(), &retry_configs));
+ EXPECT_EQ(SSL_ERROR_HANDSHAKE_NOT_COMPLETED, PORT_GetError());
+ ASSERT_EQ(0U, retry_configs.len);
+ EXPECT_TRUE(filter->captured());
+}
+
+// Test that mismatched ECHConfigContents triggers a retry.
+TEST_F(TlsConnectStreamTls13Ech, EchMismatchHpkeCiphersRetry) {
+ EnsureTlsSetup();
+ ScopedSECKEYPublicKey server_pub;
+ ScopedSECKEYPrivateKey server_priv;
+ ScopedSECKEYPublicKey client_pub;
+ ScopedSECKEYPrivateKey client_priv;
+ DataBuffer server_rec;
+ DataBuffer client_rec;
+
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteChaCha,
+ kPublicName, 100, server_rec,
+ server_pub, server_priv);
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kSuiteAes,
+ kPublicName, 100, client_rec,
+ client_pub, client_priv);
+
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(),
+ server_priv.get(), server_rec.data(),
+ server_rec.len()));
+ ASSERT_EQ(SECSuccess,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), client_rec.data(),
+ client_rec.len()));
+
+ ExpectAlert(client_, kTlsAlertEchRequired);
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITH_ECH);
+ server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal);
+ server_->Handshake();
+ DoEchRetry(server_pub, server_priv, server_rec);
+}
+
+// Test that mismatched ECH server keypair triggers a retry.
+TEST_F(TlsConnectStreamTls13Ech, EchMismatchKeysRetry) {
+ EnsureTlsSetup();
+ ScopedSECKEYPublicKey server_pub;
+ ScopedSECKEYPrivateKey server_priv;
+ ScopedSECKEYPublicKey client_pub;
+ ScopedSECKEYPrivateKey client_priv;
+ DataBuffer server_rec;
+ DataBuffer client_rec;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, server_rec,
+ server_pub, server_priv);
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, client_rec,
+ client_pub, client_priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(),
+ server_priv.get(), server_rec.data(),
+ server_rec.len()));
+ ASSERT_EQ(SECSuccess,
+ SSL_SetClientEchConfigs(client_->ssl_fd(), client_rec.data(),
+ client_rec.len()));
+
+ client_->ExpectSendAlert(kTlsAlertEchRequired);
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITH_ECH);
+ server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal);
+ server_->Handshake();
+ DoEchRetry(server_pub, server_priv, server_rec);
+}
+
+// Check that the client validates any server response to GREASE ECH
+TEST_F(TlsConnectStreamTls13, EchValidateGreaseResponse) {
+ EnsureTlsSetup();
+ ScopedSECKEYPublicKey server_pub;
+ ScopedSECKEYPrivateKey server_priv;
+ DataBuffer server_rec;
+ TlsConnectTestBase::GenerateEchConfig(HpkeDhKemX25519Sha256, kDefaultSuites,
+ kPublicName, 100, server_rec,
+ server_pub, server_priv);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(),
+ server_priv.get(), server_rec.data(),
+ server_rec.len()));
+
+ // Damage the length and expect an alert.
+ auto filter = MakeTlsFilter<TlsExtensionDamager>(
+ server_, ssl_tls13_encrypted_client_hello_xtn, 0);
+ filter->EnableDecryption();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(),
+ PR_TRUE)); // GREASE
+ ConnectExpectAlert(client_, kTlsAlertDecodeError);
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_ECH_CONFIG);
+ server_->CheckErrorCode(SSL_ERROR_DECODE_ERROR_ALERT);
+
+ // If the retry_config contains an unknown version, it should be ignored.
+ Reset();
+ EnsureTlsSetup();
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(),
+ server_priv.get(), server_rec.data(),
+ server_rec.len()));
+ server_rec.Write(2, 0xfefe, 2);
+ // Skip the ECHConfigs length, the server sender will re-encode.
+ ASSERT_EQ(SECSuccess, SSLInt_SetRawEchConfigForRetry(server_->ssl_fd(),
+ &server_rec.data()[2],
+ server_rec.len() - 2));
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(),
+ PR_TRUE)); // GREASE
+ Connect();
+
+ // Lastly, if we DO support the retry_config, GREASE ECH should ignore it.
+ Reset();
+ EnsureTlsSetup();
+ server_rec.Write(2, ssl_tls13_encrypted_client_hello_xtn, 2);
+ ASSERT_EQ(SECSuccess,
+ SSL_SetServerEchConfigs(server_->ssl_fd(), server_pub.get(),
+ server_priv.get(), server_rec.data(),
+ server_rec.len()));
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(),
+ PR_TRUE)); // GREASE
+ Connect();
+}
+
+// Test a tampered CHInner (decrypt failure).
+// Expect negotiation on outer, which fails due to the tampered transcript.
+TEST_F(TlsConnectStreamTls13, EchBadCiphertext) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ /* Target the payload:
+ struct {
+ ECHCipherSuite suite; // 4B
+ opaque config_id<0..255>; // 32B
+ opaque enc<1..2^16-1>; // 32B for X25519
+ opaque payload<1..2^16-1>;
+ } ClientEncryptedCH;
+ */
+ MakeTlsFilter<TlsExtensionDamager>(client_,
+ ssl_tls13_encrypted_client_hello_xtn, 80);
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+}
+
+// Test a tampered CHOuter (decrypt failure on AAD).
+// Expect negotiation on outer, which fails due to the tampered transcript.
+TEST_F(TlsConnectStreamTls13, EchOuterBinding) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+
+ static const uint8_t supported_vers_13[] = {0x02, 0x03, 0x04};
+ DataBuffer buf(supported_vers_13, sizeof(supported_vers_13));
+ MakeTlsFilter<TlsExtensionReplacer>(client_, ssl_tls13_supported_versions_xtn,
+ buf);
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+}
+
+// Altering the CH after the Ech Xtn should also cause a failure.
+TEST_F(TlsConnectStreamTls13, EchOuterBindingAfterXtn) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+
+ static const uint8_t supported_vers_13[] = {0x02, 0x03, 0x04};
+ DataBuffer buf(supported_vers_13, sizeof(supported_vers_13));
+ MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello, 5044,
+ buf);
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+}
+
+// Test a bad (unknown) ECHCipherSuite.
+// Expect negotiation on outer, which fails due to the tampered transcript.
+TEST_F(TlsConnectStreamTls13, EchBadCiphersuite) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ /* Make KDF unknown */
+ MakeTlsFilter<TlsExtensionDamager>(client_,
+ ssl_tls13_encrypted_client_hello_xtn, 1);
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+
+ Reset();
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ /* Make AEAD unknown */
+ MakeTlsFilter<TlsExtensionDamager>(client_,
+ ssl_tls13_encrypted_client_hello_xtn, 4);
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+}
+
+/* ECH (configured) client connects to a 1.2 server, this MUST lead to an
+ * 'ech_required' alert being sent by the client when handling the handshake
+ * finished messages [draft-ietf-tls-esni-14, Section 6.1.6]. */
+TEST_F(TlsConnectStreamTls13, EchToTls12Server) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+
+ client_->ExpectEch(false);
+ server_->ExpectEch(false);
+
+ client_->ExpectSendAlert(kTlsAlertEchRequired, kTlsAlertFatal);
+ server_->ExpectReceiveAlert(kTlsAlertEchRequired, kTlsAlertFatal);
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ client_->CheckErrorCode(SSL_ERROR_ECH_RETRY_WITHOUT_ECH);
+
+ /* Reset expectations for the TlsAgent deconstructor. */
+ server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning);
+}
+
+TEST_F(TlsConnectStreamTls13, NoEchFromTls12Client) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ client_->ExpectEch(false);
+ server_->ExpectEch(false);
+ SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_2);
+ Connect();
+ ASSERT_FALSE(filter->captured());
+}
+
+TEST_F(TlsConnectStreamTls13, EchOuterWith12Max) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+
+ static const uint8_t supported_vers_12[] = {0x02, 0x03, 0x03};
+ DataBuffer buf(supported_vers_12, sizeof(supported_vers_12));
+
+ // The server will set the downgrade sentinel. The client needs
+ // to ignore it for this test.
+ client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_FALSE);
+
+ StartConnect();
+ MakeTlsFilter<TlsExtensionReplacer>(client_, ssl_tls13_supported_versions_xtn,
+ buf);
+
+ // Server should ignore the extension if 1.2 is negotiated.
+ // Here the CHInner is not modified, so if Accepted we'd connect.
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ server_, ssl_tls13_encrypted_client_hello_xtn);
+ client_->ExpectEch(false);
+ server_->ExpectEch(false);
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+ ASSERT_FALSE(filter->captured());
+}
+
+TEST_F(TlsConnectStreamTls13, EchOuterExtensionsInCHOuter) {
+ EnsureTlsSetup();
+ uint8_t outer[2] = {0};
+ DataBuffer outer_buf(outer, sizeof(outer));
+ MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello,
+ ssl_tls13_outer_extensions_xtn,
+ outer_buf);
+
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+static SECStatus NoopExtensionHandler(PRFileDesc* fd, SSLHandshakeType message,
+ const PRUint8* data, unsigned int len,
+ SSLAlertDescription* alert, void* arg) {
+ return SECSuccess;
+}
+
+static PRBool EmptyExtensionWriter(PRFileDesc* fd, SSLHandshakeType message,
+ PRUint8* data, unsigned int* len,
+ unsigned int maxLen, void* arg) {
+ return true;
+}
+
+static PRBool LargeExtensionWriter(PRFileDesc* fd, SSLHandshakeType message,
+ PRUint8* data, unsigned int* len,
+ unsigned int maxLen, void* arg) {
+ unsigned int length = 1024;
+ PR_ASSERT(length <= maxLen);
+ memset(data, 0, length);
+ *len = length;
+ return true;
+}
+
+static PRBool OuterOnlyExtensionWriter(PRFileDesc* fd, SSLHandshakeType message,
+ PRUint8* data, unsigned int* len,
+ unsigned int maxLen, void* arg) {
+ if (message == ssl_hs_ech_outer_client_hello) {
+ return LargeExtensionWriter(fd, message, data, len, maxLen, arg);
+ }
+ return false;
+}
+
+static PRBool InnerOnlyExtensionWriter(PRFileDesc* fd, SSLHandshakeType message,
+ PRUint8* data, unsigned int* len,
+ unsigned int maxLen, void* arg) {
+ if (message == ssl_hs_client_hello) {
+ return LargeExtensionWriter(fd, message, data, len, maxLen, arg);
+ }
+ return false;
+}
+
+static PRBool InnerOuterDiffExtensionWriter(PRFileDesc* fd,
+ SSLHandshakeType message,
+ PRUint8* data, unsigned int* len,
+ unsigned int maxLen, void* arg) {
+ unsigned int length = 1024;
+ PR_ASSERT(length <= maxLen);
+ memset(data, (message == ssl_hs_client_hello) ? 1 : 0, length);
+ *len = length;
+ return true;
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchCustomExtensionWriter) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62028, EmptyExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchCustomExtensionWriterOuterOnly) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62028, OuterOnlyExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ EXPECT_EQ(SECSuccess,
+ SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true));
+
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchCustomExtensionWriterInnerOnly) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62028, InnerOnlyExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ EXPECT_EQ(SECSuccess,
+ SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true));
+
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+}
+
+// Write different values to inner and outer CH.
+TEST_F(TlsConnectStreamTls13Ech, EchCustomExtensionWriterDifferent) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_InstallExtensionHooks(client_->ssl_fd(), 62028,
+ InnerOuterDiffExtensionWriter, nullptr,
+ NoopExtensionHandler, nullptr));
+ EXPECT_EQ(SECSuccess,
+ SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true));
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+ ASSERT_TRUE(filter->extension().len() > 1024);
+}
+
+// Test that basic compression works
+TEST_F(TlsConnectStreamTls13Ech, EchCustomExtensionWriterCompressionBasic) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+
+ // This will be compressed.
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62028, LargeExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ EXPECT_EQ(SECSuccess,
+ SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true));
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+ size_t echXtnLen = filter->extension().len();
+ ASSERT_TRUE(echXtnLen > 0 && echXtnLen < 1024);
+}
+
+// Test that compression works when things change.
+TEST_F(TlsConnectStreamTls13Ech,
+ EchCustomExtensionWriterCompressSomeDifferent) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+
+ // This will be compressed.
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62028, LargeExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ // This can't be.
+ EXPECT_EQ(SECSuccess,
+ SSL_InstallExtensionHooks(client_->ssl_fd(), 62029,
+ InnerOuterDiffExtensionWriter, nullptr,
+ NoopExtensionHandler, nullptr));
+ // This will be compressed.
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62030, LargeExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ EXPECT_EQ(SECSuccess,
+ SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true));
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+ auto echXtnLen = filter->extension().len();
+ /* Exactly one custom xtn plus change */
+ ASSERT_TRUE(echXtnLen > 1024 && echXtnLen < 2048);
+}
+
+// An outer-only extension stops compression.
+TEST_F(TlsConnectStreamTls13Ech,
+ EchCustomExtensionWriterCompressSomeOuterOnly) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+
+ // This will be compressed.
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62028, LargeExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ // This can't be as it appears in the outer only.
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62029, OuterOnlyExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ // This will be compressed
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62030, LargeExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ EXPECT_EQ(SECSuccess,
+ SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true));
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+ size_t echXtnLen = filter->extension().len();
+ ASSERT_TRUE(echXtnLen > 0 && echXtnLen < 1024);
+}
+
+// An inner only extension does not stop compression.
+TEST_F(TlsConnectStreamTls13Ech, EchCustomExtensionWriterCompressAllInnerOnly) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+
+ // This will be compressed.
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62028, LargeExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ // This can't be as it appears in the inner only.
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62029, InnerOnlyExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ // This will be compressed.
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62030, LargeExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ EXPECT_EQ(SECSuccess,
+ SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true));
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_encrypted_client_hello_xtn);
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+ size_t echXtnLen = filter->extension().len();
+ ASSERT_TRUE(echXtnLen > 1024 && echXtnLen < 2048);
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchAcceptCustomXtn) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62028, LargeExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+
+ EXPECT_EQ(SECSuccess,
+ SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true));
+
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ server_->ssl_fd(), 62028, LargeExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(server_, 62028);
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+}
+
+// Test that we reject Outer Xtn in SH if accepting ECH Inner
+TEST_F(TlsConnectStreamTls13Ech, EchRejectOuterXtnOnInner) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62028, OuterOnlyExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+
+ EXPECT_EQ(SECSuccess,
+ SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true));
+
+ // Put the same extension on the Server Hello
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ server_->ssl_fd(), 62028, LargeExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(server_, 62028);
+ client_->ExpectEch(false);
+ server_->ExpectEch(false);
+ client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ // The server will be expecting an alert encrypted under a different key.
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ ASSERT_TRUE(filter->captured());
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_EXTENSION);
+}
+
+// Test that we reject Inner Xtn in SH if accepting ECH Outer
+TEST_F(TlsConnectStreamTls13Ech, EchRejectInnerXtnOnOuter) {
+ EnsureTlsSetup();
+
+ // Setup ECH only on the client
+ SetupEch(client_, server_, HpkeDhKemX25519Sha256, false, true, false);
+
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62028, InnerOnlyExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+
+ EXPECT_EQ(SECSuccess,
+ SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true));
+
+ // Put the same extension on the Server Hello
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ server_->ssl_fd(), 62028, LargeExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(server_, 62028);
+ client_->ExpectEch(false);
+ server_->ExpectEch(false);
+ client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ // The server will be expecting an alert encrypted under a different key.
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ ASSERT_TRUE(filter->captured());
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_EXTENSION);
+}
+
+// Test that we reject an Inner Xtn in SH, if accepting Ech Inner and
+// we didn't advertise it on SH Outer.
+TEST_F(TlsConnectStreamTls13Ech, EchRejectInnerXtnNotOnOuter) {
+ EnsureTlsSetup();
+
+ // Setup ECH only on the client
+ SetupEch(client_, server_);
+
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ client_->ssl_fd(), 62028, InnerOnlyExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+
+ EXPECT_EQ(SECSuccess,
+ SSL_CallExtensionWriterOnEchInner(client_->ssl_fd(), true));
+
+ // Put the same extension on the Server Hello
+ EXPECT_EQ(SECSuccess, SSL_InstallExtensionHooks(
+ server_->ssl_fd(), 62028, LargeExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr));
+ auto filter = MakeTlsFilter<TlsExtensionCapture>(server_, 62028);
+ client_->ExpectEch(false);
+ server_->ExpectEch(false);
+ client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ // The server will be expecting an alert encrypted under a different key.
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ ASSERT_TRUE(filter->captured());
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_EXTENSION);
+}
+
+// At draft-09: If a CH containing the ech_is_inner extension is received, the
+// server acts as backend server in split-mode by responding with the ECH
+// acceptance signal. The signal value itself depends on the handshake secret,
+// which we've broken by appending ech_is_inner. For now, just check that the
+// server negotiates ech_is_inner (which is what triggers sending the signal).
+TEST_F(TlsConnectStreamTls13, EchBackendAcceptance) {
+ DataBuffer ch_buf;
+ static uint8_t inner_value[1] = {1};
+ DataBuffer inner_buffer(inner_value, sizeof(inner_value));
+
+ EnsureTlsSetup();
+ StartConnect();
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_FALSE));
+ MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello,
+ ssl_tls13_encrypted_client_hello_xtn,
+ inner_buffer);
+
+ EXPECT_EQ(SECSuccess, SSL_EnableTls13BackendEch(server_->ssl_fd(), PR_TRUE));
+ client_->Handshake();
+ server_->Handshake();
+
+ ExpectAlert(client_, kTlsAlertBadRecordMac);
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ EXPECT_EQ(PR_TRUE,
+ SSLInt_ExtensionNegotiated(server_->ssl_fd(),
+ ssl_tls13_encrypted_client_hello_xtn));
+ server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning);
+}
+
+// A public_name that includes an IP address has to be rejected.
+TEST_F(TlsConnectStreamTls13Ech, EchPublicNameIp) {
+ static const std::vector<std::string> kIps = {
+ "0.0.0.0",
+ "1.1.1.1",
+ "255.255.255.255",
+ "255.255.65535",
+ "255.16777215",
+ "4294967295",
+ "0377.0377.0377.0377",
+ "0377.0377.0177777",
+ "0377.077777777",
+ "037777777777",
+ "00377.00377.00377.00377",
+ "00377.00377.00177777",
+ "00377.0077777777",
+ "0037777777777",
+ "0xff.0xff.0xff.0xff",
+ "0xff.0xff.0xffff",
+ "0xff.0xffffff",
+ "0xffffffff",
+ "0XFF.0XFF.0XFF.0XFF",
+ "0XFF.0XFF.0XFFFF",
+ "0XFF.0XFFFFFF",
+ "0XFFFFFFFF",
+ "0x0ff.0x0ff.0x0ff.0x0ff",
+ "0x0ff.0x0ff.0x0ffff",
+ "0x0ff.0x0ffffff",
+ "0x0ffffffff",
+ "00000000000000000000000000000000000000000",
+ "00000000000000000000000000000000000000001",
+ "127.0.0.1",
+ "127.0.1",
+ "127.1",
+ "2130706433",
+ "017700000001",
+ };
+ ValidatePublicNames(kIps, SECFailure);
+}
+
+// These are nearly IP addresses.
+TEST_F(TlsConnectStreamTls13Ech, EchPublicNameNotIp) {
+ static const std::vector<std::string> kNotIps = {
+ "0.0.0.0.0",
+ "1.2.3.4.5",
+ "999999999999999999999999999999999",
+ "07777777777777777777777777777777777777777",
+ "111111111100000000001111111111000000000011111111110000000000123",
+ "256.255.255.255",
+ "255.256.255.255",
+ "255.255.256.255",
+ "255.255.255.256",
+ "255.255.65536",
+ "255.16777216",
+ "4294967296",
+ "0400.0377.0377.0377",
+ "0377.0400.0377.0377",
+ "0377.0377.0400.0377",
+ "0377.0377.0377.0400",
+ "0377.0377.0200000",
+ "0377.0100000000",
+ "040000000000",
+ "0x100.0xff.0xff.0xff",
+ "0xff.0x100.0xff.0xff",
+ "0xff.0xff.0x100.0xff",
+ "0xff.0xff.0xff.0x100",
+ "0xff.0xff.0x10000",
+ "0xff.0x1000000",
+ "0x100000000",
+ "08",
+ "09",
+ "a",
+ "0xg",
+ "0XG",
+ "0x",
+ "0x.1.2.3",
+ "test-name",
+ "test-name.test",
+ "TEST-NAME",
+ "under_score",
+ "_under_score",
+ "under_score_",
+ };
+ ValidatePublicNames(kNotIps, SECSuccess);
+}
+
+TEST_F(TlsConnectStreamTls13Ech, EchPublicNameNotLdh) {
+ static const std::vector<std::string> kNotLdh = {
+ ".",
+ "name.",
+ ".name",
+ "test..name",
+ "1111111111000000000011111111110000000000111111111100000000001234",
+ "-name",
+ "name-",
+ "test-.name",
+ "!",
+ u8"\u2077",
+ };
+ ValidatePublicNames(kNotLdh, SECFailure);
+}
+
+TEST_F(TlsConnectStreamTls13, EchClientHelloExtensionPermutation) {
+ EnsureTlsSetup();
+ PR_ASSERT(SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_CH_EXTENSION_PERMUTATION,
+ PR_TRUE) == SECSuccess);
+ SetupEch(client_, server_);
+
+ client_->ExpectEch();
+ server_->ExpectEch();
+ Connect();
+}
+
+TEST_F(TlsConnectStreamTls13, EchGreaseClientHelloExtensionPermutation) {
+ EnsureTlsSetup();
+ PR_ASSERT(SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_CH_EXTENSION_PERMUTATION,
+ PR_TRUE) == SECSuccess);
+ PR_ASSERT(SSL_EnableTls13GreaseEch(client_->ssl_fd(), PR_FALSE) ==
+ SECSuccess);
+ Connect();
+}
+
+INSTANTIATE_TEST_SUITE_P(EchAgentTest, TlsAgentEchTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV13));
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc
new file mode 100644
index 0000000000..ab52a07e84
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_filter.cc
@@ -0,0 +1,1293 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_filter.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include <cassert>
+#include <iostream>
+#include "gtest_utils.h"
+#include "tls_agent.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+#include "tls_protect.h"
+
+namespace nss_test {
+
+void TlsVersioned::WriteStream(std::ostream& stream) const {
+ stream << (is_dtls() ? "DTLS " : "TLS ");
+ switch (version()) {
+ case 0:
+ stream << "(no version)";
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_0:
+ stream << "1.0";
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ stream << (is_dtls() ? "1.0" : "1.1");
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ stream << "1.2";
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_3:
+ stream << "1.3";
+ break;
+ default:
+ stream << "Invalid version: " << version();
+ break;
+ }
+}
+
+TlsRecordFilter::TlsRecordFilter(const std::shared_ptr<TlsAgent>& a)
+ : agent_(a) {
+ cipher_specs_.emplace_back(a->variant() == ssl_variant_datagram, 0);
+}
+
+void TlsRecordFilter::EnableDecryption() {
+ EXPECT_EQ(SECSuccess,
+ SSL_SecretCallback(agent()->ssl_fd(), SecretCallback, this));
+ decrypting_ = true;
+}
+
+void TlsRecordFilter::SecretCallback(PRFileDesc* fd, PRUint16 epoch,
+ SSLSecretDirection dir, PK11SymKey* secret,
+ void* arg) {
+ TlsRecordFilter* self = static_cast<TlsRecordFilter*>(arg);
+ if (g_ssl_gtest_verbose) {
+ std::cerr << self->agent()->role_str() << ": " << dir
+ << " secret changed for epoch " << epoch << std::endl;
+ }
+
+ if (dir == ssl_secret_read) {
+ return;
+ }
+
+ for (auto& spec : self->cipher_specs_) {
+ ASSERT_NE(spec.epoch(), epoch) << "duplicate spec for epoch " << epoch;
+ }
+
+ SSLPreliminaryChannelInfo preinfo;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetPreliminaryChannelInfo(self->agent()->ssl_fd(), &preinfo,
+ sizeof(preinfo)));
+ EXPECT_EQ(sizeof(preinfo), preinfo.length);
+
+ // Check the version.
+ if (preinfo.valuesSet & ssl_preinfo_version) {
+ EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion);
+ } else {
+ EXPECT_EQ(1U, epoch);
+ }
+
+ uint16_t suite;
+ if (epoch == 1) {
+ // 0-RTT
+ EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_0rtt_cipher_suite);
+ suite = preinfo.zeroRttCipherSuite;
+ } else {
+ EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite);
+ suite = preinfo.cipherSuite;
+ }
+
+ SSLCipherSuiteInfo cipherinfo;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetCipherSuiteInfo(suite, &cipherinfo, sizeof(cipherinfo)));
+ EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length);
+
+ self->cipher_specs_.emplace_back(self->is_dtls_agent(), epoch);
+ EXPECT_TRUE(self->cipher_specs_.back().SetKeys(&cipherinfo, secret));
+}
+
+bool TlsRecordFilter::is_dtls_agent() const {
+ return agent()->variant() == ssl_variant_datagram;
+}
+
+bool TlsRecordFilter::is_dtls13() const {
+ if (!is_dtls_agent()) {
+ return false;
+ }
+ if (agent()->state() == TlsAgent::STATE_CONNECTED) {
+ return agent()->version() >= SSL_LIBRARY_VERSION_TLS_1_3;
+ }
+ SSLPreliminaryChannelInfo info;
+ EXPECT_EQ(SECSuccess, SSL_GetPreliminaryChannelInfo(agent()->ssl_fd(), &info,
+ sizeof(info)));
+ return (info.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) ||
+ info.canSendEarlyData;
+}
+
+bool TlsRecordFilter::is_dtls13_ciphertext(uint8_t ct) const {
+ return is_dtls13() && (ct & kCtDtlsCiphertextMask) == kCtDtlsCiphertext;
+}
+
+// Gets the cipher spec that matches the specified epoch.
+TlsCipherSpec& TlsRecordFilter::spec(uint16_t write_epoch) {
+ for (auto& sp : cipher_specs_) {
+ if (sp.epoch() == write_epoch) {
+ return sp;
+ }
+ }
+
+ // If we aren't decrypting, provide a cipher spec that does nothing other than
+ // count sequence numbers.
+ EXPECT_FALSE(decrypting_) << "No spec available for epoch " << write_epoch;
+ ;
+ cipher_specs_.emplace_back(is_dtls_agent(), write_epoch);
+ return cipher_specs_.back();
+}
+
+PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
+ DataBuffer* output) {
+ // Disable during shutdown.
+ if (!agent()) {
+ return KEEP;
+ }
+
+ bool changed = false;
+ size_t offset = 0U;
+
+ output->Allocate(input.len());
+ TlsParser parser(input);
+
+ // This uses the current write spec for the purposes of parsing the epoch and
+ // sequence number from the header. This might be wrong because we can
+ // receive records from older specs, but guessing is good enough:
+ // - In DTLS, parsing the sequence number corrects any errors.
+ // - In TLS, we don't use the sequence number unless decrypting, where we use
+ // trial decryption to get the right epoch.
+ uint16_t write_epoch = 0;
+ SECStatus rv = SSL_GetCurrentEpoch(agent()->ssl_fd(), nullptr, &write_epoch);
+ if (rv != SECSuccess) {
+ ADD_FAILURE() << "unable to read epoch";
+ return KEEP;
+ }
+ uint64_t guess_seqno = static_cast<uint64_t>(write_epoch) << 48;
+
+ while (parser.remaining()) {
+ TlsRecordHeader header;
+ DataBuffer record;
+ if (!header.Parse(is_dtls13(), guess_seqno, &parser, &record)) {
+ ADD_FAILURE() << "not a valid record";
+ return KEEP;
+ }
+
+ if (FilterRecord(header, record, &offset, output) != KEEP) {
+ changed = true;
+ } else {
+ offset = header.Write(output, offset, record);
+ }
+ }
+ output->Truncate(offset);
+
+ // Record how many packets we actually touched.
+ if (changed) {
+ ++count_;
+ return (offset == 0) ? DROP : CHANGE;
+ }
+
+ return KEEP;
+}
+
+PacketFilter::Action TlsRecordFilter::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& record, size_t* offset,
+ DataBuffer* output) {
+ DataBuffer filtered;
+ uint8_t inner_content_type;
+ DataBuffer plaintext;
+ uint16_t protection_epoch = 0;
+ TlsRecordHeader out_header(header);
+
+ if (!Unprotect(header, record, &protection_epoch, &inner_content_type,
+ &plaintext, &out_header)) {
+ std::cerr << agent()->role_str() << ": unprotect failed: " << header << ":"
+ << record << std::endl;
+ return KEEP;
+ }
+
+ auto& protection_spec = spec(protection_epoch);
+ TlsRecordHeader real_header(out_header.variant(), out_header.version(),
+ inner_content_type, out_header.sequence_number());
+
+ PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered);
+ // In stream mode, even if something doesn't change we need to re-encrypt if
+ // previous packets were dropped.
+ if (action == KEEP) {
+ if (out_header.is_dtls() || !protection_spec.record_dropped()) {
+ // Count every outgoing packet.
+ protection_spec.RecordProtected();
+ return KEEP;
+ }
+ filtered = plaintext;
+ }
+
+ if (action == DROP) {
+ std::cerr << "record drop: " << out_header << ":" << record << std::endl;
+ protection_spec.RecordDropped();
+ return DROP;
+ }
+
+ EXPECT_GT(0x10000U, filtered.len());
+ if (action != KEEP) {
+ std::cerr << "record old: " << plaintext << std::endl;
+ std::cerr << "record new: " << filtered << std::endl;
+ }
+
+ uint64_t seq_num = protection_spec.next_out_seqno();
+ if (!decrypting_ && out_header.is_dtls()) {
+ // Copy over the epoch, which isn't tracked when not decrypting.
+ seq_num |= out_header.sequence_number() & (0xffffULL << 48);
+ }
+ out_header.sequence_number(seq_num);
+
+ DataBuffer ciphertext;
+ bool rv = Protect(protection_spec, out_header, inner_content_type, filtered,
+ &ciphertext, &out_header);
+ if (!rv) {
+ return KEEP;
+ }
+ *offset = out_header.Write(output, *offset, ciphertext);
+ return CHANGE;
+}
+
+size_t TlsRecordHeader::header_length() const {
+ // If we have a header, return it's length.
+ if (header_.len()) {
+ return header_.len();
+ }
+
+ // Otherwise make a dummy header and return the length.
+ DataBuffer buf;
+ return WriteHeader(&buf, 0, 0);
+}
+
+bool TlsRecordHeader::MaskSequenceNumber() {
+ return MaskSequenceNumber(sn_mask());
+}
+
+bool TlsRecordHeader::MaskSequenceNumber(const DataBuffer& mask_buf) {
+ if (mask_buf.empty()) {
+ return false;
+ }
+
+ DataBuffer mask;
+ if (is_dtls13_ciphertext()) {
+ uint64_t seqno = sequence_number();
+ uint8_t len = content_type() & kCtDtlsCiphertext16bSeqno ? 2 : 1;
+ uint16_t seqno_bitmask = (1 << len * 8) - 1;
+ DataBuffer val;
+ if (val.Write(0, seqno & seqno_bitmask, len) != len) {
+ return false;
+ }
+
+#ifdef UNSAFE_FUZZER_MODE
+ // Use a null mask.
+ mask.Allocate(mask_buf.len());
+#endif
+ mask.Append(mask_buf);
+ val.data()[0] ^= mask.data()[0];
+ if (len == 2 && mask.len() > 1) {
+ val.data()[1] ^= mask.data()[1];
+ }
+ uint32_t tmp;
+ if (!val.Read(0, len, &tmp)) {
+ return false;
+ }
+
+ seqno = (seqno & ~seqno_bitmask) | tmp;
+ seqno_is_masked_ = !seqno_is_masked_;
+ if (!seqno_is_masked_) {
+ seqno = ParseSequenceNumber(guess_seqno_, seqno, len * 8, 2);
+ }
+ sequence_number_ = seqno;
+
+ // Now update the header bytes
+ if (header_.len() > 1) {
+ header_.data()[1] ^= mask.data()[0];
+ if ((content_type() & kCtDtlsCiphertext16bSeqno) && header().len() > 2) {
+ header_.data()[2] ^= mask.data()[1];
+ }
+ }
+ }
+
+ sn_mask_ = mask;
+ return true;
+}
+
+uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t guess_seqno,
+ uint32_t partial,
+ size_t partial_bits) {
+ EXPECT_GE(32U, partial_bits);
+ uint64_t mask = (1ULL << partial_bits) - 1;
+ // First we determine the highest possible value. This is half the
+ // expressible range above the expected value (|guess_seqno|), less 1.
+ //
+ // We subtract the extra 1 from the cap so that when given a choice between
+ // the equidistant expected+N and expected-N we want to chose the lower. With
+ // 0-RTT, we sometimes have to recover an epoch of 1 when we expect an epoch
+ // of 3 and with 2 partial bits, the alternative result of 5 is wrong.
+ uint64_t cap = guess_seqno + (1ULL << (partial_bits - 1)) - 1;
+ // Add the partial piece in. e.g., xxxx789a and 1234 becomes xxxx1234.
+ uint64_t seq_no = (cap & ~mask) | partial;
+ // If the partial value is higher than the same partial piece from the cap,
+ // then the real value has to be lower. e.g., xxxx1234 can't become xxxx5678.
+ if (partial > (cap & mask) && (seq_no >= (1ULL << partial_bits))) {
+ seq_no -= 1ULL << partial_bits;
+ }
+ return seq_no;
+}
+
+// Determine the full epoch and sequence number from an expected and raw value.
+// The expected, raw, and output values are packed as they are in DTLS 1.2 and
+// earlier: with 16 bits of epoch and 48 bits of sequence number. The raw value
+// is packed this way (even before recovery) so that we don't need to track a
+// moving value between two calls (one to recover the epoch, and one after
+// unmasking to recover the sequence number).
+uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint64_t raw,
+ size_t seq_no_bits,
+ size_t epoch_bits) {
+ uint64_t epoch_mask = (1ULL << epoch_bits) - 1;
+ uint64_t ep = RecoverSequenceNumber(expected >> 48, (raw >> 48) & epoch_mask,
+ epoch_bits);
+ if (ep > (expected >> 48)) {
+ // If the epoch has changed, reset the expected sequence number.
+ expected = 0;
+ } else {
+ // Otherwise, retain just the sequence number part.
+ expected &= (1ULL << 48) - 1;
+ }
+ uint64_t seq_no_mask = (1ULL << seq_no_bits) - 1;
+ uint64_t seq_no = (raw & seq_no_mask);
+ if (!seqno_is_masked_) {
+ seq_no = RecoverSequenceNumber(expected, seq_no, seq_no_bits);
+ }
+
+ return (ep << 48) | seq_no;
+}
+
+bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser,
+ DataBuffer* body) {
+ auto mark = parser->consumed();
+
+ if (!parser->Read(&content_type_)) {
+ return false;
+ }
+
+ if (is_dtls13) {
+ variant_ = ssl_variant_datagram;
+ version_ = SSL_LIBRARY_VERSION_TLS_1_3;
+
+#ifndef UNSAFE_FUZZER_MODE
+ // Deal with the DTLSCipherText header.
+ if (is_dtls13_ciphertext()) {
+ uint8_t seq_no_bytes =
+ (content_type_ & kCtDtlsCiphertext16bSeqno) ? 2 : 1;
+ uint32_t tmp;
+
+ if (!parser->Read(&tmp, seq_no_bytes)) {
+ return false;
+ }
+
+ // Store the guess if masked. If and when seqno_bytesenceNumber is called,
+ // the value will be unmasked and recovered. This assumes we only call
+ // Parse() on headers containing masked values.
+ seqno_is_masked_ = true;
+ guess_seqno_ = seqno;
+ uint64_t ep = content_type_ & 0x03;
+ sequence_number_ = (ep << 48) | tmp;
+
+ // Recover the full epoch. Note the sequence number portion holds the
+ // masked value until a call to Mask() reveals it (as indicated by
+ // |seqno_is_masked_|).
+ sequence_number_ =
+ ParseSequenceNumber(seqno, sequence_number_, seq_no_bytes * 8, 2);
+
+ uint32_t len_bytes =
+ (content_type_ & kCtDtlsCiphertextLengthPresent) ? 2 : 0;
+ if (len_bytes) {
+ if (!parser->Read(&tmp, 2)) {
+ return false;
+ }
+ }
+
+ if (!parser->ReadFromMark(&header_, parser->consumed() - mark, mark)) {
+ return false;
+ }
+
+ return len_bytes ? parser->Read(body, tmp)
+ : parser->Read(body, parser->remaining());
+ }
+
+ // The full DTLSPlainText header can only be used for a few types.
+ EXPECT_TRUE(content_type_ == ssl_ct_alert ||
+ content_type_ == ssl_ct_handshake ||
+ content_type_ == ssl_ct_ack);
+#endif
+ }
+
+ uint32_t ver;
+ if (!parser->Read(&ver, 2)) {
+ return false;
+ }
+ if (!is_dtls13) {
+ variant_ = IsDtls(ver) ? ssl_variant_datagram : ssl_variant_stream;
+ }
+ version_ = NormalizeTlsVersion(ver);
+
+ if (is_dtls()) {
+ // If this is DTLS, read the sequence number.
+ uint32_t tmp;
+ if (!parser->Read(&tmp, 4)) {
+ return false;
+ }
+ sequence_number_ = static_cast<uint64_t>(tmp) << 32;
+ if (!parser->Read(&tmp, 4)) {
+ return false;
+ }
+ sequence_number_ |= static_cast<uint64_t>(tmp);
+ } else {
+ sequence_number_ = seqno;
+ }
+ if (!parser->ReadFromMark(&header_, parser->consumed() + 2 - mark, mark)) {
+ return false;
+ }
+ return parser->ReadVariable(body, 2);
+}
+
+size_t TlsRecordHeader::WriteHeader(DataBuffer* buffer, size_t offset,
+ size_t body_len) const {
+ if (is_dtls13_ciphertext()) {
+ uint8_t seq_no_bytes = (content_type_ & kCtDtlsCiphertext16bSeqno) ? 2 : 1;
+ // application_data records in TLS 1.3 have a different header format.
+ uint32_t e = (sequence_number_ >> 48) & 0x3;
+ uint32_t seqno = sequence_number_ & ((1ULL << seq_no_bytes * 8) - 1);
+ uint8_t new_content_type_ = content_type_ | e;
+ offset = buffer->Write(offset, new_content_type_, 1);
+ offset = buffer->Write(offset, seqno, seq_no_bytes);
+
+ if (content_type_ & kCtDtlsCiphertextLengthPresent) {
+ offset = buffer->Write(offset, body_len, 2);
+ }
+ } else {
+ offset = buffer->Write(offset, content_type_, 1);
+ uint16_t v = is_dtls() ? TlsVersionToDtlsVersion(version_) : version_;
+ offset = buffer->Write(offset, v, 2);
+ if (is_dtls()) {
+ // write epoch (2 octet), and seqnum (6 octet)
+ offset = buffer->Write(offset, sequence_number_ >> 32, 4);
+ offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4);
+ }
+ offset = buffer->Write(offset, body_len, 2);
+ }
+
+ return offset;
+}
+
+size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body) const {
+ offset = WriteHeader(buffer, offset, body.len());
+ offset = buffer->Write(offset, body);
+ return offset;
+}
+
+bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
+ const DataBuffer& ciphertext,
+ uint16_t* protection_epoch,
+ uint8_t* inner_content_type,
+ DataBuffer* plaintext,
+ TlsRecordHeader* out_header) {
+ if (!decrypting_ || !header.is_protected()) {
+ // Maintain the epoch and sequence number for plaintext records.
+ uint16_t ep = 0;
+ if (is_dtls_agent()) {
+ ep = static_cast<uint16_t>(header.sequence_number() >> 48);
+ }
+ spec(ep).RecordUnprotected(header.sequence_number());
+ *protection_epoch = ep;
+ *inner_content_type = header.content_type();
+ *plaintext = ciphertext;
+ return true;
+ }
+
+ uint16_t ep = 0;
+ if (is_dtls_agent()) {
+ ep = static_cast<uint16_t>(header.sequence_number() >> 48);
+ if (!spec(ep).Unprotect(header, ciphertext, plaintext, out_header)) {
+ return false;
+ }
+ } else {
+ // In TLS, records aren't clearly labelled with their epoch, and we
+ // can't just use the newest keys because the same flight of messages can
+ // contain multiple epochs. So... trial decrypt!
+ for (size_t i = cipher_specs_.size() - 1; i > 0; --i) {
+ if (cipher_specs_[i].Unprotect(header, ciphertext, plaintext,
+ out_header)) {
+ ep = cipher_specs_[i].epoch();
+ break;
+ }
+ }
+ if (!ep) {
+ return false;
+ }
+ }
+
+ size_t len = plaintext->len();
+ while (len > 0 && !plaintext->data()[len - 1]) {
+ --len;
+ }
+ if (!len) {
+ // Bogus padding.
+ return false;
+ }
+
+ *protection_epoch = ep;
+ *inner_content_type = plaintext->data()[len - 1];
+ plaintext->Truncate(len - 1);
+ if (g_ssl_gtest_verbose) {
+ std::cerr << agent()->role_str() << ": unprotect: epoch=" << ep
+ << " seq=" << std::hex << header.sequence_number() << std::dec
+ << " " << *plaintext << std::endl;
+ }
+
+ return true;
+}
+
+bool TlsRecordFilter::Protect(TlsCipherSpec& protection_spec,
+ const TlsRecordHeader& header,
+ uint8_t inner_content_type,
+ const DataBuffer& plaintext,
+ DataBuffer* ciphertext,
+ TlsRecordHeader* out_header, size_t padding) {
+ if (!protection_spec.is_protected()) {
+ // Not protected, just keep the sequence numbers updated.
+ protection_spec.RecordProtected();
+ *ciphertext = plaintext;
+ return true;
+ }
+
+ DataBuffer padded;
+ padded.Allocate(plaintext.len() + 1 + padding);
+ size_t offset = padded.Write(0, plaintext.data(), plaintext.len());
+ padded.Write(offset, inner_content_type, 1);
+
+ bool ok = protection_spec.Protect(header, padded, ciphertext, out_header);
+ if (!ok) {
+ ADD_FAILURE() << "protect fail";
+ } else if (g_ssl_gtest_verbose) {
+ std::cerr << agent()->role_str()
+ << ": protect: epoch=" << protection_spec.epoch()
+ << " seq=" << std::hex << header.sequence_number() << std::dec
+ << " " << *ciphertext << std::endl;
+ }
+ return ok;
+}
+
+bool IsHelloRetry(const DataBuffer& body) {
+ static const uint8_t ssl_hello_retry_random[] = {
+ 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C,
+ 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB,
+ 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C};
+ return memcmp(body.data() + 2, ssl_hello_retry_random,
+ sizeof(ssl_hello_retry_random)) == 0;
+}
+
+bool TlsHandshakeFilter::IsFilteredType(const HandshakeHeader& header,
+ const DataBuffer& body) {
+ if (handshake_types_.empty()) {
+ return true;
+ }
+
+ uint8_t type = header.handshake_type();
+ if (type == kTlsHandshakeServerHello) {
+ if (IsHelloRetry(body)) {
+ type = kTlsHandshakeHelloRetryRequest;
+ }
+ }
+ return handshake_types_.count(type) > 0U;
+}
+
+PacketFilter::Action TlsHandshakeFilter::FilterRecord(
+ const TlsRecordHeader& record_header, const DataBuffer& input,
+ DataBuffer* output) {
+ // Check that the first byte is as requested.
+ if (record_header.content_type() != ssl_ct_handshake) {
+ return KEEP;
+ }
+
+ bool changed = false;
+ size_t offset = 0U;
+ output->Allocate(input.len()); // Preallocate a little.
+
+ TlsParser parser(input);
+ while (parser.remaining()) {
+ HandshakeHeader header;
+ DataBuffer handshake;
+ bool complete = false;
+ if (!header.Parse(&parser, record_header, preceding_fragment_, &handshake,
+ &complete)) {
+ return KEEP;
+ }
+
+ if (!complete) {
+ EXPECT_TRUE(record_header.is_dtls());
+ // Save the fragment and drop it from this record. Fragments are
+ // coalesced with the last fragment of the handshake message.
+ changed = true;
+ preceding_fragment_.Assign(handshake);
+ continue;
+ }
+ preceding_fragment_.Truncate(0);
+
+ DataBuffer filtered;
+ PacketFilter::Action action;
+ if (!IsFilteredType(header, handshake)) {
+ action = KEEP;
+ } else {
+ action = FilterHandshake(header, handshake, &filtered);
+ }
+ if (action == DROP) {
+ changed = true;
+ std::cerr << "handshake drop: " << handshake << std::endl;
+ continue;
+ }
+
+ const DataBuffer* source = &handshake;
+ if (action == CHANGE) {
+ EXPECT_GT(0x1000000U, filtered.len());
+ changed = true;
+ std::cerr << "handshake old: " << handshake << std::endl;
+ std::cerr << "handshake new: " << filtered << std::endl;
+ source = &filtered;
+ } else if (preceding_fragment_.len()) {
+ changed = true;
+ }
+
+ offset = header.Write(output, offset, *source);
+ }
+ output->Truncate(offset);
+ return changed ? (offset ? CHANGE : DROP) : KEEP;
+}
+
+bool TlsHandshakeFilter::HandshakeHeader::ReadLength(
+ TlsParser* parser, const TlsRecordHeader& header, uint32_t expected_offset,
+ uint32_t* length, bool* last_fragment) {
+ uint32_t message_length;
+ if (!parser->Read(&message_length, 3)) {
+ return false; // malformed
+ }
+
+ if (!header.is_dtls()) {
+ *last_fragment = true;
+ *length = message_length;
+ return true; // nothing left to do
+ }
+
+ // Read and check DTLS parameters
+ uint32_t message_seq_tmp;
+ if (!parser->Read(&message_seq_tmp, 2)) { // sequence number
+ return false;
+ }
+ message_seq_ = message_seq_tmp;
+
+ uint32_t offset = 0;
+ if (!parser->Read(&offset, 3)) {
+ return false;
+ }
+ // We only parse if the fragments are all complete and in order.
+ if (offset != expected_offset) {
+ EXPECT_NE(0U, header.epoch())
+ << "Received out of order handshake fragment for epoch 0";
+ return false;
+ }
+
+ // For DTLS, we return the length of just this fragment.
+ if (!parser->Read(length, 3)) {
+ return false;
+ }
+
+ // It's a fragment if the entire message is longer than what we have.
+ *last_fragment = message_length == (*length + offset);
+ return true;
+}
+
+bool TlsHandshakeFilter::HandshakeHeader::Parse(
+ TlsParser* parser, const TlsRecordHeader& record_header,
+ const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) {
+ *complete = false;
+
+ variant_ = record_header.variant();
+ version_ = record_header.version();
+ if (!parser->Read(&handshake_type_)) {
+ return false; // malformed
+ }
+
+ uint32_t length;
+ if (!ReadLength(parser, record_header, preceding_fragment.len(), &length,
+ complete)) {
+ return false;
+ }
+
+ if (!parser->Read(body, length)) {
+ return false;
+ }
+ if (preceding_fragment.len()) {
+ body->Splice(preceding_fragment, 0);
+ }
+ return true;
+}
+
+size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment(
+ DataBuffer* buffer, size_t offset, const DataBuffer& body,
+ size_t fragment_offset, size_t fragment_length) const {
+ EXPECT_TRUE(is_dtls());
+ EXPECT_GE(body.len(), fragment_offset + fragment_length);
+ offset = buffer->Write(offset, handshake_type(), 1);
+ offset = buffer->Write(offset, body.len(), 3);
+ offset = buffer->Write(offset, message_seq_, 2);
+ offset = buffer->Write(offset, fragment_offset, 3);
+ offset = buffer->Write(offset, fragment_length, 3);
+ offset =
+ buffer->Write(offset, body.data() + fragment_offset, fragment_length);
+ return offset;
+}
+
+size_t TlsHandshakeFilter::HandshakeHeader::Write(
+ DataBuffer* buffer, size_t offset, const DataBuffer& body) const {
+ if (is_dtls()) {
+ return WriteFragment(buffer, offset, body, 0U, body.len());
+ }
+ offset = buffer->Write(offset, handshake_type(), 1);
+ offset = buffer->Write(offset, body.len(), 3);
+ offset = buffer->Write(offset, body);
+ return offset;
+}
+
+PacketFilter::Action TlsHandshakeRecorder::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ // Only do this once.
+ if (buffer_.len()) {
+ return KEEP;
+ }
+
+ buffer_ = input;
+ return KEEP;
+}
+
+PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ *output = buffer_;
+ return CHANGE;
+}
+
+PacketFilter::Action TlsRecordRecorder::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ if (!filter_ || (header.content_type() == ct_)) {
+ records_.push_back({header, input});
+ }
+ return KEEP;
+}
+
+PacketFilter::Action TlsConversationRecorder::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ buffer_.Append(input);
+ return KEEP;
+}
+
+PacketFilter::Action TlsHeaderRecorder::FilterRecord(const TlsRecordHeader& hdr,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ headers_.push_back(hdr);
+ return KEEP;
+}
+
+const TlsRecordHeader* TlsHeaderRecorder::header(size_t index) {
+ if (index > headers_.size() + 1) {
+ return nullptr;
+ }
+ return &headers_[index];
+}
+
+PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
+ DataBuffer* output) {
+ DataBuffer in(input);
+ bool changed = false;
+ for (auto it = filters_.begin(); it != filters_.end(); ++it) {
+ PacketFilter::Action action = (*it)->Process(in, output);
+ if (action == DROP) {
+ return DROP;
+ }
+
+ if (action == CHANGE) {
+ in = *output;
+ changed = true;
+ }
+ }
+ return changed ? CHANGE : KEEP;
+}
+
+bool FindClientHelloExtensions(TlsParser* parser, const TlsVersioned& header) {
+ if (!parser->Skip(2 + 32)) { // version + random
+ return false;
+ }
+ if (!parser->SkipVariable(1)) { // session ID
+ return false;
+ }
+ if (header.is_dtls() && !parser->SkipVariable(1)) { // DTLS cookie
+ return false;
+ }
+ if (!parser->SkipVariable(2)) { // cipher suites
+ return false;
+ }
+ if (!parser->SkipVariable(1)) { // compression methods
+ return false;
+ }
+ return true;
+}
+
+bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) {
+ uint32_t vtmp;
+ if (!parser->Read(&vtmp, 2)) {
+ return false;
+ }
+ uint16_t version = static_cast<uint16_t>(vtmp);
+ if (!parser->Skip(32)) { // random
+ return false;
+ }
+ if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) {
+ if (!parser->SkipVariable(1)) { // session ID
+ return false;
+ }
+ }
+ if (!parser->Skip(2)) { // cipher suite
+ return false;
+ }
+ if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) {
+ if (!parser->Skip(1)) { // compression method
+ return false;
+ }
+ }
+ return true;
+}
+
+bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) {
+ return true;
+}
+
+static bool FindCertReqExtensions(TlsParser* parser,
+ const TlsVersioned& header) {
+ if (!parser->SkipVariable(1)) { // request context
+ return false;
+ }
+ return true;
+}
+
+// Only look at the EE cert for this one.
+static bool FindCertificateExtensions(TlsParser* parser,
+ const TlsVersioned& header) {
+ if (!parser->SkipVariable(1)) { // request context
+ return false;
+ }
+ if (!parser->Skip(3)) { // length of certificate list
+ return false;
+ }
+ if (!parser->SkipVariable(3)) { // ASN1Cert
+ return false;
+ }
+ return true;
+}
+
+static bool FindNewSessionTicketExtensions(TlsParser* parser,
+ const TlsVersioned& header) {
+ if (!parser->Skip(8)) { // lifetime, age add
+ return false;
+ }
+ if (!parser->SkipVariable(1)) { // ticket_nonce
+ return false;
+ }
+ if (!parser->SkipVariable(2)) { // ticket
+ return false;
+ }
+ return true;
+}
+
+static const std::map<uint16_t, TlsExtensionFinder> kExtensionFinders = {
+ {kTlsHandshakeClientHello, FindClientHelloExtensions},
+ {kTlsHandshakeServerHello, FindServerHelloExtensions},
+ {kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions},
+ {kTlsHandshakeCertificateRequest, FindCertReqExtensions},
+ {kTlsHandshakeCertificate, FindCertificateExtensions},
+ {kTlsHandshakeNewSessionTicket, FindNewSessionTicketExtensions}};
+
+bool TlsExtensionFilter::FindExtensions(TlsParser* parser,
+ const HandshakeHeader& header) {
+ auto it = kExtensionFinders.find(header.handshake_type());
+ if (it == kExtensionFinders.end()) {
+ return false;
+ }
+ return (it->second)(parser, header);
+}
+
+PacketFilter::Action TlsExtensionFilter::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ TlsParser parser(input);
+ if (!FindExtensions(&parser, header)) {
+ return KEEP;
+ }
+ return FilterExtensions(&parser, input, output);
+}
+
+PacketFilter::Action TlsExtensionFilter::FilterExtensions(
+ TlsParser* parser, const DataBuffer& input, DataBuffer* output) {
+ size_t length_offset = parser->consumed();
+ uint32_t all_extensions;
+ if (!parser->Read(&all_extensions, 2)) {
+ return KEEP; // no extensions, odd but OK
+ }
+ if (all_extensions != parser->remaining()) {
+ return KEEP; // malformed
+ }
+
+ bool changed = false;
+
+ // Write out the start of the message.
+ output->Allocate(input.len());
+ size_t offset = output->Write(0, input.data(), parser->consumed());
+
+ while (parser->remaining()) {
+ uint32_t extension_type;
+ if (!parser->Read(&extension_type, 2)) {
+ return KEEP; // malformed
+ }
+
+ DataBuffer extension;
+ if (!parser->ReadVariable(&extension, 2)) {
+ return KEEP; // malformed
+ }
+
+ DataBuffer filtered;
+ PacketFilter::Action action =
+ FilterExtension(extension_type, extension, &filtered);
+ if (action == DROP) {
+ changed = true;
+ std::cerr << "extension drop: " << extension << std::endl;
+ continue;
+ }
+
+ const DataBuffer* source = &extension;
+ if (action == CHANGE) {
+ EXPECT_GT(0x10000U, filtered.len());
+ changed = true;
+ std::cerr << "extension old: " << extension << std::endl;
+ std::cerr << "extension new: " << filtered << std::endl;
+ source = &filtered;
+ }
+
+ // Write out extension.
+ offset = output->Write(offset, extension_type, 2);
+ offset = output->Write(offset, source->len(), 2);
+ if (source->len() > 0) {
+ offset = output->Write(offset, *source);
+ }
+ }
+ output->Truncate(offset);
+
+ if (changed) {
+ size_t newlen = output->len() - length_offset - 2;
+ EXPECT_GT(0x10000U, newlen);
+ if (newlen >= 0x10000) {
+ return KEEP; // bad: size increased too much
+ }
+ output->Write(length_offset, newlen, 2);
+ return CHANGE;
+ }
+ return KEEP;
+}
+
+PacketFilter::Action TlsExtensionOrderCapture::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ order.push_back(extension_type);
+ return KEEP;
+}
+
+PacketFilter::Action TlsExtensionCapture::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ if (extension_type == extension_ && (last_ || !captured_)) {
+ data_.Assign(input);
+ captured_ = true;
+ }
+ return KEEP;
+}
+
+PacketFilter::Action TlsExtensionReplacer::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ if (extension_type != extension_) {
+ return KEEP;
+ }
+
+ *output = data_;
+ return CHANGE;
+}
+
+PacketFilter::Action TlsExtensionResizer::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ if (extension_type != extension_) {
+ return KEEP;
+ }
+
+ if (input.len() <= length_) {
+ DataBuffer buf(length_ - input.len());
+ output->Append(buf);
+ return CHANGE;
+ }
+
+ output->Assign(input.data(), length_);
+ return CHANGE;
+}
+
+PacketFilter::Action TlsExtensionAppender::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ TlsParser parser(input);
+ if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
+ return KEEP;
+ }
+ *output = input;
+
+ // Increase the length of the extensions block.
+ if (!UpdateLength(output, parser.consumed(), 2)) {
+ return KEEP;
+ }
+
+ // Extensions in Certificate are nested twice. Increase the size of the
+ // certificate list.
+ if (header.handshake_type() == kTlsHandshakeCertificate) {
+ TlsParser p2(input);
+ if (!p2.SkipVariable(1)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+ if (!UpdateLength(output, p2.consumed(), 3)) {
+ return KEEP;
+ }
+ }
+
+ size_t offset = output->len();
+ offset = output->Write(offset, extension_, 2);
+ WriteVariable(output, offset, data_, 2);
+
+ return CHANGE;
+}
+
+bool TlsExtensionAppender::UpdateLength(DataBuffer* output, size_t offset,
+ size_t size) {
+ uint32_t len;
+ if (!output->Read(offset, size, &len)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ len += 4 + data_.len();
+ output->Write(offset, len, size);
+ return true;
+}
+
+PacketFilter::Action TlsExtensionDropper::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ if (extension_type == extension_) {
+ return DROP;
+ }
+ return KEEP;
+}
+
+PacketFilter::Action TlsExtensionDamager::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ if (extension_type != extension_) {
+ return KEEP;
+ }
+
+ *output = input;
+ output->data()[index_] += 73; // Increment selected for maximum damage
+ return CHANGE;
+}
+
+PacketFilter::Action TlsExtensionInjector::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ TlsParser parser(input);
+ if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
+ return KEEP;
+ }
+ size_t offset = parser.consumed();
+
+ *output = input;
+
+ // Increase the size of the extensions.
+ uint16_t ext_len;
+ memcpy(&ext_len, output->data() + offset, sizeof(ext_len));
+ ext_len = htons(ntohs(ext_len) + data_.len() + 4);
+ memcpy(output->data() + offset, &ext_len, sizeof(ext_len));
+
+ // Insert the extension type and length.
+ DataBuffer type_length;
+ type_length.Allocate(4);
+ type_length.Write(0, extension_, 2);
+ type_length.Write(2, data_.len(), 2);
+ output->Splice(type_length, offset + 2);
+
+ // Insert the payload.
+ if (data_.len() > 0) {
+ output->Splice(data_, offset + 6);
+ }
+
+ return CHANGE;
+}
+
+PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& body,
+ DataBuffer* out) {
+ if (counter_++ == record_) {
+ DataBuffer buf;
+ header.Write(&buf, 0, body);
+ agent()->SendDirect(buf);
+ dest_.lock()->Handshake();
+ func_();
+ return DROP;
+ }
+
+ return KEEP;
+}
+
+PacketFilter::Action TlsClientHelloVersionChanger::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ EXPECT_EQ(SECSuccess,
+ SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd()));
+ return KEEP;
+}
+
+PacketFilter::Action SelectiveDropFilter::Filter(const DataBuffer& input,
+ DataBuffer* output) {
+ if (counter_ >= 32) {
+ return KEEP;
+ }
+ return ((1 << counter_++) & pattern_) ? DROP : KEEP;
+}
+
+PacketFilter::Action SelectiveRecordDropFilter::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& data,
+ DataBuffer* changed) {
+ if (counter_ >= 32) {
+ return KEEP;
+ }
+ return ((1 << counter_++) & pattern_) ? DROP : KEEP;
+}
+
+/* static */ uint32_t SelectiveRecordDropFilter::ToPattern(
+ std::initializer_list<size_t> records) {
+ uint32_t pattern = 0;
+ for (auto it = records.begin(); it != records.end(); ++it) {
+ EXPECT_GT(32U, *it);
+ assert(*it < 32U);
+ pattern |= 1 << *it;
+ }
+ return pattern;
+}
+
+PacketFilter::Action TlsMessageVersionSetter::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+ output->Write(0, version_, 2);
+ return CHANGE;
+}
+
+PacketFilter::Action SelectedCipherSuiteReplacer::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+ uint32_t temp = 0;
+ EXPECT_TRUE(input.Read(0, 2, &temp));
+ EXPECT_EQ(header.version(), NormalizeTlsVersion(temp));
+ // Cipher suite is after version(2), random(32)
+ // and [legacy_]session_id(<0..32>).
+ size_t pos = 34;
+ EXPECT_TRUE(input.Read(pos, 1, &temp));
+ pos += 1 + temp;
+
+ output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2);
+ return CHANGE;
+}
+
+PacketFilter::Action ServerHelloRandomChanger::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+ uint32_t temp = 0;
+ size_t pos = 30;
+ EXPECT_TRUE(input.Read(pos, 2, &temp));
+ output->Write(pos, (temp ^ 0xffff), 2);
+ return CHANGE;
+}
+
+PacketFilter::Action ClientHelloPreambleCapture::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ EXPECT_TRUE(header.handshake_type() == kTlsHandshakeClientHello);
+
+ if (captured_) {
+ return KEEP;
+ }
+ captured_ = true;
+
+ DataBuffer temp;
+ TlsParser parser(input);
+ EXPECT_TRUE(parser.Read(&temp, 2 + 32)); // Version + Random
+ EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Session ID
+ if (is_dtls_agent()) {
+ EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Cookie
+ }
+ EXPECT_TRUE(parser.ReadVariable(&temp, 2)); // Ciphersuites
+ EXPECT_TRUE(parser.ReadVariable(&temp, 1)); // Compression
+
+ // Copy the preamble into a new buffer
+ data_ = input;
+ data_.Truncate(parser.consumed());
+
+ return KEEP;
+}
+
+PacketFilter::Action ClientHelloCiphersuiteCapture::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ EXPECT_TRUE(header.handshake_type() == kTlsHandshakeClientHello);
+
+ if (captured_) {
+ return KEEP;
+ }
+ captured_ = true;
+
+ TlsParser parser(input);
+ EXPECT_TRUE(parser.Skip(2 + 32)); // Version + Random
+ EXPECT_TRUE(parser.SkipVariable(1)); // Session ID
+ if (is_dtls_agent()) {
+ EXPECT_TRUE(parser.SkipVariable(1)); // Cookie
+ }
+
+ EXPECT_TRUE(parser.ReadVariable(&data_, 2)); // Ciphersuites
+
+ return KEEP;
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h
new file mode 100644
index 0000000000..7c45aab12f
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_filter.h
@@ -0,0 +1,1013 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_filter_h_
+#define tls_filter_h_
+
+#include <functional>
+#include <memory>
+#include <set>
+#include <vector>
+#include "pk11pub.h"
+#include "sslt.h"
+#include "sslproto.h"
+#include "test_io.h"
+#include "tls_agent.h"
+#include "tls_parser.h"
+#include "tls_protect.h"
+
+extern "C" {
+#include "libssl_internals.h"
+}
+
+namespace nss_test {
+
+class TlsCipherSpec;
+
+class TlsSendCipherSpecCapturer {
+ public:
+ TlsSendCipherSpecCapturer(const std::shared_ptr<TlsAgent>& agent)
+ : agent_(agent), send_cipher_specs_() {
+ EXPECT_EQ(SECSuccess,
+ SSL_SecretCallback(agent_->ssl_fd(), SecretCallback, this));
+ }
+
+ std::shared_ptr<TlsCipherSpec> spec(size_t i) {
+ if (i >= send_cipher_specs_.size()) {
+ return nullptr;
+ }
+ return send_cipher_specs_[i];
+ }
+
+ private:
+ static void SecretCallback(PRFileDesc* fd, PRUint16 epoch,
+ SSLSecretDirection dir, PK11SymKey* secret,
+ void* arg) {
+ auto self = static_cast<TlsSendCipherSpecCapturer*>(arg);
+ std::cerr << self->agent_->role_str() << ": capture " << dir
+ << " secret for epoch " << epoch << std::endl;
+
+ if (dir == ssl_secret_read) {
+ return;
+ }
+
+ SSLPreliminaryChannelInfo preinfo;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetPreliminaryChannelInfo(self->agent_->ssl_fd(), &preinfo,
+ sizeof(preinfo)));
+ EXPECT_EQ(sizeof(preinfo), preinfo.length);
+ EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite);
+
+ // Check the version:
+ EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version);
+ ASSERT_GE(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion);
+
+ SSLCipherSuiteInfo cipherinfo;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetCipherSuiteInfo(preinfo.cipherSuite, &cipherinfo,
+ sizeof(cipherinfo)));
+ EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length);
+
+ auto spec = std::make_shared<TlsCipherSpec>(true, epoch);
+ EXPECT_TRUE(spec->SetKeys(&cipherinfo, secret));
+ self->send_cipher_specs_.push_back(spec);
+ }
+
+ std::shared_ptr<TlsAgent> agent_;
+ std::vector<std::shared_ptr<TlsCipherSpec>> send_cipher_specs_;
+};
+
+class TlsVersioned {
+ public:
+ TlsVersioned() : variant_(ssl_variant_stream), version_(0) {}
+ TlsVersioned(SSLProtocolVariant var, uint16_t ver)
+ : variant_(var), version_(ver) {}
+
+ bool is_dtls() const { return variant_ == ssl_variant_datagram; }
+ SSLProtocolVariant variant() const { return variant_; }
+ uint16_t version() const { return version_; }
+
+ void WriteStream(std::ostream& stream) const;
+
+ protected:
+ SSLProtocolVariant variant_;
+ uint16_t version_;
+};
+
+class TlsRecordHeader : public TlsVersioned {
+ public:
+ TlsRecordHeader()
+ : TlsVersioned(),
+ content_type_(0),
+ guess_seqno_(0),
+ seqno_is_masked_(false),
+ sequence_number_(0),
+ header_() {}
+ TlsRecordHeader(SSLProtocolVariant var, uint16_t ver, uint8_t ct,
+ uint64_t seqno)
+ : TlsVersioned(var, ver),
+ content_type_(ct),
+ guess_seqno_(0),
+ seqno_is_masked_(false),
+ sequence_number_(seqno),
+ header_(),
+ sn_mask_() {}
+
+ bool is_protected() const {
+ // *TLS < 1.3
+ if (version() < SSL_LIBRARY_VERSION_TLS_1_3 &&
+ content_type() == ssl_ct_application_data) {
+ return true;
+ }
+
+ // TLS 1.3
+ if (!is_dtls() && version() >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ content_type() == ssl_ct_application_data) {
+ return true;
+ }
+
+ // DTLS 1.3
+ return is_dtls13_ciphertext();
+ }
+
+ uint8_t content_type() const { return content_type_; }
+ uint16_t epoch() const {
+ return static_cast<uint16_t>(sequence_number_ >> 48);
+ }
+ uint64_t sequence_number() const { return sequence_number_; }
+ void sequence_number(uint64_t seqno) { sequence_number_ = seqno; }
+ const DataBuffer& sn_mask() const { return sn_mask_; }
+ bool is_dtls13_ciphertext() const {
+ return is_dtls() && (version() >= SSL_LIBRARY_VERSION_TLS_1_3) &&
+ (content_type() & kCtDtlsCiphertextMask) == kCtDtlsCiphertext;
+ }
+
+ size_t header_length() const;
+ const DataBuffer& header() const { return header_; }
+
+ bool MaskSequenceNumber();
+ bool MaskSequenceNumber(const DataBuffer& mask_buf);
+
+ // Parse the header; return true if successful; body in an outparam if OK.
+ bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser,
+ DataBuffer* body);
+ // Write the header and body to a buffer at the given offset.
+ // Return the offset of the end of the write.
+ size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
+ size_t WriteHeader(DataBuffer* buffer, size_t offset, size_t body_len) const;
+
+ private:
+ static uint64_t RecoverSequenceNumber(uint64_t guess_seqno, uint32_t partial,
+ size_t partial_bits);
+ uint64_t ParseSequenceNumber(uint64_t expected, uint64_t raw,
+ size_t seq_no_bits, size_t epoch_bits);
+
+ uint8_t content_type_;
+ uint64_t guess_seqno_;
+ bool seqno_is_masked_;
+ uint64_t sequence_number_;
+ DataBuffer header_;
+ DataBuffer sn_mask_;
+};
+
+struct TlsRecord {
+ const TlsRecordHeader header;
+ const DataBuffer buffer;
+};
+
+// Make a filter and install it on a TlsAgent.
+template <class T, typename... Args>
+inline std::shared_ptr<T> MakeTlsFilter(const std::shared_ptr<TlsAgent>& agent,
+ Args&&... args) {
+ auto filter = std::make_shared<T>(agent, std::forward<Args>(args)...);
+ agent->SetFilter(filter);
+ return filter;
+}
+
+// Abstract filter that operates on entire (D)TLS records.
+class TlsRecordFilter : public PacketFilter {
+ public:
+ TlsRecordFilter(const std::shared_ptr<TlsAgent>& a);
+
+ std::shared_ptr<TlsAgent> agent() const { return agent_.lock(); }
+
+ // External interface. Overrides PacketFilter.
+ PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output);
+
+ // Report how many packets were altered by the filter.
+ size_t filtered_packets() const { return count_; }
+
+ // Enable decryption. This only works properly for TLS 1.3 and above.
+ // Enabling it for lower version tests will cause undefined
+ // behavior.
+ void EnableDecryption();
+ bool decrypting() const { return decrypting_; };
+ bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText,
+ uint16_t* protection_epoch, uint8_t* inner_content_type,
+ DataBuffer* plaintext, TlsRecordHeader* out_header);
+ bool Protect(TlsCipherSpec& protection_spec, const TlsRecordHeader& header,
+ uint8_t inner_content_type, const DataBuffer& plaintext,
+ DataBuffer* ciphertext, TlsRecordHeader* out_header,
+ size_t padding = 0);
+
+ protected:
+ // There are two filter functions which can be overriden. Both are
+ // called with the header and the record but the outer one is called
+ // with a raw pointer to let you write into the buffer and lets you
+ // do anything with this section of the stream. The inner one
+ // just lets you change the record contents. By default, the
+ // outer one calls the inner one, so if you override the outer
+ // one, the inner one is never called unless you call it yourself.
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& record,
+ size_t* offset, DataBuffer* output);
+
+ // The record filter receives the record contentType, version and DTLS
+ // sequence number (which is zero for TLS), plus the existing record payload.
+ // It returns an action (KEEP, CHANGE, DROP). It writes to the `changed`
+ // outparam with the new record contents if it chooses to CHANGE the record.
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) {
+ return KEEP;
+ }
+
+ bool is_dtls_agent() const;
+ bool is_dtls13() const;
+ bool is_dtls13_ciphertext(uint8_t ct) const;
+ TlsCipherSpec& spec(uint16_t epoch);
+
+ private:
+ static void SecretCallback(PRFileDesc* fd, PRUint16 epoch,
+ SSLSecretDirection dir, PK11SymKey* secret,
+ void* arg);
+
+ std::weak_ptr<TlsAgent> agent_;
+ size_t count_ = 0;
+ std::vector<TlsCipherSpec> cipher_specs_;
+ bool decrypting_ = false;
+};
+
+inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) {
+ v.WriteStream(stream);
+ return stream;
+}
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const TlsRecordHeader& hdr) {
+ hdr.WriteStream(stream);
+ stream << ' ';
+ switch (hdr.content_type()) {
+ case ssl_ct_change_cipher_spec:
+ stream << "CCS";
+ break;
+ case ssl_ct_alert:
+ stream << "Alert";
+ break;
+ case ssl_ct_handshake:
+ stream << "Handshake";
+ break;
+ case ssl_ct_application_data:
+ stream << "Data";
+ break;
+ case ssl_ct_ack:
+ stream << "ACK";
+ break;
+ default:
+ stream << '<' << static_cast<int>(hdr.content_type()) << '>';
+ break;
+ }
+ return stream << ' ' << std::hex << hdr.sequence_number() << std::dec;
+}
+
+// Abstract filter that operates on handshake messages rather than records.
+// This assumes that the handshake messages are written in a block as entire
+// records and that they don't span records or anything crazy like that.
+class TlsHandshakeFilter : public TlsRecordFilter {
+ public:
+ TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a), handshake_types_(), preceding_fragment_() {}
+ TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a,
+ const std::set<uint8_t>& types)
+ : TlsRecordFilter(a), handshake_types_(types), preceding_fragment_() {}
+
+ // This filter can be set to be selective based on handshake message type. If
+ // this function isn't used (or the set is empty), then all handshake messages
+ // will be filtered.
+ void SetHandshakeTypes(const std::set<uint8_t>& types) {
+ handshake_types_ = types;
+ }
+
+ class HandshakeHeader : public TlsVersioned {
+ public:
+ HandshakeHeader() : TlsVersioned(), handshake_type_(0), message_seq_(0) {}
+
+ uint8_t handshake_type() const { return handshake_type_; }
+ bool Parse(TlsParser* parser, const TlsRecordHeader& record_header,
+ const DataBuffer& preceding_fragment, DataBuffer* body,
+ bool* complete);
+ size_t Write(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body) const;
+ size_t WriteFragment(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body, size_t fragment_offset,
+ size_t fragment_length) const;
+
+ private:
+ // Reads the length from the record header.
+ // This also reads the DTLS fragment information and checks it.
+ bool ReadLength(TlsParser* parser, const TlsRecordHeader& header,
+ uint32_t expected_offset, uint32_t* length,
+ bool* last_fragment);
+
+ uint8_t handshake_type_;
+ uint16_t message_seq_;
+ // fragment_offset is always zero in these tests.
+ };
+
+ protected:
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) = 0;
+
+ private:
+ bool IsFilteredType(const HandshakeHeader& header,
+ const DataBuffer& handshake);
+
+ std::set<uint8_t> handshake_types_;
+ DataBuffer preceding_fragment_;
+};
+
+// Make a copy of the first instance of a handshake message.
+class TlsHandshakeRecorder : public TlsHandshakeFilter {
+ public:
+ TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a,
+ uint8_t handshake_type)
+ : TlsHandshakeFilter(a, {handshake_type}), buffer_() {}
+ TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a,
+ const std::set<uint8_t>& handshake_types)
+ : TlsHandshakeFilter(a, handshake_types), buffer_() {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ void Reset() { buffer_.Truncate(0); }
+
+ const DataBuffer& buffer() const { return buffer_; }
+
+ private:
+ DataBuffer buffer_;
+};
+
+// Replace all instances of a handshake message.
+class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
+ public:
+ TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& a,
+ uint8_t handshake_type,
+ const DataBuffer& replacement)
+ : TlsHandshakeFilter(a, {handshake_type}), buffer_(replacement) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ DataBuffer buffer_;
+};
+
+// Make a copy of each record of a given type.
+class TlsRecordRecorder : public TlsRecordFilter {
+ public:
+ TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a, uint8_t ct)
+ : TlsRecordFilter(a), filter_(true), ct_(ct), records_() {}
+ TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a),
+ filter_(false),
+ ct_(ssl_ct_handshake), // dummy (<optional> is C++14)
+ records_() {}
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ size_t count() const { return records_.size(); }
+ void Clear() { records_.clear(); }
+
+ const TlsRecord& record(size_t i) const { return records_[i]; }
+
+ private:
+ bool filter_;
+ uint8_t ct_;
+ std::vector<TlsRecord> records_;
+};
+
+// Make a copy of the complete conversation.
+class TlsConversationRecorder : public TlsRecordFilter {
+ public:
+ TlsConversationRecorder(const std::shared_ptr<TlsAgent>& a,
+ DataBuffer& buffer)
+ : TlsRecordFilter(a), buffer_(buffer) {}
+
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ DataBuffer buffer_;
+};
+
+// Make a copy of the records
+class TlsHeaderRecorder : public TlsRecordFilter {
+ public:
+ TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& a) : TlsRecordFilter(a) {}
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+ const TlsRecordHeader* header(size_t index);
+
+ private:
+ std::vector<TlsRecordHeader> headers_;
+};
+
+typedef std::initializer_list<std::shared_ptr<PacketFilter>>
+ ChainedPacketFilterInit;
+
+// Runs multiple packet filters in series.
+class ChainedPacketFilter : public PacketFilter {
+ public:
+ ChainedPacketFilter() {}
+ ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters)
+ : filters_(filters.begin(), filters.end()) {}
+ ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {}
+ virtual ~ChainedPacketFilter() {}
+
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output);
+
+ // Takes ownership of the filter.
+ void Add(std::shared_ptr<PacketFilter> filter) { filters_.push_back(filter); }
+
+ private:
+ std::vector<std::shared_ptr<PacketFilter>> filters_;
+};
+
+typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
+ TlsExtensionFinder;
+
+class TlsExtensionFilter : public TlsHandshakeFilter {
+ public:
+ TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a,
+ {kTlsHandshakeClientHello, kTlsHandshakeServerHello,
+ kTlsHandshakeHelloRetryRequest,
+ kTlsHandshakeEncryptedExtensions}) {}
+
+ TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a,
+ const std::set<uint8_t>& types)
+ : TlsHandshakeFilter(a, types) {}
+
+ static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) = 0;
+
+ private:
+ PacketFilter::Action FilterExtensions(TlsParser* parser,
+ const DataBuffer& input,
+ DataBuffer* output);
+};
+
+class TlsExtensionOrderCapture : public TlsExtensionFilter {
+ public:
+ TlsExtensionOrderCapture(const std::shared_ptr<TlsAgent>& a, uint8_t message)
+ : TlsExtensionFilter(a, {message}){};
+
+ std::vector<uint16_t> order;
+
+ protected:
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+};
+
+class TlsExtensionCapture : public TlsExtensionFilter {
+ public:
+ TlsExtensionCapture(const std::shared_ptr<TlsAgent>& a, uint16_t ext,
+ bool last = false)
+ : TlsExtensionFilter(a),
+ extension_(ext),
+ captured_(false),
+ last_(last),
+ data_() {}
+
+ const DataBuffer& extension() const { return data_; }
+ bool captured() const { return captured_; }
+
+ protected:
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ const uint16_t extension_;
+ bool captured_;
+ bool last_;
+ DataBuffer data_;
+};
+
+class TlsExtensionReplacer : public TlsExtensionFilter {
+ public:
+ TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& a, uint16_t extension,
+ const DataBuffer& data)
+ : TlsExtensionFilter(a), extension_(extension), data_(data) {}
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ const uint16_t extension_;
+ const DataBuffer data_;
+};
+
+class TlsExtensionResizer : public TlsExtensionFilter {
+ public:
+ TlsExtensionResizer(const std::shared_ptr<TlsAgent>& a, uint16_t extension,
+ size_t length)
+ : TlsExtensionFilter(a), extension_(extension), length_(length) {}
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ uint16_t extension_;
+ size_t length_;
+};
+
+class TlsExtensionAppender : public TlsHandshakeFilter {
+ public:
+ TlsExtensionAppender(const std::shared_ptr<TlsAgent>& a,
+ uint8_t handshake_type, uint16_t ext, DataBuffer& data)
+ : TlsHandshakeFilter(a, {handshake_type}), extension_(ext), data_(data) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ bool UpdateLength(DataBuffer* output, size_t offset, size_t size);
+
+ const uint16_t extension_;
+ const DataBuffer data_;
+};
+
+class TlsExtensionDropper : public TlsExtensionFilter {
+ public:
+ TlsExtensionDropper(const std::shared_ptr<TlsAgent>& a, uint16_t extension)
+ : TlsExtensionFilter(a), extension_(extension) {}
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer&, DataBuffer*) override;
+
+ private:
+ uint16_t extension_;
+};
+
+class TlsHandshakeDropper : public TlsHandshakeFilter {
+ public:
+ TlsHandshakeDropper(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ return DROP;
+ }
+};
+
+class TlsEncryptedHandshakeMessageReplacer : public TlsRecordFilter {
+ public:
+ TlsEncryptedHandshakeMessageReplacer(const std::shared_ptr<TlsAgent>& a,
+ uint8_t old_ct, uint8_t new_ct)
+ : TlsRecordFilter(a), old_ct_(old_ct), new_ct_(new_ct) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& record, size_t* offset,
+ DataBuffer* output) override {
+ if (header.content_type() != ssl_ct_application_data) {
+ return KEEP;
+ }
+
+ uint16_t protection_epoch = 0;
+ uint8_t inner_content_type;
+ DataBuffer plaintext;
+ TlsRecordHeader out_header;
+ if (!Unprotect(header, record, &protection_epoch, &inner_content_type,
+ &plaintext, &out_header) ||
+ !plaintext.len()) {
+ return KEEP;
+ }
+
+ if (inner_content_type != ssl_ct_handshake) {
+ return KEEP;
+ }
+
+ size_t off = 0;
+ uint32_t msg_len = 0;
+ uint32_t msg_type = 255; // Not a real message
+ do {
+ if (!plaintext.Read(off, 1, &msg_type) || msg_type == old_ct_) {
+ break;
+ }
+
+ // Increment and check next messages
+ if (!plaintext.Read(++off, 3, &msg_len)) {
+ break;
+ }
+ off += 3 + msg_len;
+ } while (msg_type != old_ct_);
+
+ if (msg_type == old_ct_) {
+ plaintext.Write(off, new_ct_, 1);
+ }
+
+ DataBuffer ciphertext;
+ bool ok = Protect(spec(protection_epoch), out_header, inner_content_type,
+ plaintext, &ciphertext, &out_header);
+ if (!ok) {
+ return KEEP;
+ }
+ *offset = out_header.Write(output, *offset, ciphertext);
+ return CHANGE;
+ }
+
+ private:
+ uint8_t old_ct_;
+ uint8_t new_ct_;
+};
+
+class TlsExtensionInjector : public TlsHandshakeFilter {
+ public:
+ TlsExtensionInjector(const std::shared_ptr<TlsAgent>& a, uint16_t ext,
+ const DataBuffer& data)
+ : TlsHandshakeFilter(a), extension_(ext), data_(data) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ const uint16_t extension_;
+ const DataBuffer data_;
+};
+
+class TlsExtensionDamager : public TlsExtensionFilter {
+ public:
+ TlsExtensionDamager(const std::shared_ptr<TlsAgent>& a, uint16_t extension,
+ size_t index)
+ : TlsExtensionFilter(a), extension_(extension), index_(index) {}
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ uint16_t extension_;
+ size_t index_;
+};
+
+typedef std::function<void(void)> VoidFunction;
+
+class AfterRecordN : public TlsRecordFilter {
+ public:
+ AfterRecordN(const std::shared_ptr<TlsAgent>& src,
+ const std::shared_ptr<TlsAgent>& dest, unsigned int record,
+ VoidFunction func)
+ : TlsRecordFilter(src),
+ dest_(dest),
+ record_(record),
+ func_(func),
+ counter_(0) {}
+
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& body,
+ DataBuffer* out) override;
+
+ private:
+ std::weak_ptr<TlsAgent> dest_;
+ unsigned int record_;
+ VoidFunction func_;
+ unsigned int counter_;
+};
+
+// When we see the ClientKeyExchange from |client|, increment the
+// ClientHelloVersion on |server|.
+class TlsClientHelloVersionChanger : public TlsHandshakeFilter {
+ public:
+ TlsClientHelloVersionChanger(const std::shared_ptr<TlsAgent>& client,
+ const std::shared_ptr<TlsAgent>& server)
+ : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}),
+ server_(server) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ std::weak_ptr<TlsAgent> server_;
+};
+
+// Damage a record.
+class TlsRecordLastByteDamager : public TlsRecordFilter {
+ public:
+ TlsRecordLastByteDamager(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ *changed = data;
+ changed->data()[changed->len() - 1]++;
+ return CHANGE;
+ }
+};
+
+// This class selectively drops complete writes. This relies on the fact that
+// writes in libssl are on record boundaries.
+class SelectiveDropFilter : public PacketFilter {
+ public:
+ SelectiveDropFilter(uint32_t pattern) : pattern_(pattern), counter_(0) {}
+
+ protected:
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ const uint32_t pattern_;
+ uint8_t counter_;
+};
+
+// This class selectively drops complete records. The difference from
+// SelectiveDropFilter is that if multiple DTLS records are in the same
+// datagram, we just drop one.
+class SelectiveRecordDropFilter : public TlsRecordFilter {
+ public:
+ SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a,
+ uint32_t pattern, bool on = true)
+ : TlsRecordFilter(a), pattern_(pattern), counter_(0) {
+ if (!on) {
+ Disable();
+ }
+ }
+ SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a,
+ std::initializer_list<size_t> records)
+ : SelectiveRecordDropFilter(a, ToPattern(records), true) {}
+
+ void Reset(uint32_t pattern) {
+ counter_ = 0;
+ PacketFilter::Enable();
+ pattern_ = pattern;
+ }
+
+ void Reset(std::initializer_list<size_t> records) {
+ Reset(ToPattern(records));
+ }
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override;
+
+ private:
+ static uint32_t ToPattern(std::initializer_list<size_t> records);
+
+ uint32_t pattern_;
+ uint8_t counter_;
+};
+
+// Set the version value in the ClientHello, ServerHello or HelloRetryRequest
+class TlsMessageVersionSetter : public TlsHandshakeFilter {
+ public:
+ TlsMessageVersionSetter(const std::shared_ptr<TlsAgent>& a, uint8_t message,
+ uint16_t version)
+ : TlsHandshakeFilter(a, {message}), version_(version) {
+ PR_ASSERT(message == kTlsHandshakeClientHello ||
+ message == kTlsHandshakeServerHello ||
+ message == kTlsHandshakeHelloRetryRequest);
+ }
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ uint16_t version_;
+};
+
+// Damages the last byte of a handshake message.
+class TlsLastByteDamager : public TlsHandshakeFilter {
+ public:
+ TlsLastByteDamager(const std::shared_ptr<TlsAgent>& a, uint8_t type)
+ : TlsHandshakeFilter(a), type_(type) {}
+ PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) override {
+ if (header.handshake_type() != type_) {
+ return KEEP;
+ }
+
+ *output = input;
+
+ output->data()[output->len() - 1]++;
+ return CHANGE;
+ }
+
+ private:
+ uint8_t type_;
+};
+
+class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
+ public:
+ SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& a,
+ uint16_t suite)
+ : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}),
+ cipher_suite_(suite) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ uint16_t cipher_suite_;
+};
+
+class ClientHelloPreambleCapture : public TlsHandshakeFilter {
+ public:
+ ClientHelloPreambleCapture(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}),
+ captured_(false),
+ data_() {}
+
+ const DataBuffer& contents() const { return data_; }
+ bool captured() const { return captured_; }
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ bool captured_;
+ DataBuffer data_;
+};
+
+class ClientHelloCiphersuiteCapture : public TlsHandshakeFilter {
+ public:
+ ClientHelloCiphersuiteCapture(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}),
+ captured_(false),
+ data_() {}
+
+ const DataBuffer& contents() const { return data_; }
+ bool captured() const { return captured_; }
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ bool captured_;
+ DataBuffer data_;
+};
+
+class ServerHelloRandomChanger : public TlsHandshakeFilter {
+ public:
+ ServerHelloRandomChanger(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+};
+
+// Replace SignatureAndHashAlgorithm of a SKE.
+class DHEServerKEXSigAlgReplacer : public TlsHandshakeFilter {
+ public:
+ DHEServerKEXSigAlgReplacer(const std::shared_ptr<TlsAgent>& server,
+ uint16_t sig_scheme)
+ : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}),
+ sig_scheme_(sig_scheme) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+
+ uint32_t len;
+ uint32_t idx = 0;
+ EXPECT_TRUE(output->Read(idx, 2, &len));
+ idx += 2 + len;
+ EXPECT_TRUE(output->Read(idx, 2, &len));
+ idx += 2 + len;
+ EXPECT_TRUE(output->Read(idx, 2, &len));
+ idx += 2 + len;
+ output->Write(idx, sig_scheme_, 2);
+
+ return CHANGE;
+ }
+
+ private:
+ uint16_t sig_scheme_;
+};
+
+// Replace SignatureAndHashAlgorithm of a SKE.
+class ECCServerKEXSigAlgReplacer : public TlsHandshakeFilter {
+ public:
+ ECCServerKEXSigAlgReplacer(const std::shared_ptr<TlsAgent>& server,
+ uint16_t sig_scheme)
+ : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}),
+ sig_scheme_(sig_scheme) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+
+ uint32_t point_len;
+ EXPECT_TRUE(output->Read(3, 1, &point_len));
+ output->Write(4 + point_len, sig_scheme_, 2);
+
+ return CHANGE;
+ }
+
+ private:
+ uint16_t sig_scheme_;
+};
+
+// Replace NamedCurve of a ECDHE SKE.
+class ECCServerKEXNamedCurveReplacer : public TlsHandshakeFilter {
+ public:
+ ECCServerKEXNamedCurveReplacer(const std::shared_ptr<TlsAgent>& server,
+ uint16_t curve_name)
+ : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}),
+ curve_name_(curve_name) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+
+ uint32_t curve_type;
+ EXPECT_TRUE(output->Read(0, 1, &curve_type));
+ EXPECT_EQ(curve_type, ec_type_named);
+ output->Write(1, curve_name_, 2);
+
+ return CHANGE;
+ }
+
+ private:
+ uint16_t curve_name_;
+};
+
+// Replaces the signature scheme in a CertificateVerify message.
+class TlsReplaceSignatureSchemeFilter : public TlsHandshakeFilter {
+ public:
+ TlsReplaceSignatureSchemeFilter(const std::shared_ptr<TlsAgent>& a,
+ uint16_t scheme)
+ : TlsHandshakeFilter(a, {kTlsHandshakeCertificateVerify}),
+ scheme_(scheme) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+ output->Write(0, scheme_, 2);
+ return CHANGE;
+ }
+
+ private:
+ uint16_t scheme_;
+};
+
+} // namespace nss_test
+
+#endif
diff --git a/security/nss/gtests/ssl_gtest/tls_grease_unittest.cc b/security/nss/gtests/ssl_gtest/tls_grease_unittest.cc
new file mode 100644
index 0000000000..c89c41be04
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_grease_unittest.cc
@@ -0,0 +1,878 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+#include "util.h"
+
+namespace nss_test {
+
+const uint8_t kTlsGreaseExtensionMessages[] = {kTlsHandshakeEncryptedExtensions,
+ kTlsHandshakeCertificate};
+
+const uint16_t kTlsGreaseValues[] = {
+ 0x0a0a, 0x1a1a, 0x2a2a, 0x3a3a, 0x4a4a, 0x5a5a, 0x6a6a, 0x7a7a,
+ 0x8a8a, 0x9a9a, 0xaaaa, 0xbaba, 0xcaca, 0xdada, 0xeaea, 0xfafa};
+
+const uint8_t kTlsGreasePskValues[] = {0x0B, 0x2A, 0x49, 0x68,
+ 0x87, 0xA6, 0xC5, 0xE4};
+
+size_t countGreaseInBuffer(const DataBuffer& list) {
+ if (!list.len()) {
+ return 0;
+ }
+ size_t occurrence = 0;
+ for (uint16_t greaseVal : kTlsGreaseValues) {
+ for (size_t i = 0; i < (list.len() - 1); i += 2) {
+ uint16_t sample = list.data()[i + 1] + (list.data()[i] << 8);
+ if (greaseVal == sample) {
+ occurrence++;
+ }
+ }
+ }
+ return occurrence;
+}
+
+class GreasePresenceAbsenceTestBase : public TlsConnectTestBase {
+ public:
+ GreasePresenceAbsenceTestBase(SSLProtocolVariant variant, uint16_t version,
+ bool shouldGrease)
+ : TlsConnectTestBase(variant, version), set_grease_(shouldGrease){};
+
+ void SetupGrease() {
+ EnsureTlsSetup();
+ ASSERT_EQ(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, set_grease_),
+ SECSuccess);
+ ASSERT_EQ(SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_GREASE, set_grease_),
+ SECSuccess);
+ }
+
+ bool expectGrease() {
+ return set_grease_ && version_ >= SSL_LIBRARY_VERSION_TLS_1_3;
+ }
+
+ void checkGreasePresence(const int ifEnabled, const int ifDisabled,
+ const DataBuffer& buffer) {
+ size_t expected = expectGrease() ? size_t(ifEnabled) : size_t(ifDisabled);
+ EXPECT_EQ(expected, countGreaseInBuffer(buffer));
+ }
+
+ private:
+ bool set_grease_;
+};
+
+class GreasePresenceAbsenceTestAllVersions
+ : public GreasePresenceAbsenceTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t, bool>> {
+ public:
+ GreasePresenceAbsenceTestAllVersions()
+ : GreasePresenceAbsenceTestBase(std::get<0>(GetParam()),
+ std::get<1>(GetParam()),
+ std::get<2>(GetParam())){};
+};
+
+// Varies stream/datagram, TLS Version and whether GREASE is enabled
+INSTANTIATE_TEST_SUITE_P(GreaseTests, GreasePresenceAbsenceTestAllVersions,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV11Plus,
+ ::testing::Values(true, false)));
+
+// Varies whether GREASE is enabled for TLS13 only
+class GreasePresenceAbsenceTestTlsStream13
+ : public GreasePresenceAbsenceTestBase,
+ public ::testing::WithParamInterface<bool> {
+ public:
+ GreasePresenceAbsenceTestTlsStream13()
+ : GreasePresenceAbsenceTestBase(
+ ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3, GetParam()){};
+};
+
+INSTANTIATE_TEST_SUITE_P(GreaseTests, GreasePresenceAbsenceTestTlsStream13,
+ ::testing::Values(true, false));
+
+// These tests check for the presence / absence of GREASE values in the various
+// positions that we are permitted to add them. For positions which existed in
+// prior versions of TLS, we check that enabling GREASE is only effective when
+// negotiating TLS1.3 or higher and that disabling GREASE results in the absence
+// of any GREASE values.
+// For positions that specific to TLS1.3, we only check that enabling/disabling
+// GREASE results in the correct presence/absence of the GREASE value.
+
+TEST_P(GreasePresenceAbsenceTestAllVersions, ClientGreaseCiphersuites) {
+ SetupGrease();
+
+ auto ch1 = MakeTlsFilter<ClientHelloCiphersuiteCapture>(client_);
+ Connect();
+ EXPECT_TRUE(ch1->captured());
+
+ checkGreasePresence(1, 0, ch1->contents());
+}
+
+TEST_P(GreasePresenceAbsenceTestAllVersions, ClientGreaseNamedGroups) {
+ SetupGrease();
+
+ auto ch1 =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
+ Connect();
+ EXPECT_TRUE(ch1->captured());
+
+ checkGreasePresence(1, 0, ch1->extension());
+}
+
+TEST_P(GreasePresenceAbsenceTestAllVersions, ClientGreaseKeyShare) {
+ SetupGrease();
+
+ auto ch1 =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
+ Connect();
+ EXPECT_TRUE((version_ >= SSL_LIBRARY_VERSION_TLS_1_3) == ch1->captured());
+
+ checkGreasePresence(1, 0, ch1->extension());
+}
+
+TEST_P(GreasePresenceAbsenceTestAllVersions, ClientGreaseSigAlg) {
+ SetupGrease();
+
+ auto ch1 =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn);
+ Connect();
+ EXPECT_TRUE((version_ >= SSL_LIBRARY_VERSION_TLS_1_2) == ch1->captured());
+
+ checkGreasePresence(1, 0, ch1->extension());
+}
+
+TEST_P(GreasePresenceAbsenceTestAllVersions, ClientGreaseSupportedVersions) {
+ SetupGrease();
+
+ auto ch1 = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_supported_versions_xtn);
+ Connect();
+ EXPECT_TRUE((version_ >= SSL_LIBRARY_VERSION_TLS_1_3) == ch1->captured());
+
+ // Supported Versions have a 1 byte length field.
+ TlsParser extParser(ch1->extension());
+ DataBuffer versions;
+ extParser.ReadVariable(&versions, 1);
+
+ checkGreasePresence(1, 0, versions);
+}
+
+TEST_P(GreasePresenceAbsenceTestTlsStream13, ClientGreasePskExchange) {
+ SetupGrease();
+
+ auto ch1 = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_psk_key_exchange_modes_xtn);
+ Connect();
+ EXPECT_TRUE(ch1->captured());
+
+ // PSK Exchange Modes have a 1 byte length field
+ TlsParser extParser(ch1->extension());
+ DataBuffer modes;
+ extParser.ReadVariable(&modes, 1);
+
+ // Scan for single byte GREASE PSK Values
+ size_t numGrease = 0;
+ for (uint8_t greaseVal : kTlsGreasePskValues) {
+ for (unsigned long i = 0; i < modes.len(); i++) {
+ if (greaseVal == modes.data()[i]) {
+ numGrease++;
+ }
+ }
+ }
+
+ EXPECT_EQ(expectGrease() ? size_t(1) : size_t(0), numGrease);
+}
+
+TEST_P(GreasePresenceAbsenceTestAllVersions, ClientGreaseAlpn) {
+ SetupGrease();
+ EnableAlpn();
+
+ auto ch1 =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_app_layer_protocol_xtn);
+ Connect();
+ EXPECT_TRUE((version_ >= SSL_LIBRARY_VERSION_TLS_1_1) == ch1->captured());
+
+ // ALPN Xtns have a redundant two-byte length
+ TlsParser alpnParser(ch1->extension());
+ alpnParser.Skip(2); // Skip the length
+ DataBuffer alpnEntry;
+
+ // Each ALPN entry has a single byte length prefixed.
+ size_t greaseAlpnEntrys = 0;
+ while (alpnParser.remaining()) {
+ alpnParser.ReadVariable(&alpnEntry, 1);
+ if (alpnEntry.len() == 2) {
+ greaseAlpnEntrys += countGreaseInBuffer(alpnEntry);
+ }
+ }
+
+ EXPECT_EQ(expectGrease() ? size_t(1) : size_t(0), greaseAlpnEntrys);
+}
+
+TEST_P(GreasePresenceAbsenceTestAllVersions, GreaseClientHelloExtension) {
+ SetupGrease();
+
+ auto ch1 =
+ MakeTlsFilter<TlsHandshakeRecorder>(client_, kTlsHandshakeClientHello);
+ Connect();
+ EXPECT_TRUE(ch1->buffer().len() > 0);
+
+ TlsParser extParser(ch1->buffer());
+ EXPECT_TRUE(extParser.Skip(2 + 32)); // Version + Random
+ EXPECT_TRUE(extParser.SkipVariable(1)); // Session ID
+ if (variant_ == ssl_variant_datagram) {
+ EXPECT_TRUE(extParser.SkipVariable(1)); // Cookie
+ }
+ EXPECT_TRUE(extParser.SkipVariable(2)); // Ciphersuites
+ EXPECT_TRUE(extParser.SkipVariable(1)); // Compression Methods
+ EXPECT_TRUE(extParser.Skip(2)); // Extension Lengths
+
+ // Scan for a 1-byte and a 0-byte extension.
+ uint32_t extType;
+ DataBuffer extBuf;
+ bool foundSmall = false;
+ bool foundLarge = false;
+ size_t numFound = 0;
+ while (extParser.remaining()) {
+ extParser.Read(&extType, 2);
+ extParser.ReadVariable(&extBuf, 2);
+ for (uint16_t greaseVal : kTlsGreaseValues) {
+ if (greaseVal == extType) {
+ numFound++;
+ foundSmall |= extBuf.len() == 0;
+ foundLarge |= extBuf.len() > 0;
+ }
+ }
+ }
+
+ EXPECT_EQ(foundSmall, expectGrease());
+ EXPECT_EQ(foundLarge, expectGrease());
+ EXPECT_EQ(numFound, expectGrease() ? size_t(2) : size_t(0));
+}
+
+TEST_P(GreasePresenceAbsenceTestTlsStream13, GreaseCertificateRequestSigAlg) {
+ SetupGrease();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+
+ auto cr =
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_signature_algorithms_xtn);
+ cr->SetHandshakeTypes({kTlsHandshakeCertificateRequest});
+ cr->EnableDecryption();
+ Connect();
+ EXPECT_TRUE(cr->captured());
+
+ checkGreasePresence(1, 0, cr->extension());
+}
+
+TEST_P(GreasePresenceAbsenceTestTlsStream13,
+ GreaseCertificateRequestExtension) {
+ SetupGrease();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+
+ auto cr = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeCertificateRequest);
+ cr->EnableDecryption();
+ Connect();
+ EXPECT_TRUE(cr->buffer().len() > 0);
+
+ TlsParser extParser(cr->buffer());
+ EXPECT_TRUE(extParser.SkipVariable(1)); // Context
+ EXPECT_TRUE(extParser.Skip(2)); // Extension Lengths
+
+ uint32_t extType;
+ DataBuffer extBuf;
+ bool found = false;
+ // Scan for a single, empty extension
+ while (extParser.remaining()) {
+ extParser.Read(&extType, 2);
+ extParser.ReadVariable(&extBuf, 2);
+ for (uint16_t greaseVal : kTlsGreaseValues) {
+ if (greaseVal == extType) {
+ EXPECT_TRUE(!found);
+ EXPECT_EQ(extBuf.len(), size_t(0));
+ found = true;
+ }
+ }
+ }
+
+ EXPECT_EQ(expectGrease(), found);
+}
+
+TEST_P(GreasePresenceAbsenceTestTlsStream13, GreaseNewSessionTicketExtension) {
+ SetupGrease();
+
+ auto nst = MakeTlsFilter<TlsHandshakeRecorder>(server_,
+ kTlsHandshakeNewSessionTicket);
+ nst->EnableDecryption();
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), nullptr, 0));
+ EXPECT_TRUE(nst->buffer().len() > 0);
+
+ TlsParser extParser(nst->buffer());
+ EXPECT_TRUE(extParser.Skip(4)); // lifetime
+ EXPECT_TRUE(extParser.Skip(4)); // age
+ EXPECT_TRUE(extParser.SkipVariable(1)); // Nonce
+ EXPECT_TRUE(extParser.SkipVariable(2)); // Ticket
+ EXPECT_TRUE(extParser.Skip(2)); // Extension Length
+
+ uint32_t extType;
+ DataBuffer extBuf;
+ bool found = false;
+ // Scan for a single, empty extension
+ while (extParser.remaining()) {
+ extParser.Read(&extType, 2);
+ extParser.ReadVariable(&extBuf, 2);
+ for (uint16_t greaseVal : kTlsGreaseValues) {
+ if (greaseVal == extType) {
+ EXPECT_TRUE(!found);
+ EXPECT_EQ(extBuf.len(), size_t(0));
+ found = true;
+ }
+ }
+ }
+
+ EXPECT_EQ(expectGrease(), found);
+}
+
+// Generic Client GREASE test
+TEST_P(TlsConnectGeneric, ClientGrease) {
+ EnsureTlsSetup();
+ ASSERT_EQ(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE),
+ SECSuccess);
+ Connect();
+}
+
+// Generic Server GREASE test
+TEST_P(TlsConnectGeneric, ServerGrease) {
+ EnsureTlsSetup();
+ ASSERT_EQ(SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE),
+ SECSuccess);
+ Connect();
+}
+
+// Generic GREASE test
+TEST_P(TlsConnectGeneric, Grease) {
+ EnsureTlsSetup();
+ ASSERT_EQ(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE),
+ SECSuccess);
+ ASSERT_EQ(SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE),
+ SECSuccess);
+ Connect();
+}
+
+// Check that GREASE values can be correctly reconstructed after HRR.
+TEST_P(TlsConnectGeneric, GreaseHRR) {
+ EnsureTlsSetup();
+ const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1};
+ const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+ ASSERT_EQ(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE),
+ SECSuccess);
+ ASSERT_EQ(SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE),
+ SECSuccess);
+ Connect();
+}
+
+// Check that GREASE additions interact correctly with psk-only handshake.
+TEST_F(TlsConnectStreamTls13, GreasePsk) {
+ EnsureTlsSetup();
+ ASSERT_EQ(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE),
+ SECSuccess);
+ ASSERT_EQ(SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE),
+ SECSuccess);
+
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ const uint8_t kPskDummyVal_[16] = {0x01, 0x02, 0x03, 0x04, 0x05,
+ 0x06, 0x07, 0x08, 0x09, 0x0a,
+ 0x0b, 0x0c, 0x0d, 0x0e, 0x0f};
+ SECItem psk_item;
+ psk_item.type = siBuffer;
+ psk_item.len = sizeof(kPskDummyVal_);
+ psk_item.data = const_cast<uint8_t*>(kPskDummyVal_);
+ PK11SymKey* key =
+ PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN, PK11_OriginUnwrap,
+ CKA_DERIVE, &psk_item, NULL);
+
+ ScopedPK11SymKey scoped_psk_(key);
+ const std::string kPskDummyLabel_ = "NSS PSK GTEST label";
+ const SSLHashType kPskHash_ = ssl_hash_sha384;
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+}
+
+// Test that ECH and GREASE work together successfully
+TEST_F(TlsConnectStreamTls13, GreaseAndECH) {
+ EnsureTlsSetup();
+ SetupEch(client_, server_);
+ ASSERT_EQ(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE),
+ SECSuccess);
+ ASSERT_EQ(SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE),
+ SECSuccess);
+ Connect();
+}
+
+// Test that TLS12 Server handles Client GREASE correctly
+TEST_F(TlsConnectTest, GreaseTLS12Server) {
+ EnsureTlsSetup();
+ ASSERT_EQ(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE),
+ SECSuccess);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+}
+
+// Test that TLS12 Client handles Server GREASE correctly
+TEST_F(TlsConnectTest, GreaseTLS12Client) {
+ EnsureTlsSetup();
+ ASSERT_EQ(SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE),
+ SECSuccess);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ Connect();
+}
+
+class GreaseOnlyTestStreamTls13 : public TlsConnectStreamTls13 {
+ public:
+ GreaseOnlyTestStreamTls13() : TlsConnectStreamTls13() {}
+
+ void ConnectWithCustomChExpectFail(const std::string& ch,
+ uint8_t server_alert, uint32_t server_code,
+ uint32_t client_code) {
+ std::vector<uint8_t> ch_vec = hex_string_to_bytes(ch);
+ DataBuffer ch_buf;
+ EnsureTlsSetup();
+
+ TlsAgentTestBase::MakeRecord(variant_, ssl_ct_handshake,
+ SSL_LIBRARY_VERSION_TLS_1_3, ch_vec.data(),
+ ch_vec.size(), &ch_buf, 0);
+ StartConnect();
+ client_->SendDirect(ch_buf);
+ ExpectAlert(server_, server_alert);
+ server_->Handshake();
+ server_->CheckErrorCode(server_code);
+ client_->ExpectReceiveAlert(server_alert, kTlsAlertFatal);
+ client_->Handshake();
+ client_->CheckErrorCode(client_code);
+ }
+};
+
+// Client: Offer only GREASE CipherSuite value
+TEST_F(GreaseOnlyTestStreamTls13, GreaseOnlyClientCipherSuite) {
+ // 0xdada
+ std::string ch =
+ "010000b003038afacda2963358e98f464f3ff0680ed3a9d382a8c3eac5e5604f5721add9"
+ "855c000002dada010000850000000b0009000006736572766572ff01000100000a001400"
+ "12001d00170018001901000101010201030104003300260024001d0020683668992de470"
+ "38660ee37bafc7392b05b8a94402ea1f3463ad3cfd7a694a46002b0003020304000d0018"
+ "001604030503060302030804080508060401050106010201002d00020101001c0002400"
+ "1";
+
+ ConnectWithCustomChExpectFail(ch, kTlsAlertHandshakeFailure,
+ SSL_ERROR_NO_CYPHER_OVERLAP,
+ SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+// Client: Offer only GREASE SupportedGroups value
+TEST_F(GreaseOnlyTestStreamTls13, GreaseOnlyClientSupportedGroup) {
+ // 0x3a3a
+ std::string ch =
+ "010000a40303484a4e14f547404da6115d7f73bbb0f1c9d65e66ac073dee6c4a62f72de9"
+ "a36f000006130113031302010000750000000b0009000006736572766572ff0100010000"
+ "0a000400023a3a003300260024001d0020e75cb8e217c95176954e8b5fb95843882462ce"
+ "2cd3fcfe67cf31463a05ea3d57002b0003020304000d0018001604030503060302030804"
+ "080508060401050106010201002d00020101001c00024001";
+
+ ConnectWithCustomChExpectFail(ch, kTlsAlertHandshakeFailure,
+ SSL_ERROR_NO_CYPHER_OVERLAP,
+ SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+// Client: Offer only GREASE SigAlgs value
+TEST_F(GreaseOnlyTestStreamTls13, GreaseOnlyClientSignatureAlgorithm) {
+ // 0x8a8a
+ std::string ch =
+ "010000a00303dfd8e2438a8d1b9f48d921dfc08959108807bd1105238bb3da2a2a8e3db0"
+ "6990000006130113031302010000710000000b0009000006736572766572ff0100010000"
+ "0a00140012001d00170018001901000101010201030104003300260024001d002074bb2c"
+ "94996d3ffc7ae5792f0c3c58676358a85ea304cd029fa3d6551013b333002b0003020304"
+ "000d000400028a8a002d00020101001c00024001";
+
+ ConnectWithCustomChExpectFail(ch, kTlsAlertHandshakeFailure,
+ SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM,
+ SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+// Client: Offer only GREASE SupportedVersions value
+TEST_F(GreaseOnlyTestStreamTls13, GreaseOnlyClientSupportedVersion) {
+ // 0xeaea
+ std::string ch =
+ "010000b203037e3618abae0dd0b3f06a504c47354551d1d5be36e9c3e1eac9c139c246b1"
+ "66da000006130113031302010000830000000b0009000006736572766572ff0100010000"
+ "0a00140012001d00170018001901000101010201030104003300260024001d00206b1816"
+ "577ff2e69d4d2661419150eaefa0328ffd396425cf1733ec06536b4e55002b000100000d"
+ "0018001604030503060302030804080508060401050106010201002d00020101001c0002"
+ "4001";
+
+ ConnectWithCustomChExpectFail(ch, kTlsAlertIllegalParameter,
+ SSL_ERROR_RX_MALFORMED_CLIENT_HELLO,
+ SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+class GreaseTestStreamTls12
+ : public TlsConnectStreamTls12,
+ public ::testing::WithParamInterface<uint16_t /* GREASE */> {
+ public:
+ GreaseTestStreamTls12() : TlsConnectStreamTls12(), grease_(GetParam()){};
+
+ void ConnectExpectSigAlgFail() {
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectReceiveAlert(kTlsAlertIllegalParameter);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ }
+
+ protected:
+ uint16_t grease_;
+};
+
+class TlsCertificateRequestSigAlgSetterFilter : public TlsHandshakeFilter {
+ public:
+ TlsCertificateRequestSigAlgSetterFilter(const std::shared_ptr<TlsAgent>& a,
+ uint16_t sigAlg)
+ : TlsHandshakeFilter(a, {kTlsHandshakeCertificateRequest}),
+ sigAlg_(sigAlg) {}
+ virtual PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) {
+ TlsParser parser(input);
+ DataBuffer cert_types;
+ if (!parser.ReadVariable(&cert_types, 1)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+
+ if (!parser.SkipVariable(2)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+
+ DataBuffer cas;
+ if (!parser.ReadVariable(&cas, 2)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+
+ size_t idx = 0;
+
+ // Write certificate types.
+ idx = output->Write(idx, cert_types.len(), 1);
+ idx = output->Write(idx, cert_types);
+
+ // Write signature algorithm.
+ idx = output->Write(idx, sizeof(sigAlg_), 2);
+ idx = output->Write(idx, sigAlg_, 2);
+
+ // Write certificate authorities.
+ idx = output->Write(idx, cas.len(), 2);
+ idx = output->Write(idx, cas);
+
+ return CHANGE;
+ }
+
+ private:
+ uint16_t sigAlg_;
+};
+
+// Server: Offer only GREASE CertificateRequest SigAlg value
+TEST_P(GreaseTestStreamTls12, GreaseOnlyServerTLS12CertificateRequestSigAlg) {
+ EnsureTlsSetup();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ MakeTlsFilter<TlsCertificateRequestSigAlgSetterFilter>(server_, grease_);
+
+ client_->ExpectSendAlert(kTlsAlertHandshakeFailure);
+ server_->ExpectReceiveAlert(kTlsAlertHandshakeFailure);
+ ConnectExpectFail();
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT);
+ client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+}
+
+// Illegally GREASE ServerKeyExchange ECC SignatureAlgorithm
+TEST_P(GreaseTestStreamTls12, GreasedTLS12ServerKexEccSigAlg) {
+ MakeTlsFilter<ECCServerKEXSigAlgReplacer>(server_, grease_);
+ EnableSomeEcdhCiphers();
+
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectReceiveAlert(kTlsAlertIllegalParameter);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Illegally GREASE ServerKeyExchange DHE SignatureAlgorithm
+TEST_P(GreaseTestStreamTls12, GreasedTLS12ServerKexDheSigAlg) {
+ MakeTlsFilter<DHEServerKEXSigAlgReplacer>(server_, grease_);
+ EnableOnlyDheCiphers();
+
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectReceiveAlert(kTlsAlertIllegalParameter);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Illegally GREASE ServerKeyExchange ECDHE NamedCurve
+TEST_P(GreaseTestStreamTls12, GreasedTLS12ServerKexEcdheNamedCurve) {
+ MakeTlsFilter<ECCServerKEXNamedCurveReplacer>(server_, grease_);
+ EnableSomeEcdhCiphers();
+
+ client_->ExpectSendAlert(kTlsAlertHandshakeFailure);
+ server_->ExpectReceiveAlert(kTlsAlertHandshakeFailure);
+ ConnectExpectFail();
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT);
+ client_->CheckErrorCode(SEC_ERROR_UNSUPPORTED_ELLIPTIC_CURVE);
+}
+
+// Illegally GREASE TLS12 Client CertificateVerify SignatureAlgorithm
+TEST_P(GreaseTestStreamTls12, GreasedTLS12ClientCertificateVerifySigAlg) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ MakeTlsFilter<TlsReplaceSignatureSchemeFilter>(client_, grease_);
+
+ server_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ client_->ExpectReceiveAlert(kTlsAlertIllegalParameter);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+}
+
+class GreaseTestStreamTls13
+ : public TlsConnectStreamTls13,
+ public ::testing::WithParamInterface<uint16_t /* GREASE */> {
+ public:
+ GreaseTestStreamTls13() : grease_(GetParam()){};
+
+ protected:
+ uint16_t grease_;
+};
+
+// Illegally GREASE TLS13 Client CertificateVerify SignatureAlgorithm
+TEST_P(GreaseTestStreamTls13, GreasedTLS13ClientCertificateVerifySigAlg) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ auto filter =
+ MakeTlsFilter<TlsReplaceSignatureSchemeFilter>(client_, grease_);
+ filter->EnableDecryption();
+
+ server_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ client_->ExpectReceiveAlert(kTlsAlertIllegalParameter);
+
+ // Manually trigger handshake to avoid race conditions
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ client_->Handshake();
+ server_->Handshake();
+ client_->Handshake();
+
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CERT_VERIFY);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Illegally GREASE TLS13 Server CertificateVerify SignatureAlgorithm
+TEST_P(GreaseTestStreamTls13, GreasedTLS13ServerCertificateVerifySigAlg) {
+ EnsureTlsSetup();
+ auto filter =
+ MakeTlsFilter<TlsReplaceSignatureSchemeFilter>(server_, grease_);
+ filter->EnableDecryption();
+
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectReceiveAlert(kTlsAlertIllegalParameter);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CERT_VERIFY);
+}
+
+// Illegally GREASE HelloRetryRequest version value
+TEST_P(GreaseTestStreamTls13, GreasedHelloRetryRequestVersion) {
+ EnsureTlsSetup();
+ // Trigger HelloRetryRequest
+ MakeTlsFilter<TlsExtensionDropper>(client_, ssl_tls13_key_share_xtn);
+ auto filter = MakeTlsFilter<TlsMessageVersionSetter>(
+ server_, kTlsHandshakeHelloRetryRequest, grease_);
+ filter->EnableDecryption();
+
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectReceiveAlert(kTlsAlertIllegalParameter);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+class GreaseTestStreamTls123
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<uint16_t /* version */, uint16_t /* GREASE */>> {
+ public:
+ GreaseTestStreamTls123()
+ : TlsConnectTestBase(ssl_variant_stream, std::get<0>(GetParam())),
+ grease_(std::get<1>(GetParam())){};
+
+ void ConnectExpectIllegalGreaseFail() {
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // Server expects handshake but receives encrypted alert.
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ } else {
+ server_->ExpectReceiveAlert(kTlsAlertIllegalParameter);
+ }
+ ConnectExpectFail();
+ }
+
+ protected:
+ uint16_t grease_;
+};
+
+// Illegally GREASE TLS12 and TLS13 ServerHello version value
+TEST_P(GreaseTestStreamTls123, GreasedServerHelloVersion) {
+ EnsureTlsSetup();
+ auto filter = MakeTlsFilter<TlsMessageVersionSetter>(
+ server_, kTlsHandshakeServerHello, grease_);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ filter->EnableDecryption();
+ }
+ ConnectExpectIllegalGreaseFail();
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+}
+
+// Illegally GREASE TLS12 and TLS13 selected CipherSuite value
+TEST_P(GreaseTestStreamTls123, GreasedServerHelloCipherSuite) {
+ EnsureTlsSetup();
+ auto filter = MakeTlsFilter<SelectedCipherSuiteReplacer>(server_, grease_);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ filter->EnableDecryption();
+ }
+ ConnectExpectIllegalGreaseFail();
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+class GreaseExtensionTestStreamTls13
+ : public TlsConnectStreamTls13,
+ public ::testing::WithParamInterface<
+ std::tuple<uint8_t /* message */, uint16_t /* GREASE */>> {
+ public:
+ GreaseExtensionTestStreamTls13()
+ : TlsConnectStreamTls13(),
+ message_(std::get<0>(GetParam())),
+ grease_(std::get<1>(GetParam())){};
+
+ protected:
+ uint8_t message_;
+ uint16_t grease_;
+};
+
+// Illegally GREASE TLS13 Server EncryptedExtensions and Certificate Extensions
+// NSS currently allows offering unkown extensions in HelloRetryRequests!
+TEST_P(GreaseExtensionTestStreamTls13, GreasedServerExtensions) {
+ EnsureTlsSetup();
+ DataBuffer empty = DataBuffer(1);
+ auto filter =
+ MakeTlsFilter<TlsExtensionAppender>(server_, message_, grease_, empty);
+ filter->EnableDecryption();
+
+ server_->ExpectReceiveAlert(kTlsAlertUnsupportedExtension);
+ client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_EXTENSION);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_EXTENSION_ALERT);
+}
+
+// Illegally GREASE TLS12 and TLS13 ServerHello Extensions
+TEST_P(GreaseTestStreamTls123, GreasedServerHelloExtensions) {
+ EnsureTlsSetup();
+ DataBuffer empty = DataBuffer(1);
+ auto filter = MakeTlsFilter<TlsExtensionAppender>(
+ server_, kTlsHandshakeServerHello, grease_, empty);
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ filter->EnableDecryption();
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ } else {
+ server_->ExpectReceiveAlert(kTlsAlertUnsupportedExtension);
+ }
+ client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_EXTENSION);
+}
+
+// Illegally GREASE TLS13 Client Certificate Extensions
+// Server ignores injected client extensions and fails on CertificateVerify
+TEST_P(GreaseTestStreamTls13, GreasedClientCertificateExtensions) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ DataBuffer empty = DataBuffer(1);
+ auto filter = MakeTlsFilter<TlsExtensionAppender>(
+ client_, kTlsHandshakeCertificate, grease_, empty);
+ filter->EnableDecryption();
+
+ server_->ExpectSendAlert(kTlsAlertDecryptError);
+ client_->ExpectReceiveAlert(kTlsAlertDecryptError);
+
+ // Manually trigger handshake to avoid race conditions
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ client_->Handshake();
+ server_->Handshake();
+ client_->Handshake();
+
+ server_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13, GreaseClientHelloExtensionPermutation) {
+ EnsureTlsSetup();
+ PR_ASSERT(SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_CH_EXTENSION_PERMUTATION,
+ PR_TRUE) == SECSuccess);
+ PR_ASSERT(SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_GREASE, PR_TRUE) ==
+ SECSuccess);
+ Connect();
+}
+
+INSTANTIATE_TEST_SUITE_P(GreaseTestTls12, GreaseTestStreamTls12,
+ ::testing::ValuesIn(kTlsGreaseValues));
+
+INSTANTIATE_TEST_SUITE_P(GreaseTestTls13, GreaseTestStreamTls13,
+ ::testing::ValuesIn(kTlsGreaseValues));
+
+INSTANTIATE_TEST_SUITE_P(
+ GreaseTestTls123, GreaseTestStreamTls123,
+ ::testing::Combine(TlsConnectTestBase::kTlsV12Plus,
+ ::testing::ValuesIn(kTlsGreaseValues)));
+
+INSTANTIATE_TEST_SUITE_P(
+ GreaseExtensionTest, GreaseExtensionTestStreamTls13,
+ testing::Combine(testing::ValuesIn(kTlsGreaseExtensionMessages),
+ testing::ValuesIn(kTlsGreaseValues)));
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc
new file mode 100644
index 0000000000..3e1e30bb86
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc
@@ -0,0 +1,433 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <memory>
+#include "nss.h"
+#include "pk11pub.h"
+#include "secerr.h"
+#include "sslproto.h"
+#include "sslexp.h"
+#include "tls13hkdf.h"
+
+#include "databuffer.h"
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+
+namespace nss_test {
+
+const uint8_t kKey1Data[] = {
+ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
+ 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
+ 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23,
+ 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f};
+const DataBuffer kKey1(kKey1Data, sizeof(kKey1Data));
+
+// The same as key1 but with the first byte
+// 0x01.
+const uint8_t kKey2Data[] = {
+ 0x01, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
+ 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
+ 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23,
+ 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f};
+const DataBuffer kKey2(kKey2Data, sizeof(kKey2Data));
+
+const char kLabelMasterSecret[] = "master secret";
+
+const uint8_t kSessionHash[] = {
+ 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb,
+ 0xfc, 0xfd, 0xfe, 0xff, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7,
+ 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, 0xd0, 0xd1, 0xd2, 0xd3,
+ 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf,
+ 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb,
+ 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7,
+ 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0xe0, 0xe1, 0xe2, 0xe3,
+ 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef,
+ 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb,
+ 0xfc, 0xfd, 0xfe, 0xff,
+};
+
+const size_t kHashLength[] = {
+ 0, /* ssl_hash_none */
+ 16, /* ssl_hash_md5 */
+ 20, /* ssl_hash_sha1 */
+ 28, /* ssl_hash_sha224 */
+ 32, /* ssl_hash_sha256 */
+ 48, /* ssl_hash_sha384 */
+ 64, /* ssl_hash_sha512 */
+};
+
+size_t GetHashLength(SSLHashType hash) {
+ size_t i = static_cast<size_t>(hash);
+ if (i < PR_ARRAY_SIZE(kHashLength)) {
+ return kHashLength[i];
+ }
+ ADD_FAILURE() << "Unknown hash: " << hash;
+ return 0;
+}
+
+PRUint16 GetSomeCipherSuiteForHash(SSLHashType hash) {
+ switch (hash) {
+ case ssl_hash_sha256:
+ return TLS_AES_128_GCM_SHA256;
+ case ssl_hash_sha384:
+ return TLS_AES_256_GCM_SHA384;
+ default:
+ ADD_FAILURE() << "Unknown hash: " << hash;
+ }
+ return 0;
+}
+
+const std::string kHashName[] = {"None", "MD5", "SHA-1", "SHA-224",
+ "SHA-256", "SHA-384", "SHA-512"};
+
+static void ImportKey(ScopedPK11SymKey* to, const DataBuffer& key,
+ SSLHashType hash_type, PK11SlotInfo* slot) {
+ ASSERT_LT(hash_type, sizeof(kHashLength));
+ ASSERT_LE(kHashLength[hash_type], key.len());
+ SECItem key_item = {siBuffer, const_cast<uint8_t*>(key.data()),
+ static_cast<unsigned int>(GetHashLength(hash_type))};
+
+ PK11SymKey* inner =
+ PK11_ImportSymKey(slot, CKM_SSL3_MASTER_KEY_DERIVE, PK11_OriginUnwrap,
+ CKA_DERIVE, &key_item, NULL);
+ ASSERT_NE(nullptr, inner);
+ to->reset(inner);
+}
+
+static void DumpData(const std::string& label, const uint8_t* buf, size_t len) {
+ DataBuffer d(buf, len);
+
+ std::cerr << label << ": " << d << std::endl;
+}
+
+void DumpKey(const std::string& label, ScopedPK11SymKey& key) {
+ SECStatus rv = PK11_ExtractKeyValue(key.get());
+ ASSERT_EQ(SECSuccess, rv);
+
+ SECItem* key_data = PK11_GetKeyData(key.get());
+ ASSERT_NE(nullptr, key_data);
+
+ DumpData(label, key_data->data, key_data->len);
+}
+
+extern "C" {
+extern char ssl_trace;
+extern FILE* ssl_trace_iob;
+}
+
+class TlsHkdfTest : public ::testing::Test,
+ public ::testing::WithParamInterface<SSLHashType> {
+ public:
+ TlsHkdfTest()
+ : k1_(), k2_(), hash_type_(GetParam()), slot_(PK11_GetInternalSlot()) {
+ EXPECT_NE(nullptr, slot_);
+ char* ev = getenv("SSLTRACE");
+ if (ev && ev[0]) {
+ ssl_trace = atoi(ev);
+ ssl_trace_iob = stderr;
+ }
+ }
+
+ void SetUp() {
+ ImportKey(&k1_, kKey1, hash_type_, slot_.get());
+ ImportKey(&k2_, kKey2, hash_type_, slot_.get());
+ }
+
+ void VerifyKey(const ScopedPK11SymKey& key, CK_MECHANISM_TYPE expected_mech,
+ const DataBuffer& expected_value) {
+ EXPECT_EQ(expected_mech, PK11_GetMechanism(key.get()));
+
+ SECStatus rv = PK11_ExtractKeyValue(key.get());
+ ASSERT_EQ(SECSuccess, rv);
+
+ SECItem* key_data = PK11_GetKeyData(key.get());
+ ASSERT_NE(nullptr, key_data);
+
+ EXPECT_EQ(expected_value.len(), key_data->len);
+ EXPECT_EQ(
+ 0, memcmp(expected_value.data(), key_data->data, expected_value.len()));
+ }
+
+ void HkdfExtract(const ScopedPK11SymKey& ikmk1, const ScopedPK11SymKey& ikmk2,
+ SSLHashType base_hash, const DataBuffer& expected) {
+ std::cerr << "Hash = " << kHashName[base_hash] << std::endl;
+
+ PK11SymKey* prk = nullptr;
+ SECStatus rv = tls13_HkdfExtract(ikmk1.get(), ikmk2.get(), base_hash, &prk);
+ ASSERT_EQ(SECSuccess, rv);
+ ScopedPK11SymKey prkk(prk);
+
+ DumpKey("Output", prkk);
+ VerifyKey(prkk, CKM_HKDF_DERIVE, expected);
+
+ // Now test the public wrapper.
+ PRUint16 cs = GetSomeCipherSuiteForHash(base_hash);
+ rv = SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_3, cs, ikmk1.get(),
+ ikmk2.get(), &prk);
+ ASSERT_EQ(SECSuccess, rv);
+ ASSERT_NE(nullptr, prk);
+ VerifyKey(ScopedPK11SymKey(prk), CKM_HKDF_DERIVE, expected);
+ }
+
+ void HkdfExpandLabel(ScopedPK11SymKey* prk, SSLHashType base_hash,
+ const uint8_t* session_hash, size_t session_hash_len,
+ const char* label, size_t label_len,
+ const DataBuffer& expected) {
+ ASSERT_NE(nullptr, prk);
+ std::cerr << "Hash = " << kHashName[base_hash] << std::endl;
+
+ std::vector<uint8_t> output(expected.len());
+
+ SECStatus rv = tls13_HkdfExpandLabelRaw(
+ prk->get(), base_hash, session_hash, session_hash_len, label, label_len,
+ ssl_variant_stream, &output[0], output.size());
+ ASSERT_EQ(SECSuccess, rv);
+ DumpData("Output", &output[0], output.size());
+ EXPECT_EQ(0, memcmp(expected.data(), &output[0], expected.len()));
+
+ // Verify that the public API produces the same result.
+ PRUint16 cs = GetSomeCipherSuiteForHash(base_hash);
+ PK11SymKey* secret;
+ rv = SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3, cs, prk->get(),
+ session_hash, session_hash_len, label, label_len,
+ &secret);
+ EXPECT_EQ(SECSuccess, rv);
+ ASSERT_NE(nullptr, secret);
+ VerifyKey(ScopedPK11SymKey(secret), CKM_HKDF_DERIVE, expected);
+
+ // Verify that a key can be created with a different key type and size.
+ rv = SSL_HkdfExpandLabelWithMech(
+ SSL_LIBRARY_VERSION_TLS_1_3, cs, prk->get(), session_hash,
+ session_hash_len, label, label_len, CKM_DES3_CBC_PAD, 24, &secret);
+ EXPECT_EQ(SECSuccess, rv);
+ ASSERT_NE(nullptr, secret);
+ ScopedPK11SymKey with_mech(secret);
+ EXPECT_EQ(static_cast<CK_MECHANISM_TYPE>(CKM_DES3_CBC_PAD),
+ PK11_GetMechanism(with_mech.get()));
+ // Just verify that the key is the right size.
+ rv = PK11_ExtractKeyValue(with_mech.get());
+ ASSERT_EQ(SECSuccess, rv);
+ SECItem* key_data = PK11_GetKeyData(with_mech.get());
+ ASSERT_NE(nullptr, key_data);
+ EXPECT_EQ(24U, key_data->len);
+ }
+
+ protected:
+ ScopedPK11SymKey k1_;
+ ScopedPK11SymKey k2_;
+ SSLHashType hash_type_;
+
+ private:
+ ScopedPK11SlotInfo slot_;
+};
+
+TEST_P(TlsHkdfTest, HkdfNullNull) {
+ const uint8_t tv[][48] = {
+ {/* ssl_hash_none */},
+ {/* ssl_hash_md5 */},
+ {/* ssl_hash_sha1 */},
+ {/* ssl_hash_sha224 */},
+ {0x33, 0xad, 0x0a, 0x1c, 0x60, 0x7e, 0xc0, 0x3b, 0x09, 0xe6, 0xcd,
+ 0x98, 0x93, 0x68, 0x0c, 0xe2, 0x10, 0xad, 0xf3, 0x00, 0xaa, 0x1f,
+ 0x26, 0x60, 0xe1, 0xb2, 0x2e, 0x10, 0xf1, 0x70, 0xf9, 0x2a},
+ {0x7e, 0xe8, 0x20, 0x6f, 0x55, 0x70, 0x02, 0x3e, 0x6d, 0xc7, 0x51, 0x9e,
+ 0xb1, 0x07, 0x3b, 0xc4, 0xe7, 0x91, 0xad, 0x37, 0xb5, 0xc3, 0x82, 0xaa,
+ 0x10, 0xba, 0x18, 0xe2, 0x35, 0x7e, 0x71, 0x69, 0x71, 0xf9, 0x36, 0x2f,
+ 0x2c, 0x2f, 0xe2, 0xa7, 0x6b, 0xfd, 0x78, 0xdf, 0xec, 0x4e, 0xa9, 0xb5}};
+
+ const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_));
+ HkdfExtract(nullptr, nullptr, hash_type_, expected_data);
+}
+
+TEST_P(TlsHkdfTest, HkdfKey1Only) {
+ const uint8_t tv[][48] = {
+ {/* ssl_hash_none */},
+ {/* ssl_hash_md5 */},
+ {/* ssl_hash_sha1 */},
+ {/* ssl_hash_sha224 */},
+ {0x41, 0x6c, 0x53, 0x92, 0xb9, 0xf3, 0x6d, 0xf1, 0x88, 0xe9, 0x0e,
+ 0xb1, 0x4d, 0x17, 0xbf, 0x0d, 0xa1, 0x90, 0xbf, 0xdb, 0x7f, 0x1f,
+ 0x49, 0x56, 0xe6, 0xe5, 0x66, 0xa5, 0x69, 0xc8, 0xb1, 0x5c},
+ {0x51, 0xb1, 0xd5, 0xb4, 0x59, 0x79, 0x79, 0x08, 0x4a, 0x15, 0xb2, 0xdb,
+ 0x84, 0xd3, 0xd6, 0xbc, 0xfc, 0x93, 0x45, 0xd9, 0xdc, 0x74, 0xda, 0x1a,
+ 0x57, 0xc2, 0x76, 0x9f, 0x3f, 0x83, 0x45, 0x2f, 0xf6, 0xf3, 0x56, 0x1f,
+ 0x58, 0x63, 0xdb, 0x88, 0xda, 0x40, 0xce, 0x63, 0x7d, 0x24, 0x37, 0xf3}};
+
+ const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_));
+ HkdfExtract(k1_, nullptr, hash_type_, expected_data);
+}
+
+TEST_P(TlsHkdfTest, HkdfKey2Only) {
+ const uint8_t tv[][48] = {
+ {/* ssl_hash_none */},
+ {/* ssl_hash_md5 */},
+ {/* ssl_hash_sha1 */},
+ {/* ssl_hash_sha224 */},
+ {0x16, 0xaf, 0x00, 0x54, 0x3a, 0x56, 0xc8, 0x26, 0xa2, 0xa7, 0xfc,
+ 0xb6, 0x34, 0x66, 0x8a, 0xfd, 0x36, 0xdc, 0x8e, 0xce, 0xc4, 0xd2,
+ 0x6c, 0x7a, 0xdc, 0xe3, 0x70, 0x36, 0x3d, 0x60, 0xfa, 0x0b},
+ {0x7b, 0x40, 0xf9, 0xef, 0x91, 0xff, 0xc9, 0xd1, 0x29, 0x24, 0x5c, 0xbf,
+ 0xf8, 0x82, 0x76, 0x68, 0xae, 0x4b, 0x63, 0xe8, 0x03, 0xdd, 0x39, 0xa8,
+ 0xd4, 0x6a, 0xf6, 0xe5, 0xec, 0xea, 0xf8, 0x7d, 0x91, 0x71, 0x81, 0xf1,
+ 0xdb, 0x3b, 0xaf, 0xbf, 0xde, 0x71, 0x61, 0x15, 0xeb, 0xb5, 0x5f, 0x68}};
+
+ const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_));
+ HkdfExtract(nullptr, k2_, hash_type_, expected_data);
+}
+
+TEST_P(TlsHkdfTest, HkdfKey1Key2) {
+ const uint8_t tv[][48] = {
+ {/* ssl_hash_none */},
+ {/* ssl_hash_md5 */},
+ {/* ssl_hash_sha1 */},
+ {/* ssl_hash_sha224 */},
+ {0xa5, 0x68, 0x02, 0x5a, 0x95, 0xc9, 0x7f, 0x55, 0x38, 0xbc, 0xf7,
+ 0x97, 0xcc, 0x0f, 0xd5, 0xf6, 0xa8, 0x8d, 0x15, 0xbc, 0x0e, 0x85,
+ 0x74, 0x70, 0x3c, 0xa3, 0x65, 0xbd, 0x76, 0xcf, 0x9f, 0xd3},
+ {0x01, 0x93, 0xc0, 0x07, 0x3f, 0x6a, 0x83, 0x0e, 0x2e, 0x4f, 0xb2, 0x58,
+ 0xe4, 0x00, 0x08, 0x5c, 0x68, 0x9c, 0x37, 0x32, 0x00, 0x37, 0xff, 0xc3,
+ 0x1c, 0x5b, 0x98, 0x0b, 0x02, 0x92, 0x3f, 0xfd, 0x73, 0x5a, 0x6f, 0x2a,
+ 0x95, 0xa3, 0xee, 0xf6, 0xd6, 0x8e, 0x6f, 0x86, 0xea, 0x63, 0xf8, 0x33}};
+
+ const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_));
+ HkdfExtract(k1_, k2_, hash_type_, expected_data);
+}
+
+TEST_P(TlsHkdfTest, HkdfExpandLabel) {
+ const uint8_t tv[][48] = {
+ {/* ssl_hash_none */},
+ {/* ssl_hash_md5 */},
+ {/* ssl_hash_sha1 */},
+ {/* ssl_hash_sha224 */},
+ {0x3e, 0x4e, 0x6e, 0xd0, 0xbc, 0xc4, 0xf4, 0xff, 0xf0, 0xf5, 0x69,
+ 0xd0, 0x6c, 0x1e, 0x0e, 0x10, 0x32, 0xaa, 0xd7, 0xa3, 0xef, 0xf6,
+ 0xa8, 0x65, 0x8e, 0xbe, 0xee, 0xc7, 0x1f, 0x01, 0x6d, 0x3c},
+ {0x41, 0xea, 0x77, 0x09, 0x8c, 0x90, 0x04, 0x10, 0xec, 0xbc, 0x37, 0xd8,
+ 0x5b, 0x54, 0xcd, 0x7b, 0x08, 0x15, 0x13, 0x20, 0xed, 0x1e, 0x3f, 0x54,
+ 0x74, 0xf7, 0x8b, 0x06, 0x38, 0x28, 0x06, 0x37, 0x75, 0x23, 0xa2, 0xb7,
+ 0x34, 0xb1, 0x72, 0x2e, 0x59, 0x6d, 0x5a, 0x31, 0xf5, 0x53, 0xab, 0x99}};
+
+ const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_));
+ HkdfExpandLabel(&k1_, hash_type_, kSessionHash, GetHashLength(hash_type_),
+ kLabelMasterSecret, strlen(kLabelMasterSecret),
+ expected_data);
+}
+
+TEST_P(TlsHkdfTest, HkdfExpandLabelNoHash) {
+ const uint8_t tv[][48] = {
+ {/* ssl_hash_none */},
+ {/* ssl_hash_md5 */},
+ {/* ssl_hash_sha1 */},
+ {/* ssl_hash_sha224 */},
+ {0xb7, 0x08, 0x00, 0xe3, 0x8e, 0x48, 0x68, 0x91, 0xb1, 0x0f, 0x5e,
+ 0x6f, 0x22, 0x53, 0x6b, 0x84, 0x69, 0x75, 0xaa, 0xa3, 0x2a, 0xe7,
+ 0xde, 0xaa, 0xc3, 0xd1, 0xb4, 0x05, 0x22, 0x5c, 0x68, 0xf5},
+ {0x13, 0xd3, 0x36, 0x9f, 0x3c, 0x78, 0xa0, 0x32, 0x40, 0xee, 0x16, 0xe9,
+ 0x11, 0x12, 0x66, 0xc7, 0x51, 0xad, 0xd8, 0x3c, 0xa1, 0xa3, 0x97, 0x74,
+ 0xd7, 0x45, 0xff, 0xa7, 0x88, 0x9e, 0x52, 0x17, 0x2e, 0xaa, 0x3a, 0xd2,
+ 0x35, 0xd8, 0xd5, 0x35, 0xfd, 0x65, 0x70, 0x9f, 0xa9, 0xf9, 0xfa, 0x23}};
+
+ const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_));
+ HkdfExpandLabel(&k1_, hash_type_, nullptr, 0, kLabelMasterSecret,
+ strlen(kLabelMasterSecret), expected_data);
+}
+
+TEST_P(TlsHkdfTest, BadExtractWrapperInput) {
+ PK11SymKey* key = nullptr;
+
+ // Bad version.
+ EXPECT_EQ(SECFailure,
+ SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256,
+ k1_.get(), k2_.get(), &key));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ // Bad ciphersuite.
+ EXPECT_EQ(SECFailure,
+ SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_NULL_SHA,
+ k1_.get(), k2_.get(), &key));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ // Old ciphersuite.
+ EXPECT_EQ(SECFailure, SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_3,
+ TLS_RSA_WITH_AES_128_CBC_SHA, k1_.get(),
+ k2_.get(), &key));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ // NULL outparam..
+ EXPECT_EQ(SECFailure, SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_3,
+ TLS_RSA_WITH_AES_128_CBC_SHA, k1_.get(),
+ k2_.get(), nullptr));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ EXPECT_EQ(nullptr, key);
+}
+
+TEST_P(TlsHkdfTest, BadExpandLabelWrapperInput) {
+ PK11SymKey* key = nullptr;
+ static const char* kLabel = "label";
+
+ // Bad version.
+ EXPECT_EQ(
+ SECFailure,
+ SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256,
+ k1_.get(), nullptr, 0, kLabel, strlen(kLabel), &key));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ // Bad ciphersuite.
+ EXPECT_EQ(
+ SECFailure,
+ SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_NULL_MD5,
+ k1_.get(), nullptr, 0, kLabel, strlen(kLabel), &key));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ // Old ciphersuite.
+ EXPECT_EQ(SECFailure,
+ SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3,
+ TLS_RSA_WITH_AES_128_CBC_SHA, k1_.get(),
+ nullptr, 0, kLabel, strlen(kLabel), &key));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ // Null PRK.
+ EXPECT_EQ(SECFailure, SSL_HkdfExpandLabel(
+ SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256,
+ nullptr, nullptr, 0, kLabel, strlen(kLabel), &key));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ // Null, non-zero-length handshake hash.
+ EXPECT_EQ(
+ SECFailure,
+ SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256,
+ k1_.get(), nullptr, 2, kLabel, strlen(kLabel), &key));
+
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ // Null, non-zero-length label.
+ EXPECT_EQ(SECFailure,
+ SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3,
+ TLS_AES_128_GCM_SHA256, k1_.get(), nullptr, 0,
+ nullptr, strlen(kLabel), &key));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ // Null, empty label.
+ EXPECT_EQ(SECFailure, SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3,
+ TLS_AES_128_GCM_SHA256, k1_.get(),
+ nullptr, 0, nullptr, 0, &key));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ // Null key pointer..
+ EXPECT_EQ(SECFailure,
+ SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3,
+ TLS_AES_128_GCM_SHA256, k1_.get(), nullptr, 0,
+ kLabel, strlen(kLabel), nullptr));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ EXPECT_EQ(nullptr, key);
+}
+
+static const SSLHashType kHashTypes[] = {ssl_hash_sha256, ssl_hash_sha384};
+INSTANTIATE_TEST_SUITE_P(AllHashFuncs, TlsHkdfTest,
+ ::testing::ValuesIn(kHashTypes));
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/tls_protect.cc b/security/nss/gtests/ssl_gtest/tls_protect.cc
new file mode 100644
index 0000000000..6187660a5c
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_protect.cc
@@ -0,0 +1,148 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_protect.h"
+#include "sslproto.h"
+#include "tls_filter.h"
+
+namespace nss_test {
+
+static uint64_t FirstSeqno(bool dtls, uint16_t epoc) {
+ if (dtls) {
+ return static_cast<uint64_t>(epoc) << 48;
+ }
+ return 0;
+}
+
+TlsCipherSpec::TlsCipherSpec(bool dtls, uint16_t epoc)
+ : dtls_(dtls),
+ epoch_(epoc),
+ in_seqno_(FirstSeqno(dtls, epoc)),
+ out_seqno_(FirstSeqno(dtls, epoc)) {}
+
+bool TlsCipherSpec::SetKeys(SSLCipherSuiteInfo* cipherinfo,
+ PK11SymKey* secret) {
+ SSLAeadContext* aead_ctx;
+ SSLProtocolVariant variant =
+ dtls_ ? ssl_variant_datagram : ssl_variant_stream;
+ SECStatus rv =
+ SSL_MakeVariantAead(SSL_LIBRARY_VERSION_TLS_1_3, cipherinfo->cipherSuite,
+ variant, secret, "", 0, // Use the default labels.
+ &aead_ctx);
+ if (rv != SECSuccess) {
+ return false;
+ }
+ aead_.reset(aead_ctx);
+
+ SSLMaskingContext* mask_ctx;
+ const char kHkdfPurposeSn[] = "sn";
+ rv = SSL_CreateVariantMaskingContext(
+ SSL_LIBRARY_VERSION_TLS_1_3, cipherinfo->cipherSuite, variant, secret,
+ kHkdfPurposeSn, strlen(kHkdfPurposeSn), &mask_ctx);
+ if (rv != SECSuccess) {
+ return false;
+ }
+ mask_.reset(mask_ctx);
+ return true;
+}
+
+bool TlsCipherSpec::Unprotect(const TlsRecordHeader& header,
+ const DataBuffer& ciphertext,
+ DataBuffer* plaintext,
+ TlsRecordHeader* out_header) {
+ if (!aead_ || !out_header) {
+ return false;
+ }
+ *out_header = header;
+
+ // Make space.
+ plaintext->Allocate(ciphertext.len());
+
+ unsigned int len;
+ uint64_t seqno = dtls_ ? header.sequence_number() : in_seqno_;
+ SECStatus rv;
+
+ if (header.is_dtls13_ciphertext()) {
+ if (!mask_ || !out_header) {
+ return false;
+ }
+ PORT_Assert(ciphertext.len() >= 16);
+ DataBuffer mask(2);
+ rv = SSL_CreateMask(mask_.get(), ciphertext.data(), ciphertext.len(),
+ mask.data(), mask.len());
+ if (rv != SECSuccess) {
+ return false;
+ }
+
+ if (!out_header->MaskSequenceNumber(mask)) {
+ return false;
+ }
+ seqno = out_header->sequence_number();
+ }
+
+ auto header_bytes = out_header->header();
+ rv = SSL_AeadDecrypt(aead_.get(), seqno, header_bytes.data(),
+ header_bytes.len(), ciphertext.data(), ciphertext.len(),
+ plaintext->data(), &len, plaintext->len());
+ if (rv != SECSuccess) {
+ return false;
+ }
+
+ RecordUnprotected(seqno);
+ plaintext->Truncate(static_cast<size_t>(len));
+
+ return true;
+}
+
+bool TlsCipherSpec::Protect(const TlsRecordHeader& header,
+ const DataBuffer& plaintext, DataBuffer* ciphertext,
+ TlsRecordHeader* out_header) {
+ if (!aead_ || !out_header) {
+ return false;
+ }
+
+ *out_header = header;
+
+ // Make a padded buffer.
+ ciphertext->Allocate(plaintext.len() +
+ 32); // Room for any plausible auth tag
+ unsigned int len;
+
+ DataBuffer header_bytes;
+ (void)header.WriteHeader(&header_bytes, 0, plaintext.len() + 16);
+ uint64_t seqno = dtls_ ? header.sequence_number() : out_seqno_;
+
+ SECStatus rv =
+ SSL_AeadEncrypt(aead_.get(), seqno, header_bytes.data(),
+ header_bytes.len(), plaintext.data(), plaintext.len(),
+ ciphertext->data(), &len, ciphertext->len());
+ if (rv != SECSuccess) {
+ return false;
+ }
+
+ if (header.is_dtls13_ciphertext()) {
+ if (!mask_ || !out_header) {
+ return false;
+ }
+ PORT_Assert(ciphertext->len() >= 16);
+ DataBuffer mask(2);
+ rv = SSL_CreateMask(mask_.get(), ciphertext->data(), ciphertext->len(),
+ mask.data(), mask.len());
+ if (rv != SECSuccess) {
+ return false;
+ }
+ if (!out_header->MaskSequenceNumber(mask)) {
+ return false;
+ }
+ }
+
+ RecordProtected();
+ ciphertext->Truncate(len);
+
+ return true;
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/tls_protect.h b/security/nss/gtests/ssl_gtest/tls_protect.h
new file mode 100644
index 0000000000..d7ea2aa128
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_protect.h
@@ -0,0 +1,60 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_protection_h_
+#define tls_protection_h_
+
+#include <cstdint>
+#include <memory>
+
+#include "pk11pub.h"
+#include "sslt.h"
+#include "sslexp.h"
+
+#include "databuffer.h"
+#include "scoped_ptrs_ssl.h"
+
+namespace nss_test {
+class TlsRecordHeader;
+
+// Our analog of ssl3CipherSpec
+class TlsCipherSpec {
+ public:
+ TlsCipherSpec(bool dtls, uint16_t epoc);
+ bool SetKeys(SSLCipherSuiteInfo* cipherinfo, PK11SymKey* secret);
+
+ bool Protect(const TlsRecordHeader& header, const DataBuffer& plaintext,
+ DataBuffer* ciphertext, TlsRecordHeader* out_header);
+ bool Unprotect(const TlsRecordHeader& header, const DataBuffer& ciphertext,
+ DataBuffer* plaintext, TlsRecordHeader* out_header);
+
+ uint16_t epoch() const { return epoch_; }
+ uint64_t next_in_seqno() const { return in_seqno_; }
+ void RecordUnprotected(uint64_t seqno) {
+ // Reordering happens, so don't let this go backwards.
+ in_seqno_ = (std::max)(in_seqno_, seqno + 1);
+ }
+ uint64_t next_out_seqno() { return out_seqno_; }
+ void RecordProtected() { out_seqno_++; }
+
+ void RecordDropped() { record_dropped_ = true; }
+ bool record_dropped() const { return record_dropped_; }
+
+ bool is_protected() const { return aead_ != nullptr; }
+
+ private:
+ bool dtls_;
+ uint16_t epoch_;
+ uint64_t in_seqno_;
+ uint64_t out_seqno_;
+ bool record_dropped_ = false;
+ ScopedSSLAeadContext aead_;
+ ScopedSSLMaskingContext mask_;
+};
+
+} // namespace nss_test
+
+#endif
diff --git a/security/nss/gtests/ssl_gtest/tls_psk_unittest.cc b/security/nss/gtests/ssl_gtest/tls_psk_unittest.cc
new file mode 100644
index 0000000000..678a9ff585
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_psk_unittest.cc
@@ -0,0 +1,515 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+class Tls13PskTest : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ Tls13PskTest()
+ : TlsConnectTestBase(std::get<0>(GetParam()),
+ SSL_LIBRARY_VERSION_TLS_1_3),
+ suite_(std::get<1>(GetParam())) {}
+
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
+ scoped_psk_.reset(GetPsk());
+ ASSERT_TRUE(!!scoped_psk_);
+ }
+
+ private:
+ PK11SymKey* GetPsk() {
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ if (!slot) {
+ ADD_FAILURE();
+ return nullptr;
+ }
+
+ SECItem psk_item;
+ psk_item.type = siBuffer;
+ psk_item.len = sizeof(kPskDummyVal_);
+ psk_item.data = const_cast<uint8_t*>(kPskDummyVal_);
+
+ PK11SymKey* key =
+ PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN, PK11_OriginUnwrap,
+ CKA_DERIVE, &psk_item, NULL);
+ if (!key) {
+ ADD_FAILURE();
+ }
+ return key;
+ }
+
+ protected:
+ ScopedPK11SymKey scoped_psk_;
+ const uint16_t suite_;
+ const uint8_t kPskDummyVal_[16] = {0x01, 0x02, 0x03, 0x04, 0x05,
+ 0x06, 0x07, 0x08, 0x09, 0x0a,
+ 0x0b, 0x0c, 0x0d, 0x0e, 0x0f};
+ const std::string kPskDummyLabel_ = "NSS PSK GTEST label";
+ const SSLHashType kPskHash_ = ssl_hash_sha384;
+};
+
+// TLS 1.3 PSK connection test.
+TEST_P(Tls13PskTest, NormalExternal) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+ client_->RemovePsk(kPskDummyLabel_);
+ server_->RemovePsk(kPskDummyLabel_);
+
+ // Removing it again should fail.
+ EXPECT_EQ(SECFailure, SSL_RemoveExternalPsk(client_->ssl_fd(),
+ reinterpret_cast<const uint8_t*>(
+ kPskDummyLabel_.data()),
+ kPskDummyLabel_.length()));
+ EXPECT_EQ(SECFailure, SSL_RemoveExternalPsk(server_->ssl_fd(),
+ reinterpret_cast<const uint8_t*>(
+ kPskDummyLabel_.data()),
+ kPskDummyLabel_.length()));
+}
+
+TEST_P(Tls13PskTest, KeyTooLarge) {
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_TRUE(!!slot);
+ ScopedPK11SymKey scoped_psk(PK11_KeyGen(
+ slot.get(), CKM_GENERIC_SECRET_KEY_GEN, nullptr, 128, nullptr));
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+}
+
+// Attempt to use a PSK with the wrong PRF hash.
+// "Clients MUST verify that...the server selected a cipher suite
+// indicating a Hash associated with the PSK"
+TEST_P(Tls13PskTest, ClientVerifyHashType) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ MakeTlsFilter<SelectedCipherSuiteReplacer>(server_,
+ TLS_CHACHA20_POLY1305_SHA256);
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code());
+ } else {
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ }
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
+}
+
+// Different EPSKs (by label) on each endpoint. Expect cert auth.
+TEST_P(Tls13PskTest, LabelMismatch) {
+ client_->AddPsk(scoped_psk_, std::string("foo"), kPskHash_);
+ server_->AddPsk(scoped_psk_, std::string("bar"), kPskHash_);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+}
+
+SSLHelloRetryRequestAction RetryFirstHello(
+ PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen,
+ PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<size_t*>(arg);
+ ++*called;
+ EXPECT_EQ(0U, clientTokenLen);
+ EXPECT_EQ(*called, firstHello ? 1U : 2U);
+ return firstHello ? ssl_hello_retry_request : ssl_hello_retry_accept;
+}
+
+// Test resumption PSK with HRR.
+TEST_P(Tls13PskTest, ResPskRetryStateless) {
+ ConfigureSelfEncrypt();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ StartConnect();
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryFirstHello, &cb_called));
+ ExpectResumption(RESUME_TICKET);
+ Handshake();
+ CheckConnected();
+ EXPECT_EQ(2U, cb_called);
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+ SendReceive();
+}
+
+// Test external PSK with HRR.
+TEST_P(Tls13PskTest, ExtPskRetryStateless) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryFirstHello, &cb_called));
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ EXPECT_EQ(1U, cb_called);
+ auto replacement = std::make_shared<TlsAgent>(
+ server_->name(), TlsAgent::SERVER, server_->variant());
+ server_ = replacement;
+ server_->SetVersionRange(version_, version_);
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+ server_->AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ server_->ExpectPsk();
+ server_->StartConnect();
+ Handshake();
+ CheckConnected();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+}
+
+// Server not configured with PSK and sends a certificate instead of
+// a selected_identity. Client should attempt certificate authentication.
+TEST_P(Tls13PskTest, ClientOnly) {
+ client_->AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+}
+
+// Set a PSK, remove psk_key_exchange_modes.
+TEST_P(Tls13PskTest, DropKexModes) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ StartConnect();
+ MakeTlsFilter<TlsExtensionDropper>(client_,
+ ssl_tls13_psk_key_exchange_modes_xtn);
+ ConnectExpectAlert(server_, kTlsAlertMissingExtension);
+ client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES);
+}
+
+// "Clients MUST verify that...a server "key_share" extension is present
+// if required by the ClientHello "psk_key_exchange_modes" extension."
+// As we don't support PSK without DH, it is always required.
+TEST_P(Tls13PskTest, DropRequiredKeyShare) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ StartConnect();
+ MakeTlsFilter<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn);
+ client_->ExpectSendAlert(kTlsAlertMissingExtension);
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ } else {
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ }
+ client_->CheckErrorCode(SSL_ERROR_MISSING_KEY_SHARE);
+}
+
+// "Clients MUST verify that...the server's selected_identity is
+// within the range supplied by the client". We send one OfferedPsk.
+TEST_P(Tls13PskTest, InvalidSelectedIdentity) {
+ static const uint8_t selected_identity[] = {0x00, 0x01};
+ DataBuffer buf(selected_identity, sizeof(selected_identity));
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ StartConnect();
+ MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_pre_shared_key_xtn,
+ buf);
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ } else {
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ }
+ client_->CheckErrorCode(SSL_ERROR_MALFORMED_PRE_SHARED_KEY);
+}
+
+// Resume-eligible reconnect with an EPSK configured.
+// Expect the EPSK to be used.
+TEST_P(Tls13PskTest, PreferEpsk) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ ExpectResumption(RESUME_NONE);
+ StartConnect();
+ Handshake();
+ CheckConnected();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+}
+
+// Enable resumption, but connect (initially) with an EPSK.
+// Expect no session ticket.
+TEST_P(Tls13PskTest, SuppressNewSessionTicket) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ auto nst_capture =
+ MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_new_session_ticket);
+ nst_capture->EnableDecryption();
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), nullptr, 0));
+ EXPECT_EQ(0U, nst_capture->buffer().len());
+ if (variant_ == ssl_variant_stream) {
+ EXPECT_EQ(SSL_ERROR_FEATURE_DISABLED, PORT_GetError());
+ } else {
+ EXPECT_EQ(SSL_ERROR_FEATURE_NOT_SUPPORTED_FOR_VERSION, PORT_GetError());
+ }
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+}
+
+TEST_P(Tls13PskTest, BadConfigValues) {
+ EXPECT_TRUE(client_->EnsureTlsSetup());
+ std::vector<uint8_t> label{'L', 'A', 'B', 'E', 'L'};
+ EXPECT_EQ(SECFailure,
+ SSL_AddExternalPsk(client_->ssl_fd(), nullptr, label.data(),
+ label.size(), kPskHash_));
+ EXPECT_EQ(SECFailure, SSL_AddExternalPsk(client_->ssl_fd(), scoped_psk_.get(),
+ nullptr, label.size(), kPskHash_));
+
+ EXPECT_EQ(SECFailure, SSL_AddExternalPsk(client_->ssl_fd(), scoped_psk_.get(),
+ label.data(), 0, kPskHash_));
+ EXPECT_EQ(SECSuccess,
+ SSL_AddExternalPsk(client_->ssl_fd(), scoped_psk_.get(),
+ label.data(), label.size(), ssl_hash_sha256));
+
+ EXPECT_EQ(SECFailure,
+ SSL_RemoveExternalPsk(client_->ssl_fd(), nullptr, label.size()));
+
+ EXPECT_EQ(SECFailure,
+ SSL_RemoveExternalPsk(client_->ssl_fd(), label.data(), 0));
+
+ EXPECT_EQ(SECSuccess, SSL_RemoveExternalPsk(client_->ssl_fd(), label.data(),
+ label.size()));
+}
+
+// If the server has an EPSK configured with a ciphersuite not supported
+// by the client, it should use certificate authentication.
+TEST_P(Tls13PskTest, FallbackUnsupportedCiphersuite) {
+ client_->AddPsk(scoped_psk_, kPskDummyLabel_, ssl_hash_sha256,
+ TLS_AES_128_GCM_SHA256);
+ server_->AddPsk(scoped_psk_, kPskDummyLabel_, ssl_hash_sha256,
+ TLS_CHACHA20_POLY1305_SHA256);
+
+ client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+}
+
+// That fallback should not occur if there is no cipher overlap.
+TEST_P(Tls13PskTest, ExplicitSuiteNoOverlap) {
+ client_->AddPsk(scoped_psk_, kPskDummyLabel_, ssl_hash_sha256,
+ TLS_AES_128_GCM_SHA256);
+ server_->AddPsk(scoped_psk_, kPskDummyLabel_, ssl_hash_sha256,
+ TLS_CHACHA20_POLY1305_SHA256);
+
+ client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ server_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256);
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+TEST_P(Tls13PskTest, SuppressHandshakeCertReq) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ server_->SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE);
+ server_->SetOption(SSL_REQUIRE_CERTIFICATE, PR_TRUE);
+ const std::set<uint8_t> hs_types = {ssl_hs_certificate,
+ ssl_hs_certificate_request};
+ auto cr_cert_capture = MakeTlsFilter<TlsHandshakeRecorder>(server_, hs_types);
+ cr_cert_capture->EnableDecryption();
+
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+ EXPECT_EQ(0U, cr_cert_capture->buffer().len());
+}
+
+TEST_P(Tls13PskTest, DisallowClientConfigWithoutServerCert) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ server_->SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE);
+ server_->SetOption(SSL_REQUIRE_CERTIFICATE, PR_TRUE);
+ const std::set<uint8_t> hs_types = {ssl_hs_certificate,
+ ssl_hs_certificate_request};
+ auto cr_cert_capture = MakeTlsFilter<TlsHandshakeRecorder>(server_, hs_types);
+ cr_cert_capture->EnableDecryption();
+
+ EXPECT_EQ(SECSuccess, SSLInt_RemoveServerCertificates(server_->ssl_fd()));
+
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ EXPECT_EQ(0U, cr_cert_capture->buffer().len());
+}
+
+TEST_F(TlsConnectStreamTls13, ClientRejectHandshakeCertReq) {
+ // Stream only, as the filter doesn't support DTLS 1.3 yet.
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_TRUE(!!slot);
+ ScopedPK11SymKey scoped_psk(PK11_KeyGen(
+ slot.get(), CKM_GENERIC_SECRET_KEY_GEN, nullptr, 32, nullptr));
+ AddPsk(scoped_psk, std::string("foo"), ssl_hash_sha256);
+ // Inject a CR after EE. This would be legal if not for ssl_auth_psk.
+ auto filter = MakeTlsFilter<TlsEncryptedHandshakeMessageReplacer>(
+ server_, kTlsHandshakeFinished, kTlsHandshakeCertificateRequest);
+ filter->EnableDecryption();
+
+ ExpectAlert(client_, kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CERT_REQUEST);
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13, RejectPha) {
+ // Stream only, as the filter doesn't support DTLS 1.3 yet.
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_TRUE(!!slot);
+ ScopedPK11SymKey scoped_psk(PK11_KeyGen(
+ slot.get(), CKM_GENERIC_SECRET_KEY_GEN, nullptr, 32, nullptr));
+ AddPsk(scoped_psk, std::string("foo"), ssl_hash_sha256);
+ server_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE);
+ auto kuToCr = MakeTlsFilter<TlsEncryptedHandshakeMessageReplacer>(
+ server_, kTlsHandshakeKeyUpdate, kTlsHandshakeCertificateRequest);
+ kuToCr->EnableDecryption();
+ Connect();
+
+ // Make sure the direct path is blocked.
+ EXPECT_EQ(SECFailure, SSL_SendCertificateRequest(server_->ssl_fd()));
+ EXPECT_EQ(SSL_ERROR_FEATURE_DISABLED, PORT_GetError());
+
+ // Inject a PHA CR. Since this is not allowed, send KeyUpdate
+ // and change the message type.
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ ExpectAlert(client_, kTlsAlertUnexpectedMessage);
+ client_->Handshake(); // Eat the CR.
+ server_->Handshake();
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CERT_REQUEST);
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+class Tls13PskTestWithCiphers : public Tls13PskTest {};
+
+TEST_P(Tls13PskTestWithCiphers, 0RttCiphers) {
+ RolloverAntiReplay();
+ AddPsk(scoped_psk_, kPskDummyLabel_, tls13_GetHashForCipherSuite(suite_),
+ suite_);
+ StartConnect();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+}
+
+TEST_P(Tls13PskTestWithCiphers, 0RttMaxEarlyData) {
+ EnsureTlsSetup();
+ RolloverAntiReplay();
+ const char* big_message = "0123456789abcdef";
+ const size_t short_size = strlen(big_message) - 1;
+ const PRInt32 short_length = static_cast<PRInt32>(short_size);
+
+ // Set up the PSK
+ EXPECT_EQ(SECSuccess,
+ SSL_AddExternalPsk0Rtt(
+ client_->ssl_fd(), scoped_psk_.get(),
+ reinterpret_cast<const uint8_t*>(kPskDummyLabel_.data()),
+ kPskDummyLabel_.length(), tls13_GetHashForCipherSuite(suite_),
+ suite_, short_length));
+ EXPECT_EQ(SECSuccess,
+ SSL_AddExternalPsk0Rtt(
+ server_->ssl_fd(), scoped_psk_.get(),
+ reinterpret_cast<const uint8_t*>(kPskDummyLabel_.data()),
+ kPskDummyLabel_.length(), tls13_GetHashForCipherSuite(suite_),
+ suite_, short_length));
+ client_->ExpectPsk();
+ server_->ExpectPsk();
+ client_->expected_cipher_suite(suite_);
+ server_->expected_cipher_suite(suite_);
+ StartConnect();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ client_->Handshake();
+ CheckEarlyDataLimit(client_, short_size);
+
+ PRInt32 sent;
+ // Writing more than the limit will succeed in TLS, but fail in DTLS.
+ if (variant_ == ssl_variant_stream) {
+ sent = PR_Write(client_->ssl_fd(), big_message,
+ static_cast<PRInt32>(strlen(big_message)));
+ } else {
+ sent = PR_Write(client_->ssl_fd(), big_message,
+ static_cast<PRInt32>(strlen(big_message)));
+ EXPECT_GE(0, sent);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Try an exact-sized write now.
+ sent = PR_Write(client_->ssl_fd(), big_message, short_length);
+ }
+ EXPECT_EQ(short_length, sent);
+
+ // Even a single octet write should now fail.
+ sent = PR_Write(client_->ssl_fd(), big_message, 1);
+ EXPECT_GE(0, sent);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Process the ClientHello and read 0-RTT.
+ server_->Handshake();
+ CheckEarlyDataLimit(server_, short_size);
+
+ std::vector<uint8_t> buf(short_size + 1);
+ PRInt32 read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity());
+ EXPECT_EQ(short_length, read);
+ EXPECT_EQ(0, memcmp(big_message, buf.data(), short_size));
+
+ // Second read fails.
+ read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity());
+ EXPECT_EQ(SECFailure, read);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+static const uint16_t k0RttCipherDefs[] = {TLS_CHACHA20_POLY1305_SHA256,
+ TLS_AES_128_GCM_SHA256,
+ TLS_AES_256_GCM_SHA384};
+
+static const uint16_t kDefaultSuite[] = {TLS_CHACHA20_POLY1305_SHA256};
+
+INSTANTIATE_TEST_SUITE_P(
+ Tls13PskTest, Tls13PskTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ ::testing::ValuesIn(kDefaultSuite)));
+
+INSTANTIATE_TEST_SUITE_P(
+ Tls13PskTestWithCiphers, Tls13PskTestWithCiphers,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ ::testing::ValuesIn(k0RttCipherDefs)));
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/tls_subcerts_unittest.cc b/security/nss/gtests/ssl_gtest/tls_subcerts_unittest.cc
new file mode 100644
index 0000000000..5e01dee518
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_subcerts_unittest.cc
@@ -0,0 +1,723 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <ctime>
+
+#include "prtime.h"
+#include "secerr.h"
+#include "ssl.h"
+#include "nss.h"
+#include "blapit.h"
+
+#include "gtest_utils.h"
+#include "tls_agent.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+const std::string kEcdsaDelegatorId = TlsAgent::kDelegatorEcdsa256;
+const std::string kRsaeDelegatorId = TlsAgent::kDelegatorRsae2048;
+const std::string kPssDelegatorId = TlsAgent::kDelegatorRsaPss2048;
+const std::string kDCId = TlsAgent::kServerEcdsa256;
+const SSLSignatureScheme kDCScheme = ssl_sig_ecdsa_secp256r1_sha256;
+const PRUint32 kDCValidFor = 60 * 60 * 24 * 7 /* 1 week (seconds) */;
+
+static void CheckPreliminaryPeerDelegCred(
+ const std::shared_ptr<TlsAgent>& client, bool expected,
+ PRUint32 key_bits = 0, SSLSignatureScheme sig_scheme = ssl_sig_none) {
+ EXPECT_NE(0U, (client->pre_info().valuesSet & ssl_preinfo_peer_auth));
+ EXPECT_EQ(expected, client->pre_info().peerDelegCred);
+ if (expected) {
+ EXPECT_EQ(key_bits, client->pre_info().authKeyBits);
+ EXPECT_EQ(sig_scheme, client->pre_info().signatureScheme);
+ }
+}
+
+static void CheckPeerDelegCred(const std::shared_ptr<TlsAgent>& client,
+ bool expected, PRUint32 key_bits = 0) {
+ EXPECT_EQ(expected, client->info().peerDelegCred);
+ EXPECT_EQ(expected, client->pre_info().peerDelegCred);
+ if (expected) {
+ EXPECT_EQ(key_bits, client->info().authKeyBits);
+ EXPECT_EQ(key_bits, client->pre_info().authKeyBits);
+ EXPECT_EQ(client->info().signatureScheme,
+ client->pre_info().signatureScheme);
+ }
+}
+
+// AuthCertificate callbacks to simulate DC validation
+static SECStatus CheckPreliminaryDC(TlsAgent* agent, bool checksig,
+ bool isServer) {
+ agent->UpdatePreliminaryChannelInfo();
+ EXPECT_EQ(PR_TRUE, agent->pre_info().peerDelegCred);
+ EXPECT_EQ(256U, agent->pre_info().authKeyBits);
+ EXPECT_EQ(ssl_sig_ecdsa_secp256r1_sha256, agent->pre_info().signatureScheme);
+ return SECSuccess;
+}
+
+static SECStatus CheckPreliminaryNoDC(TlsAgent* agent, bool checksig,
+ bool isServer) {
+ agent->UpdatePreliminaryChannelInfo();
+ EXPECT_EQ(PR_FALSE, agent->pre_info().peerDelegCred);
+ return SECSuccess;
+}
+
+// AuthCertificate callbacks for modifying DC attributes.
+// This allows testing tls13_CertificateVerify for rejection
+// of DC attributes that have changed since AuthCertificateHook
+// may have handled them.
+static SECStatus ModifyDCAuthKeyBits(TlsAgent* agent, bool checksig,
+ bool isServer) {
+ return SSLInt_TweakChannelInfoForDC(agent->ssl_fd(),
+ PR_TRUE, // Change authKeyBits
+ PR_FALSE); // Change scheme
+}
+
+static SECStatus ModifyDCScheme(TlsAgent* agent, bool checksig, bool isServer) {
+ return SSLInt_TweakChannelInfoForDC(agent->ssl_fd(),
+ PR_FALSE, // Change authKeyBits
+ PR_TRUE); // Change scheme
+}
+
+// Attempt to configure a DC when either the DC or DC private key is missing.
+TEST_P(TlsConnectTls13, DCNotConfigured) {
+ // Load and delegate the credential.
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(kDCId, &pub, &priv));
+
+ StackSECItem dc;
+ TlsAgent::DelegateCredential(kEcdsaDelegatorId, pub, kDCScheme, kDCValidFor,
+ now(), &dc);
+
+ // Attempt to install the certificate and DC with a missing DC private key.
+ EnsureTlsSetup();
+ SSLExtraServerCertData extra_data_missing_dc_priv_key = {
+ ssl_auth_null, nullptr, nullptr, nullptr, &dc, nullptr};
+ EXPECT_FALSE(server_->ConfigServerCert(kEcdsaDelegatorId, true,
+ &extra_data_missing_dc_priv_key));
+
+ // Attempt to install the certificate and with only the DC private key.
+ EnsureTlsSetup();
+ SSLExtraServerCertData extra_data_missing_dc = {
+ ssl_auth_null, nullptr, nullptr, nullptr, nullptr, priv.get()};
+ EXPECT_FALSE(server_->ConfigServerCert(kEcdsaDelegatorId, true,
+ &extra_data_missing_dc));
+}
+
+// Connected with ECDSA-P256.
+TEST_P(TlsConnectTls13, DCConnectEcdsaP256) {
+ Reset(kEcdsaDelegatorId);
+ client_->EnableDelegatedCredentials();
+ server_->AddDelegatedCredential(TlsAgent::kServerEcdsa256,
+ ssl_sig_ecdsa_secp256r1_sha256, kDCValidFor,
+ now());
+
+ auto cfilter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_delegated_credentials_xtn);
+ Connect();
+
+ EXPECT_TRUE(cfilter->captured());
+ CheckPeerDelegCred(client_, true, 256);
+ EXPECT_EQ(ssl_sig_ecdsa_secp256r1_sha256, client_->info().signatureScheme);
+}
+
+// Connected with ECDSA-P384.
+TEST_P(TlsConnectTls13, DCConnectEcdsaP483) {
+ Reset(kEcdsaDelegatorId);
+ client_->EnableDelegatedCredentials();
+ server_->AddDelegatedCredential(TlsAgent::kServerEcdsa384,
+ ssl_sig_ecdsa_secp384r1_sha384, kDCValidFor,
+ now());
+
+ auto cfilter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_delegated_credentials_xtn);
+ Connect();
+
+ EXPECT_TRUE(cfilter->captured());
+ CheckPeerDelegCred(client_, true, 384);
+ EXPECT_EQ(ssl_sig_ecdsa_secp384r1_sha384, client_->info().signatureScheme);
+}
+
+// Connected with ECDSA-P521.
+TEST_P(TlsConnectTls13, DCConnectEcdsaP521) {
+ Reset(kEcdsaDelegatorId);
+ client_->EnableDelegatedCredentials();
+ server_->AddDelegatedCredential(TlsAgent::kServerEcdsa521,
+ ssl_sig_ecdsa_secp521r1_sha512, kDCValidFor,
+ now());
+ client_->EnableDelegatedCredentials();
+
+ auto cfilter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_delegated_credentials_xtn);
+ Connect();
+
+ EXPECT_TRUE(cfilter->captured());
+ CheckPeerDelegCred(client_, true, 521);
+ EXPECT_EQ(ssl_sig_ecdsa_secp521r1_sha512, client_->info().signatureScheme);
+}
+
+// Connected with RSA-PSS, using a PSS SPKI and ECDSA delegation cert.
+TEST_P(TlsConnectTls13, DCConnectRsaPssEcdsa) {
+ Reset(kEcdsaDelegatorId);
+
+ // Need to enable PSS-PSS, which is not on by default.
+ static const SSLSignatureScheme kSchemes[] = {ssl_sig_ecdsa_secp256r1_sha256,
+ ssl_sig_rsa_pss_pss_sha256};
+ client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+
+ client_->EnableDelegatedCredentials();
+ server_->AddDelegatedCredential(
+ TlsAgent::kServerRsaPss, ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now());
+
+ auto cfilter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_delegated_credentials_xtn);
+ Connect();
+
+ EXPECT_TRUE(cfilter->captured());
+ CheckPeerDelegCred(client_, true, 1024);
+ EXPECT_EQ(ssl_sig_rsa_pss_pss_sha256, client_->info().signatureScheme);
+}
+
+// Connected with RSA-PSS, using a PSS SPKI and PSS delegation cert.
+TEST_P(TlsConnectTls13, DCConnectRsaPssRsaPss) {
+ Reset(kPssDelegatorId);
+
+ // Need to enable PSS-PSS, which is not on by default.
+ static const SSLSignatureScheme kSchemes[] = {ssl_sig_ecdsa_secp256r1_sha256,
+ ssl_sig_rsa_pss_pss_sha256};
+ client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+
+ client_->EnableDelegatedCredentials();
+ server_->AddDelegatedCredential(
+ TlsAgent::kServerRsaPss, ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now());
+
+ auto cfilter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_delegated_credentials_xtn);
+ Connect();
+
+ EXPECT_TRUE(cfilter->captured());
+ CheckPeerDelegCred(client_, true, 1024);
+ EXPECT_EQ(ssl_sig_rsa_pss_pss_sha256, client_->info().signatureScheme);
+}
+
+// Connected with ECDSA-P256 using a PSS delegation cert.
+TEST_P(TlsConnectTls13, DCConnectEcdsaP256RsaPss) {
+ Reset(kPssDelegatorId);
+
+ // Need to enable PSS-PSS, which is not on by default.
+ static const SSLSignatureScheme kSchemes[] = {ssl_sig_ecdsa_secp256r1_sha256,
+ ssl_sig_rsa_pss_pss_sha256};
+ client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+
+ client_->EnableDelegatedCredentials();
+ server_->AddDelegatedCredential(TlsAgent::kServerEcdsa256,
+ ssl_sig_ecdsa_secp256r1_sha256, kDCValidFor,
+ now());
+
+ auto cfilter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_delegated_credentials_xtn);
+ Connect();
+
+ EXPECT_TRUE(cfilter->captured());
+ CheckPeerDelegCred(client_, true, 256);
+ EXPECT_EQ(ssl_sig_ecdsa_secp256r1_sha256, client_->info().signatureScheme);
+}
+
+// Simulate the client receiving a DC containing algorithms not advertised.
+// Do this by tweaking the client's supported sigSchemes after the CH.
+TEST_P(TlsConnectTls13, DCReceiveUnadvertisedScheme) {
+ Reset(kEcdsaDelegatorId);
+ static const SSLSignatureScheme kClientSchemes[] = {
+ ssl_sig_ecdsa_secp256r1_sha256, ssl_sig_ecdsa_secp384r1_sha384};
+ static const SSLSignatureScheme kServerSchemes[] = {
+ ssl_sig_ecdsa_secp384r1_sha384, ssl_sig_ecdsa_secp256r1_sha256};
+ static const SSLSignatureScheme kEcdsaP256Only[] = {
+ ssl_sig_ecdsa_secp256r1_sha256};
+ client_->SetSignatureSchemes(kClientSchemes, PR_ARRAY_SIZE(kClientSchemes));
+ server_->SetSignatureSchemes(kServerSchemes, PR_ARRAY_SIZE(kServerSchemes));
+ client_->EnableDelegatedCredentials();
+ server_->AddDelegatedCredential(TlsAgent::kServerEcdsa384,
+ ssl_sig_ecdsa_secp384r1_sha384, kDCValidFor,
+ now());
+ StartConnect();
+ client_->Handshake(); // CH with P256/P384.
+ server_->Handshake(); // Respond with P384 DC.
+ // Tell the client it only advertised P256.
+ SECStatus rv = SSLInt_SetDCAdvertisedSigSchemes(
+ client_->ssl_fd(), kEcdsaP256Only, PR_ARRAY_SIZE(kEcdsaP256Only));
+ EXPECT_EQ(SECSuccess, rv);
+ ExpectAlert(client_, kTlsAlertIllegalParameter);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Server schemes includes only RSAE schemes. Connection should succeed
+// without delegation.
+TEST_P(TlsConnectTls13, DCConnectServerRsaeOnly) {
+ Reset(kRsaeDelegatorId);
+ static const SSLSignatureScheme kClientSchemes[] = {
+ ssl_sig_rsa_pss_rsae_sha256, ssl_sig_rsa_pss_pss_sha256};
+ static const SSLSignatureScheme kServerSchemes[] = {
+ ssl_sig_rsa_pss_rsae_sha256};
+ client_->SetSignatureSchemes(kClientSchemes, PR_ARRAY_SIZE(kClientSchemes));
+ server_->SetSignatureSchemes(kServerSchemes, PR_ARRAY_SIZE(kServerSchemes));
+ client_->EnableDelegatedCredentials();
+ Connect();
+
+ CheckPeerDelegCred(client_, false);
+}
+
+// Connect with an RSA-PSS DC SPKI, and an RSAE Delegator SPKI.
+TEST_P(TlsConnectTls13, DCConnectRsaeDelegator) {
+ Reset(kRsaeDelegatorId);
+
+ static const SSLSignatureScheme kSchemes[] = {ssl_sig_rsa_pss_rsae_sha256,
+ ssl_sig_rsa_pss_pss_sha256};
+ client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+
+ client_->EnableDelegatedCredentials();
+ server_->AddDelegatedCredential(
+ TlsAgent::kServerRsaPss, ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now());
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+}
+
+// Client schemes includes only RSAE schemes. Connection should succeed
+// without delegation, and no DC extension should be present in the CH.
+TEST_P(TlsConnectTls13, DCConnectClientRsaeOnly) {
+ Reset(kRsaeDelegatorId);
+ static const SSLSignatureScheme kClientSchemes[] = {
+ ssl_sig_rsa_pss_rsae_sha256};
+ static const SSLSignatureScheme kServerSchemes[] = {
+ ssl_sig_rsa_pss_rsae_sha256, ssl_sig_rsa_pss_pss_sha256};
+ client_->SetSignatureSchemes(kClientSchemes, PR_ARRAY_SIZE(kClientSchemes));
+ server_->SetSignatureSchemes(kServerSchemes, PR_ARRAY_SIZE(kServerSchemes));
+ client_->EnableDelegatedCredentials();
+ auto cfilter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_delegated_credentials_xtn);
+ Connect();
+ EXPECT_FALSE(cfilter->captured());
+ CheckPeerDelegCred(client_, false);
+}
+
+// Test fallback. DC extension will not advertise RSAE schemes.
+// The server will attempt to set one, but decline to after seeing
+// the client-advertised schemes does not include it. Expect non-
+// delegated success.
+TEST_P(TlsConnectTls13, DCConnectRsaeDcSpki) {
+ Reset(kRsaeDelegatorId);
+
+ static const SSLSignatureScheme kSchemes[] = {ssl_sig_rsa_pss_rsae_sha256,
+ ssl_sig_rsa_pss_pss_sha256};
+ client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ client_->EnableDelegatedCredentials();
+
+ EnsureTlsSetup();
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ EXPECT_TRUE(
+ TlsAgent::LoadKeyPairFromCert(TlsAgent::kDelegatorRsae2048, &pub, &priv));
+
+ StackSECItem dc;
+ server_->DelegateCredential(server_->name(), pub, ssl_sig_rsa_pss_rsae_sha256,
+ kDCValidFor, now(), &dc);
+
+ SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr,
+ nullptr, &dc, priv.get()};
+ EXPECT_TRUE(server_->ConfigServerCert(server_->name(), true, &extra_data));
+ auto sfilter = MakeTlsFilter<TlsExtensionCapture>(
+ server_, ssl_delegated_credentials_xtn);
+ Connect();
+ EXPECT_FALSE(sfilter->captured());
+ CheckPeerDelegCred(client_, false);
+}
+
+// Generate a weak key. We can't do this in the fixture because certutil
+// won't sign with such a tiny key. That's OK, because this is fast(ish).
+static void GenerateWeakRsaKey(ScopedSECKEYPrivateKey& priv,
+ ScopedSECKEYPublicKey& pub) {
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_TRUE(slot);
+ PK11RSAGenParams rsaparams;
+// The absolute minimum size of RSA key that we can use with SHA-256 is
+// 256bit (hash) + 256bit (salt) + 8 (start byte) + 8 (end byte) = 528.
+#define RSA_WEAK_KEY 528
+#if RSA_MIN_MODULUS_BITS < RSA_WEAK_KEY
+ rsaparams.keySizeInBits = 528;
+#else
+ rsaparams.keySizeInBits = RSA_MIN_MODULUS_BITS + 1;
+#endif
+ rsaparams.pe = 65537;
+
+ SECKEYPublicKey* p_pub = nullptr;
+ priv.reset(PK11_GenerateKeyPair(slot.get(), CKM_RSA_PKCS_KEY_PAIR_GEN,
+ &rsaparams, &p_pub, false, false, nullptr));
+ pub.reset(p_pub);
+ PR_ASSERT(priv);
+ return;
+}
+
+// Fail to connect with a weak RSA key.
+TEST_P(TlsConnectTls13, DCWeakKey) {
+ Reset(kPssDelegatorId);
+ EnsureTlsSetup();
+ static const SSLSignatureScheme kSchemes[] = {ssl_sig_rsa_pss_rsae_sha256,
+ ssl_sig_rsa_pss_pss_sha256};
+ client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+ server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes));
+#if RSA_MIN_MODULUS_BITS > RSA_WEAK_KEY
+ // save the MIN POLICY length.
+ PRInt32 minRsa;
+
+ ASSERT_EQ(SECSuccess, NSS_OptionGet(NSS_RSA_MIN_KEY_SIZE, &minRsa));
+#if RSA_MIN_MODULUS_BITS >= 2048
+ ASSERT_EQ(SECSuccess,
+ NSS_OptionSet(NSS_RSA_MIN_KEY_SIZE, RSA_MIN_MODULUS_BITS + 1024));
+#else
+ ASSERT_EQ(SECSuccess, NSS_OptionSet(NSS_RSA_MIN_KEY_SIZE, 2048));
+#endif
+#endif
+
+ ScopedSECKEYPrivateKey dc_priv;
+ ScopedSECKEYPublicKey dc_pub;
+ GenerateWeakRsaKey(dc_priv, dc_pub);
+ ASSERT_TRUE(dc_priv);
+
+ // Construct a DC.
+ StackSECItem dc;
+ TlsAgent::DelegateCredential(kPssDelegatorId, dc_pub,
+ ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now(),
+ &dc);
+
+ // Configure the DC on the server.
+ SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr,
+ nullptr, &dc, dc_priv.get()};
+ EXPECT_TRUE(server_->ConfigServerCert(kPssDelegatorId, true, &extra_data));
+
+ client_->EnableDelegatedCredentials();
+
+ auto cfilter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_delegated_credentials_xtn);
+ ConnectExpectAlert(client_, kTlsAlertInsufficientSecurity);
+#if RSA_MIN_MODULUS_BITS > RSA_WEAK_KEY
+ ASSERT_EQ(SECSuccess, NSS_OptionSet(NSS_RSA_MIN_KEY_SIZE, minRsa));
+#endif
+}
+
+class ReplaceDCSigScheme : public TlsHandshakeFilter {
+ public:
+ ReplaceDCSigScheme(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a, {ssl_hs_certificate_verify}) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ *output = input;
+ output->Write(0, ssl_sig_ecdsa_secp384r1_sha384, 2);
+ return CHANGE;
+ }
+};
+
+// Aborted because of incorrect DC signature algorithm indication.
+TEST_P(TlsConnectTls13, DCAbortBadExpectedCertVerifyAlg) {
+ Reset(kEcdsaDelegatorId);
+ client_->EnableDelegatedCredentials();
+ server_->AddDelegatedCredential(TlsAgent::kServerEcdsa256,
+ ssl_sig_ecdsa_secp256r1_sha256, kDCValidFor,
+ now());
+ auto filter = MakeTlsFilter<ReplaceDCSigScheme>(server_);
+ filter->EnableDecryption();
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_DC_CERT_VERIFY_ALG_MISMATCH);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Aborted because of invalid DC signature.
+TEST_P(TlsConnectTls13, DCAbortBadSignature) {
+ Reset(kEcdsaDelegatorId);
+ EnsureTlsSetup();
+ client_->EnableDelegatedCredentials();
+
+ ScopedSECKEYPublicKey pub;
+ ScopedSECKEYPrivateKey priv;
+ EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(kDCId, &pub, &priv));
+
+ StackSECItem dc;
+ TlsAgent::DelegateCredential(kEcdsaDelegatorId, pub, kDCScheme, kDCValidFor,
+ now(), &dc);
+ ASSERT_TRUE(dc.data != nullptr);
+
+ // Flip the last bit of the DC so that the signature is invalid.
+ dc.data[dc.len - 1] ^= 0x01;
+
+ SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr,
+ nullptr, &dc, priv.get()};
+ EXPECT_TRUE(server_->ConfigServerCert(kEcdsaDelegatorId, true, &extra_data));
+
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_DC_BAD_SIGNATURE);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Aborted because of expired DC.
+TEST_P(TlsConnectTls13, DCAbortExpired) {
+ Reset(kEcdsaDelegatorId);
+ server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now());
+ client_->EnableDelegatedCredentials();
+ // When the client checks the time, it will be at least one second after the
+ // DC expired.
+ AdvanceTime((static_cast<PRTime>(kDCValidFor) + 1) * PR_USEC_PER_SEC);
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_DC_EXPIRED);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Aborted due to remaining TTL > max validity period.
+TEST_P(TlsConnectTls13, DCAbortExcessiveTTL) {
+ Reset(kEcdsaDelegatorId);
+ server_->AddDelegatedCredential(kDCId, kDCScheme,
+ kDCValidFor + 1 /* seconds */, now());
+ client_->EnableDelegatedCredentials();
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->CheckErrorCode(SSL_ERROR_DC_INAPPROPRIATE_VALIDITY_PERIOD);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Aborted because of invalid key usage.
+TEST_P(TlsConnectTls13, DCAbortBadKeyUsage) {
+ // The sever does not have the delegationUsage extension.
+ Reset(TlsAgent::kServerEcdsa256);
+ client_->EnableDelegatedCredentials();
+ server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now());
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+}
+
+// Connected without DC because of no client indication.
+TEST_P(TlsConnectTls13, DCConnectNoClientSupport) {
+ Reset(kEcdsaDelegatorId);
+ server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now());
+
+ auto cfilter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_delegated_credentials_xtn);
+ Connect();
+
+ EXPECT_FALSE(cfilter->captured());
+ CheckPeerDelegCred(client_, false);
+}
+
+// Connected without DC because of no server DC.
+TEST_P(TlsConnectTls13, DCConnectNoServerSupport) {
+ Reset(kEcdsaDelegatorId);
+ client_->EnableDelegatedCredentials();
+
+ auto cfilter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_delegated_credentials_xtn);
+ Connect();
+
+ EXPECT_TRUE(cfilter->captured());
+ CheckPeerDelegCred(client_, false);
+}
+
+// Connected without DC because client doesn't support TLS 1.3.
+TEST_P(TlsConnectTls13, DCConnectClientNoTls13) {
+ Reset(kEcdsaDelegatorId);
+ client_->EnableDelegatedCredentials();
+ server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now());
+
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+
+ auto cfilter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_delegated_credentials_xtn);
+ Connect();
+
+ // Should fallback to TLS 1.2 and not negotiate a DC.
+ EXPECT_FALSE(cfilter->captured());
+ CheckPeerDelegCred(client_, false);
+}
+
+// Connected without DC because server doesn't support TLS 1.3.
+TEST_P(TlsConnectTls13, DCConnectServerNoTls13) {
+ Reset(kEcdsaDelegatorId);
+ client_->EnableDelegatedCredentials();
+ server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now());
+
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+
+ auto cfilter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_delegated_credentials_xtn);
+ Connect();
+
+ // Should fallback to TLS 1.2 and not negotiate a DC. The client will still
+ // send the indication because it supports 1.3.
+ EXPECT_TRUE(cfilter->captured());
+ CheckPeerDelegCred(client_, false);
+}
+
+// Connected without DC because client doesn't support the signature scheme.
+TEST_P(TlsConnectTls13, DCConnectExpectedCertVerifyAlgNotSupported) {
+ Reset(kEcdsaDelegatorId);
+ client_->EnableDelegatedCredentials();
+ static const SSLSignatureScheme kClientSchemes[] = {
+ ssl_sig_ecdsa_secp256r1_sha256,
+ };
+ client_->SetSignatureSchemes(kClientSchemes, PR_ARRAY_SIZE(kClientSchemes));
+
+ server_->AddDelegatedCredential(TlsAgent::kServerEcdsa521,
+ ssl_sig_ecdsa_secp521r1_sha512, kDCValidFor,
+ now());
+
+ auto cfilter = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_delegated_credentials_xtn);
+ Connect();
+
+ // Client sends indication, but the server doesn't send a DC.
+ EXPECT_TRUE(cfilter->captured());
+ CheckPeerDelegCred(client_, false);
+}
+
+// Check that preliminary channel info properly reflects the DC.
+TEST_P(TlsConnectTls13, DCCheckPreliminaryInfo) {
+ Reset(kEcdsaDelegatorId);
+ EnsureTlsSetup();
+ client_->EnableDelegatedCredentials();
+ server_->AddDelegatedCredential(TlsAgent::kServerEcdsa256,
+ ssl_sig_ecdsa_secp256r1_sha256, kDCValidFor,
+ now());
+
+ auto filter = MakeTlsFilter<TlsHandshakeDropper>(server_);
+ filter->SetHandshakeTypes(
+ {kTlsHandshakeCertificateVerify, kTlsHandshakeFinished});
+ filter->EnableDecryption();
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Send ServerHello
+
+ client_->SetAuthCertificateCallback(CheckPreliminaryDC);
+ client_->Handshake(); // Process response
+
+ client_->UpdatePreliminaryChannelInfo();
+ CheckPreliminaryPeerDelegCred(client_, true, 256,
+ ssl_sig_ecdsa_secp256r1_sha256);
+}
+
+// Check that preliminary channel info properly reflects a lack of DC.
+TEST_P(TlsConnectTls13, DCCheckPreliminaryInfoNoDC) {
+ Reset(kEcdsaDelegatorId);
+ EnsureTlsSetup();
+ client_->EnableDelegatedCredentials();
+ auto filter = MakeTlsFilter<TlsHandshakeDropper>(server_);
+ filter->SetHandshakeTypes(
+ {kTlsHandshakeCertificateVerify, kTlsHandshakeFinished});
+ filter->EnableDecryption();
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Send ServerHello
+
+ client_->SetAuthCertificateCallback(CheckPreliminaryNoDC);
+ client_->Handshake(); // Process response
+
+ client_->UpdatePreliminaryChannelInfo();
+ CheckPreliminaryPeerDelegCred(client_, false);
+}
+
+// Tweak the scheme in between |Cert| and |CertVerify|.
+TEST_P(TlsConnectTls13, DCRejectModifiedDCScheme) {
+ Reset(kEcdsaDelegatorId);
+ client_->EnableDelegatedCredentials();
+ client_->SetAuthCertificateCallback(ModifyDCScheme);
+ server_->AddDelegatedCredential(TlsAgent::kServerEcdsa521,
+ ssl_sig_ecdsa_secp521r1_sha512, kDCValidFor,
+ now());
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ client_->CheckErrorCode(SSL_ERROR_DC_CERT_VERIFY_ALG_MISMATCH);
+}
+
+// Tweak the authKeyBits in between |Cert| and |CertVerify|.
+TEST_P(TlsConnectTls13, DCRejectModifiedDCAuthKeyBits) {
+ Reset(kEcdsaDelegatorId);
+ client_->EnableDelegatedCredentials();
+ client_->SetAuthCertificateCallback(ModifyDCAuthKeyBits);
+ server_->AddDelegatedCredential(TlsAgent::kServerEcdsa521,
+ ssl_sig_ecdsa_secp521r1_sha512, kDCValidFor,
+ now());
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ client_->CheckErrorCode(SSL_ERROR_DC_CERT_VERIFY_ALG_MISMATCH);
+}
+
+class DCDelegation : public ::testing::Test {};
+
+TEST_F(DCDelegation, DCDelegations) {
+ PRTime now = PR_Now();
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ ASSERT_TRUE(TlsAgent::LoadCertificate(kEcdsaDelegatorId, &cert, &priv));
+
+ ScopedSECKEYPublicKey pub_rsa;
+ ScopedSECKEYPrivateKey priv_rsa;
+ ASSERT_TRUE(
+ TlsAgent::LoadKeyPairFromCert(TlsAgent::kServerRsa, &pub_rsa, &priv_rsa));
+
+ StackSECItem dc;
+ EXPECT_EQ(SECFailure,
+ SSL_DelegateCredential(cert.get(), priv.get(), pub_rsa.get(),
+ ssl_sig_ecdsa_secp256r1_sha256, kDCValidFor,
+ now, &dc));
+ EXPECT_EQ(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM, PORT_GetError());
+
+ // Using different PSS hashes should be OK.
+ EXPECT_EQ(SECSuccess, SSL_DelegateCredential(
+ cert.get(), priv.get(), pub_rsa.get(),
+ ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now, &dc));
+ // Make sure to reset |dc| after each success.
+ dc.Reset();
+ EXPECT_EQ(SECSuccess, SSL_DelegateCredential(
+ cert.get(), priv.get(), pub_rsa.get(),
+ ssl_sig_rsa_pss_pss_sha384, kDCValidFor, now, &dc));
+ dc.Reset();
+ EXPECT_EQ(SECSuccess, SSL_DelegateCredential(
+ cert.get(), priv.get(), pub_rsa.get(),
+ ssl_sig_rsa_pss_pss_sha512, kDCValidFor, now, &dc));
+ dc.Reset();
+
+ ScopedSECKEYPublicKey pub_ecdsa;
+ ScopedSECKEYPrivateKey priv_ecdsa;
+ ASSERT_TRUE(TlsAgent::LoadKeyPairFromCert(TlsAgent::kServerEcdsa256,
+ &pub_ecdsa, &priv_ecdsa));
+
+ EXPECT_EQ(SECFailure,
+ SSL_DelegateCredential(cert.get(), priv.get(), pub_ecdsa.get(),
+ ssl_sig_rsa_pss_rsae_sha256, kDCValidFor,
+ now, &dc));
+ EXPECT_EQ(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM, PORT_GetError());
+ EXPECT_EQ(SECFailure, SSL_DelegateCredential(
+ cert.get(), priv.get(), pub_ecdsa.get(),
+ ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now, &dc));
+ EXPECT_EQ(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM, PORT_GetError());
+ EXPECT_EQ(SECFailure,
+ SSL_DelegateCredential(cert.get(), priv.get(), pub_ecdsa.get(),
+ ssl_sig_ecdsa_secp384r1_sha384, kDCValidFor,
+ now, &dc));
+ EXPECT_EQ(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM, PORT_GetError());
+}
+
+} // namespace nss_test