diff options
Diffstat (limited to 'tcpiohandler.cc')
-rw-r--r-- | tcpiohandler.cc | 177 |
1 files changed, 90 insertions, 87 deletions
diff --git a/tcpiohandler.cc b/tcpiohandler.cc index 72c149b..cf82471 100644 --- a/tcpiohandler.cc +++ b/tcpiohandler.cc @@ -11,7 +11,7 @@ const bool TCPIOHandler::s_disableConnectForUnitTests = false; #include <sodium.h> #endif /* HAVE_LIBSODIUM */ -#ifdef HAVE_DNS_OVER_TLS +#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS) #ifdef HAVE_LIBSSL #include <openssl/conf.h> @@ -52,7 +52,7 @@ public: OpenSSLTLSTicketKeysRing d_ticketKeys; std::map<int, std::string> d_ocspResponses; std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free}; - std::unique_ptr<FILE, int(*)(FILE*)> d_keyLogFile{nullptr, fclose}; + pdns::UniqueFilePtr d_keyLogFile{nullptr}; }; class OpenSSLSession : public TLSSession @@ -62,10 +62,6 @@ public: { } - virtual ~OpenSSLSession() - { - } - std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> getNative() { return std::move(d_sess); @@ -79,7 +75,7 @@ class OpenSSLTLSConnection: public TLSConnection { public: /* server side connection */ - OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(feContext), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout) + OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(std::move(feContext)), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout) { d_socket = socket; @@ -133,7 +129,7 @@ public: #endif } else { -#if (OPENSSL_VERSION_NUMBER >= 0x1010000fL) && HAVE_SSL_SET_HOSTFLAGS // grrr libressl +#if (OPENSSL_VERSION_NUMBER >= 0x1010000fL) && defined(HAVE_SSL_SET_HOSTFLAGS) // grrr libressl SSL_set_hostflags(d_conn.get(), X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS); if (SSL_set1_host(d_conn.get(), d_hostname.c_str()) != 1) { throw std::runtime_error("Error setting TLS hostname for certificate validation"); @@ -432,15 +428,6 @@ public: return got; } - bool hasBufferedData() const override - { - if (d_conn) { - return SSL_pending(d_conn.get()) > 0; - } - - return false; - } - bool isUsable() const override { if (!d_conn) { @@ -629,6 +616,10 @@ public: } #endif /* DISABLE_OCSP_STAPLING */ + if (fe.d_tlsConfig.d_readAhead) { + SSL_CTX_set_read_ahead(d_feContext->d_tlsCtx.get(), 1); + } + libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters); if (!fe.d_tlsConfig.d_keyLogFile.empty()) { @@ -822,7 +813,7 @@ public: } } - void loadTicketsKeys(const std::string& keyFile) override final + void loadTicketsKeys(const std::string& keyFile) final { d_feContext->d_ticketKeys.loadTicketsKeys(keyFile); @@ -866,7 +857,7 @@ public: private: /* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */ #ifndef DISABLE_NPN - static int npnSelectCallback(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg) + static int npnSelectCallback(SSL* /* s */, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg) { if (!arg) { return SSL_TLSEXT_ERR_ALERT_WARNING; @@ -885,25 +876,28 @@ private: if (!arg) { return SSL_TLSEXT_ERR_ALERT_WARNING; } + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): OpenSSL's API OpenSSLTLSIOCtx* obj = reinterpret_cast<OpenSSLTLSIOCtx*>(arg); - size_t pos = 0; - while (pos < inlen) { - size_t protoLen = in[pos]; - pos++; - if (protoLen > (inlen - pos)) { - /* something is very wrong */ - return SSL_TLSEXT_ERR_ALERT_WARNING; - } + const pdns::views::UnsignedCharView inView(in, inlen); + // Server preference algorithm as per RFC 7301 section 3.2 + for (const auto& tentative : obj->d_alpnProtos) { + size_t pos = 0; + while (pos < inView.size()) { + size_t protoLen = inView.at(pos); + pos++; + if (protoLen > (inlen - pos)) { + /* something is very wrong */ + return SSL_TLSEXT_ERR_ALERT_WARNING; + } - for (const auto& tentative : obj->d_alpnProtos) { - if (tentative.size() == protoLen && memcmp(in + pos, tentative.data(), tentative.size()) == 0) { - *out = in + pos; + if (tentative.size() == protoLen && memcmp(&inView.at(pos), tentative.data(), tentative.size()) == 0) { + *out = &inView.at(pos); *outlen = protoLen; return SSL_TLSEXT_ERR_OK; } + pos += protoLen; } - pos += protoLen; } return SSL_TLSEXT_ERR_NOACK; @@ -1020,7 +1014,7 @@ public: sess.size = 0; } - virtual ~GnuTLSSession() + ~GnuTLSSession() override { if (d_sess.data != nullptr && d_sess.size > 0) { safe_memory_release(d_sess.data, d_sess.size); @@ -1115,7 +1109,7 @@ public: gnutls_handshake_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000); gnutls_record_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000); -#if HAVE_GNUTLS_SESSION_SET_VERIFY_CERT +#ifdef HAVE_GNUTLS_SESSION_SET_VERIFY_CERT if (validateCerts && !d_host.empty()) { gnutls_session_set_verify_cert(d_conn.get(), d_host.c_str(), GNUTLS_VERIFY_ALLOW_UNSORTED_CHAIN); rc = gnutls_server_name_set(d_conn.get(), GNUTLS_NAME_DNS, d_host.c_str(), d_host.size()); @@ -1134,9 +1128,9 @@ public: /* The callback prototype changed in 3.4.0. */ #if GNUTLS_VERSION_NUMBER >= 0x030400 - static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int incoming, const gnutls_datum_t* msg) + static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int /* incoming */, const gnutls_datum_t* /* msg */) #else - static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int incoming) + static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int /* incoming */) #endif /* GNUTLS_VERSION_NUMBER >= 0x030400 */ { if (htype != GNUTLS_HANDSHAKE_NEW_SESSION_TICKET || post != GNUTLS_HOOK_POST || session == nullptr) { @@ -1158,7 +1152,7 @@ public: return 0; } - IOState tryConnect(bool fastOpen, const ComboAddress& remote) override + IOState tryConnect(bool fastOpen, [[maybe_unused]] const ComboAddress& remote) override { int ret = 0; @@ -1263,7 +1257,7 @@ public: else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) { if (d_client) { std::string error; -#if HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS +#ifdef HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS if (ret == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR) { gnutls_datum_t out; if (gnutls_certificate_verification_status_print(gnutls_session_get_verify_cert_status(d_conn.get()), gnutls_certificate_type_get(d_conn.get()), &out, 0) == 0) { @@ -1314,7 +1308,7 @@ public: else if (res == GNUTLS_E_AGAIN) { return IOState::NeedWrite; } - warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); + vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); } } while (pos < toWrite); @@ -1350,7 +1344,7 @@ public: else if (res == GNUTLS_E_AGAIN) { return IOState::NeedRead; } - warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); + vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); } } while (pos < toRead); @@ -1440,15 +1434,6 @@ public: return got; } - bool hasBufferedData() const override - { - if (d_conn) { - return gnutls_record_check_pending(d_conn.get()) > 0; - } - - return false; - } - bool isUsable() const override { if (!d_conn) { @@ -1678,7 +1663,7 @@ public: } } - virtual ~GnuTLSIOCtx() override + ~GnuTLSIOCtx() override { d_creds.reset(); @@ -1754,7 +1739,7 @@ public: auto newKey = std::make_shared<GnuTLSTicketsKey>(); { - *(d_ticketsKey.write_lock()) = newKey; + *(d_ticketsKey.write_lock()) = std::move(newKey); } if (d_ticketsKeyRotationDelay > 0) { @@ -1762,7 +1747,7 @@ public: } } - void loadTicketsKeys(const std::string& file) override final + void loadTicketsKeys(const std::string& file) final { if (!d_enableTickets) { return; @@ -1770,7 +1755,7 @@ public: auto newKey = std::make_shared<GnuTLSTicketsKey>(file); { - *(d_ticketsKey.write_lock()) = newKey; + *(d_ticketsKey.write_lock()) = std::move(newKey); } if (d_ticketsKeyRotationDelay > 0) { @@ -1811,7 +1796,7 @@ private: #endif /* HAVE_GNUTLS */ -#endif /* HAVE_DNS_OVER_TLS */ +#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */ bool setupDoTProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx) { @@ -1824,67 +1809,85 @@ bool setupDoTProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx) return true; } +bool setupDoHProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx) +{ + if (ctx == nullptr) { + return false; + } + /* This code is only called for incoming/server TLS contexts (not outgoing/client), + and h2o sets it own ALPN values. + We want to set the ALPN for DoH: + - HTTP/1.1 so that the OpenSSL callback ALPN accepts it, letting us later return a static response + - HTTP/2 + */ + const std::vector<std::vector<uint8_t>> dohAlpns{{'h', '2'},{'h', 't', 't', 'p', '/', '1', '.', '1'}}; + ctx->setALPNProtos(dohAlpns); + + return true; +} + bool TLSFrontend::setupTLS() { -#ifdef HAVE_DNS_OVER_TLS +#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS) std::shared_ptr<TLSCtx> newCtx{nullptr}; /* get the "best" available provider */ - if (!d_provider.empty()) { -#ifdef HAVE_GNUTLS - if (d_provider == "gnutls") { - newCtx = std::make_shared<GnuTLSIOCtx>(*this); - setupDoTProtocolNegotiation(newCtx); - std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release); - return true; - } -#endif /* HAVE_GNUTLS */ -#ifdef HAVE_LIBSSL - if (d_provider == "openssl") { - newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this); - setupDoTProtocolNegotiation(newCtx); - std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release); - return true; - } -#endif /* HAVE_LIBSSL */ +#if defined(HAVE_GNUTLS) + if (d_provider == "gnutls") { + newCtx = std::make_shared<GnuTLSIOCtx>(*this); } -#ifdef HAVE_LIBSSL - newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this); -#else /* HAVE_LIBSSL */ -#ifdef HAVE_GNUTLS - newCtx = std::make_shared<GnuTLSIOCtx>(*this); #endif /* HAVE_GNUTLS */ +#if defined(HAVE_LIBSSL) + if (d_provider == "openssl") { + newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this); + } #endif /* HAVE_LIBSSL */ - setupDoTProtocolNegotiation(newCtx); - std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release); -#endif /* HAVE_DNS_OVER_TLS */ + if (!newCtx) { +#if defined(HAVE_LIBSSL) + newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this); +#elif defined(HAVE_GNUTLS) + newCtx = std::make_shared<GnuTLSIOCtx>(*this); +#else +#error "TLS support needed but neither libssl nor GnuTLS were selected" +#endif + } + + if (d_alpn == ALPN::DoT) { + setupDoTProtocolNegotiation(newCtx); + } + else if (d_alpn == ALPN::DoH) { + setupDoHProtocolNegotiation(newCtx); + } + + std::atomic_store_explicit(&d_ctx, std::move(newCtx), std::memory_order_release); +#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */ return true; } -std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params) +std::shared_ptr<TLSCtx> getTLSContext([[maybe_unused]] const TLSContextParameters& params) { #ifdef HAVE_DNS_OVER_TLS /* get the "best" available provider */ if (!params.d_provider.empty()) { -#ifdef HAVE_GNUTLS +#if defined(HAVE_GNUTLS) if (params.d_provider == "gnutls") { return std::make_shared<GnuTLSIOCtx>(params); } #endif /* HAVE_GNUTLS */ -#ifdef HAVE_LIBSSL +#if defined(HAVE_LIBSSL) if (params.d_provider == "openssl") { return std::make_shared<OpenSSLTLSIOCtx>(params); } #endif /* HAVE_LIBSSL */ } -#ifdef HAVE_LIBSSL +#if defined(HAVE_LIBSSL) return std::make_shared<OpenSSLTLSIOCtx>(params); -#else /* HAVE_LIBSSL */ -#ifdef HAVE_GNUTLS +#elif defined(HAVE_GNUTLS) return std::make_shared<GnuTLSIOCtx>(params); -#endif /* HAVE_GNUTLS */ -#endif /* HAVE_LIBSSL */ +#else +#error "DNS over TLS support needed but neither libssl nor GnuTLS were selected" +#endif #endif /* HAVE_DNS_OVER_TLS */ return nullptr; |