diff options
Diffstat (limited to 'security/manager/ssl/tests/unit/tlsserver/lib')
5 files changed, 1048 insertions, 0 deletions
diff --git a/security/manager/ssl/tests/unit/tlsserver/lib/OCSPCommon.cpp b/security/manager/ssl/tests/unit/tlsserver/lib/OCSPCommon.cpp new file mode 100644 index 0000000000..be9a9af9b1 --- /dev/null +++ b/security/manager/ssl/tests/unit/tlsserver/lib/OCSPCommon.cpp @@ -0,0 +1,204 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "OCSPCommon.h" + +#include <stdio.h> + +#include "mozpkix/test/pkixtestutil.h" +#include "mozpkix/test/pkixtestnss.h" +#include "TLSServer.h" +#include "secder.h" +#include "secerr.h" + +using namespace mozilla; +using namespace mozilla::pkix; +using namespace mozilla::pkix::test; +using namespace mozilla::test; + +static TestKeyPair* CreateTestKeyPairFromCert( + const UniqueCERTCertificate& cert) { + ScopedSECKEYPrivateKey privateKey(PK11_FindKeyByAnyCert(cert.get(), nullptr)); + if (!privateKey) { + return nullptr; + } + ScopedSECKEYPublicKey publicKey(CERT_ExtractPublicKey(cert.get())); + if (!publicKey) { + return nullptr; + } + return CreateTestKeyPair(RSA_PKCS1(), publicKey, privateKey); +} + +SECItemArray* GetOCSPResponseForType(OCSPResponseType aORT, + const UniqueCERTCertificate& aCert, + const UniquePLArenaPool& aArena, + const char* aAdditionalCertName, + time_t aThisUpdateSkew) { + MOZ_ASSERT(aArena); + MOZ_ASSERT(aCert); + // Note: |aAdditionalCertName| may or may not need to be non-null depending + // on the |aORT| value given. + + if (aORT == ORTNone) { + if (gDebugLevel >= DEBUG_WARNINGS) { + fprintf(stderr, + "GetOCSPResponseForType called with type ORTNone, " + "which makes no sense.\n"); + } + return nullptr; + } + + if (aORT == ORTEmpty) { + SECItemArray* arr = SECITEM_AllocArray(aArena.get(), nullptr, 1); + arr->items[0].data = nullptr; + arr->items[0].len = 0; + return arr; + } + + time_t now = time(nullptr) + aThisUpdateSkew; + time_t oldNow = now - (8 * Time::ONE_DAY_IN_SECONDS); + + mozilla::UniqueCERTCertificate cert(CERT_DupCertificate(aCert.get())); + + if (aORT == ORTGoodOtherCert) { + cert.reset(PK11_FindCertFromNickname(aAdditionalCertName, nullptr)); + if (!cert) { + PrintPRError("PK11_FindCertFromNickname failed"); + return nullptr; + } + } + // XXX CERT_FindCertIssuer uses the old, deprecated path-building logic + mozilla::UniqueCERTCertificate issuerCert( + CERT_FindCertIssuer(aCert.get(), PR_Now(), certUsageSSLCA)); + if (!issuerCert) { + PrintPRError("CERT_FindCertIssuer failed"); + return nullptr; + } + Input issuer; + if (issuer.Init(cert->derIssuer.data, cert->derIssuer.len) != Success) { + return nullptr; + } + Input issuerPublicKey; + if (issuerPublicKey.Init(issuerCert->derPublicKey.data, + issuerCert->derPublicKey.len) != Success) { + return nullptr; + } + Input serialNumber; + if (serialNumber.Init(cert->serialNumber.data, cert->serialNumber.len) != + Success) { + return nullptr; + } + CertID certID(issuer, issuerPublicKey, serialNumber); + OCSPResponseContext context(certID, now); + + mozilla::UniqueCERTCertificate signerCert; + if (aORT == ORTGoodOtherCA || aORT == ORTDelegatedIncluded || + aORT == ORTDelegatedIncludedLast || aORT == ORTDelegatedMissing || + aORT == ORTDelegatedMissingMultiple) { + signerCert.reset(PK11_FindCertFromNickname(aAdditionalCertName, nullptr)); + if (!signerCert) { + PrintPRError("PK11_FindCertFromNickname failed"); + return nullptr; + } + } + + ByteString certs[5]; + + if (aORT == ORTDelegatedIncluded) { + certs[0].assign(signerCert->derCert.data, signerCert->derCert.len); + context.certs = certs; + } + if (aORT == ORTDelegatedIncludedLast || aORT == ORTDelegatedMissingMultiple) { + certs[0].assign(issuerCert->derCert.data, issuerCert->derCert.len); + certs[1].assign(cert->derCert.data, cert->derCert.len); + certs[2].assign(issuerCert->derCert.data, issuerCert->derCert.len); + if (aORT != ORTDelegatedMissingMultiple) { + certs[3].assign(signerCert->derCert.data, signerCert->derCert.len); + } + context.certs = certs; + } + + switch (aORT) { + case ORTMalformed: + context.responseStatus = 1; + break; + case ORTSrverr: + context.responseStatus = 2; + break; + case ORTTryLater: + context.responseStatus = 3; + break; + case ORTNeedsSig: + context.responseStatus = 5; + break; + case ORTUnauthorized: + context.responseStatus = 6; + break; + default: + // context.responseStatus is 0 in all other cases, and it has + // already been initialized in the constructor. + break; + } + if (aORT == ORTSkipResponseBytes) { + context.skipResponseBytes = true; + } + if (aORT == ORTExpired || aORT == ORTExpiredFreshCA || + aORT == ORTRevokedOld || aORT == ORTUnknownOld) { + context.thisUpdate = oldNow; + context.nextUpdate = oldNow + Time::ONE_DAY_IN_SECONDS; + } + if (aORT == ORTLongValidityAlmostExpired) { + context.thisUpdate = now - (320 * Time::ONE_DAY_IN_SECONDS); + } + if (aORT == ORTAncientAlmostExpired) { + context.thisUpdate = now - (640 * Time::ONE_DAY_IN_SECONDS); + } + if (aORT == ORTRevoked || aORT == ORTRevokedOld) { + context.certStatus = 1; + } + if (aORT == ORTUnknown || aORT == ORTUnknownOld) { + context.certStatus = 2; + } + if (aORT == ORTBadSignature) { + context.badSignature = true; + } + OCSPResponseExtension extension; + if (aORT == ORTCriticalExtension || aORT == ORTNoncriticalExtension) { + // python DottedOIDToCode.py --tlv + // some-Mozilla-OID 1.3.6.1.4.1.13769.666.666.666.1.500.9.2 + static const uint8_t tlv_some_Mozilla_OID[] = { + 0x06, 0x12, 0x2b, 0x06, 0x01, 0x04, 0x01, 0xeb, 0x49, 0x85, + 0x1a, 0x85, 0x1a, 0x85, 0x1a, 0x01, 0x83, 0x74, 0x09, 0x02}; + + extension.id.assign(tlv_some_Mozilla_OID, sizeof(tlv_some_Mozilla_OID)); + extension.critical = (aORT == ORTCriticalExtension); + extension.value.push_back(0x05); // tag: NULL + extension.value.push_back(0x00); // length: 0 + extension.next = nullptr; + context.responseExtensions = &extension; + } + if (aORT == ORTEmptyExtensions) { + context.includeEmptyExtensions = true; + } + + if (!signerCert) { + signerCert.reset(CERT_DupCertificate(issuerCert.get())); + } + context.signerKeyPair.reset(CreateTestKeyPairFromCert(signerCert)); + if (!context.signerKeyPair) { + PrintPRError("PK11_FindKeyByAnyCert failed"); + return nullptr; + } + + ByteString response(CreateEncodedOCSPResponse(context)); + if (ENCODING_FAILED(response)) { + PrintPRError("CreateEncodedOCSPResponse failed"); + return nullptr; + } + + SECItem item = {siBuffer, const_cast<uint8_t*>(response.data()), + static_cast<unsigned int>(response.length())}; + SECItemArray arr = {&item, 1}; + return SECITEM_DupArray(aArena.get(), &arr); +} diff --git a/security/manager/ssl/tests/unit/tlsserver/lib/OCSPCommon.h b/security/manager/ssl/tests/unit/tlsserver/lib/OCSPCommon.h new file mode 100644 index 0000000000..c72eae6a8e --- /dev/null +++ b/security/manager/ssl/tests/unit/tlsserver/lib/OCSPCommon.h @@ -0,0 +1,66 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Implements generating OCSP responses of various types. Used by the +// programs in tlsserver/cmd. + +#ifndef OCSPCommon_h +#define OCSPCommon_h + +#include "ScopedNSSTypes.h" +#include "certt.h" +#include "seccomon.h" + +enum OCSPResponseType { + ORTNull = 0, + ORTGood, // the certificate is good + ORTRevoked, // the certificate has been revoked + ORTRevokedOld, // same, but the response is old + ORTUnknown, // the responder doesn't know if the cert is good + ORTUnknownOld, // same, but the response is old + ORTGoodOtherCert, // the response references a different certificate + ORTGoodOtherCA, // the wrong CA has signed the response + ORTExpired, // the signature on the response has expired + ORTExpiredFreshCA, // fresh signature, but old validity period + ORTNone, // no stapled response + ORTEmpty, // an empty stapled response + ORTMalformed, // the response from the responder was malformed + ORTSrverr, // the response indicates there was a server error + ORTTryLater, // the responder replied with "try again later" + ORTNeedsSig, // the response needs a signature + ORTUnauthorized, // the responder is not authorized for this certificate + ORTBadSignature, // the response has a signature that does not verify + ORTSkipResponseBytes, // the response does not include responseBytes + ORTCriticalExtension, // the response includes a critical extension + ORTNoncriticalExtension, // the response includes an extension that is not + // critical + ORTEmptyExtensions, // the response includes a SEQUENCE OF Extension that is + // empty + ORTDelegatedIncluded, // the response is signed by an included delegated + // responder + ORTDelegatedIncludedLast, // same, but multiple other certificates are + // included + ORTDelegatedMissing, // the response is signed by a not included delegated + // responder + ORTDelegatedMissingMultiple, // same, but multiple other certificates are + // included + ORTLongValidityAlmostExpired, // a good response, but that was generated a + // almost a year ago + ORTAncientAlmostExpired, // a good response, with a validity of almost two + // years almost expiring +}; + +struct OCSPHost { + const char* mHostName; + OCSPResponseType mORT; + const char* mAdditionalCertName; // useful for ORTGoodOtherCert, etc. + const char* mServerCertName; +}; + +SECItemArray* GetOCSPResponseForType( + OCSPResponseType aORT, const mozilla::UniqueCERTCertificate& aCert, + const mozilla::UniquePLArenaPool& aArena, const char* aAdditionalCertName, + time_t aThisUpdateSkew); + +#endif // OCSPCommon_h diff --git a/security/manager/ssl/tests/unit/tlsserver/lib/TLSServer.cpp b/security/manager/ssl/tests/unit/tlsserver/lib/TLSServer.cpp new file mode 100644 index 0000000000..463f865411 --- /dev/null +++ b/security/manager/ssl/tests/unit/tlsserver/lib/TLSServer.cpp @@ -0,0 +1,669 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "TLSServer.h" + +#include <stdio.h> +#include <string> +#include <thread> +#include <vector> +#include <fstream> +#include <iostream> +#ifdef XP_WIN +# include <windows.h> +#else +# include <unistd.h> +#endif + +#include <utility> + +#include "base64.h" +#include "mozilla/Sprintf.h" +#include "nspr.h" +#include "nss.h" +#include "plarenas.h" +#include "prenv.h" +#include "prerror.h" +#include "prnetdb.h" +#include "prtime.h" +#include "ssl.h" +#include "sslproto.h" + +namespace mozilla { +namespace test { + +static const uint16_t LISTEN_PORT = 8443; + +DebugLevel gDebugLevel = DEBUG_ERRORS; +uint16_t gCallbackPort = 0; + +const std::string kPEMBegin = "-----BEGIN "; +const std::string kPEMEnd = "-----END "; +const char DEFAULT_CERT_NICKNAME[] = "default-ee"; + +struct Connection { + PRFileDesc* mSocket; + char mByte; + + explicit Connection(PRFileDesc* aSocket); + ~Connection(); +}; + +Connection::Connection(PRFileDesc* aSocket) : mSocket(aSocket), mByte(0) {} + +Connection::~Connection() { + if (mSocket) { + PR_Close(mSocket); + } +} + +void PrintPRError(const char* aPrefix) { + const char* err = PR_ErrorToName(PR_GetError()); + if (err) { + if (gDebugLevel >= DEBUG_ERRORS) { + fprintf(stderr, "%s: %s\n", aPrefix, err); + } + } else { + if (gDebugLevel >= DEBUG_ERRORS) { + fprintf(stderr, "%s\n", aPrefix); + } + } +} + +// This decodes a PEM file into `item`. The line endings need to be +// UNIX-style, or there will be cross-platform issues. +static bool DecodePEMFile(const std::string& filename, SECItem* item) { + std::ifstream in(filename); + if (in.bad()) { + return false; + } + + char buf[1024]; + in.getline(buf, sizeof(buf)); + if (in.bad()) { + return false; + } + + if (strncmp(buf, kPEMBegin.c_str(), kPEMBegin.size()) != 0) { + return false; + } + + std::string value; + for (;;) { + in.getline(buf, sizeof(buf)); + if (in.bad()) { + return false; + } + + if (strncmp(buf, kPEMEnd.c_str(), kPEMEnd.size()) == 0) { + break; + } + + value += buf; + } + + unsigned int binLength; + UniquePORTString bin(BitwiseCast<char*, unsigned char*>( + ATOB_AsciiToData(value.c_str(), &binLength))); + if (!bin || binLength == 0) { + PrintPRError("ATOB_AsciiToData failed"); + return false; + } + + if (SECITEM_AllocItem(nullptr, item, binLength) == nullptr) { + return false; + } + + PORT_Memcpy(item->data, bin.get(), binLength); + return true; +} + +static SECStatus AddKeyFromFile(const std::string& path, + const std::string& filename) { + ScopedAutoSECItem item; + + std::string file = path + "/" + filename; + if (!DecodePEMFile(file, &item)) { + return SECFailure; + } + + UniquePK11SlotInfo slot(PK11_GetInternalKeySlot()); + if (!slot) { + PrintPRError("PK11_GetInternalKeySlot failed"); + return SECFailure; + } + + if (PK11_NeedUserInit(slot.get())) { + if (PK11_InitPin(slot.get(), nullptr, nullptr) != SECSuccess) { + PrintPRError("PK11_InitPin failed"); + return SECFailure; + } + } + + SECKEYPrivateKey* privateKey = nullptr; + SECItem nick = {siBuffer, + BitwiseCast<unsigned char*, const char*>(filename.data()), + static_cast<unsigned int>(filename.size())}; + if (PK11_ImportDERPrivateKeyInfoAndReturnKey( + slot.get(), &item, &nick, nullptr, true, false, KU_ALL, &privateKey, + nullptr) != SECSuccess) { + PrintPRError("PK11_ImportDERPrivateKeyInfoAndReturnKey failed"); + return SECFailure; + } + + SECKEY_DestroyPrivateKey(privateKey); + return SECSuccess; +} + +static SECStatus AddCertificateFromFile(const std::string& path, + const std::string& filename) { + ScopedAutoSECItem item; + + std::string file = path + "/" + filename; + if (!DecodePEMFile(file, &item)) { + return SECFailure; + } + + UniqueCERTCertificate cert(CERT_NewTempCertificate( + CERT_GetDefaultCertDB(), &item, nullptr, false, true)); + if (!cert) { + PrintPRError("CERT_NewTempCertificate failed"); + return SECFailure; + } + + UniquePK11SlotInfo slot(PK11_GetInternalKeySlot()); + if (!slot) { + PrintPRError("PK11_GetInternalKeySlot failed"); + return SECFailure; + } + // The nickname is the filename without '.pem'. + std::string nickname = filename.substr(0, filename.length() - 4); + SECStatus rv = PK11_ImportCert(slot.get(), cert.get(), CK_INVALID_HANDLE, + nickname.c_str(), false); + if (rv != SECSuccess) { + PrintPRError("PK11_ImportCert failed"); + return rv; + } + + return SECSuccess; +} + +SECStatus LoadCertificatesAndKeys(const char* basePath) { + // The NSS cert DB path could have been specified as "sql:path". Trim off + // the leading "sql:" if so. + if (strncmp(basePath, "sql:", 4) == 0) { + basePath = basePath + 4; + } + + UniquePRDir fdDir(PR_OpenDir(basePath)); + if (!fdDir) { + PrintPRError("PR_OpenDir failed"); + return SECFailure; + } + // On the B2G ICS emulator, operations taken in AddCertificateFromFile + // appear to interact poorly with readdir (more specifically, something is + // causing readdir to never return null - it indefinitely loops through every + // file in the directory, which causes timeouts). Rather than waste more time + // chasing this down, loading certificates and keys happens in two phases: + // filename collection and then loading. (This is probably a good + // idea anyway because readdir isn't reentrant. Something could change later + // such that it gets called as a result of calling AddCertificateFromFile or + // AddKeyFromFile.) + std::vector<std::string> certificates; + std::vector<std::string> keys; + for (PRDirEntry* dirEntry = PR_ReadDir(fdDir.get(), PR_SKIP_BOTH); dirEntry; + dirEntry = PR_ReadDir(fdDir.get(), PR_SKIP_BOTH)) { + size_t nameLength = strlen(dirEntry->name); + if (nameLength > 4) { + if (strncmp(dirEntry->name + nameLength - 4, ".pem", 4) == 0) { + certificates.push_back(dirEntry->name); + } else if (strncmp(dirEntry->name + nameLength - 4, ".key", 4) == 0) { + keys.push_back(dirEntry->name); + } + } + } + SECStatus rv; + for (std::string& certificate : certificates) { + rv = AddCertificateFromFile(basePath, certificate.c_str()); + if (rv != SECSuccess) { + return rv; + } + } + for (std::string& key : keys) { + rv = AddKeyFromFile(basePath, key.c_str()); + if (rv != SECSuccess) { + return rv; + } + } + return SECSuccess; +} + +SECStatus InitializeNSS(const char* nssCertDBDir) { + // Try initializing an existing DB. + if (NSS_Init(nssCertDBDir) == SECSuccess) { + return SECSuccess; + } + + // Create a new DB if there is none... + SECStatus rv = NSS_Initialize(nssCertDBDir, nullptr, nullptr, nullptr, 0); + if (rv != SECSuccess) { + return rv; + } + + // ...and load all certificates into it. + return LoadCertificatesAndKeys(nssCertDBDir); +} + +nsresult SendAll(PRFileDesc* aSocket, const char* aData, size_t aDataLen) { + if (gDebugLevel >= DEBUG_VERBOSE) { + fprintf(stderr, "sending '%s'\n", aData); + } + + while (aDataLen > 0) { + int32_t bytesSent = + PR_Send(aSocket, aData, aDataLen, 0, PR_INTERVAL_NO_TIMEOUT); + if (bytesSent == -1) { + PrintPRError("PR_Send failed"); + return NS_ERROR_FAILURE; + } + + aDataLen -= bytesSent; + aData += bytesSent; + } + + return NS_OK; +} + +nsresult ReplyToRequest(Connection* aConn) { + // For debugging purposes, SendAll can print out what it's sending. + // So, any strings we give to it to send need to be null-terminated. + char buf[2] = {aConn->mByte, 0}; + return SendAll(aConn->mSocket, buf, 1); +} + +nsresult SetupTLS(Connection* aConn, PRFileDesc* aModelSocket) { + PRFileDesc* sslSocket = SSL_ImportFD(aModelSocket, aConn->mSocket); + if (!sslSocket) { + PrintPRError("SSL_ImportFD failed"); + return NS_ERROR_FAILURE; + } + aConn->mSocket = sslSocket; + + SSL_OptionSet(sslSocket, SSL_SECURITY, true); + SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT, false); + SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_SERVER, true); + + SSL_ResetHandshake(sslSocket, /* asServer */ 1); + + return NS_OK; +} + +nsresult ReadRequest(Connection* aConn) { + int32_t bytesRead = + PR_Recv(aConn->mSocket, &aConn->mByte, 1, 0, PR_INTERVAL_NO_TIMEOUT); + if (bytesRead < 0) { + PrintPRError("PR_Recv failed"); + return NS_ERROR_FAILURE; + } else if (bytesRead == 0) { + PR_SetError(PR_IO_ERROR, 0); + PrintPRError("PR_Recv EOF in ReadRequest"); + return NS_ERROR_FAILURE; + } else { + if (gDebugLevel >= DEBUG_VERBOSE) { + fprintf(stderr, "read '0x%hhx'\n", aConn->mByte); + } + } + return NS_OK; +} + +void HandleConnection(PRFileDesc* aSocket, + const UniquePRFileDesc& aModelSocket) { + Connection conn(aSocket); + nsresult rv = SetupTLS(&conn, aModelSocket.get()); + if (NS_FAILED(rv)) { + PR_SetError(PR_INVALID_STATE_ERROR, 0); + PrintPRError("PR_Recv failed"); + exit(1); + } + + // TODO: On tests that are expected to fail (e.g. due to a revoked + // certificate), the client will close the connection wtihout sending us the + // request byte. In those cases, we should keep going. But, in the cases + // where the connection is supposed to suceed, we should verify that we + // successfully receive the request and send the response. + rv = ReadRequest(&conn); + if (NS_SUCCEEDED(rv)) { + rv = ReplyToRequest(&conn); + } +} + +// returns 0 on success, non-zero on error +int DoCallback() { + UniquePRFileDesc socket(PR_NewTCPSocket()); + if (!socket) { + PrintPRError("PR_NewTCPSocket failed"); + return 1; + } + + PRNetAddr addr; + PR_InitializeNetAddr(PR_IpAddrLoopback, gCallbackPort, &addr); + if (PR_Connect(socket.get(), &addr, PR_INTERVAL_NO_TIMEOUT) != PR_SUCCESS) { + PrintPRError("PR_Connect failed"); + return 1; + } + + const char* request = "GET / HTTP/1.0\r\n\r\n"; + SendAll(socket.get(), request, strlen(request)); + char buf[4096]; + memset(buf, 0, sizeof(buf)); + int32_t bytesRead = + PR_Recv(socket.get(), buf, sizeof(buf) - 1, 0, PR_INTERVAL_NO_TIMEOUT); + if (bytesRead < 0) { + PrintPRError("PR_Recv failed 1"); + return 1; + } + if (bytesRead == 0) { + fprintf(stderr, "PR_Recv eof 1\n"); + return 1; + } + fprintf(stderr, "%s\n", buf); + return 0; +} + +SECStatus ConfigSecureServerWithNamedCert( + PRFileDesc* fd, const char* certName, + /*optional*/ UniqueCERTCertificate* certOut, + /*optional*/ SSLKEAType* keaOut, + /*optional*/ SSLExtraServerCertData* extraData) { + UniqueCERTCertificate cert(PK11_FindCertFromNickname(certName, nullptr)); + if (!cert) { + PrintPRError("PK11_FindCertFromNickname failed"); + return SECFailure; + } + // If an intermediate certificate issued the server certificate (rather than + // directly by a trust anchor), we want to send it along in the handshake so + // we don't encounter unknown issuer errors when that's not what we're + // testing. + UniqueCERTCertificateList certList; + UniqueCERTCertificate issuerCert( + CERT_FindCertByName(CERT_GetDefaultCertDB(), &cert->derIssuer)); + // If we can't find the issuer cert, continue without it. + if (issuerCert) { + // Sadly, CERTCertificateList does not have a CERT_NewCertificateList + // utility function, so we must create it ourselves. This consists + // of creating an arena, allocating space for the CERTCertificateList, + // and then transferring ownership of the arena to that list. + UniquePLArenaPool arena(PORT_NewArena(DER_DEFAULT_CHUNKSIZE)); + if (!arena) { + PrintPRError("PORT_NewArena failed"); + return SECFailure; + } + certList.reset(static_cast<CERTCertificateList*>( + PORT_ArenaAlloc(arena.get(), sizeof(CERTCertificateList)))); + if (!certList) { + PrintPRError("PORT_ArenaAlloc failed"); + return SECFailure; + } + certList->arena = arena.release(); + // We also have to manually copy the certificates we care about to the + // list, because there aren't any utility functions for that either. + certList->certs = static_cast<SECItem*>( + PORT_ArenaAlloc(certList->arena, 2 * sizeof(SECItem))); + if (SECITEM_CopyItem(certList->arena, certList->certs, &cert->derCert) != + SECSuccess) { + PrintPRError("SECITEM_CopyItem failed"); + return SECFailure; + } + if (SECITEM_CopyItem(certList->arena, certList->certs + 1, + &issuerCert->derCert) != SECSuccess) { + PrintPRError("SECITEM_CopyItem failed"); + return SECFailure; + } + certList->len = 2; + } + + UniquePK11SlotInfo slot(PK11_GetInternalKeySlot()); + if (!slot) { + PrintPRError("PK11_GetInternalKeySlot failed"); + return SECFailure; + } + UniqueSECKEYPrivateKey key( + PK11_FindKeyByDERCert(slot.get(), cert.get(), nullptr)); + if (!key) { + PrintPRError("PK11_FindKeyByDERCert failed"); + return SECFailure; + } + + if (extraData) { + SSLExtraServerCertData dataCopy = {ssl_auth_null, nullptr, nullptr, + nullptr, nullptr, nullptr}; + memcpy(&dataCopy, extraData, sizeof(dataCopy)); + dataCopy.certChain = certList.get(); + + if (SSL_ConfigServerCert(fd, cert.get(), key.get(), &dataCopy, + sizeof(dataCopy)) != SECSuccess) { + PrintPRError("SSL_ConfigServerCert failed"); + return SECFailure; + } + + } else { + // This is the deprecated setup mechanism, to be cleaned up in Bug 1569222 + SSLKEAType certKEA = NSS_FindCertKEAType(cert.get()); + if (SSL_ConfigSecureServerWithCertChain(fd, cert.get(), certList.get(), + key.get(), certKEA) != SECSuccess) { + PrintPRError("SSL_ConfigSecureServer failed"); + return SECFailure; + } + + if (keaOut) { + *keaOut = certKEA; + } + } + + if (certOut) { + *certOut = std::move(cert); + } + + SSL_OptionSet(fd, SSL_NO_CACHE, false); + SSL_OptionSet(fd, SSL_ENABLE_SESSION_TICKETS, true); + + return SECSuccess; +} + +#ifdef XP_WIN +using PidType = DWORD; +constexpr bool IsValidPid(long long pid) { + // Excluding `(DWORD)-1` because it is not a valid process ID. + // See https://devblogs.microsoft.com/oldnewthing/20040223-00/?p=40503 + return pid > 0 && pid < std::numeric_limits<PidType>::max(); +} +#else +using PidType = pid_t; +constexpr bool IsValidPid(long long pid) { + return pid > 0 && pid <= std::numeric_limits<PidType>::max(); +} +#endif + +PidType ConvertPid(const char* pidStr) { + long long pid = strtoll(pidStr, nullptr, 10); + if (!IsValidPid(pid)) { + return 0; + } + return static_cast<PidType>(pid); +} + +int StartServer(int argc, char* argv[], SSLSNISocketConfig sniSocketConfig, + void* sniSocketConfigArg, ServerConfigFunc configFunc) { + if (argc != 3) { + fprintf(stderr, "usage: %s <NSS DB directory> <ppid>\n", argv[0]); + return 1; + } + const char* nssCertDBDir = argv[1]; + PidType ppid = ConvertPid(argv[2]); + + const char* debugLevel = PR_GetEnv("MOZ_TLS_SERVER_DEBUG_LEVEL"); + if (debugLevel) { + int level = atoi(debugLevel); + switch (level) { + case DEBUG_ERRORS: + gDebugLevel = DEBUG_ERRORS; + break; + case DEBUG_WARNINGS: + gDebugLevel = DEBUG_WARNINGS; + break; + case DEBUG_VERBOSE: + gDebugLevel = DEBUG_VERBOSE; + break; + default: + PrintPRError("invalid MOZ_TLS_SERVER_DEBUG_LEVEL"); + return 1; + } + } + + const char* callbackPort = PR_GetEnv("MOZ_TLS_SERVER_CALLBACK_PORT"); + if (callbackPort) { + gCallbackPort = atoi(callbackPort); + } + + if (InitializeNSS(nssCertDBDir) != SECSuccess) { + PR_fprintf(PR_STDERR, "InitializeNSS failed"); + return 1; + } + + if (NSS_SetDomesticPolicy() != SECSuccess) { + PrintPRError("NSS_SetDomesticPolicy failed"); + return 1; + } + + if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) { + PrintPRError("SSL_ConfigServerSessionIDCache failed"); + return 1; + } + + UniquePRFileDesc serverSocket(PR_NewTCPSocket()); + if (!serverSocket) { + PrintPRError("PR_NewTCPSocket failed"); + return 1; + } + + PRSocketOptionData socketOption; + socketOption.option = PR_SockOpt_Reuseaddr; + socketOption.value.reuse_addr = true; + PR_SetSocketOption(serverSocket.get(), &socketOption); + + PRNetAddr serverAddr; + PR_InitializeNetAddr(PR_IpAddrLoopback, LISTEN_PORT, &serverAddr); + if (PR_Bind(serverSocket.get(), &serverAddr) != PR_SUCCESS) { + PrintPRError("PR_Bind failed"); + return 1; + } + + if (PR_Listen(serverSocket.get(), 1) != PR_SUCCESS) { + PrintPRError("PR_Listen failed"); + return 1; + } + + UniquePRFileDesc rawModelSocket(PR_NewTCPSocket()); + if (!rawModelSocket) { + PrintPRError("PR_NewTCPSocket failed for rawModelSocket"); + return 1; + } + + UniquePRFileDesc modelSocket(SSL_ImportFD(nullptr, rawModelSocket.release())); + if (!modelSocket) { + PrintPRError("SSL_ImportFD of rawModelSocket failed"); + return 1; + } + + SSLVersionRange range = {0, 0}; + if (SSL_VersionRangeGet(modelSocket.get(), &range) != SECSuccess) { + PrintPRError("SSL_VersionRangeGet failed"); + return 1; + } + + if (range.max < SSL_LIBRARY_VERSION_TLS_1_3) { + range.max = SSL_LIBRARY_VERSION_TLS_1_3; + if (SSL_VersionRangeSet(modelSocket.get(), &range) != SECSuccess) { + PrintPRError("SSL_VersionRangeSet failed"); + return 1; + } + } + + if (SSL_SNISocketConfigHook(modelSocket.get(), sniSocketConfig, + sniSocketConfigArg) != SECSuccess) { + PrintPRError("SSL_SNISocketConfigHook failed"); + return 1; + } + + // We have to configure the server with a certificate, but it's not one + // we're actually going to end up using. In the SNI callback, we pick + // the right certificate for the connection. + // + // Provide an empty |extra_data| to force config via SSL_ConfigServerCert. + // This is a temporary mechanism to work around inconsistent setting of + // |authType| in the deprecated API (preventing the default cert from + // being removed in favor of the SNI-selected cert). This may be removed + // after Bug 1569222 removes the deprecated mechanism. + SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr, + nullptr, nullptr, nullptr}; + if (ConfigSecureServerWithNamedCert(modelSocket.get(), DEFAULT_CERT_NICKNAME, + nullptr, nullptr, + &extra_data) != SECSuccess) { + return 1; + } + + // Call back to implementation-defined configuration func, if provided. + if (configFunc) { + if (((configFunc)(modelSocket.get())) != SECSuccess) { + PrintPRError("configFunc failed"); + return 1; + } + } + + if (gCallbackPort != 0) { + if (DoCallback()) { + return 1; + } + } + + std::thread([ppid] { + if (!ppid) { + if (gDebugLevel >= DEBUG_ERRORS) { + fprintf(stderr, "invalid ppid\n"); + } + return; + } +#ifdef XP_WIN + HANDLE parent = OpenProcess(SYNCHRONIZE, false, ppid); + if (!parent) { + if (gDebugLevel >= DEBUG_ERRORS) { + fprintf(stderr, "OpenProcess failed\n"); + } + return; + } + WaitForSingleObject(parent, INFINITE); + CloseHandle(parent); +#else + while (getppid() == ppid) { + sleep(1); + } +#endif + if (gDebugLevel >= DEBUG_ERRORS) { + fprintf(stderr, "Parent process crashed\n"); + } + exit(1); + }).detach(); + + while (true) { + PRNetAddr clientAddr; + PRFileDesc* clientSocket = + PR_Accept(serverSocket.get(), &clientAddr, PR_INTERVAL_NO_TIMEOUT); + HandleConnection(clientSocket, modelSocket); + } + + return 0; +} + +} // namespace test +} // namespace mozilla diff --git a/security/manager/ssl/tests/unit/tlsserver/lib/TLSServer.h b/security/manager/ssl/tests/unit/tlsserver/lib/TLSServer.h new file mode 100644 index 0000000000..3927b3e541 --- /dev/null +++ b/security/manager/ssl/tests/unit/tlsserver/lib/TLSServer.h @@ -0,0 +1,93 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef TLSServer_h +#define TLSServer_h + +// This is a standalone server for testing SSL features of Gecko. +// The client is expected to connect and initiate an SSL handshake (with SNI +// to indicate which "server" to connect to). If all is good, the client then +// sends one encrypted byte and receives that same byte back. +// This server also has the ability to "call back" another process waiting on +// it. That is, when the server is all set up and ready to receive connections, +// it will connect to a specified port and issue a simple HTTP request. + +#include <stdint.h> + +#include "ScopedNSSTypes.h" +#include "mozilla/Casting.h" +#include "prio.h" +#include "secerr.h" +#include "ssl.h" + +namespace mozilla { + +MOZ_TYPE_SPECIFIC_UNIQUE_PTR_TEMPLATE(UniquePRDir, PRDir, PR_CloseDir); + +} // namespace mozilla + +namespace mozilla { +namespace test { + +typedef SECStatus (*ServerConfigFunc)(PRFileDesc* fd); + +enum DebugLevel { DEBUG_ERRORS = 1, DEBUG_WARNINGS = 2, DEBUG_VERBOSE = 3 }; + +extern DebugLevel gDebugLevel; + +void PrintPRError(const char* aPrefix); + +// The default certificate is trusted for localhost and *.example.com +extern const char DEFAULT_CERT_NICKNAME[]; + +// ConfigSecureServerWithNamedCert sets up the hostname name provided. If the +// extraData parameter is presented, extraData->certChain will be automatically +// filled in using database information. +// Pass DEFAULT_CERT_NICKNAME as certName unless you need a specific +// certificate. +SECStatus ConfigSecureServerWithNamedCert( + PRFileDesc* fd, const char* certName, + /*optional*/ UniqueCERTCertificate* cert, + /*optional*/ SSLKEAType* kea, + /*optional*/ SSLExtraServerCertData* extraData); + +SECStatus InitializeNSS(const char* nssCertDBDir); + +// StartServer initializes NSS, sockets, the SNI callback, and a default +// certificate. configFunc (optional) is a pointer to an implementation- +// defined configuration function, which is called on the model socket +// prior to handling any connections. +int StartServer(int argc, char* argv[], SSLSNISocketConfig sniSocketConfig, + void* sniSocketConfigArg, + ServerConfigFunc configFunc = nullptr); + +template <typename Host> +inline const Host* GetHostForSNI(const SECItem* aSrvNameArr, + uint32_t aSrvNameArrSize, const Host* hosts) { + for (uint32_t i = 0; i < aSrvNameArrSize; i++) { + for (const Host* host = hosts; host->mHostName; ++host) { + SECItem hostName; + hostName.data = BitwiseCast<unsigned char*, const char*>(host->mHostName); + hostName.len = strlen(host->mHostName); + if (SECITEM_ItemsAreEqual(&hostName, &aSrvNameArr[i])) { + if (gDebugLevel >= DEBUG_VERBOSE) { + fprintf(stderr, "found pre-defined host '%s'\n", host->mHostName); + } + return host; + } + } + } + + if (gDebugLevel >= DEBUG_VERBOSE) { + fprintf(stderr, "could not find host info from SNI\n"); + } + + PR_SetError(SEC_ERROR_INVALID_ARGS, 0); + return nullptr; +} + +} // namespace test +} // namespace mozilla + +#endif // TLSServer_h diff --git a/security/manager/ssl/tests/unit/tlsserver/lib/moz.build b/security/manager/ssl/tests/unit/tlsserver/lib/moz.build new file mode 100644 index 0000000000..1fd92d724d --- /dev/null +++ b/security/manager/ssl/tests/unit/tlsserver/lib/moz.build @@ -0,0 +1,16 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +UNIFIED_SOURCES += [ + "OCSPCommon.cpp", + "TLSServer.cpp", +] + +USE_LIBS += [ + "mozpkix-testlib", +] + +Library("tlsserver") |