summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc801
1 files changed, 801 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
new file mode 100644
index 0000000000..491f50921f
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
@@ -0,0 +1,801 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include <vector>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectGeneric, SetupOnly) {}
+
+TEST_P(TlsConnectGeneric, Connect) {
+ SetExpectedVersion(std::get<1>(GetParam()));
+ Connect();
+ CheckKeys();
+}
+
+TEST_P(TlsConnectGeneric, ConnectEcdsa) {
+ SetExpectedVersion(std::get<1>(GetParam()));
+ Reset(TlsAgent::kServerEcdsa256);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
+}
+
+TEST_P(TlsConnectGeneric, CipherSuiteMismatch) {
+ EnsureTlsSetup();
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ server_->EnableSingleCipher(TLS_AES_256_GCM_SHA384);
+ } else {
+ client_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA);
+ }
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+class TlsAlertRecorder : public TlsRecordFilter {
+ public:
+ TlsAlertRecorder(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a), level_(255), description_(255) {}
+
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ if (level_ != 255) { // Already captured.
+ return KEEP;
+ }
+ if (header.content_type() != ssl_ct_alert) {
+ return KEEP;
+ }
+
+ std::cerr << "Alert: " << input << std::endl;
+
+ TlsParser parser(input);
+ EXPECT_TRUE(parser.Read(&level_));
+ EXPECT_TRUE(parser.Read(&description_));
+ return KEEP;
+ }
+
+ uint8_t level() const { return level_; }
+ uint8_t description() const { return description_; }
+
+ private:
+ uint8_t level_;
+ uint8_t description_;
+};
+
+class HelloTruncator : public TlsHandshakeFilter {
+ public:
+ HelloTruncator(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(
+ a, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {}
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ output->Assign(input.data(), input.len() - 1);
+ return CHANGE;
+ }
+};
+
+// Verify that when NSS reports that an alert is sent, it is actually sent.
+TEST_P(TlsConnectGeneric, CaptureAlertServer) {
+ MakeTlsFilter<HelloTruncator>(client_);
+ auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(server_);
+
+ ConnectExpectAlert(server_, kTlsAlertDecodeError);
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
+}
+
+TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
+ MakeTlsFilter<HelloTruncator>(server_);
+ auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_);
+
+ ConnectExpectAlert(client_, kTlsAlertDecodeError);
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
+}
+
+// In TLS 1.3, the server can't read the client alert.
+TEST_P(TlsConnectTls13, CaptureAlertClient) {
+ MakeTlsFilter<HelloTruncator>(server_);
+ auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_);
+
+ StartConnect();
+
+ client_->Handshake();
+ client_->ExpectSendAlert(kTlsAlertDecodeError);
+ server_->Handshake();
+ client_->Handshake();
+ if (variant_ == ssl_variant_stream) {
+ // DTLS just drops the alert it can't decrypt.
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ }
+ server_->Handshake();
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectFalseStart) {
+ client_->EnableFalseStart();
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpn) {
+ EnableAlpn();
+ Connect();
+ CheckAlpn("a");
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpnPriorityA) {
+ // "alpn" "npn"
+ // alpn is the fallback here. npn has the highest priority and should be
+ // picked.
+ const std::vector<uint8_t> alpn = {0x04, 0x61, 0x6c, 0x70, 0x6e,
+ 0x03, 0x6e, 0x70, 0x6e};
+ EnableAlpn(alpn);
+ Connect();
+ CheckAlpn("npn");
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpnPriorityB) {
+ // "alpn" "npn" "http"
+ // npn has the highest priority and should be picked.
+ const std::vector<uint8_t> alpn = {0x04, 0x61, 0x6c, 0x70, 0x6e, 0x03, 0x6e,
+ 0x70, 0x6e, 0x04, 0x68, 0x74, 0x74, 0x70};
+ EnableAlpn(alpn);
+ Connect();
+ CheckAlpn("npn");
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpnClone) {
+ EnsureModelSockets();
+ client_model_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
+ server_model_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
+ Connect();
+ CheckAlpn("a");
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackA) {
+ // "ab" "alpn"
+ const std::vector<uint8_t> client_alpn = {0x02, 0x61, 0x62, 0x04,
+ 0x61, 0x6c, 0x70, 0x6e};
+ EnableAlpnWithCallback(client_alpn, "alpn");
+ Connect();
+ CheckAlpn("alpn");
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackB) {
+ // "ab" "alpn"
+ const std::vector<uint8_t> client_alpn = {0x02, 0x61, 0x62, 0x04,
+ 0x61, 0x6c, 0x70, 0x6e};
+ EnableAlpnWithCallback(client_alpn, "ab");
+ Connect();
+ CheckAlpn("ab");
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackC) {
+ // "cd" "npn" "alpn"
+ const std::vector<uint8_t> client_alpn = {0x02, 0x63, 0x64, 0x03, 0x6e, 0x70,
+ 0x6e, 0x04, 0x61, 0x6c, 0x70, 0x6e};
+ EnableAlpnWithCallback(client_alpn, "npn");
+ Connect();
+ CheckAlpn("npn");
+}
+
+TEST_P(TlsConnectDatagram, ConnectSrtp) {
+ EnableSrtp();
+ Connect();
+ CheckSrtp();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGeneric, ConnectSendReceive) {
+ Connect();
+ SendReceive();
+}
+
+class SaveTlsRecord : public TlsRecordFilter {
+ public:
+ SaveTlsRecord(const std::shared_ptr<TlsAgent>& a, size_t index)
+ : TlsRecordFilter(a), index_(index), count_(0), contents_() {}
+
+ const DataBuffer& contents() const { return contents_; }
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ if (count_++ == index_) {
+ contents_ = data;
+ }
+ return KEEP;
+ }
+
+ private:
+ const size_t index_;
+ size_t count_;
+ DataBuffer contents_;
+};
+
+// Check that decrypting filters work and can read any record.
+// This test (currently) only works in TLS 1.3 where we can decrypt.
+TEST_F(TlsConnectStreamTls13, DecryptRecordClient) {
+ EnsureTlsSetup();
+ // 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer
+ auto saved = MakeTlsFilter<SaveTlsRecord>(client_, 3);
+ saved->EnableDecryption();
+ Connect();
+ SendReceive();
+
+ static const uint8_t data[] = {0xde, 0xad, 0xdc};
+ DataBuffer buf(data, sizeof(data));
+ client_->SendBuffer(buf);
+ EXPECT_EQ(buf, saved->contents());
+}
+
+TEST_F(TlsConnectStreamTls13, DecryptRecordServer) {
+ EnsureTlsSetup();
+ // Disable tickets so that we are sure to not get NewSessionTicket.
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
+ SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
+ // 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer
+ auto saved = MakeTlsFilter<SaveTlsRecord>(server_, 3);
+ saved->EnableDecryption();
+ Connect();
+ SendReceive();
+
+ static const uint8_t data[] = {0xde, 0xad, 0xd5};
+ DataBuffer buf(data, sizeof(data));
+ server_->SendBuffer(buf);
+ EXPECT_EQ(buf, saved->contents());
+}
+
+class DropTlsRecord : public TlsRecordFilter {
+ public:
+ DropTlsRecord(const std::shared_ptr<TlsAgent>& a, size_t index)
+ : TlsRecordFilter(a), index_(index), count_(0) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ if (count_++ == index_) {
+ return DROP;
+ }
+ return KEEP;
+ }
+
+ private:
+ const size_t index_;
+ size_t count_;
+};
+
+// Test that decrypting filters work correctly and are able to drop records.
+TEST_F(TlsConnectStreamTls13, DropRecordServer) {
+ EnsureTlsSetup();
+ // Disable session tickets so that the server doesn't send an extra record.
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
+ SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
+
+ // 0 = ServerHello, 1 = other handshake, 2 = first write
+ auto filter = MakeTlsFilter<DropTlsRecord>(server_, 2);
+ filter->EnableDecryption();
+ Connect();
+ server_->SendData(23, 23); // This should be dropped, so it won't be counted.
+ server_->ResetSentBytes();
+ SendReceive();
+}
+
+TEST_F(TlsConnectStreamTls13, DropRecordClient) {
+ EnsureTlsSetup();
+ // 0 = ClientHello, 1 = Finished, 2 = first write
+ auto filter = MakeTlsFilter<DropTlsRecord>(client_, 2);
+ filter->EnableDecryption();
+ Connect();
+ client_->SendData(26, 26); // This should be dropped, so it won't be counted.
+ client_->ResetSentBytes();
+ SendReceive();
+}
+
+// Check that a server can use 0.5 RTT if client authentication isn't enabled.
+TEST_P(TlsConnectTls13, WriteBeforeClientFinished) {
+ EnsureTlsSetup();
+ StartConnect();
+ client_->Handshake(); // ClientHello
+ server_->Handshake(); // ServerHello
+
+ server_->SendData(10);
+ client_->ReadBytes(10); // Client should emit the Finished as a side-effect.
+ server_->Handshake(); // Server consumes the Finished.
+ CheckConnected();
+}
+
+// We don't allow 0.5 RTT if client authentication is requested.
+TEST_P(TlsConnectTls13, WriteBeforeClientFinishedClientAuth) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(false);
+ StartConnect();
+ client_->Handshake(); // ClientHello
+ server_->Handshake(); // ServerHello
+
+ static const uint8_t data[] = {1, 2, 3};
+ EXPECT_GT(0, PR_Write(server_->ssl_fd(), data, sizeof(data)));
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+// 0.5 RTT should fail with client authentication required.
+TEST_P(TlsConnectTls13, WriteBeforeClientFinishedClientAuthRequired) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ StartConnect();
+ client_->Handshake(); // ClientHello
+ server_->Handshake(); // ServerHello
+
+ static const uint8_t data[] = {1, 2, 3};
+ EXPECT_GT(0, PR_Write(server_->ssl_fd(), data, sizeof(data)));
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+// The next two tests takes advantage of the fact that we
+// automatically read the first 1024 bytes, so if
+// we provide 1200 bytes, they overrun the read buffer
+// provided by the calling test.
+
+// DTLS should return an error.
+TEST_P(TlsConnectDatagram, ShortRead) {
+ Connect();
+ client_->ExpectReadWriteError();
+ server_->SendData(50, 50);
+ client_->ReadBytes(20);
+ EXPECT_EQ(0U, client_->received_bytes());
+ EXPECT_EQ(SSL_ERROR_RX_SHORT_DTLS_READ, PORT_GetError());
+
+ // Now send and receive another packet.
+ server_->ResetSentBytes(); // Reset the counter.
+ SendReceive();
+}
+
+// TLS should get the write in two chunks.
+TEST_P(TlsConnectStream, ShortRead) {
+ // This test behaves oddly with TLS 1.0 because of 1/n+1 splitting,
+ // so skip in that case.
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_1) GTEST_SKIP();
+
+ Connect();
+ server_->SendData(50, 50);
+ // Read the first tranche.
+ client_->ReadBytes(20);
+ ASSERT_EQ(20U, client_->received_bytes());
+ // The second tranche should now immediately be available.
+ client_->ReadBytes();
+ ASSERT_EQ(50U, client_->received_bytes());
+}
+
+// We enable compression via the API but it's disabled internally,
+// so we should never get it.
+TEST_P(TlsConnectGeneric, ConnectWithCompressionEnabled) {
+ EnsureTlsSetup();
+ client_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE);
+ server_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE);
+ Connect();
+ EXPECT_FALSE(client_->is_compressed());
+ SendReceive();
+}
+
+class TlsHolddownTest : public TlsConnectDatagram {
+ protected:
+ // This causes all timers to run to completion. It advances the clock and
+ // handshakes on both peers until both peers have no more timers pending,
+ // which should happen at the end of a handshake. This is necessary to ensure
+ // that the relatively long holddown timer expires, but that any other timers
+ // also expire and run correctly.
+ void RunAllTimersDown() {
+ while (true) {
+ PRIntervalTime time;
+ SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time);
+ if (rv != SECSuccess) {
+ rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time);
+ if (rv != SECSuccess) {
+ break; // Neither peer has an outstanding timer.
+ }
+ }
+
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "Shifting timers" << std::endl;
+ }
+ ShiftDtlsTimers();
+ Handshake();
+ }
+ }
+};
+
+TEST_P(TlsHolddownTest, TestDtlsHolddownExpiry) {
+ Connect();
+ std::cerr << "Expiring holddown timer" << std::endl;
+ RunAllTimersDown();
+ SendReceive();
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // One for send, one for receive.
+ EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd()));
+ }
+}
+
+TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ RunAllTimersDown();
+ SendReceive();
+ // One for send, one for receive.
+ EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd()));
+}
+
+class TlsPreCCSHeaderInjector : public TlsRecordFilter {
+ public:
+ TlsPreCCSHeaderInjector(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a) {}
+ virtual PacketFilter::Action FilterRecord(
+ const TlsRecordHeader& record_header, const DataBuffer& input,
+ size_t* offset, DataBuffer* output) override {
+ if (record_header.content_type() != ssl_ct_change_cipher_spec) {
+ return KEEP;
+ }
+
+ std::cerr << "Injecting Finished header before CCS\n";
+ const uint8_t hhdr[] = {kTlsHandshakeFinished, 0x00, 0x00, 0x0c};
+ DataBuffer hhdr_buf(hhdr, sizeof(hhdr));
+ TlsRecordHeader nhdr(record_header.variant(), record_header.version(),
+ ssl_ct_handshake, 0);
+ *offset = nhdr.Write(output, *offset, hhdr_buf);
+ *offset = record_header.Write(output, *offset, input);
+ return CHANGE;
+ }
+};
+
+TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) {
+ MakeTlsFilter<TlsPreCCSHeaderInjector>(client_);
+ ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+}
+
+TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) {
+ MakeTlsFilter<TlsPreCCSHeaderInjector>(server_);
+ StartConnect();
+ ExpectAlert(client_, kTlsAlertUnexpectedMessage);
+ Handshake();
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+ server_->Handshake(); // Make sure alert is consumed.
+}
+
+TEST_P(TlsConnectTls13, UnknownAlert) {
+ Connect();
+ server_->ExpectSendAlert(0xff, kTlsAlertWarning);
+ client_->ExpectReceiveAlert(0xff, kTlsAlertWarning);
+ SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning,
+ 0xff); // Unknown value.
+ client_->ExpectReadWriteError();
+ client_->WaitForErrorCode(SSL_ERROR_RX_UNKNOWN_ALERT, 2000);
+}
+
+TEST_P(TlsConnectTls13, AlertWrongLevel) {
+ Connect();
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage, kTlsAlertWarning);
+ client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage, kTlsAlertWarning);
+ SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning,
+ kTlsAlertUnexpectedMessage);
+ client_->ExpectReadWriteError();
+ client_->WaitForErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT, 2000);
+}
+
+TEST_P(TlsConnectTls13, UnknownRecord) {
+ static const uint8_t kUknownRecord[] = {
+ 0xff, SSL_LIBRARY_VERSION_TLS_1_2 >> 8,
+ SSL_LIBRARY_VERSION_TLS_1_2 & 0xff, 0, 0};
+
+ Connect();
+ if (variant_ == ssl_variant_stream) {
+ // DTLS just drops the record with an invalid type.
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ }
+ client_->SendDirect(DataBuffer(kUknownRecord, sizeof(kUknownRecord)));
+ server_->ExpectReadWriteError();
+ server_->ReadBytes();
+ if (variant_ == ssl_variant_stream) {
+ EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code());
+ } else {
+ EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, server_->error_code());
+ }
+}
+
+TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) {
+ EnsureTlsSetup();
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake(); // Send first flight.
+ client_->adapter()->SetWriteError(PR_IO_ERROR);
+ client_->Handshake(); // This will get an error, but shouldn't crash.
+ client_->CheckErrorCode(SSL_ERROR_SOCKET_WRITE_FAILURE);
+}
+
+TEST_P(TlsConnectDatagram, BlockedWrite) {
+ Connect();
+
+ // Mark the socket as blocked.
+ client_->adapter()->SetWriteError(PR_WOULD_BLOCK_ERROR);
+ static const uint8_t data[] = {1, 2, 3};
+ int32_t rv = PR_Write(client_->ssl_fd(), data, sizeof(data));
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Remove the write error and though the previous write failed, future reads
+ // and writes should just work as if it never happened.
+ client_->adapter()->SetWriteError(0);
+ SendReceive();
+}
+
+TEST_F(TlsConnectTest, ConnectSSLv3) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_3_0);
+ EnableOnlyStaticRsaCiphers();
+ Connect();
+ CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none);
+}
+
+TEST_F(TlsConnectTest, ConnectSSLv3ClientAuth) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_3_0);
+ EnableOnlyStaticRsaCiphers();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none);
+}
+
+static size_t ExpectedCbcLen(size_t in, size_t hmac = 20, size_t block = 16) {
+ // MAC-then-Encrypt expansion formula:
+ return ((in + hmac + (block - 1)) / block) * block;
+}
+
+TEST_F(TlsConnectTest, OneNRecordSplitting) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0);
+ EnsureTlsSetup();
+ ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA);
+ auto records = MakeTlsFilter<TlsRecordRecorder>(server_);
+ // This should be split into 1, 16384 and 20.
+ DataBuffer big_buffer;
+ big_buffer.Allocate(1 + 16384 + 20);
+ server_->SendBuffer(big_buffer);
+ ASSERT_EQ(3U, records->count());
+ EXPECT_EQ(ExpectedCbcLen(1), records->record(0).buffer.len());
+ EXPECT_EQ(ExpectedCbcLen(16384), records->record(1).buffer.len());
+ EXPECT_EQ(ExpectedCbcLen(20), records->record(2).buffer.len());
+}
+
+// We can't test for randomness easily here, but we can test that we don't
+// produce a zero value, or produce the same value twice. There are 5 values
+// here: two ClientHello.random, two ServerHello.random, and one zero value.
+// Matrix them and fail if any are the same.
+TEST_P(TlsConnectGeneric, CheckRandoms) {
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+
+ static const size_t random_len = 32;
+ uint8_t crandom1[random_len], srandom1[random_len];
+ uint8_t z[random_len] = {0};
+
+ auto ch = MakeTlsFilter<TlsHandshakeRecorder>(client_, ssl_hs_client_hello);
+ auto sh = MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_server_hello);
+ Connect();
+ ASSERT_TRUE(ch->buffer().len() > (random_len + 2));
+ ASSERT_TRUE(sh->buffer().len() > (random_len + 2));
+ memcpy(crandom1, ch->buffer().data() + 2, random_len);
+ memcpy(srandom1, sh->buffer().data() + 2, random_len);
+ EXPECT_NE(0, memcmp(crandom1, srandom1, random_len));
+ EXPECT_NE(0, memcmp(crandom1, z, random_len));
+ EXPECT_NE(0, memcmp(srandom1, z, random_len));
+
+ Reset();
+ ch = MakeTlsFilter<TlsHandshakeRecorder>(client_, ssl_hs_client_hello);
+ sh = MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_server_hello);
+ Connect();
+ ASSERT_TRUE(ch->buffer().len() > (random_len + 2));
+ ASSERT_TRUE(sh->buffer().len() > (random_len + 2));
+ const uint8_t* crandom2 = ch->buffer().data() + 2;
+ const uint8_t* srandom2 = sh->buffer().data() + 2;
+
+ EXPECT_NE(0, memcmp(crandom2, srandom2, random_len));
+ EXPECT_NE(0, memcmp(crandom2, z, random_len));
+ EXPECT_NE(0, memcmp(srandom2, z, random_len));
+
+ EXPECT_NE(0, memcmp(crandom1, crandom2, random_len));
+ EXPECT_NE(0, memcmp(crandom1, srandom2, random_len));
+ EXPECT_NE(0, memcmp(srandom1, crandom2, random_len));
+ EXPECT_NE(0, memcmp(srandom1, srandom2, random_len));
+}
+
+void FailOnCloseNotify(const PRFileDesc* fd, void* arg, const SSLAlert* alert) {
+ ADD_FAILURE() << "received alert " << alert->description;
+}
+
+void CheckCloseNotify(const PRFileDesc* fd, void* arg, const SSLAlert* alert) {
+ *reinterpret_cast<bool*>(arg) = true;
+ EXPECT_EQ(close_notify, alert->description);
+ EXPECT_EQ(alert_warning, alert->level);
+}
+
+TEST_P(TlsConnectGeneric, ShutdownOneSide) {
+ Connect();
+
+ // Setup to check alerts.
+ EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(server_->ssl_fd(),
+ FailOnCloseNotify, nullptr));
+ EXPECT_EQ(SECSuccess, SSL_AlertReceivedCallback(client_->ssl_fd(),
+ FailOnCloseNotify, nullptr));
+
+ bool client_sent = false;
+ EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(client_->ssl_fd(),
+ CheckCloseNotify, &client_sent));
+ bool server_received = false;
+ EXPECT_EQ(SECSuccess,
+ SSL_AlertReceivedCallback(server_->ssl_fd(), CheckCloseNotify,
+ &server_received));
+ EXPECT_EQ(PR_SUCCESS, PR_Shutdown(client_->ssl_fd(), PR_SHUTDOWN_SEND));
+
+ // Make sure that the server reads out the close_notify.
+ uint8_t buf[10];
+ EXPECT_EQ(0, PR_Read(server_->ssl_fd(), buf, sizeof(buf)));
+
+ // Reading and writing should still work in the one open direction.
+ EXPECT_TRUE(client_sent);
+ EXPECT_TRUE(server_received);
+ server_->SendData(10, 10);
+ client_->ReadBytes(10);
+
+ // Now close the other side and do the same checks.
+ bool server_sent = false;
+ EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(server_->ssl_fd(),
+ CheckCloseNotify, &server_sent));
+ bool client_received = false;
+ EXPECT_EQ(SECSuccess,
+ SSL_AlertReceivedCallback(client_->ssl_fd(), CheckCloseNotify,
+ &client_received));
+ EXPECT_EQ(PR_SUCCESS, PR_Shutdown(server_->ssl_fd(), PR_SHUTDOWN_SEND));
+
+ EXPECT_EQ(0, PR_Read(client_->ssl_fd(), buf, sizeof(buf)));
+ EXPECT_TRUE(server_sent);
+ EXPECT_TRUE(client_received);
+}
+
+TEST_P(TlsConnectGeneric, ShutdownOneSideThenCloseTcp) {
+ Connect();
+
+ bool client_sent = false;
+ EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(client_->ssl_fd(),
+ CheckCloseNotify, &client_sent));
+ bool server_received = false;
+ EXPECT_EQ(SECSuccess,
+ SSL_AlertReceivedCallback(server_->ssl_fd(), CheckCloseNotify,
+ &server_received));
+ EXPECT_EQ(PR_SUCCESS, PR_Shutdown(client_->ssl_fd(), PR_SHUTDOWN_SEND));
+
+ // Make sure that the server reads out the close_notify.
+ uint8_t buf[10];
+ EXPECT_EQ(0, PR_Read(server_->ssl_fd(), buf, sizeof(buf)));
+
+ // Now simulate the underlying connection closing.
+ client_->adapter()->Reset();
+
+ // Now close the other side and see that things don't explode.
+ EXPECT_EQ(PR_SUCCESS, PR_Shutdown(server_->ssl_fd(), PR_SHUTDOWN_SEND));
+
+ EXPECT_GT(0, PR_Read(client_->ssl_fd(), buf, sizeof(buf)));
+ EXPECT_EQ(PR_NOT_CONNECTED_ERROR, PR_GetError());
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ GenericStream, TlsConnectGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll));
+INSTANTIATE_TEST_SUITE_P(
+ GenericDatagram, TlsConnectGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus));
+
+INSTANTIATE_TEST_SUITE_P(StreamOnly, TlsConnectStream,
+ TlsConnectTestBase::kTlsVAll);
+INSTANTIATE_TEST_SUITE_P(DatagramOnly, TlsConnectDatagram,
+ TlsConnectTestBase::kTlsV11Plus);
+INSTANTIATE_TEST_SUITE_P(DatagramHolddown, TlsHolddownTest,
+ TlsConnectTestBase::kTlsV11Plus);
+
+INSTANTIATE_TEST_SUITE_P(
+ Pre12Stream, TlsConnectPre12,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10V11));
+INSTANTIATE_TEST_SUITE_P(
+ Pre12Datagram, TlsConnectPre12,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11));
+
+INSTANTIATE_TEST_SUITE_P(Version12Only, TlsConnectTls12,
+ TlsConnectTestBase::kTlsVariantsAll);
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_TEST_SUITE_P(Version13Only, TlsConnectTls13,
+ TlsConnectTestBase::kTlsVariantsAll);
+#endif
+
+INSTANTIATE_TEST_SUITE_P(
+ Pre13Stream, TlsConnectGenericPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12));
+INSTANTIATE_TEST_SUITE_P(
+ Pre13Datagram, TlsConnectGenericPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11V12));
+INSTANTIATE_TEST_SUITE_P(Pre13StreamOnly, TlsConnectStreamPre13,
+ TlsConnectTestBase::kTlsV10ToV12);
+
+INSTANTIATE_TEST_SUITE_P(Version12Plus, TlsConnectTls12Plus,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12Plus));
+
+INSTANTIATE_TEST_SUITE_P(
+ GenericStream, TlsConnectGenericResumption,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll,
+ ::testing::Values(true, false)));
+INSTANTIATE_TEST_SUITE_P(
+ GenericDatagram, TlsConnectGenericResumption,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus,
+ ::testing::Values(true, false)));
+
+INSTANTIATE_TEST_SUITE_P(
+ GenericStream, TlsConnectGenericResumptionToken,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll));
+INSTANTIATE_TEST_SUITE_P(
+ GenericDatagram, TlsConnectGenericResumptionToken,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus));
+
+INSTANTIATE_TEST_SUITE_P(GenericDatagram, TlsConnectTls13ResumptionToken,
+ TlsConnectTestBase::kTlsVariantsAll);
+
+} // namespace nss_test