diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc | 801 |
1 files changed, 801 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc new file mode 100644 index 0000000000..491f50921f --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc @@ -0,0 +1,801 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include <vector> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectGeneric, SetupOnly) {} + +TEST_P(TlsConnectGeneric, Connect) { + SetExpectedVersion(std::get<1>(GetParam())); + Connect(); + CheckKeys(); +} + +TEST_P(TlsConnectGeneric, ConnectEcdsa) { + SetExpectedVersion(std::get<1>(GetParam())); + Reset(TlsAgent::kServerEcdsa256); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa); +} + +TEST_P(TlsConnectGeneric, CipherSuiteMismatch) { + EnsureTlsSetup(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + server_->EnableSingleCipher(TLS_AES_256_GCM_SHA384); + } else { + client_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA); + } + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +class TlsAlertRecorder : public TlsRecordFilter { + public: + TlsAlertRecorder(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a), level_(255), description_(255) {} + + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + if (level_ != 255) { // Already captured. + return KEEP; + } + if (header.content_type() != ssl_ct_alert) { + return KEEP; + } + + std::cerr << "Alert: " << input << std::endl; + + TlsParser parser(input); + EXPECT_TRUE(parser.Read(&level_)); + EXPECT_TRUE(parser.Read(&description_)); + return KEEP; + } + + uint8_t level() const { return level_; } + uint8_t description() const { return description_; } + + private: + uint8_t level_; + uint8_t description_; +}; + +class HelloTruncator : public TlsHandshakeFilter { + public: + HelloTruncator(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter( + a, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {} + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + output->Assign(input.data(), input.len() - 1); + return CHANGE; + } +}; + +// Verify that when NSS reports that an alert is sent, it is actually sent. +TEST_P(TlsConnectGeneric, CaptureAlertServer) { + MakeTlsFilter<HelloTruncator>(client_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(server_); + + ConnectExpectAlert(server_, kTlsAlertDecodeError); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description()); +} + +TEST_P(TlsConnectGenericPre13, CaptureAlertClient) { + MakeTlsFilter<HelloTruncator>(server_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_); + + ConnectExpectAlert(client_, kTlsAlertDecodeError); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description()); +} + +// In TLS 1.3, the server can't read the client alert. +TEST_P(TlsConnectTls13, CaptureAlertClient) { + MakeTlsFilter<HelloTruncator>(server_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_); + + StartConnect(); + + client_->Handshake(); + client_->ExpectSendAlert(kTlsAlertDecodeError); + server_->Handshake(); + client_->Handshake(); + if (variant_ == ssl_variant_stream) { + // DTLS just drops the alert it can't decrypt. + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + } + server_->Handshake(); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description()); +} + +TEST_P(TlsConnectGenericPre13, ConnectFalseStart) { + client_->EnableFalseStart(); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGeneric, ConnectAlpn) { + EnableAlpn(); + Connect(); + CheckAlpn("a"); +} + +TEST_P(TlsConnectGeneric, ConnectAlpnPriorityA) { + // "alpn" "npn" + // alpn is the fallback here. npn has the highest priority and should be + // picked. + const std::vector<uint8_t> alpn = {0x04, 0x61, 0x6c, 0x70, 0x6e, + 0x03, 0x6e, 0x70, 0x6e}; + EnableAlpn(alpn); + Connect(); + CheckAlpn("npn"); +} + +TEST_P(TlsConnectGeneric, ConnectAlpnPriorityB) { + // "alpn" "npn" "http" + // npn has the highest priority and should be picked. + const std::vector<uint8_t> alpn = {0x04, 0x61, 0x6c, 0x70, 0x6e, 0x03, 0x6e, + 0x70, 0x6e, 0x04, 0x68, 0x74, 0x74, 0x70}; + EnableAlpn(alpn); + Connect(); + CheckAlpn("npn"); +} + +TEST_P(TlsConnectGeneric, ConnectAlpnClone) { + EnsureModelSockets(); + client_model_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); + server_model_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); + Connect(); + CheckAlpn("a"); +} + +TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackA) { + // "ab" "alpn" + const std::vector<uint8_t> client_alpn = {0x02, 0x61, 0x62, 0x04, + 0x61, 0x6c, 0x70, 0x6e}; + EnableAlpnWithCallback(client_alpn, "alpn"); + Connect(); + CheckAlpn("alpn"); +} + +TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackB) { + // "ab" "alpn" + const std::vector<uint8_t> client_alpn = {0x02, 0x61, 0x62, 0x04, + 0x61, 0x6c, 0x70, 0x6e}; + EnableAlpnWithCallback(client_alpn, "ab"); + Connect(); + CheckAlpn("ab"); +} + +TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackC) { + // "cd" "npn" "alpn" + const std::vector<uint8_t> client_alpn = {0x02, 0x63, 0x64, 0x03, 0x6e, 0x70, + 0x6e, 0x04, 0x61, 0x6c, 0x70, 0x6e}; + EnableAlpnWithCallback(client_alpn, "npn"); + Connect(); + CheckAlpn("npn"); +} + +TEST_P(TlsConnectDatagram, ConnectSrtp) { + EnableSrtp(); + Connect(); + CheckSrtp(); + SendReceive(); +} + +TEST_P(TlsConnectGeneric, ConnectSendReceive) { + Connect(); + SendReceive(); +} + +class SaveTlsRecord : public TlsRecordFilter { + public: + SaveTlsRecord(const std::shared_ptr<TlsAgent>& a, size_t index) + : TlsRecordFilter(a), index_(index), count_(0), contents_() {} + + const DataBuffer& contents() const { return contents_; } + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + if (count_++ == index_) { + contents_ = data; + } + return KEEP; + } + + private: + const size_t index_; + size_t count_; + DataBuffer contents_; +}; + +// Check that decrypting filters work and can read any record. +// This test (currently) only works in TLS 1.3 where we can decrypt. +TEST_F(TlsConnectStreamTls13, DecryptRecordClient) { + EnsureTlsSetup(); + // 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer + auto saved = MakeTlsFilter<SaveTlsRecord>(client_, 3); + saved->EnableDecryption(); + Connect(); + SendReceive(); + + static const uint8_t data[] = {0xde, 0xad, 0xdc}; + DataBuffer buf(data, sizeof(data)); + client_->SendBuffer(buf); + EXPECT_EQ(buf, saved->contents()); +} + +TEST_F(TlsConnectStreamTls13, DecryptRecordServer) { + EnsureTlsSetup(); + // Disable tickets so that we are sure to not get NewSessionTicket. + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); + // 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer + auto saved = MakeTlsFilter<SaveTlsRecord>(server_, 3); + saved->EnableDecryption(); + Connect(); + SendReceive(); + + static const uint8_t data[] = {0xde, 0xad, 0xd5}; + DataBuffer buf(data, sizeof(data)); + server_->SendBuffer(buf); + EXPECT_EQ(buf, saved->contents()); +} + +class DropTlsRecord : public TlsRecordFilter { + public: + DropTlsRecord(const std::shared_ptr<TlsAgent>& a, size_t index) + : TlsRecordFilter(a), index_(index), count_(0) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + if (count_++ == index_) { + return DROP; + } + return KEEP; + } + + private: + const size_t index_; + size_t count_; +}; + +// Test that decrypting filters work correctly and are able to drop records. +TEST_F(TlsConnectStreamTls13, DropRecordServer) { + EnsureTlsSetup(); + // Disable session tickets so that the server doesn't send an extra record. + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); + + // 0 = ServerHello, 1 = other handshake, 2 = first write + auto filter = MakeTlsFilter<DropTlsRecord>(server_, 2); + filter->EnableDecryption(); + Connect(); + server_->SendData(23, 23); // This should be dropped, so it won't be counted. + server_->ResetSentBytes(); + SendReceive(); +} + +TEST_F(TlsConnectStreamTls13, DropRecordClient) { + EnsureTlsSetup(); + // 0 = ClientHello, 1 = Finished, 2 = first write + auto filter = MakeTlsFilter<DropTlsRecord>(client_, 2); + filter->EnableDecryption(); + Connect(); + client_->SendData(26, 26); // This should be dropped, so it won't be counted. + client_->ResetSentBytes(); + SendReceive(); +} + +// Check that a server can use 0.5 RTT if client authentication isn't enabled. +TEST_P(TlsConnectTls13, WriteBeforeClientFinished) { + EnsureTlsSetup(); + StartConnect(); + client_->Handshake(); // ClientHello + server_->Handshake(); // ServerHello + + server_->SendData(10); + client_->ReadBytes(10); // Client should emit the Finished as a side-effect. + server_->Handshake(); // Server consumes the Finished. + CheckConnected(); +} + +// We don't allow 0.5 RTT if client authentication is requested. +TEST_P(TlsConnectTls13, WriteBeforeClientFinishedClientAuth) { + client_->SetupClientAuth(); + server_->RequestClientAuth(false); + StartConnect(); + client_->Handshake(); // ClientHello + server_->Handshake(); // ServerHello + + static const uint8_t data[] = {1, 2, 3}; + EXPECT_GT(0, PR_Write(server_->ssl_fd(), data, sizeof(data))); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + Handshake(); + CheckConnected(); + SendReceive(); +} + +// 0.5 RTT should fail with client authentication required. +TEST_P(TlsConnectTls13, WriteBeforeClientFinishedClientAuthRequired) { + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + StartConnect(); + client_->Handshake(); // ClientHello + server_->Handshake(); // ServerHello + + static const uint8_t data[] = {1, 2, 3}; + EXPECT_GT(0, PR_Write(server_->ssl_fd(), data, sizeof(data))); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + Handshake(); + CheckConnected(); + SendReceive(); +} + +// The next two tests takes advantage of the fact that we +// automatically read the first 1024 bytes, so if +// we provide 1200 bytes, they overrun the read buffer +// provided by the calling test. + +// DTLS should return an error. +TEST_P(TlsConnectDatagram, ShortRead) { + Connect(); + client_->ExpectReadWriteError(); + server_->SendData(50, 50); + client_->ReadBytes(20); + EXPECT_EQ(0U, client_->received_bytes()); + EXPECT_EQ(SSL_ERROR_RX_SHORT_DTLS_READ, PORT_GetError()); + + // Now send and receive another packet. + server_->ResetSentBytes(); // Reset the counter. + SendReceive(); +} + +// TLS should get the write in two chunks. +TEST_P(TlsConnectStream, ShortRead) { + // This test behaves oddly with TLS 1.0 because of 1/n+1 splitting, + // so skip in that case. + if (version_ < SSL_LIBRARY_VERSION_TLS_1_1) GTEST_SKIP(); + + Connect(); + server_->SendData(50, 50); + // Read the first tranche. + client_->ReadBytes(20); + ASSERT_EQ(20U, client_->received_bytes()); + // The second tranche should now immediately be available. + client_->ReadBytes(); + ASSERT_EQ(50U, client_->received_bytes()); +} + +// We enable compression via the API but it's disabled internally, +// so we should never get it. +TEST_P(TlsConnectGeneric, ConnectWithCompressionEnabled) { + EnsureTlsSetup(); + client_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE); + server_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE); + Connect(); + EXPECT_FALSE(client_->is_compressed()); + SendReceive(); +} + +class TlsHolddownTest : public TlsConnectDatagram { + protected: + // This causes all timers to run to completion. It advances the clock and + // handshakes on both peers until both peers have no more timers pending, + // which should happen at the end of a handshake. This is necessary to ensure + // that the relatively long holddown timer expires, but that any other timers + // also expire and run correctly. + void RunAllTimersDown() { + while (true) { + PRIntervalTime time; + SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); + if (rv != SECSuccess) { + rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); + if (rv != SECSuccess) { + break; // Neither peer has an outstanding timer. + } + } + + if (g_ssl_gtest_verbose) { + std::cerr << "Shifting timers" << std::endl; + } + ShiftDtlsTimers(); + Handshake(); + } + } +}; + +TEST_P(TlsHolddownTest, TestDtlsHolddownExpiry) { + Connect(); + std::cerr << "Expiring holddown timer" << std::endl; + RunAllTimersDown(); + SendReceive(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + // One for send, one for receive. + EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd())); + } +} + +TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + Connect(); + RunAllTimersDown(); + SendReceive(); + // One for send, one for receive. + EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd())); +} + +class TlsPreCCSHeaderInjector : public TlsRecordFilter { + public: + TlsPreCCSHeaderInjector(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a) {} + virtual PacketFilter::Action FilterRecord( + const TlsRecordHeader& record_header, const DataBuffer& input, + size_t* offset, DataBuffer* output) override { + if (record_header.content_type() != ssl_ct_change_cipher_spec) { + return KEEP; + } + + std::cerr << "Injecting Finished header before CCS\n"; + const uint8_t hhdr[] = {kTlsHandshakeFinished, 0x00, 0x00, 0x0c}; + DataBuffer hhdr_buf(hhdr, sizeof(hhdr)); + TlsRecordHeader nhdr(record_header.variant(), record_header.version(), + ssl_ct_handshake, 0); + *offset = nhdr.Write(output, *offset, hhdr_buf); + *offset = record_header.Write(output, *offset, input); + return CHANGE; + } +}; + +TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) { + MakeTlsFilter<TlsPreCCSHeaderInjector>(client_); + ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); +} + +TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) { + MakeTlsFilter<TlsPreCCSHeaderInjector>(server_); + StartConnect(); + ExpectAlert(client_, kTlsAlertUnexpectedMessage); + Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + server_->Handshake(); // Make sure alert is consumed. +} + +TEST_P(TlsConnectTls13, UnknownAlert) { + Connect(); + server_->ExpectSendAlert(0xff, kTlsAlertWarning); + client_->ExpectReceiveAlert(0xff, kTlsAlertWarning); + SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning, + 0xff); // Unknown value. + client_->ExpectReadWriteError(); + client_->WaitForErrorCode(SSL_ERROR_RX_UNKNOWN_ALERT, 2000); +} + +TEST_P(TlsConnectTls13, AlertWrongLevel) { + Connect(); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage, kTlsAlertWarning); + client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage, kTlsAlertWarning); + SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning, + kTlsAlertUnexpectedMessage); + client_->ExpectReadWriteError(); + client_->WaitForErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT, 2000); +} + +TEST_P(TlsConnectTls13, UnknownRecord) { + static const uint8_t kUknownRecord[] = { + 0xff, SSL_LIBRARY_VERSION_TLS_1_2 >> 8, + SSL_LIBRARY_VERSION_TLS_1_2 & 0xff, 0, 0}; + + Connect(); + if (variant_ == ssl_variant_stream) { + // DTLS just drops the record with an invalid type. + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + } + client_->SendDirect(DataBuffer(kUknownRecord, sizeof(kUknownRecord))); + server_->ExpectReadWriteError(); + server_->ReadBytes(); + if (variant_ == ssl_variant_stream) { + EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code()); + } else { + EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, server_->error_code()); + } +} + +TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) { + EnsureTlsSetup(); + StartConnect(); + client_->Handshake(); + server_->Handshake(); // Send first flight. + client_->adapter()->SetWriteError(PR_IO_ERROR); + client_->Handshake(); // This will get an error, but shouldn't crash. + client_->CheckErrorCode(SSL_ERROR_SOCKET_WRITE_FAILURE); +} + +TEST_P(TlsConnectDatagram, BlockedWrite) { + Connect(); + + // Mark the socket as blocked. + client_->adapter()->SetWriteError(PR_WOULD_BLOCK_ERROR); + static const uint8_t data[] = {1, 2, 3}; + int32_t rv = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_GT(0, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Remove the write error and though the previous write failed, future reads + // and writes should just work as if it never happened. + client_->adapter()->SetWriteError(0); + SendReceive(); +} + +TEST_F(TlsConnectTest, ConnectSSLv3) { + ConfigureVersion(SSL_LIBRARY_VERSION_3_0); + EnableOnlyStaticRsaCiphers(); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); +} + +TEST_F(TlsConnectTest, ConnectSSLv3ClientAuth) { + ConfigureVersion(SSL_LIBRARY_VERSION_3_0); + EnableOnlyStaticRsaCiphers(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); +} + +static size_t ExpectedCbcLen(size_t in, size_t hmac = 20, size_t block = 16) { + // MAC-then-Encrypt expansion formula: + return ((in + hmac + (block - 1)) / block) * block; +} + +TEST_F(TlsConnectTest, OneNRecordSplitting) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0); + EnsureTlsSetup(); + ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA); + auto records = MakeTlsFilter<TlsRecordRecorder>(server_); + // This should be split into 1, 16384 and 20. + DataBuffer big_buffer; + big_buffer.Allocate(1 + 16384 + 20); + server_->SendBuffer(big_buffer); + ASSERT_EQ(3U, records->count()); + EXPECT_EQ(ExpectedCbcLen(1), records->record(0).buffer.len()); + EXPECT_EQ(ExpectedCbcLen(16384), records->record(1).buffer.len()); + EXPECT_EQ(ExpectedCbcLen(20), records->record(2).buffer.len()); +} + +// We can't test for randomness easily here, but we can test that we don't +// produce a zero value, or produce the same value twice. There are 5 values +// here: two ClientHello.random, two ServerHello.random, and one zero value. +// Matrix them and fail if any are the same. +TEST_P(TlsConnectGeneric, CheckRandoms) { + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + + static const size_t random_len = 32; + uint8_t crandom1[random_len], srandom1[random_len]; + uint8_t z[random_len] = {0}; + + auto ch = MakeTlsFilter<TlsHandshakeRecorder>(client_, ssl_hs_client_hello); + auto sh = MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_server_hello); + Connect(); + ASSERT_TRUE(ch->buffer().len() > (random_len + 2)); + ASSERT_TRUE(sh->buffer().len() > (random_len + 2)); + memcpy(crandom1, ch->buffer().data() + 2, random_len); + memcpy(srandom1, sh->buffer().data() + 2, random_len); + EXPECT_NE(0, memcmp(crandom1, srandom1, random_len)); + EXPECT_NE(0, memcmp(crandom1, z, random_len)); + EXPECT_NE(0, memcmp(srandom1, z, random_len)); + + Reset(); + ch = MakeTlsFilter<TlsHandshakeRecorder>(client_, ssl_hs_client_hello); + sh = MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_server_hello); + Connect(); + ASSERT_TRUE(ch->buffer().len() > (random_len + 2)); + ASSERT_TRUE(sh->buffer().len() > (random_len + 2)); + const uint8_t* crandom2 = ch->buffer().data() + 2; + const uint8_t* srandom2 = sh->buffer().data() + 2; + + EXPECT_NE(0, memcmp(crandom2, srandom2, random_len)); + EXPECT_NE(0, memcmp(crandom2, z, random_len)); + EXPECT_NE(0, memcmp(srandom2, z, random_len)); + + EXPECT_NE(0, memcmp(crandom1, crandom2, random_len)); + EXPECT_NE(0, memcmp(crandom1, srandom2, random_len)); + EXPECT_NE(0, memcmp(srandom1, crandom2, random_len)); + EXPECT_NE(0, memcmp(srandom1, srandom2, random_len)); +} + +void FailOnCloseNotify(const PRFileDesc* fd, void* arg, const SSLAlert* alert) { + ADD_FAILURE() << "received alert " << alert->description; +} + +void CheckCloseNotify(const PRFileDesc* fd, void* arg, const SSLAlert* alert) { + *reinterpret_cast<bool*>(arg) = true; + EXPECT_EQ(close_notify, alert->description); + EXPECT_EQ(alert_warning, alert->level); +} + +TEST_P(TlsConnectGeneric, ShutdownOneSide) { + Connect(); + + // Setup to check alerts. + EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(server_->ssl_fd(), + FailOnCloseNotify, nullptr)); + EXPECT_EQ(SECSuccess, SSL_AlertReceivedCallback(client_->ssl_fd(), + FailOnCloseNotify, nullptr)); + + bool client_sent = false; + EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(client_->ssl_fd(), + CheckCloseNotify, &client_sent)); + bool server_received = false; + EXPECT_EQ(SECSuccess, + SSL_AlertReceivedCallback(server_->ssl_fd(), CheckCloseNotify, + &server_received)); + EXPECT_EQ(PR_SUCCESS, PR_Shutdown(client_->ssl_fd(), PR_SHUTDOWN_SEND)); + + // Make sure that the server reads out the close_notify. + uint8_t buf[10]; + EXPECT_EQ(0, PR_Read(server_->ssl_fd(), buf, sizeof(buf))); + + // Reading and writing should still work in the one open direction. + EXPECT_TRUE(client_sent); + EXPECT_TRUE(server_received); + server_->SendData(10, 10); + client_->ReadBytes(10); + + // Now close the other side and do the same checks. + bool server_sent = false; + EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(server_->ssl_fd(), + CheckCloseNotify, &server_sent)); + bool client_received = false; + EXPECT_EQ(SECSuccess, + SSL_AlertReceivedCallback(client_->ssl_fd(), CheckCloseNotify, + &client_received)); + EXPECT_EQ(PR_SUCCESS, PR_Shutdown(server_->ssl_fd(), PR_SHUTDOWN_SEND)); + + EXPECT_EQ(0, PR_Read(client_->ssl_fd(), buf, sizeof(buf))); + EXPECT_TRUE(server_sent); + EXPECT_TRUE(client_received); +} + +TEST_P(TlsConnectGeneric, ShutdownOneSideThenCloseTcp) { + Connect(); + + bool client_sent = false; + EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(client_->ssl_fd(), + CheckCloseNotify, &client_sent)); + bool server_received = false; + EXPECT_EQ(SECSuccess, + SSL_AlertReceivedCallback(server_->ssl_fd(), CheckCloseNotify, + &server_received)); + EXPECT_EQ(PR_SUCCESS, PR_Shutdown(client_->ssl_fd(), PR_SHUTDOWN_SEND)); + + // Make sure that the server reads out the close_notify. + uint8_t buf[10]; + EXPECT_EQ(0, PR_Read(server_->ssl_fd(), buf, sizeof(buf))); + + // Now simulate the underlying connection closing. + client_->adapter()->Reset(); + + // Now close the other side and see that things don't explode. + EXPECT_EQ(PR_SUCCESS, PR_Shutdown(server_->ssl_fd(), PR_SHUTDOWN_SEND)); + + EXPECT_GT(0, PR_Read(client_->ssl_fd(), buf, sizeof(buf))); + EXPECT_EQ(PR_NOT_CONNECTED_ERROR, PR_GetError()); +} + +INSTANTIATE_TEST_SUITE_P( + GenericStream, TlsConnectGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_SUITE_P( + GenericDatagram, TlsConnectGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus)); + +INSTANTIATE_TEST_SUITE_P(StreamOnly, TlsConnectStream, + TlsConnectTestBase::kTlsVAll); +INSTANTIATE_TEST_SUITE_P(DatagramOnly, TlsConnectDatagram, + TlsConnectTestBase::kTlsV11Plus); +INSTANTIATE_TEST_SUITE_P(DatagramHolddown, TlsHolddownTest, + TlsConnectTestBase::kTlsV11Plus); + +INSTANTIATE_TEST_SUITE_P( + Pre12Stream, TlsConnectPre12, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10V11)); +INSTANTIATE_TEST_SUITE_P( + Pre12Datagram, TlsConnectPre12, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11)); + +INSTANTIATE_TEST_SUITE_P(Version12Only, TlsConnectTls12, + TlsConnectTestBase::kTlsVariantsAll); +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_SUITE_P(Version13Only, TlsConnectTls13, + TlsConnectTestBase::kTlsVariantsAll); +#endif + +INSTANTIATE_TEST_SUITE_P( + Pre13Stream, TlsConnectGenericPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); +INSTANTIATE_TEST_SUITE_P( + Pre13Datagram, TlsConnectGenericPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11V12)); +INSTANTIATE_TEST_SUITE_P(Pre13StreamOnly, TlsConnectStreamPre13, + TlsConnectTestBase::kTlsV10ToV12); + +INSTANTIATE_TEST_SUITE_P(Version12Plus, TlsConnectTls12Plus, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12Plus)); + +INSTANTIATE_TEST_SUITE_P( + GenericStream, TlsConnectGenericResumption, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll, + ::testing::Values(true, false))); +INSTANTIATE_TEST_SUITE_P( + GenericDatagram, TlsConnectGenericResumption, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus, + ::testing::Values(true, false))); + +INSTANTIATE_TEST_SUITE_P( + GenericStream, TlsConnectGenericResumptionToken, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_SUITE_P( + GenericDatagram, TlsConnectGenericResumptionToken, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus)); + +INSTANTIATE_TEST_SUITE_P(GenericDatagram, TlsConnectTls13ResumptionToken, + TlsConnectTestBase::kTlsVariantsAll); + +} // namespace nss_test |