260 lines
8 KiB
C++
260 lines
8 KiB
C++
/* This Source Code Form is subject to the terms of the Mozilla Public
|
|
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
|
|
|
|
#include <stdio.h>
|
|
|
|
#include "nspr.h"
|
|
#include "ScopedNSSTypes.h"
|
|
#include "ssl.h"
|
|
#include "ssl3prot.h"
|
|
#include "sslexp.h"
|
|
#include "sslimpl.h"
|
|
#include "TLSServer.h"
|
|
|
|
#include "mozilla/Sprintf.h"
|
|
|
|
using namespace mozilla;
|
|
using namespace mozilla::test;
|
|
|
|
enum FaultType {
|
|
None = 0,
|
|
ZeroRtt,
|
|
UnknownSNI,
|
|
Mlkem768x25519,
|
|
};
|
|
|
|
struct FaultyServerHost {
|
|
const char* mHostName;
|
|
const char* mCertName;
|
|
FaultType mFaultType;
|
|
};
|
|
|
|
const char* kHostOk = "ok.example.com";
|
|
const char* kHostUnknown = "unknown.example.com";
|
|
const char* kHostZeroRttAlertBadMac = "0rtt-alert-bad-mac.example.com";
|
|
const char* kHostZeroRttAlertVersion =
|
|
"0rtt-alert-protocol-version.example.com";
|
|
const char* kHostZeroRttAlertUnexpected = "0rtt-alert-unexpected.example.com";
|
|
const char* kHostZeroRttAlertDowngrade = "0rtt-alert-downgrade.example.com";
|
|
|
|
const char* kHostMlkem768x25519NetInterrupt =
|
|
"mlkem768x25519-net-interrupt.example.com";
|
|
const char* kHostMlkem768x25519AlertAfterServerHello =
|
|
"mlkem768x25519-alert-after-server-hello.example.com";
|
|
|
|
const char* kCertWildcard = "default-ee";
|
|
|
|
/* Each type of failure gets a different SNI.
|
|
* the "default-ee" cert has a SAN for *.example.com
|
|
* the "no-san-ee" cert is signed by the test-ca, but it doesn't have any SANs.
|
|
*/
|
|
MOZ_RUNINIT const FaultyServerHost sFaultyServerHosts[]{
|
|
{kHostOk, kCertWildcard, None},
|
|
{kHostUnknown, kCertWildcard, UnknownSNI},
|
|
{kHostZeroRttAlertBadMac, kCertWildcard, ZeroRtt},
|
|
{kHostZeroRttAlertVersion, kCertWildcard, ZeroRtt},
|
|
{kHostZeroRttAlertUnexpected, kCertWildcard, ZeroRtt},
|
|
{kHostZeroRttAlertDowngrade, kCertWildcard, ZeroRtt},
|
|
{kHostMlkem768x25519NetInterrupt, kCertWildcard, Mlkem768x25519},
|
|
{kHostMlkem768x25519AlertAfterServerHello, kCertWildcard, Mlkem768x25519},
|
|
{nullptr, nullptr},
|
|
};
|
|
|
|
nsresult SendAll(PRFileDesc* aSocket, const char* aData, size_t aDataLen) {
|
|
if (gDebugLevel >= DEBUG_VERBOSE) {
|
|
fprintf(stderr, "sending '%s'\n", aData);
|
|
}
|
|
|
|
int32_t len = static_cast<int32_t>(aDataLen);
|
|
while (len > 0) {
|
|
int32_t bytesSent = PR_Send(aSocket, aData, len, 0, PR_INTERVAL_NO_TIMEOUT);
|
|
if (bytesSent == -1) {
|
|
PrintPRError("PR_Send failed");
|
|
return NS_ERROR_FAILURE;
|
|
}
|
|
|
|
len -= bytesSent;
|
|
aData += bytesSent;
|
|
}
|
|
|
|
return NS_OK;
|
|
}
|
|
|
|
// returns 0 on success, non-zero on error
|
|
int DoCallback(const char* path) {
|
|
UniquePRFileDesc socket(PR_NewTCPSocket());
|
|
if (!socket) {
|
|
PrintPRError("PR_NewTCPSocket failed");
|
|
return 1;
|
|
}
|
|
|
|
uint32_t port = 0;
|
|
const char* callbackPort = PR_GetEnv("FAULTY_SERVER_CALLBACK_PORT");
|
|
if (callbackPort) {
|
|
port = atoi(callbackPort);
|
|
}
|
|
if (!port) {
|
|
return 0;
|
|
}
|
|
|
|
PRNetAddr addr;
|
|
PR_InitializeNetAddr(PR_IpAddrLoopback, port, &addr);
|
|
if (PR_Connect(socket.get(), &addr, PR_INTERVAL_NO_TIMEOUT) != PR_SUCCESS) {
|
|
PrintPRError("PR_Connect failed");
|
|
return 1;
|
|
}
|
|
|
|
char request[512];
|
|
SprintfLiteral(request, "GET %s HTTP/1.0\r\n\r\n", path);
|
|
SendAll(socket.get(), request, strlen(request));
|
|
char buf[4096];
|
|
memset(buf, 0, sizeof(buf));
|
|
int32_t bytesRead =
|
|
PR_Recv(socket.get(), buf, sizeof(buf) - 1, 0, PR_INTERVAL_NO_TIMEOUT);
|
|
if (bytesRead < 0) {
|
|
PrintPRError("PR_Recv failed 1");
|
|
return 1;
|
|
}
|
|
if (bytesRead == 0) {
|
|
fprintf(stderr, "PR_Recv eof 1\n");
|
|
return 1;
|
|
}
|
|
// fprintf(stderr, "%s\n", buf);
|
|
return 0;
|
|
}
|
|
|
|
/* These are very rough examples. In practice the `arg` parameter to a callback
|
|
* might need to be an object that holds some state, like the various traffic
|
|
* secrets. */
|
|
|
|
/* An SSLSecretCallback is called after every key derivation step in the TLS
|
|
* 1.3 key schedule.
|
|
*
|
|
* Epoch 1 is for the early traffic secret.
|
|
* Epoch 2 is for the handshake traffic secrets.
|
|
* Epoch 3 is for the application traffic secrets.
|
|
*/
|
|
void SecretCallbackFailZeroRtt(PRFileDesc* fd, PRUint16 epoch,
|
|
SSLSecretDirection dir, PK11SymKey* secret,
|
|
void* arg) {
|
|
fprintf(stderr, "0RTT handler epoch=%d dir=%d\n", epoch, (uint32_t)dir);
|
|
FaultyServerHost* host = static_cast<FaultyServerHost*>(arg);
|
|
|
|
if (epoch == 1 && dir == ssl_secret_read) {
|
|
sslSocket* ss = ssl_FindSocket(fd);
|
|
if (!ss) {
|
|
fprintf(stderr, "0RTT handler, no ss!\n");
|
|
return;
|
|
}
|
|
|
|
char path[256];
|
|
SprintfLiteral(path, "/callback/%d", epoch);
|
|
DoCallback(path);
|
|
|
|
fprintf(stderr, "0RTT handler, configuring alert\n");
|
|
if (!strcmp(host->mHostName, kHostZeroRttAlertBadMac)) {
|
|
SSL3_SendAlert(ss, alert_fatal, bad_record_mac);
|
|
} else if (!strcmp(host->mHostName, kHostZeroRttAlertVersion)) {
|
|
SSL3_SendAlert(ss, alert_fatal, protocol_version);
|
|
} else if (!strcmp(host->mHostName, kHostZeroRttAlertUnexpected)) {
|
|
SSL3_SendAlert(ss, alert_fatal, unexpected_message);
|
|
}
|
|
}
|
|
}
|
|
|
|
SECStatus FailingWriteCallback(PRFileDesc* fd, PRUint16 epoch,
|
|
SSLContentType contentType, const PRUint8* data,
|
|
unsigned int len, void* arg) {
|
|
return SECFailure;
|
|
}
|
|
|
|
void SecretCallbackFailMlkem768x25519(PRFileDesc* fd, PRUint16 epoch,
|
|
SSLSecretDirection dir,
|
|
PK11SymKey* secret, void* arg) {
|
|
fprintf(stderr, "Mlkem768x25519 handler epoch=%d dir=%d\n", epoch,
|
|
(uint32_t)dir);
|
|
FaultyServerHost* host = static_cast<FaultyServerHost*>(arg);
|
|
|
|
if (epoch == 2 && dir == ssl_secret_write) {
|
|
sslSocket* ss = ssl_FindSocket(fd);
|
|
if (!ss) {
|
|
fprintf(stderr, "Mlkem768x25519 handler, no ss!\n");
|
|
return;
|
|
}
|
|
|
|
if (!ss->sec.keaGroup) {
|
|
fprintf(stderr, "Mlkem768x25519 handler, no ss->sec.keaGroup!\n");
|
|
return;
|
|
}
|
|
|
|
char path[256];
|
|
SprintfLiteral(path, "/callback/%u", ss->sec.keaGroup->name);
|
|
DoCallback(path);
|
|
|
|
if (ss->sec.keaGroup->name != ssl_grp_kem_mlkem768x25519) {
|
|
return;
|
|
}
|
|
|
|
fprintf(stderr, "Mlkem768x25519 handler, configuring alert\n");
|
|
if (strcmp(host->mHostName, kHostMlkem768x25519NetInterrupt) == 0) {
|
|
// Install a record write callback that causes the next write to fail.
|
|
// The client will see this as a PR_END_OF_FILE / NS_ERROR_NET_INTERRUPT
|
|
// error.
|
|
ss->recordWriteCallback = FailingWriteCallback;
|
|
} else if (!strcmp(host->mHostName,
|
|
kHostMlkem768x25519AlertAfterServerHello)) {
|
|
SSL3_SendAlert(ss, alert_fatal, close_notify);
|
|
}
|
|
}
|
|
}
|
|
|
|
int32_t DoSNISocketConfig(PRFileDesc* aFd, const SECItem* aSrvNameArr,
|
|
uint32_t aSrvNameArrSize, void* aArg) {
|
|
const FaultyServerHost* host =
|
|
GetHostForSNI(aSrvNameArr, aSrvNameArrSize, sFaultyServerHosts);
|
|
if (!host || host->mFaultType == UnknownSNI) {
|
|
PrintPRError("No cert found for hostname");
|
|
return SSL_SNI_SEND_ALERT;
|
|
}
|
|
|
|
if (gDebugLevel >= DEBUG_VERBOSE) {
|
|
fprintf(stderr, "found pre-defined host '%s'\n", host->mHostName);
|
|
}
|
|
|
|
const SSLNamedGroup mlkemTestNamedGroups[] = {ssl_grp_kem_mlkem768x25519,
|
|
ssl_grp_ec_curve25519};
|
|
|
|
switch (host->mFaultType) {
|
|
case ZeroRtt:
|
|
SSL_SecretCallback(aFd, &SecretCallbackFailZeroRtt, (void*)host);
|
|
break;
|
|
case Mlkem768x25519:
|
|
SSL_SecretCallback(aFd, &SecretCallbackFailMlkem768x25519, (void*)host);
|
|
SSL_NamedGroupConfig(aFd, mlkemTestNamedGroups,
|
|
std::size(mlkemTestNamedGroups));
|
|
break;
|
|
case None:
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
|
|
UniqueCERTCertificate cert;
|
|
SSLKEAType certKEA;
|
|
if (SECSuccess != ConfigSecureServerWithNamedCert(aFd, host->mCertName, &cert,
|
|
&certKEA, nullptr)) {
|
|
return SSL_SNI_SEND_ALERT;
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
SECStatus ConfigureServer(PRFileDesc* aFd) { return SECSuccess; }
|
|
|
|
int main(int argc, char* argv[]) {
|
|
int rv = StartServer(argc, argv, DoSNISocketConfig, nullptr, ConfigureServer);
|
|
if (rv < 0) {
|
|
return rv;
|
|
}
|
|
}
|