summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc246
1 files changed, 246 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
new file mode 100644
index 0000000000..606e731033
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
@@ -0,0 +1,246 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "sslerr.h"
+
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+/*
+ * The tests in this file test that the TLS state machine is robust against
+ * attacks that alter the order of handshake messages.
+ *
+ * See <https://www.smacktls.com/smack.pdf> for a description of the problems
+ * that this sort of attack can enable.
+ */
+namespace nss_test {
+
+class TlsHandshakeSkipFilter : public TlsRecordFilter {
+ public:
+ // A TLS record filter that skips handshake messages of the identified type.
+ TlsHandshakeSkipFilter(const std::shared_ptr<TlsAgent>& a,
+ uint8_t handshake_type)
+ : TlsRecordFilter(a), handshake_type_(handshake_type), skipped_(false) {}
+
+ protected:
+ // Takes a record; if it is a handshake record, it removes the first handshake
+ // message that is of handshake_type_ type.
+ virtual PacketFilter::Action FilterRecord(
+ const TlsRecordHeader& record_header, const DataBuffer& input,
+ DataBuffer* output) {
+ if (record_header.content_type() != ssl_ct_handshake) {
+ return KEEP;
+ }
+
+ size_t output_offset = 0U;
+ output->Allocate(input.len());
+
+ TlsParser parser(input);
+ while (parser.remaining()) {
+ size_t start = parser.consumed();
+ TlsHandshakeFilter::HandshakeHeader header;
+ DataBuffer ignored;
+ bool complete = false;
+ if (!header.Parse(&parser, record_header, DataBuffer(), &ignored,
+ &complete)) {
+ ADD_FAILURE() << "Error parsing handshake header";
+ return KEEP;
+ }
+ if (!complete) {
+ ADD_FAILURE() << "Don't want to deal with fragmented input";
+ return KEEP;
+ }
+
+ if (skipped_ || header.handshake_type() != handshake_type_) {
+ size_t entire_length = parser.consumed() - start;
+ output->Write(output_offset, input.data() + start, entire_length);
+ // DTLS sequence numbers need to be rewritten
+ if (skipped_ && header.is_dtls()) {
+ output->data()[start + 5] -= 1;
+ }
+ output_offset += entire_length;
+ } else {
+ std::cerr << "Dropping handshake: "
+ << static_cast<unsigned>(handshake_type_) << std::endl;
+ // We only need to report that the output contains changed data if we
+ // drop a handshake message. But once we've skipped one message, we
+ // have to modify all subsequent handshake messages so that they include
+ // the correct DTLS sequence numbers.
+ skipped_ = true;
+ }
+ }
+ output->Truncate(output_offset);
+ return skipped_ ? CHANGE : KEEP;
+ }
+
+ private:
+ // The type of handshake message to drop.
+ uint8_t handshake_type_;
+ // Whether this filter has ever skipped a handshake message. Track this so
+ // that sequence numbers on DTLS handshake messages can be rewritten in
+ // subsequent calls.
+ bool skipped_;
+};
+
+class TlsSkipTest : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ protected:
+ TlsSkipTest()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
+ EnsureTlsSetup();
+ }
+
+ void ServerSkipTest(std::shared_ptr<PacketFilter> filter,
+ uint8_t alert = kTlsAlertUnexpectedMessage) {
+ server_->SetFilter(filter);
+ ConnectExpectAlert(client_, alert);
+ }
+};
+
+class Tls13SkipTest : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
+ protected:
+ Tls13SkipTest()
+ : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
+
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
+ EnsureTlsSetup();
+ }
+
+ void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ filter->EnableDecryption();
+ server_->SetFilter(filter);
+ ExpectAlert(client_, kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ client_->CheckErrorCode(error);
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ }
+
+ void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ filter->EnableDecryption();
+ client_->SetFilter(filter);
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFailOneSide(TlsAgent::SERVER);
+
+ server_->CheckErrorCode(error);
+ ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+
+ client_->Handshake(); // Make sure to consume the alert the server sends.
+ }
+};
+
+TEST_P(TlsSkipTest, SkipCertificateRsa) {
+ EnableOnlyStaticRsaCiphers();
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+TEST_P(TlsSkipTest, SkipCertificateDhe) {
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
+}
+
+TEST_P(TlsSkipTest, SkipCertificateEcdhe) {
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
+}
+
+TEST_P(TlsSkipTest, SkipCertificateEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
+}
+
+TEST_P(TlsSkipTest, SkipServerKeyExchange) {
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
+ auto chain = std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit{std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate),
+ std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange)});
+ ServerSkipTest(chain);
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ auto chain = std::make_shared<ChainedPacketFilter>();
+ chain->Add(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
+ chain->Add(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(chain);
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+TEST_P(Tls13SkipTest, SkipEncryptedExtensions) {
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeEncryptedExtensions),
+ SSL_ERROR_RX_UNEXPECTED_CERTIFICATE);
+}
+
+TEST_P(Tls13SkipTest, SkipServerCertificate) {
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+}
+
+TEST_P(Tls13SkipTest, SkipServerCertificateVerify) {
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
+}
+
+TEST_P(Tls13SkipTest, SkipClientCertificate) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
+ ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ client_, kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+}
+
+TEST_P(Tls13SkipTest, SkipClientCertificateVerify) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
+ ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ client_, kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ SkipTls10, TlsSkipTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10));
+INSTANTIATE_TEST_SUITE_P(SkipVariants, TlsSkipTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV11V12));
+INSTANTIATE_TEST_SUITE_P(Skip13Variants, Tls13SkipTest,
+ TlsConnectTestBase::kTlsVariantsAll);
+} // namespace nss_test