summaryrefslogtreecommitdiffstats
path: root/tcpiohandler.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tcpiohandler.cc')
-rw-r--r--tcpiohandler.cc177
1 files changed, 90 insertions, 87 deletions
diff --git a/tcpiohandler.cc b/tcpiohandler.cc
index 72c149b..cf82471 100644
--- a/tcpiohandler.cc
+++ b/tcpiohandler.cc
@@ -11,7 +11,7 @@ const bool TCPIOHandler::s_disableConnectForUnitTests = false;
#include <sodium.h>
#endif /* HAVE_LIBSODIUM */
-#ifdef HAVE_DNS_OVER_TLS
+#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
#ifdef HAVE_LIBSSL
#include <openssl/conf.h>
@@ -52,7 +52,7 @@ public:
OpenSSLTLSTicketKeysRing d_ticketKeys;
std::map<int, std::string> d_ocspResponses;
std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free};
- std::unique_ptr<FILE, int(*)(FILE*)> d_keyLogFile{nullptr, fclose};
+ pdns::UniqueFilePtr d_keyLogFile{nullptr};
};
class OpenSSLSession : public TLSSession
@@ -62,10 +62,6 @@ public:
{
}
- virtual ~OpenSSLSession()
- {
- }
-
std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> getNative()
{
return std::move(d_sess);
@@ -79,7 +75,7 @@ class OpenSSLTLSConnection: public TLSConnection
{
public:
/* server side connection */
- OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(feContext), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout)
+ OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(std::move(feContext)), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout)
{
d_socket = socket;
@@ -133,7 +129,7 @@ public:
#endif
}
else {
-#if (OPENSSL_VERSION_NUMBER >= 0x1010000fL) && HAVE_SSL_SET_HOSTFLAGS // grrr libressl
+#if (OPENSSL_VERSION_NUMBER >= 0x1010000fL) && defined(HAVE_SSL_SET_HOSTFLAGS) // grrr libressl
SSL_set_hostflags(d_conn.get(), X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
if (SSL_set1_host(d_conn.get(), d_hostname.c_str()) != 1) {
throw std::runtime_error("Error setting TLS hostname for certificate validation");
@@ -432,15 +428,6 @@ public:
return got;
}
- bool hasBufferedData() const override
- {
- if (d_conn) {
- return SSL_pending(d_conn.get()) > 0;
- }
-
- return false;
- }
-
bool isUsable() const override
{
if (!d_conn) {
@@ -629,6 +616,10 @@ public:
}
#endif /* DISABLE_OCSP_STAPLING */
+ if (fe.d_tlsConfig.d_readAhead) {
+ SSL_CTX_set_read_ahead(d_feContext->d_tlsCtx.get(), 1);
+ }
+
libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters);
if (!fe.d_tlsConfig.d_keyLogFile.empty()) {
@@ -822,7 +813,7 @@ public:
}
}
- void loadTicketsKeys(const std::string& keyFile) override final
+ void loadTicketsKeys(const std::string& keyFile) final
{
d_feContext->d_ticketKeys.loadTicketsKeys(keyFile);
@@ -866,7 +857,7 @@ public:
private:
/* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */
#ifndef DISABLE_NPN
- static int npnSelectCallback(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg)
+ static int npnSelectCallback(SSL* /* s */, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg)
{
if (!arg) {
return SSL_TLSEXT_ERR_ALERT_WARNING;
@@ -885,25 +876,28 @@ private:
if (!arg) {
return SSL_TLSEXT_ERR_ALERT_WARNING;
}
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): OpenSSL's API
OpenSSLTLSIOCtx* obj = reinterpret_cast<OpenSSLTLSIOCtx*>(arg);
- size_t pos = 0;
- while (pos < inlen) {
- size_t protoLen = in[pos];
- pos++;
- if (protoLen > (inlen - pos)) {
- /* something is very wrong */
- return SSL_TLSEXT_ERR_ALERT_WARNING;
- }
+ const pdns::views::UnsignedCharView inView(in, inlen);
+ // Server preference algorithm as per RFC 7301 section 3.2
+ for (const auto& tentative : obj->d_alpnProtos) {
+ size_t pos = 0;
+ while (pos < inView.size()) {
+ size_t protoLen = inView.at(pos);
+ pos++;
+ if (protoLen > (inlen - pos)) {
+ /* something is very wrong */
+ return SSL_TLSEXT_ERR_ALERT_WARNING;
+ }
- for (const auto& tentative : obj->d_alpnProtos) {
- if (tentative.size() == protoLen && memcmp(in + pos, tentative.data(), tentative.size()) == 0) {
- *out = in + pos;
+ if (tentative.size() == protoLen && memcmp(&inView.at(pos), tentative.data(), tentative.size()) == 0) {
+ *out = &inView.at(pos);
*outlen = protoLen;
return SSL_TLSEXT_ERR_OK;
}
+ pos += protoLen;
}
- pos += protoLen;
}
return SSL_TLSEXT_ERR_NOACK;
@@ -1020,7 +1014,7 @@ public:
sess.size = 0;
}
- virtual ~GnuTLSSession()
+ ~GnuTLSSession() override
{
if (d_sess.data != nullptr && d_sess.size > 0) {
safe_memory_release(d_sess.data, d_sess.size);
@@ -1115,7 +1109,7 @@ public:
gnutls_handshake_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
gnutls_record_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
-#if HAVE_GNUTLS_SESSION_SET_VERIFY_CERT
+#ifdef HAVE_GNUTLS_SESSION_SET_VERIFY_CERT
if (validateCerts && !d_host.empty()) {
gnutls_session_set_verify_cert(d_conn.get(), d_host.c_str(), GNUTLS_VERIFY_ALLOW_UNSORTED_CHAIN);
rc = gnutls_server_name_set(d_conn.get(), GNUTLS_NAME_DNS, d_host.c_str(), d_host.size());
@@ -1134,9 +1128,9 @@ public:
/* The callback prototype changed in 3.4.0. */
#if GNUTLS_VERSION_NUMBER >= 0x030400
- static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int incoming, const gnutls_datum_t* msg)
+ static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int /* incoming */, const gnutls_datum_t* /* msg */)
#else
- static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int incoming)
+ static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int /* incoming */)
#endif /* GNUTLS_VERSION_NUMBER >= 0x030400 */
{
if (htype != GNUTLS_HANDSHAKE_NEW_SESSION_TICKET || post != GNUTLS_HOOK_POST || session == nullptr) {
@@ -1158,7 +1152,7 @@ public:
return 0;
}
- IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
+ IOState tryConnect(bool fastOpen, [[maybe_unused]] const ComboAddress& remote) override
{
int ret = 0;
@@ -1263,7 +1257,7 @@ public:
else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
if (d_client) {
std::string error;
-#if HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS
+#ifdef HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS
if (ret == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR) {
gnutls_datum_t out;
if (gnutls_certificate_verification_status_print(gnutls_session_get_verify_cert_status(d_conn.get()), gnutls_certificate_type_get(d_conn.get()), &out, 0) == 0) {
@@ -1314,7 +1308,7 @@ public:
else if (res == GNUTLS_E_AGAIN) {
return IOState::NeedWrite;
}
- warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
+ vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
}
}
while (pos < toWrite);
@@ -1350,7 +1344,7 @@ public:
else if (res == GNUTLS_E_AGAIN) {
return IOState::NeedRead;
}
- warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
+ vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
}
}
while (pos < toRead);
@@ -1440,15 +1434,6 @@ public:
return got;
}
- bool hasBufferedData() const override
- {
- if (d_conn) {
- return gnutls_record_check_pending(d_conn.get()) > 0;
- }
-
- return false;
- }
-
bool isUsable() const override
{
if (!d_conn) {
@@ -1678,7 +1663,7 @@ public:
}
}
- virtual ~GnuTLSIOCtx() override
+ ~GnuTLSIOCtx() override
{
d_creds.reset();
@@ -1754,7 +1739,7 @@ public:
auto newKey = std::make_shared<GnuTLSTicketsKey>();
{
- *(d_ticketsKey.write_lock()) = newKey;
+ *(d_ticketsKey.write_lock()) = std::move(newKey);
}
if (d_ticketsKeyRotationDelay > 0) {
@@ -1762,7 +1747,7 @@ public:
}
}
- void loadTicketsKeys(const std::string& file) override final
+ void loadTicketsKeys(const std::string& file) final
{
if (!d_enableTickets) {
return;
@@ -1770,7 +1755,7 @@ public:
auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
{
- *(d_ticketsKey.write_lock()) = newKey;
+ *(d_ticketsKey.write_lock()) = std::move(newKey);
}
if (d_ticketsKeyRotationDelay > 0) {
@@ -1811,7 +1796,7 @@ private:
#endif /* HAVE_GNUTLS */
-#endif /* HAVE_DNS_OVER_TLS */
+#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
bool setupDoTProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx)
{
@@ -1824,67 +1809,85 @@ bool setupDoTProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx)
return true;
}
+bool setupDoHProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx)
+{
+ if (ctx == nullptr) {
+ return false;
+ }
+ /* This code is only called for incoming/server TLS contexts (not outgoing/client),
+ and h2o sets it own ALPN values.
+ We want to set the ALPN for DoH:
+ - HTTP/1.1 so that the OpenSSL callback ALPN accepts it, letting us later return a static response
+ - HTTP/2
+ */
+ const std::vector<std::vector<uint8_t>> dohAlpns{{'h', '2'},{'h', 't', 't', 'p', '/', '1', '.', '1'}};
+ ctx->setALPNProtos(dohAlpns);
+
+ return true;
+}
+
bool TLSFrontend::setupTLS()
{
-#ifdef HAVE_DNS_OVER_TLS
+#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
std::shared_ptr<TLSCtx> newCtx{nullptr};
/* get the "best" available provider */
- if (!d_provider.empty()) {
-#ifdef HAVE_GNUTLS
- if (d_provider == "gnutls") {
- newCtx = std::make_shared<GnuTLSIOCtx>(*this);
- setupDoTProtocolNegotiation(newCtx);
- std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release);
- return true;
- }
-#endif /* HAVE_GNUTLS */
-#ifdef HAVE_LIBSSL
- if (d_provider == "openssl") {
- newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
- setupDoTProtocolNegotiation(newCtx);
- std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release);
- return true;
- }
-#endif /* HAVE_LIBSSL */
+#if defined(HAVE_GNUTLS)
+ if (d_provider == "gnutls") {
+ newCtx = std::make_shared<GnuTLSIOCtx>(*this);
}
-#ifdef HAVE_LIBSSL
- newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
-#else /* HAVE_LIBSSL */
-#ifdef HAVE_GNUTLS
- newCtx = std::make_shared<GnuTLSIOCtx>(*this);
#endif /* HAVE_GNUTLS */
+#if defined(HAVE_LIBSSL)
+ if (d_provider == "openssl") {
+ newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
+ }
#endif /* HAVE_LIBSSL */
- setupDoTProtocolNegotiation(newCtx);
- std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release);
-#endif /* HAVE_DNS_OVER_TLS */
+ if (!newCtx) {
+#if defined(HAVE_LIBSSL)
+ newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
+#elif defined(HAVE_GNUTLS)
+ newCtx = std::make_shared<GnuTLSIOCtx>(*this);
+#else
+#error "TLS support needed but neither libssl nor GnuTLS were selected"
+#endif
+ }
+
+ if (d_alpn == ALPN::DoT) {
+ setupDoTProtocolNegotiation(newCtx);
+ }
+ else if (d_alpn == ALPN::DoH) {
+ setupDoHProtocolNegotiation(newCtx);
+ }
+
+ std::atomic_store_explicit(&d_ctx, std::move(newCtx), std::memory_order_release);
+#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
return true;
}
-std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params)
+std::shared_ptr<TLSCtx> getTLSContext([[maybe_unused]] const TLSContextParameters& params)
{
#ifdef HAVE_DNS_OVER_TLS
/* get the "best" available provider */
if (!params.d_provider.empty()) {
-#ifdef HAVE_GNUTLS
+#if defined(HAVE_GNUTLS)
if (params.d_provider == "gnutls") {
return std::make_shared<GnuTLSIOCtx>(params);
}
#endif /* HAVE_GNUTLS */
-#ifdef HAVE_LIBSSL
+#if defined(HAVE_LIBSSL)
if (params.d_provider == "openssl") {
return std::make_shared<OpenSSLTLSIOCtx>(params);
}
#endif /* HAVE_LIBSSL */
}
-#ifdef HAVE_LIBSSL
+#if defined(HAVE_LIBSSL)
return std::make_shared<OpenSSLTLSIOCtx>(params);
-#else /* HAVE_LIBSSL */
-#ifdef HAVE_GNUTLS
+#elif defined(HAVE_GNUTLS)
return std::make_shared<GnuTLSIOCtx>(params);
-#endif /* HAVE_GNUTLS */
-#endif /* HAVE_LIBSSL */
+#else
+#error "DNS over TLS support needed but neither libssl nor GnuTLS were selected"
+#endif
#endif /* HAVE_DNS_OVER_TLS */
return nullptr;