diff options
Diffstat (limited to 'security/nss/gtests/nss_bogo_shim/nss_bogo_shim.cc')
-rw-r--r-- | security/nss/gtests/nss_bogo_shim/nss_bogo_shim.cc | 1266 |
1 files changed, 1266 insertions, 0 deletions
diff --git a/security/nss/gtests/nss_bogo_shim/nss_bogo_shim.cc b/security/nss/gtests/nss_bogo_shim/nss_bogo_shim.cc new file mode 100644 index 0000000000..12adcc5de0 --- /dev/null +++ b/security/nss/gtests/nss_bogo_shim/nss_bogo_shim.cc @@ -0,0 +1,1266 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ +#include "config.h" + +#include <algorithm> +#include <cstdlib> +#include <iostream> +#include <memory> +#include "nspr.h" +#include "nss.h" +#include "prio.h" +#include "prnetdb.h" +#include "secerr.h" +#include "ssl.h" +#include "ssl3prot.h" +#include "sslerr.h" +#include "sslproto.h" +#include "nss_scoped_ptrs.h" +#include "sslimpl.h" +#include "tls13ech.h" +#include "base64.h" + +#include "nsskeys.h" + +static const char* kVersionDisableFlags[] = {"no-ssl3", "no-tls1", "no-tls11", + "no-tls12", "no-tls13"}; + +/* Default EarlyData dummy data determined by Bogo implementation. */ +const unsigned char kBogoDummyData[] = {'h', 'e', 'l', 'l', 'o'}; + +bool exitCodeUnimplemented = false; + +std::string FormatError(PRErrorCode code) { + return std::string(":") + PORT_ErrorToName(code) + ":" + ":" + + PORT_ErrorToString(code); +} + +static void StringRemoveNewlines(std::string& str) { + str.erase(std::remove(str.begin(), str.end(), '\n'), str.cend()); + str.erase(std::remove(str.begin(), str.end(), '\r'), str.cend()); +} + +class TestAgent { + public: + TestAgent(const Config& cfg) : cfg_(cfg) {} + + ~TestAgent() {} + + static std::unique_ptr<TestAgent> Create(const Config& cfg) { + std::unique_ptr<TestAgent> agent(new TestAgent(cfg)); + + if (!agent->Init()) return nullptr; + + return agent; + } + + bool Init() { + if (!ConnectTcp()) { + return false; + } + + if (!SetupKeys()) { + std::cerr << "Couldn't set up keys/certs\n"; + return false; + } + + if (!SetupOptions()) { + std::cerr << "Couldn't configure socket\n"; + return false; + } + + SECStatus rv = SSL_ResetHandshake(ssl_fd_.get(), cfg_.get<bool>("server")); + if (rv != SECSuccess) return false; + + return true; + } + + bool ConnectTcp() { + if (!(cfg_.get<bool>("ipv6") && OpenConnection("::1")) && + !OpenConnection("127.0.0.1")) { + return false; + } + + ssl_fd_ = ScopedPRFileDesc(SSL_ImportFD(NULL, pr_fd_.get())); + if (!ssl_fd_) { + return false; + } + pr_fd_.release(); + + return true; + } + + bool OpenConnection(const char* ip) { + PRStatus prv; + PRNetAddr addr; + + prv = PR_StringToNetAddr(ip, &addr); + + if (prv != PR_SUCCESS) { + return false; + } + + addr.inet.port = PR_htons(cfg_.get<int>("port")); + + pr_fd_ = ScopedPRFileDesc(PR_OpenTCPSocket(addr.raw.family)); + if (!pr_fd_) return false; + + prv = PR_Connect(pr_fd_.get(), &addr, PR_INTERVAL_NO_TIMEOUT); + if (prv != PR_SUCCESS) { + return false; + } + + uint64_t shim_id = cfg_.get<int>("shim-id"); + uint8_t buf[8] = {0}; + for (size_t i = 0; i < 8; i++) { + buf[i] = shim_id & 0xff; + shim_id >>= 8; + } + int sent = PR_Write(pr_fd_.get(), buf, sizeof(buf)); + if (sent != sizeof(buf)) { + return false; + } + + return true; + } + + bool SetupKeys() { + SECStatus rv; + + if (cfg_.get<std::string>("key-file") != "") { + key_ = ScopedSECKEYPrivateKey( + ReadPrivateKey(cfg_.get<std::string>("key-file"))); + if (!key_) return false; + } + if (cfg_.get<std::string>("cert-file") != "") { + cert_ = ScopedCERTCertificate( + ReadCertificate(cfg_.get<std::string>("cert-file"))); + if (!cert_) return false; + } + + // Needed because certs are not entirely valid. + rv = SSL_AuthCertificateHook(ssl_fd_.get(), AuthCertificateHook, this); + if (rv != SECSuccess) return false; + + if (cfg_.get<bool>("server")) { + // Server + rv = SSL_ConfigServerCert(ssl_fd_.get(), cert_.get(), key_.get(), nullptr, + 0); + if (rv != SECSuccess) { + std::cerr << "Couldn't configure server cert\n"; + return false; + } + + } else if (key_ && cert_) { + // Client. + rv = + SSL_GetClientAuthDataHook(ssl_fd_.get(), GetClientAuthDataHook, this); + if (rv != SECSuccess) return false; + } + + return true; + } + + static bool ConvertFromWireVersion(SSLProtocolVariant variant, + int wire_version, uint16_t* lib_version) { + // These default values are used when {min,max}-version isn't given. + if (wire_version == 0 || wire_version == 0xffff) { + *lib_version = static_cast<uint16_t>(wire_version); + return true; + } + +#ifdef TLS_1_3_DRAFT_VERSION + if (wire_version == (0x7f00 | TLS_1_3_DRAFT_VERSION)) { + // N.B. SSL_LIBRARY_VERSION_DTLS_1_3_WIRE == SSL_LIBRARY_VERSION_TLS_1_3 + wire_version = SSL_LIBRARY_VERSION_TLS_1_3; + } +#endif + + if (variant == ssl_variant_datagram) { + switch (wire_version) { + case SSL_LIBRARY_VERSION_DTLS_1_0_WIRE: + *lib_version = SSL_LIBRARY_VERSION_DTLS_1_0; + break; + case SSL_LIBRARY_VERSION_DTLS_1_2_WIRE: + *lib_version = SSL_LIBRARY_VERSION_DTLS_1_2; + break; + case SSL_LIBRARY_VERSION_DTLS_1_3_WIRE: + *lib_version = SSL_LIBRARY_VERSION_DTLS_1_3; + break; + default: + std::cerr << "Unrecognized DTLS version " << wire_version << ".\n"; + return false; + } + } else { + if (wire_version < SSL_LIBRARY_VERSION_3_0 || + wire_version > SSL_LIBRARY_VERSION_TLS_1_3) { + std::cerr << "Unrecognized TLS version " << wire_version << ".\n"; + return false; + } + *lib_version = static_cast<uint16_t>(wire_version); + } + return true; + } + + bool GetVersionRange(SSLVersionRange* range_out, SSLProtocolVariant variant) { + SSLVersionRange supported; + if (SSL_VersionRangeGetSupported(variant, &supported) != SECSuccess) { + return false; + } + + uint16_t min_allowed; + uint16_t max_allowed; + if (!ConvertFromWireVersion(variant, cfg_.get<int>("min-version"), + &min_allowed)) { + return false; + } + if (!ConvertFromWireVersion(variant, cfg_.get<int>("max-version"), + &max_allowed)) { + return false; + } + + min_allowed = std::max(min_allowed, supported.min); + max_allowed = std::min(max_allowed, supported.max); + + bool found_min = false; + bool found_max = false; + // Ignore -no-ssl3, because SSLv3 is never supported. + for (size_t i = 1; i < PR_ARRAY_SIZE(kVersionDisableFlags); ++i) { + auto version = + static_cast<uint16_t>(SSL_LIBRARY_VERSION_TLS_1_0 + (i - 1)); + if (variant == ssl_variant_datagram) { + // In DTLS mode, the -no-tlsN flags refer to DTLS versions, + // but NSS wants the corresponding TLS versions. + if (version == SSL_LIBRARY_VERSION_TLS_1_1) { + // DTLS 1.1 doesn't exist. + continue; + } + if (version == SSL_LIBRARY_VERSION_TLS_1_0) { + version = SSL_LIBRARY_VERSION_DTLS_1_0; + } + } + + if (version < min_allowed) { + continue; + } + if (version > max_allowed) { + break; + } + + const bool allowed = !cfg_.get<bool>(kVersionDisableFlags[i]); + + if (!found_min && allowed) { + found_min = true; + range_out->min = version; + } + if (found_min && !found_max) { + if (allowed) { + range_out->max = version; + } else { + found_max = true; + } + } + if (found_max && allowed) { + std::cerr << "Discontiguous version range.\n"; + return false; + } + } + + if (!found_min) { + std::cerr << "All versions disabled.\n"; + } + return found_min; + } + + bool SetupOptions() { + SECStatus rv = + SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE); + if (rv != SECSuccess) return false; + + rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_SESSION_TICKETS, PR_TRUE); + if (rv != SECSuccess) return false; + + SSLVersionRange vrange; + if (!GetVersionRange(&vrange, ssl_variant_stream)) return false; + + rv = SSL_VersionRangeSet(ssl_fd_.get(), &vrange); + if (rv != SECSuccess) return false; + + SSLVersionRange verify_vrange; + rv = SSL_VersionRangeGet(ssl_fd_.get(), &verify_vrange); + if (rv != SECSuccess) return false; + if (vrange.min != verify_vrange.min || vrange.max != verify_vrange.max) + return false; + + rv = SSL_OptionSet(ssl_fd_.get(), SSL_NO_CACHE, false); + if (rv != SECSuccess) return false; + + auto alpn = cfg_.get<std::string>("advertise-alpn"); + if (!alpn.empty()) { + assert(!cfg_.get<bool>("server")); + + rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_ALPN, PR_TRUE); + if (rv != SECSuccess) return false; + + rv = SSL_SetNextProtoNego( + ssl_fd_.get(), reinterpret_cast<const unsigned char*>(alpn.c_str()), + alpn.size()); + if (rv != SECSuccess) return false; + } + + // Set supported signature schemes. + auto sign_prefs = cfg_.get<std::vector<int>>("signing-prefs"); + auto verify_prefs = cfg_.get<std::vector<int>>("verify-prefs"); + if (sign_prefs.empty()) { + sign_prefs = verify_prefs; + } else if (!verify_prefs.empty()) { + return false; // Both shouldn't be set. + } + if (!sign_prefs.empty()) { + std::vector<SSLSignatureScheme> sig_schemes; + std::transform( + sign_prefs.begin(), sign_prefs.end(), std::back_inserter(sig_schemes), + [](int scheme) { return static_cast<SSLSignatureScheme>(scheme); }); + + rv = SSL_SignatureSchemePrefSet( + ssl_fd_.get(), sig_schemes.data(), + static_cast<unsigned int>(sig_schemes.size())); + if (rv != SECSuccess) return false; + } + + if (cfg_.get<bool>("fallback-scsv")) { + rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_FALLBACK_SCSV, PR_TRUE); + if (rv != SECSuccess) return false; + } + + if (cfg_.get<bool>("false-start")) { + rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_FALSE_START, PR_TRUE); + if (rv != SECSuccess) return false; + } + + if (cfg_.get<bool>("enable-ocsp-stapling")) { + rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_OCSP_STAPLING, PR_TRUE); + if (rv != SECSuccess) return false; + } + + bool requireClientCert = cfg_.get<bool>("require-any-client-certificate"); + if (requireClientCert || cfg_.get<bool>("verify-peer")) { + assert(cfg_.get<bool>("server")); + + rv = SSL_OptionSet(ssl_fd_.get(), SSL_REQUEST_CERTIFICATE, PR_TRUE); + if (rv != SECSuccess) return false; + + rv = SSL_OptionSet( + ssl_fd_.get(), SSL_REQUIRE_CERTIFICATE, + requireClientCert ? SSL_REQUIRE_ALWAYS : SSL_REQUIRE_NO_ERROR); + if (rv != SECSuccess) return false; + } + + if (!cfg_.get<bool>("server")) { + auto hostname = cfg_.get<std::string>("host-name"); + if (!hostname.empty()) { + rv = SSL_SetURL(ssl_fd_.get(), hostname.c_str()); + } else { + // Needed to make resumption work. + rv = SSL_SetURL(ssl_fd_.get(), "server"); + } + if (rv != SECSuccess) return false; + + // Setup ECH configs on client if provided + auto echConfigList = cfg_.get<std::string>("ech-config-list"); + if (!echConfigList.empty()) { + unsigned int binLen; + auto bin = ATOB_AsciiToData(echConfigList.c_str(), &binLen); + rv = SSLExp_SetClientEchConfigs(ssl_fd_.get(), bin, binLen); + if (rv != SECSuccess) return false; + free(bin); + } + + if (cfg_.get<bool>("enable-grease")) { + rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_GREASE, PR_TRUE); + if (rv != SECSuccess) return false; + } + + if (cfg_.get<bool>("permute-extensions")) { + rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_CH_EXTENSION_PERMUTATION, + PR_TRUE); + if (rv != SECSuccess) return false; + } + + } else { + // GREASE - BoGo expects servers to enable GREASE by default + rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_GREASE, PR_TRUE); + if (rv != SECSuccess) return false; + } + + rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_EXTENDED_MASTER_SECRET, + PR_TRUE); + if (rv != SECSuccess) return false; + + if (cfg_.get<bool>("server")) { + // BoGo expects servers to enable ECH (backend) by default + rv = SSLExp_EnableTls13BackendEch(ssl_fd_.get(), true); + if (rv != SECSuccess) return false; + } + + if (cfg_.get<bool>("enable-ech-grease")) { + rv = SSLExp_EnableTls13GreaseEch(ssl_fd_.get(), true); + if (rv != SECSuccess) return false; + } + + if (cfg_.get<bool>("enable-early-data")) { + rv = SSL_OptionSet(ssl_fd_.get(), SSL_ENABLE_0RTT_DATA, PR_TRUE); + if (rv != SECSuccess) return false; + } + + if (!ConfigureGroups()) return false; + + if (!ConfigureCiphers()) return false; + + return true; + } + + bool ConfigureGroups() { + auto curves = cfg_.get<std::vector<int>>("curves"); + if (curves.size() > 0) { + std::vector<SSLNamedGroup> groups; + std::transform( + curves.begin(), curves.end(), std::back_inserter(groups), + [](int curve) { return static_cast<SSLNamedGroup>(curve); }); + SECStatus rv = + SSL_NamedGroupConfig(ssl_fd_.get(), &groups[0], groups.size()); + if (rv != SECSuccess) { + return false; + } + // Xyber768 is disabled by policy by default, so if it's requested + // we need to update the policy flags as well. + for (auto group : groups) { + if (group == ssl_grp_kem_xyber768d00) { + NSS_SetAlgorithmPolicy(SEC_OID_XYBER768D00, NSS_USE_ALG_IN_SSL_KX, 0); + } + } + } + + return true; + } + + bool ConfigureCiphers() { + auto cipherList = cfg_.get<std::string>("nss-cipher"); + + if (cipherList.empty()) { + return EnableNonExportCiphers(); + } + + for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { + SSLCipherSuiteInfo csinfo; + std::string::size_type n; + SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, + sizeof(csinfo)); + if (rv != SECSuccess) { + return false; + } + + // Check if cipherList contains the name of the Cipher Suite and + // enable/disable accordingly. + n = cipherList.find(csinfo.cipherSuiteName, 0); + if (std::string::npos == n) { + rv = SSL_CipherPrefSet(ssl_fd_.get(), SSL_ImplementedCiphers[i], + PR_FALSE); + } else { + rv = SSL_CipherPrefSet(ssl_fd_.get(), SSL_ImplementedCiphers[i], + PR_TRUE); + } + if (rv != SECSuccess) { + return false; + } + } + return true; + } + + bool EnableNonExportCiphers() { + for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { + SSLCipherSuiteInfo csinfo; + + SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, + sizeof(csinfo)); + if (rv != SECSuccess) { + return false; + } + + rv = SSL_CipherPrefSet(ssl_fd_.get(), SSL_ImplementedCiphers[i], PR_TRUE); + if (rv != SECSuccess) { + return false; + } + } + return true; + } + + // Dummy auth certificate hook. + static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd, + PRBool checksig, PRBool isServer) { + return SECSuccess; + } + + static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd, + CERTDistNames* caNames, + CERTCertificate** cert, + SECKEYPrivateKey** privKey) { + TestAgent* a = static_cast<TestAgent*>(self); + *cert = CERT_DupCertificate(a->cert_.get()); + *privKey = SECKEY_CopyPrivateKey(a->key_.get()); + return SECSuccess; + } + + SECStatus Handshake() { return SSL_ForceHandshake(ssl_fd_.get()); } + + // Implement a trivial echo client/server. Read bytes from the other side, + // flip all the bits, and send them back. + SECStatus ReadWrite() { + for (;;) { + uint8_t block[512]; + int32_t rv = PR_Read(ssl_fd_.get(), block, sizeof(block)); + if (rv < 0) { + std::cerr << "Failure reading\n"; + return SECFailure; + } + if (rv == 0) return SECSuccess; + + int32_t len = rv; + for (int32_t i = 0; i < len; ++i) { + block[i] ^= 0xff; + } + + rv = PR_Write(ssl_fd_.get(), block, len); + if (rv != len) { + std::cerr << "Write failure\n"; + PORT_SetError(SEC_ERROR_OUTPUT_LEN); + return SECFailure; + } + } + } + + // Write bytes to the other side then read them back and check + // that they were correctly XORed as in ReadWrite. + SECStatus WriteRead() { + static const uint8_t ch = 'E'; + + // We do 600-byte blocks to provide mis-alignment of the + // reader and writer. + uint8_t block[600]; + memset(block, ch, sizeof(block)); + int32_t rv = PR_Write(ssl_fd_.get(), block, sizeof(block)); + if (rv != sizeof(block)) { + std::cerr << "Write failure\n"; + PORT_SetError(SEC_ERROR_OUTPUT_LEN); + return SECFailure; + } + + size_t left = sizeof(block); + while (left) { + rv = PR_Read(ssl_fd_.get(), block, left); + if (rv < 0) { + std::cerr << "Failure reading\n"; + return SECFailure; + } + if (rv == 0) { + PORT_SetError(SEC_ERROR_INPUT_LEN); + return SECFailure; + } + + int32_t len = rv; + for (int32_t i = 0; i < len; ++i) { + if (block[i] != (ch ^ 0xff)) { + PORT_SetError(SEC_ERROR_BAD_DATA); + return SECFailure; + } + } + left -= len; + } + return SECSuccess; + } + + SECStatus CheckALPN(std::string expectedALPN) { + SECStatus rv; + SSLNextProtoState state; + char chosen[256]; + unsigned int chosen_len; + + rv = SSL_GetNextProto(ssl_fd_.get(), &state, + reinterpret_cast<unsigned char*>(chosen), &chosen_len, + sizeof(chosen)); + if (rv != SECSuccess) { + PRErrorCode err = PR_GetError(); + std::cerr << "SSL_GetNextProto failed with error=" << FormatError(err) + << std::endl; + return SECFailure; + } + + assert(chosen_len <= sizeof(chosen)); + if (std::string(chosen, chosen_len) != expectedALPN) { + std::cerr << "Expexted ALPN (" << expectedALPN << ") != Choosen ALPN (" + << std::string(chosen, chosen_len) << ")" << std::endl; + return SECFailure; + } + + return SECSuccess; + } + + SECStatus AdvertiseALPN(std::string alpn) { + return SSL_SetNextProtoNego( + ssl_fd_.get(), reinterpret_cast<const unsigned char*>(alpn.c_str()), + alpn.size()); + } + + /* Certificate Encoding/Decoding Shrinking functions + * See + * https://boringssl.googlesource.com/boringssl/+/master/ssl/test/runner/runner.go#16168 + */ + static SECStatus certCompressionShrinkEncode(const SECItem* input, + SECItem* output) { + if (input == NULL || input->data == NULL) { + PR_SetError(SEC_ERROR_INVALID_ARGS, 0); + return SECFailure; + } + + if (input->len < 2) { + std::cerr << "Certificate is too short. " << std::endl; + PR_SetError(SEC_ERROR_INVALID_ARGS, 0); + return SECFailure; + } + + SECITEM_AllocItem(NULL, output, input->len - 2); + if (output == NULL || output->data == NULL) { + return SECFailure; + } + + /* The shrinking encoding primitive expects the first two bytes of a + * certificate to be equal to 0. */ + if (input->data[0] != 0 || input->data[1] != 0) { + std::cerr << "Cannot compress certificate message." << std::endl; + return SECFailure; + } + + for (size_t i = 0; i < output->len; i++) { + output->data[i] = input->data[i + 2]; + } + return SECSuccess; + } + + static SECStatus certCompressionShrinkDecode( + const SECItem* input, SECItem* output, + size_t expectedLenDecodedCertificate) { + if (input == NULL || input->data == NULL) { + PR_SetError(SEC_ERROR_INVALID_ARGS, 0); + return SECFailure; + } + + SECITEM_AllocItem(NULL, output, input->len + 2); + if (output == NULL || output->data == NULL) { + return SECFailure; + } + + if (expectedLenDecodedCertificate != output->len) { + std::cerr << "Cannot decompress certificate message." << std::endl; + return SECFailure; + } + + output->data[0] = 0; + output->data[1] = 0; + for (size_t i = 0; i < input->len; i++) { + output->data[i + 2] = input->data[i]; + } + + return SECSuccess; + } + + /* Certificate Encoding/Decoding Expanding functions + * See + * https://boringssl.googlesource.com/boringssl/+/master/ssl/test/runner/runner.go#16186 + */ + static SECStatus certCompressionExpandEncode(const SECItem* input, + SECItem* output) { + if (input == NULL || input->data == NULL) { + PR_SetError(SEC_ERROR_INVALID_ARGS, 0); + return SECFailure; + } + + SECITEM_AllocItem(NULL, output, input->len + 4); + + if (output == NULL || output->data == NULL) { + return SECFailure; + } + + output->data[0] = 1; + output->data[1] = 2; + output->data[2] = 3; + output->data[3] = 4; + for (size_t i = 0; i < input->len; i++) { + output->data[i + 4] = input->data[i]; + } + + return SECSuccess; + } + + static SECStatus certCompressionExpandDecode( + const SECItem* input, SECItem* output, + size_t expectedLenDecodedCertificate) { + if (input == NULL || input->data == NULL) { + PR_SetError(SEC_ERROR_INVALID_ARGS, 0); + return SECFailure; + } + + if (input->len < 4) { + PR_SetError(SEC_ERROR_INVALID_ARGS, 0); + std::cerr << "Certificate is too short. " << std::endl; + return SECFailure; + } + + SECITEM_AllocItem(NULL, output, input->len - 4); + + if (output == NULL || output->data == NULL) { + return SECFailure; + } + + /* See the corresponding compression function. */ + if (input->data[0] != 1 || input->data[1] != 2 || input->data[2] != 3 || + input->data[3] != 4) { + std::cerr << "Cannot decompress certificate message." << std::endl; + return SECFailure; + } + + if (expectedLenDecodedCertificate != output->len) { + std::cerr << "Cannot decompress certificate message." << std::endl; + return SECFailure; + } + + for (size_t i = 0; i < output->len; i++) { + output->data[i] = input->data[i + 4]; + } + return SECSuccess; + } + + /* Certificate Encoding/Decoding Random functions + * See + * https://boringssl.googlesource.com/boringssl/+/master/ssl/test/runner/runner.go#16201 + */ + static SECStatus certCompressionRandomEncode(const SECItem* input, + SECItem* output) { + if (input == NULL || input->data == NULL) { + PR_SetError(SEC_ERROR_INVALID_ARGS, 0); + return SECFailure; + } + + SECITEM_AllocItem(NULL, output, input->len + 1); + + if (output == NULL || output->data == NULL) { + return SECFailure; + } + + SECStatus rv = PK11_GenerateRandom(output->data, 1); + + if (rv != SECSuccess) { + std::cerr << "Failed to generate randomness. " << std::endl; + return SECFailure; + } + + for (size_t i = 0; i < input->len; i++) { + output->data[i + 1] = input->data[i]; + } + return SECSuccess; + } + + static SECStatus certCompressionRandomDecode( + const SECItem* input, SECItem* output, + size_t expectedLenDecodedCertificate) { + if (input == NULL || input->data == NULL) { + PR_SetError(SEC_ERROR_INVALID_ARGS, 0); + return SECFailure; + } + + if (input->len < 1) { + PR_SetError(SEC_ERROR_INVALID_ARGS, 0); + std::cerr << "Certificate is too short. " << std::endl; + return SECFailure; + } + SECITEM_AllocItem(NULL, output, input->len - 1); + + if (output == NULL || output->data == NULL) { + return SECFailure; + } + + if (expectedLenDecodedCertificate != output->len) { + std::cerr << "Cannot decompress certificate message." << std::endl; + return SECFailure; + } + + for (size_t i = 0; i < output->len; i++) { + output->data[i] = input->data[i + 1]; + } + return SECSuccess; + } + + SECStatus DoExchange(bool resuming) { + SECStatus rv; + int earlyDataSent = 0; + std::string str; + sslSocket* ss = ssl_FindSocket(ssl_fd_.get()); + if (!ss) { + return SECFailure; + } + if (cfg_.get<bool>("install-cert-compression-algs")) { + SSLCertificateCompressionAlgorithm t = { + (SSLCertificateCompressionAlgorithmID)0xff01, + "shrinkingCompressionAlg", certCompressionShrinkEncode, + certCompressionShrinkDecode}; + + SSLCertificateCompressionAlgorithm t1 = { + (SSLCertificateCompressionAlgorithmID)0xff02, + "expandingCompressionAlg", certCompressionExpandEncode, + certCompressionExpandDecode}; + + SSLCertificateCompressionAlgorithm t2 = { + (SSLCertificateCompressionAlgorithmID)0xff03, "randomCompressionAlg", + certCompressionRandomEncode, certCompressionRandomDecode}; + + SSLExp_SetCertificateCompressionAlgorithm(ssl_fd_.get(), t); + SSLExp_SetCertificateCompressionAlgorithm(ssl_fd_.get(), t1); + SSLExp_SetCertificateCompressionAlgorithm(ssl_fd_.get(), t2); + } + + /* Apply resumption SSL options (if any). */ + if (resuming) { + /* Client options */ + if (!cfg_.get<bool>("server")) { + auto resumeEchConfigList = + cfg_.get<std::string>("on-resume-ech-config-list"); + if (!resumeEchConfigList.empty()) { + unsigned int binLen; + auto bin = ATOB_AsciiToData(resumeEchConfigList.c_str(), &binLen); + rv = SSLExp_SetClientEchConfigs(ssl_fd_.get(), bin, binLen); + if (rv != SECSuccess) { + PRErrorCode err = PR_GetError(); + std::cerr << "Setting up resumption ECH configs failed with error=" + << err << FormatError(err) << std::endl; + } + free(bin); + } + + str = cfg_.get<std::string>("on-resume-advertise-alpn"); + if (!str.empty()) { + if (AdvertiseALPN(str) != SECSuccess) { + PRErrorCode err = PR_GetError(); + std::cerr << "Setting up resumption ALPN failed with error=" << err + << FormatError(err) << std::endl; + } + } + } + + } else { /* Explicitly not on resume (on initial) */ + /* Client options */ + if (!cfg_.get<bool>("server")) { + str = cfg_.get<std::string>("on-initial-advertise-alpn"); + if (!str.empty()) { + if (AdvertiseALPN(str) != SECSuccess) { + PRErrorCode err = PR_GetError(); + std::cerr << "Setting up initial ALPN failed with error=" << err + << FormatError(err) << std::endl; + } + } + } + } + + /* If client send ClientHello. */ + if (!cfg_.get<bool>("server")) { + ssl_Get1stHandshakeLock(ss); + rv = ssl_BeginClientHandshake(ss); + ssl_Release1stHandshakeLock(ss); + if (rv != SECSuccess) { + PRErrorCode err = PR_GetError(); + std::cerr << "Handshake failed with error=" << err << FormatError(err) + << std::endl; + return SECFailure; + } + + /* If the client is resuming. */ + if (ss->statelessResume) { + SSLPreliminaryChannelInfo pinfo; + rv = SSL_GetPreliminaryChannelInfo(ssl_fd_.get(), &pinfo, + sizeof(SSLPreliminaryChannelInfo)); + if (rv != SECSuccess) { + PRErrorCode err = PR_GetError(); + std::cerr << "SSL_GetPreliminaryChannelInfo failed with " << err + << std::endl; + return SECFailure; + } + + /* Check that the used ticket supports early data. */ + if (cfg_.get<bool>("expect-ticket-supports-early-data")) { + if (!pinfo.ticketSupportsEarlyData) { + std::cerr << "Expected ticket to support EarlyData" << std::endl; + return SECFailure; + } + } + + /* If the client should send EarlyData. */ + if (cfg_.get<bool>("on-resume-shim-writes-first")) { + earlyDataSent = + ssl_SecureWrite(ss, kBogoDummyData, sizeof(kBogoDummyData)); + if (earlyDataSent < 0) { + std::cerr << "Sending of EarlyData failed" << std::endl; + return SECFailure; + } + } + + if (cfg_.get<bool>("expect-no-offer-early-data")) { + if (earlyDataSent) { + std::cerr << "Unexpectedly offered EarlyData" << std::endl; + return SECFailure; + } + } + } + } + + /* As server start, as client continue handshake. */ + rv = Handshake(); + + /* Retry config evaluation must be done before error handling since + * handshake failure is intended on ech_required tests. */ + if (cfg_.get<bool>("expect-no-ech-retry-configs")) { + if (ss->xtnData.ech && ss->xtnData.ech->retryConfigsValid) { + std::cerr << "Unexpectedly received ECH retry configs" << std::endl; + return SECFailure; + } + } + + /* If given, verify received retry configs before error handling. */ + std::string expectedRCs64 = + cfg_.get<std::string>("expect-ech-retry-configs"); + if (!expectedRCs64.empty()) { + SECItem receivedRCs; + + /* Get received RetryConfigs. */ + if (SSLExp_GetEchRetryConfigs(ssl_fd_.get(), &receivedRCs) != + SECSuccess) { + std::cerr << "Failed to get ECH retry configs." << std::endl; + return SECFailure; + } + + /* (Re-)Encode received configs to compare with expected ASCII string. */ + std::string receivedRCs64( + BTOA_DataToAscii(receivedRCs.data, receivedRCs.len)); + /* Remove newlines (for unknown reasons) added during b64 encoding. */ + StringRemoveNewlines(receivedRCs64); + + if (receivedRCs64 != expectedRCs64) { + std::cerr << "Received ECH retry configs did not match expected retry " + "configs." + << std::endl; + return SECFailure; + } + } + + /* Check if handshake succeeded. */ + if (rv != SECSuccess) { + PRErrorCode err = PR_GetError(); + std::cerr << "Handshake failed with error=" << err << FormatError(err) + << std::endl; + return SECFailure; + } + + /* If parts of data was sent as EarlyData make sure to send possibly + * unsent rest. This is required to pass bogo resumption tests. */ + if (earlyDataSent && earlyDataSent < int(sizeof(kBogoDummyData))) { + int toSend = sizeof(kBogoDummyData) - earlyDataSent; + earlyDataSent = + ssl_SecureWrite(ss, &kBogoDummyData[earlyDataSent], toSend); + if (earlyDataSent != toSend) { + std::cerr + << "Could not send rest of EarlyData after handshake completion" + << std::endl; + return SECFailure; + } + } + + if (cfg_.get<bool>("write-then-read")) { + rv = WriteRead(); + if (rv != SECSuccess) { + PRErrorCode err = PR_GetError(); + std::cerr << "WriteRead failed with error=" << FormatError(err) + << std::endl; + return SECFailure; + } + } else { + rv = ReadWrite(); + if (rv != SECSuccess) { + PRErrorCode err = PR_GetError(); + std::cerr << "ReadWrite failed with error=" << FormatError(err) + << std::endl; + return SECFailure; + } + } + + SSLChannelInfo info; + rv = SSL_GetChannelInfo(ssl_fd_.get(), &info, sizeof(info)); + if (rv != SECSuccess) { + PRErrorCode err = PR_GetError(); + std::cerr << "SSL_GetChannelInfo failed with error=" << FormatError(err) + << std::endl; + return SECFailure; + } + + auto sig_alg = cfg_.get<int>("expect-peer-signature-algorithm"); + if (sig_alg) { + auto expected = static_cast<SSLSignatureScheme>(sig_alg); + if (info.signatureScheme != expected) { + std::cerr << "Unexpected signature scheme" << std::endl; + return SECFailure; + } + } + + auto curve_id = cfg_.get<int>("expect-curve-id"); + if (curve_id) { + auto expected = static_cast<SSLNamedGroup>(curve_id); + if (info.keaGroup != expected && !(info.keaGroup == ssl_grp_none && + info.originalKeaGroup == expected)) { + std::cerr << "Unexpected named group" << std::endl; + return SECFailure; + } + } + + if (cfg_.get<bool>("expect-ech-accept")) { + if (!info.echAccepted) { + std::cerr << "Expected ECH" << std::endl; + return SECFailure; + } + } + + if (cfg_.get<bool>("expect-hrr")) { + if (!ss->ssl3.hs.helloRetry) { + std::cerr << "Expected HRR" << std::endl; + return SECFailure; + } + } + + str = cfg_.get<std::string>("expect-alpn"); + if (!str.empty()) { + if (CheckALPN(str) != SECSuccess) { + std::cerr << "Unexpected ALPN" << std::endl; + return SECFailure; + } + } + + /* if resumed */ + if (info.resumed) { + if (cfg_.get<bool>("expect-session-miss")) { + std::cerr << "Expected reject Resume" << std::endl; + return SECFailure; + } + + if (cfg_.get<bool>("on-resume-expect-ech-accept")) { + if (!info.echAccepted) { + std::cerr << "Expected ECH on Resume" << std::endl; + return SECFailure; + } + } + + if (cfg_.get<bool>("on-resume-expect-reject-early-data")) { + if (info.earlyDataAccepted) { + std::cerr << "Expected reject EarlyData" << std::endl; + return SECFailure; + } + } + if (cfg_.get<bool>("on-resume-expect-accept-early-data")) { + if (!info.earlyDataAccepted) { + std::cerr << "Expected accept EarlyData" << std::endl; + return SECFailure; + } + } + + /* On successfully resumed connection. */ + if (info.earlyDataAccepted) { + str = cfg_.get<std::string>("on-resume-expect-alpn"); + if (!str.empty()) { + if (CheckALPN(str) != SECSuccess) { + std::cerr << "Unexpected ALPN on Resume" << std::endl; + return SECFailure; + } + } else { /* No real resume but new handshake on EarlyData rejection. */ + /* On Retry... */ + str = cfg_.get<std::string>("on-retry-expect-alpn"); + if (!str.empty()) { + if (CheckALPN(str) != SECSuccess) { + std::cerr << "Unexpected ALPN on HRR" << std::endl; + return SECFailure; + } + } + } + } + + } else { /* Explicitly not on resume */ + if (cfg_.get<bool>("on-initial-expect-ech-accept")) { + if (!info.echAccepted) { + std::cerr << "Expected ECH accept on initial connection" << std::endl; + return SECFailure; + } + } + + str = cfg_.get<std::string>("on-initial-expect-alpn"); + if (!str.empty()) { + if (CheckALPN(str) != SECSuccess) { + std::cerr << "Unexpected ALPN on Initial" << std::endl; + return SECFailure; + } + } + } + + return SECSuccess; + } + + private: + const Config& cfg_; + ScopedPRFileDesc pr_fd_; + ScopedPRFileDesc ssl_fd_; + ScopedCERTCertificate cert_; + ScopedSECKEYPrivateKey key_; +}; + +std::unique_ptr<const Config> ReadConfig(int argc, char** argv) { + std::unique_ptr<Config> cfg(new Config()); + + cfg->AddEntry<int>("port", 0); + cfg->AddEntry<bool>("ipv6", false); + cfg->AddEntry<int>("shim-id", 0); + cfg->AddEntry<bool>("server", false); + cfg->AddEntry<int>("resume-count", 0); + cfg->AddEntry<std::string>("key-file", ""); + cfg->AddEntry<std::string>("cert-file", ""); + cfg->AddEntry<int>("min-version", 0); + cfg->AddEntry<int>("max-version", 0xffff); + for (auto flag : kVersionDisableFlags) { + cfg->AddEntry<bool>(flag, false); + } + cfg->AddEntry<bool>("fallback-scsv", false); + cfg->AddEntry<bool>("false-start", false); + cfg->AddEntry<bool>("enable-ocsp-stapling", false); + cfg->AddEntry<bool>("write-then-read", false); + cfg->AddEntry<bool>("require-any-client-certificate", false); + cfg->AddEntry<bool>("verify-peer", false); + cfg->AddEntry<bool>("is-handshaker-supported", false); + cfg->AddEntry<std::string>("handshaker-path", ""); // Ignore this + cfg->AddEntry<std::string>("advertise-alpn", ""); + cfg->AddEntry<std::string>("on-initial-advertise-alpn", ""); + cfg->AddEntry<std::string>("on-resume-advertise-alpn", ""); + cfg->AddEntry<std::string>("expect-alpn", ""); + cfg->AddEntry<std::string>("on-initial-expect-alpn", ""); + cfg->AddEntry<std::string>("on-resume-expect-alpn", ""); + cfg->AddEntry<std::string>("on-retry-expect-alpn", ""); + cfg->AddEntry<std::vector<int>>("signing-prefs", std::vector<int>()); + cfg->AddEntry<std::vector<int>>("verify-prefs", std::vector<int>()); + cfg->AddEntry<int>("expect-peer-signature-algorithm", 0); + cfg->AddEntry<std::string>("nss-cipher", ""); + cfg->AddEntry<std::string>("host-name", ""); + cfg->AddEntry<std::string>("ech-config-list", ""); + cfg->AddEntry<std::string>("on-resume-ech-config-list", ""); + cfg->AddEntry<bool>("expect-ech-accept", false); + cfg->AddEntry<bool>("expect-hrr", false); + cfg->AddEntry<bool>("enable-ech-grease", false); + cfg->AddEntry<bool>("enable-early-data", false); + cfg->AddEntry<bool>("enable-grease", false); + cfg->AddEntry<bool>("permute-extensions", false); + cfg->AddEntry<bool>("on-resume-expect-reject-early-data", false); + cfg->AddEntry<bool>("on-resume-expect-accept-early-data", false); + cfg->AddEntry<bool>("expect-ticket-supports-early-data", false); + cfg->AddEntry<bool>("on-resume-shim-writes-first", + false); // Always means 0Rtt write + cfg->AddEntry<bool>("shim-writes-first", + false); // Unimplemented since not required so far + cfg->AddEntry<bool>("expect-session-miss", false); + cfg->AddEntry<std::string>("expect-ech-retry-configs", ""); + cfg->AddEntry<bool>("expect-no-ech-retry-configs", false); + cfg->AddEntry<bool>("on-initial-expect-ech-accept", false); + cfg->AddEntry<bool>("on-resume-expect-ech-accept", false); + cfg->AddEntry<bool>("expect-no-offer-early-data", false); + /* NSS does not support earlydata rejection reason logging => Ignore. */ + cfg->AddEntry<std::string>("on-resume-expect-early-data-reason", "none"); + cfg->AddEntry<std::string>("on-retry-expect-early-data-reason", "none"); + cfg->AddEntry<std::vector<int>>("curves", std::vector<int>()); + cfg->AddEntry<int>("expect-curve-id", 0); + cfg->AddEntry<bool>("install-cert-compression-algs", false); + + auto rv = cfg->ParseArgs(argc, argv); + switch (rv) { + case Config::kOK: + break; + case Config::kUnknownFlag: + exitCodeUnimplemented = true; + default: + return nullptr; + } + + // Needed to change to std::unique_ptr<const Config> + return std::move(cfg); +} + +bool RunCycle(std::unique_ptr<const Config>& cfg, bool resuming = false) { + std::unique_ptr<TestAgent> agent(TestAgent::Create(*cfg)); + return agent && agent->DoExchange(resuming) == SECSuccess; +} + +int GetExitCode(bool success) { + if (exitCodeUnimplemented) { + return 89; + } + + if (success) { + return 0; + } + + return 1; +} + +int main(int argc, char** argv) { + std::unique_ptr<const Config> cfg = ReadConfig(argc, argv); + if (!cfg) { + return GetExitCode(false); + } + + if (cfg->get<bool>("is-handshaker-supported")) { + std::cout << "No\n"; + return 0; + } + + if (cfg->get<bool>("server")) { + if (SSL_ConfigServerSessionIDCache(1024, 0, 0, ".") != SECSuccess) { + std::cerr << "Couldn't configure session cache\n"; + return 1; + } + } + + if (NSS_NoDB_Init(nullptr) != SECSuccess) { + return 1; + } + + // Run a single test cycle. + bool success = RunCycle(cfg); + + int resume_count = cfg->get<int>("resume-count"); + while (success && resume_count-- > 0) { + std::cout << "Resuming" << std::endl; + success = RunCycle(cfg, true); + } + + SSL_ClearSessionCache(); + + if (cfg->get<bool>("server")) { + SSL_ShutdownServerSessionIDCache(); + } + + if (NSS_Shutdown() != SECSuccess) { + success = false; + } + + return GetExitCode(success); +} |