/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include #include #include #include "secerr.h" #include "ssl.h" #include "sslerr.h" #include "sslproto.h" extern "C" { // This is not something that should make you happy. #include "libssl_internals.h" } #include "gtest_utils.h" #include "nss_scoped_ptrs.h" #include "tls_connect.h" #include "tls_filter.h" #include "tls_parser.h" namespace nss_test { TEST_P(TlsConnectGeneric, SetupOnly) {} TEST_P(TlsConnectGeneric, Connect) { SetExpectedVersion(std::get<1>(GetParam())); Connect(); CheckKeys(); } TEST_P(TlsConnectGeneric, ConnectEcdsa) { SetExpectedVersion(std::get<1>(GetParam())); Reset(TlsAgent::kServerEcdsa256); Connect(); CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa); } TEST_P(TlsConnectGeneric, CipherSuiteMismatch) { EnsureTlsSetup(); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); server_->EnableSingleCipher(TLS_AES_256_GCM_SHA384); } else { client_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA); } ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); } class TlsAlertRecorder : public TlsRecordFilter { public: TlsAlertRecorder(const std::shared_ptr& a) : TlsRecordFilter(a), level_(255), description_(255) {} PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output) override { if (level_ != 255) { // Already captured. return KEEP; } if (header.content_type() != ssl_ct_alert) { return KEEP; } std::cerr << "Alert: " << input << std::endl; TlsParser parser(input); EXPECT_TRUE(parser.Read(&level_)); EXPECT_TRUE(parser.Read(&description_)); return KEEP; } uint8_t level() const { return level_; } uint8_t description() const { return description_; } private: uint8_t level_; uint8_t description_; }; class HelloTruncator : public TlsHandshakeFilter { public: HelloTruncator(const std::shared_ptr& a) : TlsHandshakeFilter( a, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {} PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { output->Assign(input.data(), input.len() - 1); return CHANGE; } }; // Verify that when NSS reports that an alert is sent, it is actually sent. TEST_P(TlsConnectGeneric, CaptureAlertServer) { MakeTlsFilter(client_); auto alert_recorder = MakeTlsFilter(server_); ConnectExpectAlert(server_, kTlsAlertDecodeError); EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description()); } TEST_P(TlsConnectGenericPre13, CaptureAlertClient) { MakeTlsFilter(server_); auto alert_recorder = MakeTlsFilter(client_); ConnectExpectAlert(client_, kTlsAlertDecodeError); EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description()); } // In TLS 1.3, the server can't read the client alert. TEST_P(TlsConnectTls13, CaptureAlertClient) { MakeTlsFilter(server_); auto alert_recorder = MakeTlsFilter(client_); StartConnect(); client_->Handshake(); client_->ExpectSendAlert(kTlsAlertDecodeError); server_->Handshake(); client_->Handshake(); if (variant_ == ssl_variant_stream) { // DTLS just drops the alert it can't decrypt. server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); } server_->Handshake(); EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description()); } TEST_P(TlsConnectGenericPre13, ConnectFalseStart) { client_->EnableFalseStart(); Connect(); SendReceive(); } TEST_P(TlsConnectGeneric, ConnectAlpn) { EnableAlpn(); Connect(); CheckAlpn("a"); } TEST_P(TlsConnectGeneric, ConnectAlpnPriorityA) { // "alpn" "npn" // alpn is the fallback here. npn has the highest priority and should be // picked. const std::vector alpn = {0x04, 0x61, 0x6c, 0x70, 0x6e, 0x03, 0x6e, 0x70, 0x6e}; EnableAlpn(alpn); Connect(); CheckAlpn("npn"); } TEST_P(TlsConnectGeneric, ConnectAlpnPriorityB) { // "alpn" "npn" "http" // npn has the highest priority and should be picked. const std::vector alpn = {0x04, 0x61, 0x6c, 0x70, 0x6e, 0x03, 0x6e, 0x70, 0x6e, 0x04, 0x68, 0x74, 0x74, 0x70}; EnableAlpn(alpn); Connect(); CheckAlpn("npn"); } TEST_P(TlsConnectGeneric, ConnectAlpnClone) { EnsureModelSockets(); client_model_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); server_model_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); Connect(); CheckAlpn("a"); } TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackA) { // "ab" "alpn" const std::vector client_alpn = {0x02, 0x61, 0x62, 0x04, 0x61, 0x6c, 0x70, 0x6e}; EnableAlpnWithCallback(client_alpn, "alpn"); Connect(); CheckAlpn("alpn"); } TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackB) { // "ab" "alpn" const std::vector client_alpn = {0x02, 0x61, 0x62, 0x04, 0x61, 0x6c, 0x70, 0x6e}; EnableAlpnWithCallback(client_alpn, "ab"); Connect(); CheckAlpn("ab"); } TEST_P(TlsConnectGeneric, ConnectAlpnWithCustomCallbackC) { // "cd" "npn" "alpn" const std::vector client_alpn = {0x02, 0x63, 0x64, 0x03, 0x6e, 0x70, 0x6e, 0x04, 0x61, 0x6c, 0x70, 0x6e}; EnableAlpnWithCallback(client_alpn, "npn"); Connect(); CheckAlpn("npn"); } TEST_P(TlsConnectDatagram, ConnectSrtp) { EnableSrtp(); Connect(); CheckSrtp(); SendReceive(); } TEST_P(TlsConnectGeneric, ConnectSendReceive) { Connect(); SendReceive(); } class SaveTlsRecord : public TlsRecordFilter { public: SaveTlsRecord(const std::shared_ptr& a, size_t index) : TlsRecordFilter(a), index_(index), count_(0), contents_() {} const DataBuffer& contents() const { return contents_; } protected: PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& data, DataBuffer* changed) override { if (count_++ == index_) { contents_ = data; } return KEEP; } private: const size_t index_; size_t count_; DataBuffer contents_; }; // Check that decrypting filters work and can read any record. // This test (currently) only works in TLS 1.3 where we can decrypt. TEST_F(TlsConnectStreamTls13, DecryptRecordClient) { EnsureTlsSetup(); // 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer auto saved = MakeTlsFilter(client_, 3); saved->EnableDecryption(); Connect(); SendReceive(); static const uint8_t data[] = {0xde, 0xad, 0xdc}; DataBuffer buf(data, sizeof(data)); client_->SendBuffer(buf); EXPECT_EQ(buf, saved->contents()); } TEST_F(TlsConnectStreamTls13, DecryptRecordServer) { EnsureTlsSetup(); // Disable tickets so that we are sure to not get NewSessionTicket. EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); // 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer auto saved = MakeTlsFilter(server_, 3); saved->EnableDecryption(); Connect(); SendReceive(); static const uint8_t data[] = {0xde, 0xad, 0xd5}; DataBuffer buf(data, sizeof(data)); server_->SendBuffer(buf); EXPECT_EQ(buf, saved->contents()); } class DropTlsRecord : public TlsRecordFilter { public: DropTlsRecord(const std::shared_ptr& a, size_t index) : TlsRecordFilter(a), index_(index), count_(0) {} protected: PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& data, DataBuffer* changed) override { if (count_++ == index_) { return DROP; } return KEEP; } private: const size_t index_; size_t count_; }; // Test that decrypting filters work correctly and are able to drop records. TEST_F(TlsConnectStreamTls13, DropRecordServer) { EnsureTlsSetup(); // Disable session tickets so that the server doesn't send an extra record. EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); // 0 = ServerHello, 1 = other handshake, 2 = first write auto filter = MakeTlsFilter(server_, 2); filter->EnableDecryption(); Connect(); server_->SendData(23, 23); // This should be dropped, so it won't be counted. server_->ResetSentBytes(); SendReceive(); } TEST_F(TlsConnectStreamTls13, DropRecordClient) { EnsureTlsSetup(); // 0 = ClientHello, 1 = Finished, 2 = first write auto filter = MakeTlsFilter(client_, 2); filter->EnableDecryption(); Connect(); client_->SendData(26, 26); // This should be dropped, so it won't be counted. client_->ResetSentBytes(); SendReceive(); } // Check that a server can use 0.5 RTT if client authentication isn't enabled. TEST_P(TlsConnectTls13, WriteBeforeClientFinished) { EnsureTlsSetup(); StartConnect(); client_->Handshake(); // ClientHello server_->Handshake(); // ServerHello server_->SendData(10); client_->ReadBytes(10); // Client should emit the Finished as a side-effect. server_->Handshake(); // Server consumes the Finished. CheckConnected(); } // We don't allow 0.5 RTT if client authentication is requested. TEST_P(TlsConnectTls13, WriteBeforeClientFinishedClientAuth) { client_->SetupClientAuth(); server_->RequestClientAuth(false); StartConnect(); client_->Handshake(); // ClientHello server_->Handshake(); // ServerHello static const uint8_t data[] = {1, 2, 3}; EXPECT_GT(0, PR_Write(server_->ssl_fd(), data, sizeof(data))); EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); Handshake(); CheckConnected(); SendReceive(); } // 0.5 RTT should fail with client authentication required. TEST_P(TlsConnectTls13, WriteBeforeClientFinishedClientAuthRequired) { client_->SetupClientAuth(); server_->RequestClientAuth(true); StartConnect(); client_->Handshake(); // ClientHello server_->Handshake(); // ServerHello static const uint8_t data[] = {1, 2, 3}; EXPECT_GT(0, PR_Write(server_->ssl_fd(), data, sizeof(data))); EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); Handshake(); CheckConnected(); SendReceive(); } // The next two tests takes advantage of the fact that we // automatically read the first 1024 bytes, so if // we provide 1200 bytes, they overrun the read buffer // provided by the calling test. // DTLS should return an error. TEST_P(TlsConnectDatagram, ShortRead) { Connect(); client_->ExpectReadWriteError(); server_->SendData(50, 50); client_->ReadBytes(20); EXPECT_EQ(0U, client_->received_bytes()); EXPECT_EQ(SSL_ERROR_RX_SHORT_DTLS_READ, PORT_GetError()); // Now send and receive another packet. server_->ResetSentBytes(); // Reset the counter. SendReceive(); } // TLS should get the write in two chunks. TEST_P(TlsConnectStream, ShortRead) { // This test behaves oddly with TLS 1.0 because of 1/n+1 splitting, // so skip in that case. if (version_ < SSL_LIBRARY_VERSION_TLS_1_1) GTEST_SKIP(); Connect(); server_->SendData(50, 50); // Read the first tranche. client_->ReadBytes(20); ASSERT_EQ(20U, client_->received_bytes()); // The second tranche should now immediately be available. client_->ReadBytes(); ASSERT_EQ(50U, client_->received_bytes()); } // We enable compression via the API but it's disabled internally, // so we should never get it. TEST_P(TlsConnectGeneric, ConnectWithCompressionEnabled) { EnsureTlsSetup(); client_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE); server_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE); Connect(); EXPECT_FALSE(client_->is_compressed()); SendReceive(); } class TlsHolddownTest : public TlsConnectDatagram { protected: // This causes all timers to run to completion. It advances the clock and // handshakes on both peers until both peers have no more timers pending, // which should happen at the end of a handshake. This is necessary to ensure // that the relatively long holddown timer expires, but that any other timers // also expire and run correctly. void RunAllTimersDown() { while (true) { PRIntervalTime time; SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); if (rv != SECSuccess) { rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); if (rv != SECSuccess) { break; // Neither peer has an outstanding timer. } } if (g_ssl_gtest_verbose) { std::cerr << "Shifting timers" << std::endl; } ShiftDtlsTimers(); Handshake(); } } }; TEST_P(TlsHolddownTest, TestDtlsHolddownExpiry) { Connect(); std::cerr << "Expiring holddown timer" << std::endl; RunAllTimersDown(); SendReceive(); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { // One for send, one for receive. EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd())); } } TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); Connect(); SendReceive(); Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ExpectResumption(RESUME_TICKET); Connect(); RunAllTimersDown(); SendReceive(); // One for send, one for receive. EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd())); } class TlsPreCCSHeaderInjector : public TlsRecordFilter { public: TlsPreCCSHeaderInjector(const std::shared_ptr& a) : TlsRecordFilter(a) {} virtual PacketFilter::Action FilterRecord( const TlsRecordHeader& record_header, const DataBuffer& input, size_t* offset, DataBuffer* output) override { if (record_header.content_type() != ssl_ct_change_cipher_spec) { return KEEP; } std::cerr << "Injecting Finished header before CCS\n"; const uint8_t hhdr[] = {kTlsHandshakeFinished, 0x00, 0x00, 0x0c}; DataBuffer hhdr_buf(hhdr, sizeof(hhdr)); TlsRecordHeader nhdr(record_header.variant(), record_header.version(), ssl_ct_handshake, 0); *offset = nhdr.Write(output, *offset, hhdr_buf); *offset = record_header.Write(output, *offset, input); return CHANGE; } }; TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) { MakeTlsFilter(client_); ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); } TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) { MakeTlsFilter(server_); StartConnect(); ExpectAlert(client_, kTlsAlertUnexpectedMessage); Handshake(); EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); server_->Handshake(); // Make sure alert is consumed. } TEST_P(TlsConnectTls13, UnknownAlert) { Connect(); server_->ExpectSendAlert(0xff, kTlsAlertWarning); client_->ExpectReceiveAlert(0xff, kTlsAlertWarning); SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning, 0xff); // Unknown value. client_->ExpectReadWriteError(); client_->WaitForErrorCode(SSL_ERROR_RX_UNKNOWN_ALERT, 2000); } TEST_P(TlsConnectTls13, AlertWrongLevel) { Connect(); server_->ExpectSendAlert(kTlsAlertUnexpectedMessage, kTlsAlertWarning); client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage, kTlsAlertWarning); SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning, kTlsAlertUnexpectedMessage); client_->ExpectReadWriteError(); client_->WaitForErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT, 2000); } TEST_P(TlsConnectTls13, UnknownRecord) { static const uint8_t kUknownRecord[] = { 0xff, SSL_LIBRARY_VERSION_TLS_1_2 >> 8, SSL_LIBRARY_VERSION_TLS_1_2 & 0xff, 0, 0}; Connect(); if (variant_ == ssl_variant_stream) { // DTLS just drops the record with an invalid type. server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); } client_->SendDirect(DataBuffer(kUknownRecord, sizeof(kUknownRecord))); server_->ExpectReadWriteError(); server_->ReadBytes(); if (variant_ == ssl_variant_stream) { EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code()); } else { EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, server_->error_code()); } } TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) { EnsureTlsSetup(); StartConnect(); client_->Handshake(); server_->Handshake(); // Send first flight. client_->adapter()->SetWriteError(PR_IO_ERROR); client_->Handshake(); // This will get an error, but shouldn't crash. client_->CheckErrorCode(SSL_ERROR_SOCKET_WRITE_FAILURE); } TEST_P(TlsConnectDatagram, BlockedWrite) { Connect(); // Mark the socket as blocked. client_->adapter()->SetWriteError(PR_WOULD_BLOCK_ERROR); static const uint8_t data[] = {1, 2, 3}; int32_t rv = PR_Write(client_->ssl_fd(), data, sizeof(data)); EXPECT_GT(0, rv); EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); // Remove the write error and though the previous write failed, future reads // and writes should just work as if it never happened. client_->adapter()->SetWriteError(0); SendReceive(); } TEST_F(TlsConnectTest, ConnectSSLv3) { ConfigureVersion(SSL_LIBRARY_VERSION_3_0); EnableOnlyStaticRsaCiphers(); Connect(); CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); } TEST_F(TlsConnectTest, ConnectSSLv3ClientAuth) { ConfigureVersion(SSL_LIBRARY_VERSION_3_0); EnableOnlyStaticRsaCiphers(); client_->SetupClientAuth(); server_->RequestClientAuth(true); Connect(); CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); } static size_t ExpectedCbcLen(size_t in, size_t hmac = 20, size_t block = 16) { // MAC-then-Encrypt expansion formula: return ((in + hmac + (block - 1)) / block) * block; } TEST_F(TlsConnectTest, OneNRecordSplitting) { ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0); EnsureTlsSetup(); ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA); auto records = MakeTlsFilter(server_); // This should be split into 1, 16384 and 20. DataBuffer big_buffer; big_buffer.Allocate(1 + 16384 + 20); server_->SendBuffer(big_buffer); ASSERT_EQ(3U, records->count()); EXPECT_EQ(ExpectedCbcLen(1), records->record(0).buffer.len()); EXPECT_EQ(ExpectedCbcLen(16384), records->record(1).buffer.len()); EXPECT_EQ(ExpectedCbcLen(20), records->record(2).buffer.len()); } // We can't test for randomness easily here, but we can test that we don't // produce a zero value, or produce the same value twice. There are 5 values // here: two ClientHello.random, two ServerHello.random, and one zero value. // Matrix them and fail if any are the same. TEST_P(TlsConnectGeneric, CheckRandoms) { ConfigureSessionCache(RESUME_NONE, RESUME_NONE); static const size_t random_len = 32; uint8_t crandom1[random_len], srandom1[random_len]; uint8_t z[random_len] = {0}; auto ch = MakeTlsFilter(client_, ssl_hs_client_hello); auto sh = MakeTlsFilter(server_, ssl_hs_server_hello); Connect(); ASSERT_TRUE(ch->buffer().len() > (random_len + 2)); ASSERT_TRUE(sh->buffer().len() > (random_len + 2)); memcpy(crandom1, ch->buffer().data() + 2, random_len); memcpy(srandom1, sh->buffer().data() + 2, random_len); EXPECT_NE(0, memcmp(crandom1, srandom1, random_len)); EXPECT_NE(0, memcmp(crandom1, z, random_len)); EXPECT_NE(0, memcmp(srandom1, z, random_len)); Reset(); ch = MakeTlsFilter(client_, ssl_hs_client_hello); sh = MakeTlsFilter(server_, ssl_hs_server_hello); Connect(); ASSERT_TRUE(ch->buffer().len() > (random_len + 2)); ASSERT_TRUE(sh->buffer().len() > (random_len + 2)); const uint8_t* crandom2 = ch->buffer().data() + 2; const uint8_t* srandom2 = sh->buffer().data() + 2; EXPECT_NE(0, memcmp(crandom2, srandom2, random_len)); EXPECT_NE(0, memcmp(crandom2, z, random_len)); EXPECT_NE(0, memcmp(srandom2, z, random_len)); EXPECT_NE(0, memcmp(crandom1, crandom2, random_len)); EXPECT_NE(0, memcmp(crandom1, srandom2, random_len)); EXPECT_NE(0, memcmp(srandom1, crandom2, random_len)); EXPECT_NE(0, memcmp(srandom1, srandom2, random_len)); } void FailOnCloseNotify(const PRFileDesc* fd, void* arg, const SSLAlert* alert) { ADD_FAILURE() << "received alert " << alert->description; } void CheckCloseNotify(const PRFileDesc* fd, void* arg, const SSLAlert* alert) { *reinterpret_cast(arg) = true; EXPECT_EQ(close_notify, alert->description); EXPECT_EQ(alert_warning, alert->level); } TEST_P(TlsConnectGeneric, ShutdownOneSide) { Connect(); // Setup to check alerts. EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(server_->ssl_fd(), FailOnCloseNotify, nullptr)); EXPECT_EQ(SECSuccess, SSL_AlertReceivedCallback(client_->ssl_fd(), FailOnCloseNotify, nullptr)); bool client_sent = false; EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(client_->ssl_fd(), CheckCloseNotify, &client_sent)); bool server_received = false; EXPECT_EQ(SECSuccess, SSL_AlertReceivedCallback(server_->ssl_fd(), CheckCloseNotify, &server_received)); EXPECT_EQ(PR_SUCCESS, PR_Shutdown(client_->ssl_fd(), PR_SHUTDOWN_SEND)); // Make sure that the server reads out the close_notify. uint8_t buf[10]; EXPECT_EQ(0, PR_Read(server_->ssl_fd(), buf, sizeof(buf))); // Reading and writing should still work in the one open direction. EXPECT_TRUE(client_sent); EXPECT_TRUE(server_received); server_->SendData(10, 10); client_->ReadBytes(10); // Now close the other side and do the same checks. bool server_sent = false; EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(server_->ssl_fd(), CheckCloseNotify, &server_sent)); bool client_received = false; EXPECT_EQ(SECSuccess, SSL_AlertReceivedCallback(client_->ssl_fd(), CheckCloseNotify, &client_received)); EXPECT_EQ(PR_SUCCESS, PR_Shutdown(server_->ssl_fd(), PR_SHUTDOWN_SEND)); EXPECT_EQ(0, PR_Read(client_->ssl_fd(), buf, sizeof(buf))); EXPECT_TRUE(server_sent); EXPECT_TRUE(client_received); } TEST_P(TlsConnectGeneric, ShutdownOneSideThenCloseTcp) { Connect(); bool client_sent = false; EXPECT_EQ(SECSuccess, SSL_AlertSentCallback(client_->ssl_fd(), CheckCloseNotify, &client_sent)); bool server_received = false; EXPECT_EQ(SECSuccess, SSL_AlertReceivedCallback(server_->ssl_fd(), CheckCloseNotify, &server_received)); EXPECT_EQ(PR_SUCCESS, PR_Shutdown(client_->ssl_fd(), PR_SHUTDOWN_SEND)); // Make sure that the server reads out the close_notify. uint8_t buf[10]; EXPECT_EQ(0, PR_Read(server_->ssl_fd(), buf, sizeof(buf))); // Now simulate the underlying connection closing. client_->adapter()->Reset(); // Now close the other side and see that things don't explode. EXPECT_EQ(PR_SUCCESS, PR_Shutdown(server_->ssl_fd(), PR_SHUTDOWN_SEND)); EXPECT_GT(0, PR_Read(client_->ssl_fd(), buf, sizeof(buf))); EXPECT_EQ(PR_NOT_CONNECTED_ERROR, PR_GetError()); } INSTANTIATE_TEST_SUITE_P( GenericStream, TlsConnectGeneric, ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, TlsConnectTestBase::kTlsVAll)); INSTANTIATE_TEST_SUITE_P( GenericDatagram, TlsConnectGeneric, ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11Plus)); INSTANTIATE_TEST_SUITE_P(StreamOnly, TlsConnectStream, TlsConnectTestBase::kTlsVAll); INSTANTIATE_TEST_SUITE_P(DatagramOnly, TlsConnectDatagram, TlsConnectTestBase::kTlsV11Plus); INSTANTIATE_TEST_SUITE_P(DatagramHolddown, TlsHolddownTest, TlsConnectTestBase::kTlsV11Plus); INSTANTIATE_TEST_SUITE_P( Pre12Stream, TlsConnectPre12, ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, TlsConnectTestBase::kTlsV10V11)); INSTANTIATE_TEST_SUITE_P( Pre12Datagram, TlsConnectPre12, ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11)); INSTANTIATE_TEST_SUITE_P(Version12Only, TlsConnectTls12, TlsConnectTestBase::kTlsVariantsAll); #ifndef NSS_DISABLE_TLS_1_3 INSTANTIATE_TEST_SUITE_P(Version13Only, TlsConnectTls13, TlsConnectTestBase::kTlsVariantsAll); #endif INSTANTIATE_TEST_SUITE_P( Pre13Stream, TlsConnectGenericPre13, ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, TlsConnectTestBase::kTlsV10ToV12)); INSTANTIATE_TEST_SUITE_P( Pre13Datagram, TlsConnectGenericPre13, ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11V12)); INSTANTIATE_TEST_SUITE_P(Pre13StreamOnly, TlsConnectStreamPre13, TlsConnectTestBase::kTlsV10ToV12); INSTANTIATE_TEST_SUITE_P(Version12Plus, TlsConnectTls12Plus, ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus)); INSTANTIATE_TEST_SUITE_P( GenericStream, TlsConnectGenericResumption, ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, TlsConnectTestBase::kTlsVAll, ::testing::Values(true, false))); INSTANTIATE_TEST_SUITE_P( GenericDatagram, TlsConnectGenericResumption, ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11Plus, ::testing::Values(true, false))); INSTANTIATE_TEST_SUITE_P( GenericStream, TlsConnectGenericResumptionToken, ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, TlsConnectTestBase::kTlsVAll)); INSTANTIATE_TEST_SUITE_P( GenericDatagram, TlsConnectGenericResumptionToken, ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11Plus)); INSTANTIATE_TEST_SUITE_P(GenericDatagram, TlsConnectTls13ResumptionToken, TlsConnectTestBase::kTlsVariantsAll); } // namespace nss_test