#include "config.h" #include "dolog.hh" #include "iputils.hh" #include "lock.hh" #include "tcpiohandler.hh" const bool TCPIOHandler::s_disableConnectForUnitTests = false; #ifdef HAVE_LIBSODIUM #include #endif /* HAVE_LIBSODIUM */ #ifdef HAVE_DNS_OVER_TLS #ifdef HAVE_LIBSSL #include #include #include #include #include #include "libssl.hh" class OpenSSLFrontendContext { public: OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys) { registerOpenSSLUser(); d_tlsCtx = libssl_init_server_context(tlsConfig, d_ocspResponses); if (!d_tlsCtx) { ERR_print_errors_fp(stderr); throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort()); } } void cleanup() { d_tlsCtx.reset(); unregisterOpenSSLUser(); } OpenSSLTLSTicketKeysRing d_ticketKeys; std::map d_ocspResponses; std::unique_ptr d_tlsCtx{nullptr, SSL_CTX_free}; std::unique_ptr d_keyLogFile{nullptr, fclose}; }; class OpenSSLSession : public TLSSession { public: OpenSSLSession(std::unique_ptr&& sess): d_sess(std::move(sess)) { } virtual ~OpenSSLSession() { } std::unique_ptr getNative() { return std::move(d_sess); } private: std::unique_ptr d_sess; }; class OpenSSLTLSConnection: public TLSConnection { public: /* server side connection */ OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr feContext): d_feContext(feContext), d_conn(std::unique_ptr(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout) { d_socket = socket; if (!s_initTLSConnIndex.test_and_set()) { /* not initialized yet */ s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); if (s_tlsConnIndex == -1) { throw std::runtime_error("Error getting an index for TLS connection data"); } } if (!d_conn) { vinfolog("Error creating TLS object"); if (g_verbose) { ERR_print_errors_fp(stderr); } throw std::runtime_error("Error creating TLS object"); } if (!SSL_set_fd(d_conn.get(), d_socket)) { throw std::runtime_error("Error assigning socket"); } SSL_set_ex_data(d_conn.get(), s_tlsConnIndex, this); } /* client-side connection */ OpenSSLTLSConnection(const std::string& hostname, int socket, const struct timeval& timeout, std::shared_ptr& tlsCtx): d_tlsCtx(tlsCtx), d_conn(std::unique_ptr(SSL_new(tlsCtx.get()), SSL_free)), d_hostname(hostname), d_timeout(timeout) { d_socket = socket; if (!s_initTLSConnIndex.test_and_set()) { /* not initialized yet */ s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); if (s_tlsConnIndex == -1) { throw std::runtime_error("Error getting an index for TLS connection data"); } } if (!d_conn) { vinfolog("Error creating TLS object"); if (g_verbose) { ERR_print_errors_fp(stderr); } throw std::runtime_error("Error creating TLS object"); } if (!SSL_set_fd(d_conn.get(), d_socket)) { throw std::runtime_error("Error assigning socket"); } /* set outgoing Server Name Indication */ if (!d_hostname.empty() && SSL_set_tlsext_host_name(d_conn.get(), d_hostname.c_str()) != 1) { throw std::runtime_error("Error setting TLS SNI to " + d_hostname); } #if (OPENSSL_VERSION_NUMBER >= 0x1010000fL) && HAVE_SSL_SET_HOSTFLAGS // grrr libressl SSL_set_hostflags(d_conn.get(), X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS); if (SSL_set1_host(d_conn.get(), d_hostname.c_str()) != 1) { throw std::runtime_error("Error setting TLS hostname for certificate validation"); } #elif (OPENSSL_VERSION_NUMBER >= 0x10002000L) X509_VERIFY_PARAM *param = SSL_get0_param(d_conn.get()); /* Enable automatic hostname checks */ X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS); if (X509_VERIFY_PARAM_set1_host(param, d_hostname.c_str(), d_hostname.size()) != 1) { throw std::runtime_error("Error setting TLS hostname for certificate validation"); } #else /* no hostname validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */ #endif SSL_set_ex_data(d_conn.get(), s_tlsConnIndex, this); } IOState convertIORequestToIOState(int res) const { int error = SSL_get_error(d_conn.get(), res); if (error == SSL_ERROR_WANT_READ) { return IOState::NeedRead; } else if (error == SSL_ERROR_WANT_WRITE) { return IOState::NeedWrite; } else if (error == SSL_ERROR_SYSCALL) { if (errno == 0) { throw std::runtime_error("TLS connection closed by remote end"); } else { throw std::runtime_error("Syscall error while processing TLS connection: " + std::string(strerror(errno))); } } else if (error == SSL_ERROR_ZERO_RETURN) { throw std::runtime_error("TLS connection closed by remote end"); } else { if (g_verbose) { throw std::runtime_error("Error while processing TLS connection: (" + std::to_string(error) + ") " + libssl_get_error_string()); } else { throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error)); } } } void handleIORequest(int res, const struct timeval& timeout) { auto state = convertIORequestToIOState(res); if (state == IOState::NeedRead) { res = waitForData(d_socket, timeout.tv_sec, timeout.tv_usec); if (res == 0) { throw std::runtime_error("Timeout while reading from TLS connection"); } else if (res < 0) { throw std::runtime_error("Error waiting to read from TLS connection"); } } else if (state == IOState::NeedWrite) { res = waitForRWData(d_socket, false, timeout.tv_sec, timeout.tv_usec); if (res == 0) { throw std::runtime_error("Timeout while writing to TLS connection"); } else if (res < 0) { throw std::runtime_error("Error waiting to write to TLS connection"); } } } IOState tryConnect(bool fastOpen, const ComboAddress& remote) override { /* sorry */ (void) fastOpen; (void) remote; int res = SSL_connect(d_conn.get()); if (res == 1) { return IOState::Done; } else if (res < 0) { return convertIORequestToIOState(res); } throw std::runtime_error("Error establishing a TLS connection"); } void connect(bool fastOpen, const ComboAddress& remote, const struct timeval &timeout) override { /* sorry */ (void) fastOpen; (void) remote; struct timeval start{0,0}; struct timeval remainingTime = timeout; if (timeout.tv_sec != 0 || timeout.tv_usec != 0) { gettimeofday(&start, nullptr); } int res = 0; do { res = SSL_connect(d_conn.get()); if (res < 0) { handleIORequest(res, remainingTime); } if (timeout.tv_sec != 0 || timeout.tv_usec != 0) { struct timeval now; gettimeofday(&now, nullptr); struct timeval elapsed = now - start; if (now < start || remainingTime < elapsed) { throw runtime_error("Timeout while establishing TLS connection"); } start = now; remainingTime = remainingTime - elapsed; } } while (res != 1); } IOState tryHandshake() override { if (!d_feContext) { /* In client mode, the handshake is initiated by the call to SSL_connect() done from connect()/tryConnect(). In blocking mode it does not return before the handshake has been finished, and in non-blocking mode calling SSL_connect() once is enough for SSL_write() and SSL_read() to transparently continue to negotiate the connection after that (equivalent to doing SSL_set_connect_state() plus trying to write). */ return IOState::Done; } /* As explained above in the client-mode block, we only need to call SSL_accept() once for SSL_write() and SSL_read() to transparently continue to negotiate the connection after that. It is equivalent to calling SSL_set_accept_state() plus trying to read. */ int res = SSL_accept(d_conn.get()); if (res == 1) { return IOState::Done; } else if (res < 0) { return convertIORequestToIOState(res); } throw std::runtime_error("Error accepting TLS connection"); } void doHandshake() override { if (!d_feContext) { /* we are a client, nothing to do, see the non-blocking version */ return; } int res = 0; do { res = SSL_accept(d_conn.get()); if (res < 0) { handleIORequest(res, d_timeout); } } while (res < 0); if (res != 1) { throw std::runtime_error("Error accepting TLS connection"); } } IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override { do { int res = SSL_write(d_conn.get(), reinterpret_cast(&buffer.at(pos)), static_cast(toWrite - pos)); if (res <= 0) { return convertIORequestToIOState(res); } else { pos += static_cast(res); } } while (pos < toWrite); return IOState::Done; } IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override { do { int res = SSL_read(d_conn.get(), reinterpret_cast(&buffer.at(pos)), static_cast(toRead - pos)); if (res <= 0) { return convertIORequestToIOState(res); } else { pos += static_cast(res); if (allowIncomplete) { break; } } } while (pos < toRead); return IOState::Done; } size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override { size_t got = 0; struct timeval start = {0, 0}; struct timeval remainingTime = totalTimeout; if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) { gettimeofday(&start, nullptr); } do { int res = SSL_read(d_conn.get(), (reinterpret_cast(buffer) + got), static_cast(bufferSize - got)); if (res <= 0) { handleIORequest(res, readTimeout); } else { got += static_cast(res); if (allowIncomplete) { break; } } if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) { struct timeval now; gettimeofday(&now, nullptr); struct timeval elapsed = now - start; if (now < start || remainingTime < elapsed) { throw runtime_error("Timeout while reading data"); } start = now; remainingTime = remainingTime - elapsed; } } while (got < bufferSize); return got; } size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override { size_t got = 0; do { int res = SSL_write(d_conn.get(), (reinterpret_cast(buffer) + got), static_cast(bufferSize - got)); if (res <= 0) { handleIORequest(res, writeTimeout); } else { got += static_cast(res); } } while (got < bufferSize); return got; } bool hasBufferedData() const override { if (d_conn) { return SSL_pending(d_conn.get()) > 0; } return false; } bool isUsable() const override { if (!d_conn) { return false; } char buf; int res = SSL_peek(d_conn.get(), &buf, sizeof(buf)); if (res > 0) { return true; } try { convertIORequestToIOState(res); return true; } catch (...) { return false; } return false; } void close() override { if (d_conn) { SSL_shutdown(d_conn.get()); } } std::string getServerNameIndication() const override { if (d_conn) { const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name); if (value) { return std::string(value); } } return std::string(); } std::vector getNextProtocol() const override { std::vector result; if (!d_conn) { return result; } const unsigned char* alpn = nullptr; unsigned int alpnLen = 0; #ifdef HAVE_SSL_GET0_NEXT_PROTO_NEGOTIATED SSL_get0_next_proto_negotiated(d_conn.get(), &alpn, &alpnLen); #endif #ifdef HAVE_SSL_GET0_ALPN_SELECTED if (alpn == nullptr) { SSL_get0_alpn_selected(d_conn.get(), &alpn, &alpnLen); } #endif if (alpn != nullptr && alpnLen > 0) { result.insert(result.end(), alpn, alpn + alpnLen); } return result; } LibsslTLSVersion getTLSVersion() const override { auto proto = SSL_version(d_conn.get()); switch (proto) { case TLS1_VERSION: return LibsslTLSVersion::TLS10; case TLS1_1_VERSION: return LibsslTLSVersion::TLS11; case TLS1_2_VERSION: return LibsslTLSVersion::TLS12; #ifdef TLS1_3_VERSION case TLS1_3_VERSION: return LibsslTLSVersion::TLS13; #endif /* TLS1_3_VERSION */ default: return LibsslTLSVersion::Unknown; } } bool hasSessionBeenResumed() const override { if (d_conn) { return SSL_session_reused(d_conn.get()) != 0; } return false; } std::vector> getSessions() override { return std::move(d_tlsSessions); } void setSession(std::unique_ptr& session) override { auto sess = dynamic_cast(session.get()); if (!sess) { throw std::runtime_error("Unable to convert OpenSSL session"); } auto native = sess->getNative(); auto ret = SSL_set_session(d_conn.get(), native.get()); if (ret != 1) { throw std::runtime_error("Error setting up session: " + libssl_get_error_string()); } session.reset(); } void addNewTicket(SSL_SESSION* session) { d_tlsSessions.push_back(std::make_unique(std::unique_ptr(session, SSL_SESSION_free))); } static int s_tlsConnIndex; private: static std::atomic_flag s_initTLSConnIndex; std::vector> d_tlsSessions; /* server context */ std::shared_ptr d_feContext; /* client context */ std::shared_ptr d_tlsCtx; std::unique_ptr d_conn; std::string d_hostname; struct timeval d_timeout; }; std::atomic_flag OpenSSLTLSConnection::s_initTLSConnIndex = ATOMIC_FLAG_INIT; int OpenSSLTLSConnection::s_tlsConnIndex = -1; class OpenSSLTLSIOCtx: public TLSCtx { public: /* server side context */ OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared(fe.d_addr, fe.d_tlsConfig)) { d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay; if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) { /* use our own ticket keys handler so we can rotate them */ SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb); libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get()); } if (!d_feContext->d_ocspResponses.empty()) { SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb); SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses); } libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters); if (!fe.d_tlsConfig.d_keyLogFile.empty()) { d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, fe.d_tlsConfig.d_keyLogFile); } try { if (fe.d_tlsConfig.d_ticketKeyFile.empty()) { handleTicketsKeyRotation(time(nullptr)); } else { OpenSSLTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile); } } catch (const std::exception& e) { throw; } } /* client side context */ OpenSSLTLSIOCtx(const TLSContextParameters& params) { int sslOptions = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION | SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_CIPHER_SERVER_PREFERENCE; if (!params.d_enableRenegotiation) { #ifdef SSL_OP_NO_RENEGOTIATION sslOptions |= SSL_OP_NO_RENEGOTIATION; #elif defined(SSL_OP_NO_CLIENT_RENEGOTIATION) sslOptions |= SSL_OP_NO_CLIENT_RENEGOTIATION; #endif } registerOpenSSLUser(); #ifdef HAVE_TLS_CLIENT_METHOD d_tlsCtx = std::shared_ptr(SSL_CTX_new(TLS_client_method()), SSL_CTX_free); #else d_tlsCtx = std::shared_ptr(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free); #endif if (!d_tlsCtx) { ERR_print_errors_fp(stderr); throw std::runtime_error("Error creating TLS context"); } SSL_CTX_set_options(d_tlsCtx.get(), sslOptions); #if defined(SSL_CTX_set_ecdh_auto) SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1); #endif if (!params.d_ciphers.empty()) { if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), params.d_ciphers.c_str()) != 1) { ERR_print_errors_fp(stderr); throw std::runtime_error("Error setting the cipher list to '" + params.d_ciphers + "' for the TLS context"); } } #ifdef HAVE_SSL_CTX_SET_CIPHERSUITES if (!params.d_ciphers13.empty()) { if (SSL_CTX_set_ciphersuites(d_tlsCtx.get(), params.d_ciphers13.c_str()) != 1) { ERR_print_errors_fp(stderr); throw std::runtime_error("Error setting the TLS 1.3 cipher list to '" + params.d_ciphers13 + "' for the TLS context"); } } #endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */ if (params.d_validateCertificates) { if (params.d_caStore.empty()) { if (SSL_CTX_set_default_verify_paths(d_tlsCtx.get()) != 1) { throw std::runtime_error("Error adding the system's default trusted CAs"); } } else { if (SSL_CTX_load_verify_locations(d_tlsCtx.get(), params.d_caStore.c_str(), nullptr) != 1) { throw std::runtime_error("Error adding the trusted CAs file " + params.d_caStore); } } SSL_CTX_set_verify(d_tlsCtx.get(), SSL_VERIFY_PEER, nullptr); #if (OPENSSL_VERSION_NUMBER < 0x10002000L) warnlog("TLS hostname validation requested but not supported for OpenSSL < 1.0.2"); #endif } /* we need to set SSL_SESS_CACHE_CLIENT for the "new ticket" callback (below) to be called, but we don't want OpenSSL to cache the session itself so we set SSL_SESS_CACHE_NO_INTERNAL_STORE as well */ SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL_STORE); SSL_CTX_sess_set_new_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::newTicketFromServerCb); #ifdef SSL_MODE_RELEASE_BUFFERS if (params.d_releaseBuffers) { SSL_CTX_set_mode(d_tlsCtx.get(), SSL_MODE_RELEASE_BUFFERS); } #endif } ~OpenSSLTLSIOCtx() override { d_tlsCtx.reset(); unregisterOpenSSLUser(); } static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc) { OpenSSLFrontendContext* ctx = reinterpret_cast(libssl_get_ticket_key_callback_data(s)); if (ctx == nullptr) { return -1; } int ret = libssl_ticket_key_callback(s, ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc); if (enc == 0) { if (ret == 0 || ret == 2) { OpenSSLTLSConnection* conn = reinterpret_cast(SSL_get_ex_data(s, OpenSSLTLSConnection::s_tlsConnIndex)); if (conn) { if (ret == 0) { conn->setUnknownTicketKey(); } else if (ret == 2) { conn->setResumedFromInactiveTicketKey(); } } } } return ret; } static int ocspStaplingCb(SSL* ssl, void* arg) { if (ssl == nullptr || arg == nullptr) { return SSL_TLSEXT_ERR_NOACK; } const auto ocspMap = reinterpret_cast*>(arg); return libssl_ocsp_stapling_callback(ssl, *ocspMap); } static int newTicketFromServerCb(SSL* ssl, SSL_SESSION* session) { OpenSSLTLSConnection* conn = reinterpret_cast(SSL_get_ex_data(ssl, OpenSSLTLSConnection::s_tlsConnIndex)); if (session == nullptr || conn == nullptr) { return 0; } conn->addNewTicket(session); return 1; } std::unique_ptr getConnection(int socket, const struct timeval& timeout, time_t now) override { handleTicketsKeyRotation(now); return std::make_unique(socket, timeout, d_feContext); } std::unique_ptr getClientConnection(const std::string& host, int socket, const struct timeval& timeout) override { return std::make_unique(host, socket, timeout, d_tlsCtx); } void rotateTicketsKey(time_t now) override { d_feContext->d_ticketKeys.rotateTicketsKey(now); if (d_ticketsKeyRotationDelay > 0) { d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay; } } void loadTicketsKeys(const std::string& keyFile) override final { d_feContext->d_ticketKeys.loadTicketsKeys(keyFile); if (d_ticketsKeyRotationDelay > 0) { d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay; } } size_t getTicketsKeysCount() override { return d_feContext->d_ticketKeys.getKeysCount(); } std::string getName() const override { return "openssl"; } bool setALPNProtos(const std::vector>& protos) override { if (d_feContext && d_feContext->d_tlsCtx) { d_alpnProtos = protos; libssl_set_alpn_select_callback(d_feContext->d_tlsCtx.get(), alpnServerSelectCallback, this); return true; } if (d_tlsCtx) { return libssl_set_alpn_protos(d_tlsCtx.get(), protos); } return false; } bool setNextProtocolSelectCallback(bool(*cb)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen)) override { d_nextProtocolSelectCallback = cb; libssl_set_npn_select_callback(d_tlsCtx.get(), npnSelectCallback, this); return true; } private: /* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */ static int npnSelectCallback(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg) { if (!arg) { return SSL_TLSEXT_ERR_ALERT_WARNING; } OpenSSLTLSIOCtx* obj = reinterpret_cast(arg); if (obj->d_nextProtocolSelectCallback) { return (*obj->d_nextProtocolSelectCallback)(out, outlen, in, inlen) ? SSL_TLSEXT_ERR_OK : SSL_TLSEXT_ERR_ALERT_WARNING; } return SSL_TLSEXT_ERR_OK; } static int alpnServerSelectCallback(SSL*, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg) { if (!arg) { return SSL_TLSEXT_ERR_ALERT_WARNING; } OpenSSLTLSIOCtx* obj = reinterpret_cast(arg); size_t pos = 0; while (pos < inlen) { size_t protoLen = in[pos]; pos++; if (protoLen > (inlen - pos)) { /* something is very wrong */ return SSL_TLSEXT_ERR_ALERT_WARNING; } for (const auto& tentative : obj->d_alpnProtos) { if (tentative.size() == protoLen && memcmp(in + pos, tentative.data(), tentative.size()) == 0) { *out = in + pos; *outlen = protoLen; return SSL_TLSEXT_ERR_OK; } } pos += protoLen; } return SSL_TLSEXT_ERR_NOACK; } std::vector> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent std::shared_ptr d_feContext{nullptr}; std::shared_ptr d_tlsCtx{nullptr}; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx bool (*d_nextProtocolSelectCallback)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen){nullptr}; }; #endif /* HAVE_LIBSSL */ #ifdef HAVE_GNUTLS #include #include static void safe_memory_lock(void* data, size_t size) { #ifdef HAVE_LIBSODIUM sodium_mlock(data, size); #endif } static void safe_memory_release(void* data, size_t size) { #ifdef HAVE_LIBSODIUM sodium_munlock(data, size); #elif defined(HAVE_EXPLICIT_BZERO) explicit_bzero(data, size); #elif defined(HAVE_EXPLICIT_MEMSET) explicit_memset(data, 0, size); #elif defined(HAVE_GNUTLS_MEMSET) gnutls_memset(data, 0, size); #else /* shamelessly taken from Dovecot's src/lib/safe-memset.c */ volatile unsigned int volatile_zero_idx = 0; volatile unsigned char *p = reinterpret_cast(data); if (size == 0) return; do { memset(data, 0, size); } while (p[volatile_zero_idx] != 0); #endif } class GnuTLSTicketsKey { public: GnuTLSTicketsKey() { if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error generating tickets key for TLS context"); } safe_memory_lock(d_key.data, d_key.size); } GnuTLSTicketsKey(const std::string& keyFile) { /* to be sure we are loading the correct amount of data, which may change between versions, let's generate a correct key first */ if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context"); } safe_memory_lock(d_key.data, d_key.size); try { ifstream file(keyFile); file.read(reinterpret_cast(d_key.data), d_key.size); if (file.fail()) { file.close(); throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile); } file.close(); } catch (const std::exception& e) { safe_memory_release(d_key.data, d_key.size); gnutls_free(d_key.data); d_key.data = nullptr; throw; } } ~GnuTLSTicketsKey() { if (d_key.data != nullptr && d_key.size > 0) { safe_memory_release(d_key.data, d_key.size); } gnutls_free(d_key.data); d_key.data = nullptr; } const gnutls_datum_t& getKey() const { return d_key; } private: gnutls_datum_t d_key{nullptr, 0}; }; class GnuTLSSession : public TLSSession { public: GnuTLSSession(gnutls_datum_t& sess): d_sess(sess) { sess.data = nullptr; sess.size = 0; } virtual ~GnuTLSSession() { if (d_sess.data != nullptr && d_sess.size > 0) { safe_memory_release(d_sess.data, d_sess.size); } gnutls_free(d_sess.data); d_sess.data = nullptr; } const gnutls_datum_t& getNative() { return d_sess; } private: gnutls_datum_t d_sess{nullptr, 0}; }; class GnuTLSConnection: public TLSConnection { public: /* server side connection */ GnuTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr& creds, const gnutls_priority_t priorityCache, std::shared_ptr& ticketsKey, bool enableTickets): d_creds(creds), d_ticketsKey(ticketsKey), d_conn(std::unique_ptr(nullptr, gnutls_deinit)) { unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK; #ifdef GNUTLS_NO_SIGNAL sslOptions |= GNUTLS_NO_SIGNAL; #endif d_socket = socket; gnutls_session_t conn; if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error creating TLS connection"); } d_conn = std::unique_ptr(conn, gnutls_deinit); conn = nullptr; if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get()) != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error setting certificate and key to TLS connection"); } if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error setting ciphers to TLS connection"); } if (enableTickets && d_ticketsKey) { const gnutls_datum_t& key = d_ticketsKey->getKey(); if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error setting the tickets key to TLS connection"); } } gnutls_transport_set_int(d_conn.get(), d_socket); /* timeouts are in milliseconds */ gnutls_handshake_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000); gnutls_record_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000); } /* client-side connection */ GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, std::shared_ptr& creds, const gnutls_priority_t priorityCache, bool validateCerts): d_creds(creds), d_conn(std::unique_ptr(nullptr, gnutls_deinit)), d_host(host), d_client(true) { unsigned int sslOptions = GNUTLS_CLIENT | GNUTLS_NONBLOCK; #ifdef GNUTLS_NO_SIGNAL sslOptions |= GNUTLS_NO_SIGNAL; #endif d_socket = socket; gnutls_session_t conn; if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error creating TLS connection"); } d_conn = std::unique_ptr(conn, gnutls_deinit); conn = nullptr; int rc = gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get()); if (rc != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error setting certificate and key to TLS connection: " + std::string(gnutls_strerror(rc))); } rc = gnutls_priority_set(d_conn.get(), priorityCache); if (rc != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error setting ciphers to TLS connection: " + std::string(gnutls_strerror(rc))); } gnutls_transport_set_int(d_conn.get(), d_socket); /* timeouts are in milliseconds */ gnutls_handshake_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000); gnutls_record_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000); #if HAVE_GNUTLS_SESSION_SET_VERIFY_CERT if (validateCerts && !d_host.empty()) { gnutls_session_set_verify_cert(d_conn.get(), d_host.c_str(), GNUTLS_VERIFY_ALLOW_UNSORTED_CHAIN); rc = gnutls_server_name_set(d_conn.get(), GNUTLS_NAME_DNS, d_host.c_str(), d_host.size()); if (rc != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error setting the SNI value to '" + d_host + "' on TLS connection: " + std::string(gnutls_strerror(rc))); } } #else /* no hostname validation for you */ #endif /* allow access to our data in the callbacks */ gnutls_session_set_ptr(d_conn.get(), this); gnutls_handshake_set_hook_function(d_conn.get(), GNUTLS_HANDSHAKE_NEW_SESSION_TICKET, GNUTLS_HOOK_POST, newTicketFromServerCb); } /* The callback prototype changed in 3.4.0. */ #if GNUTLS_VERSION_NUMBER >= 0x030400 static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int incoming, const gnutls_datum_t* msg) #else static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int incoming) #endif /* GNUTLS_VERSION_NUMBER >= 0x030400 */ { if (htype != GNUTLS_HANDSHAKE_NEW_SESSION_TICKET || post != GNUTLS_HOOK_POST || session == nullptr) { return 0; } GnuTLSConnection* conn = reinterpret_cast(gnutls_session_get_ptr(session)); if (conn == nullptr) { return 0; } gnutls_datum_t sess{nullptr, 0}; auto ret = gnutls_session_get_data2(session, &sess); /* GnuTLS returns a 'fake' ticket of 4 bytes set to zero when there is no ticket available */ if (ret != GNUTLS_E_SUCCESS || sess.size <= 4) { throw std::runtime_error("Error getting GnuTLSSession: " + std::string(gnutls_strerror(ret))); } conn->d_tlsSessions.push_back(std::make_unique(sess)); return 0; } IOState tryConnect(bool fastOpen, const ComboAddress& remote) override { int ret = 0; if (fastOpen) { #ifdef HAVE_GNUTLS_TRANSPORT_SET_FASTOPEN gnutls_transport_set_fastopen(d_conn.get(), d_socket, const_cast(reinterpret_cast(&remote)), remote.getSocklen(), 0); #endif } do { ret = gnutls_handshake(d_conn.get()); if (ret == GNUTLS_E_SUCCESS) { d_handshakeDone = true; return IOState::Done; } else if (ret == GNUTLS_E_AGAIN) { int direction = gnutls_record_get_direction(d_conn.get()); return direction == 0 ? IOState::NeedRead : IOState::NeedWrite; } else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) { throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret))); } } while (ret == GNUTLS_E_INTERRUPTED); throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret))); } void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) override { struct timeval start = {0, 0}; struct timeval remainingTime = timeout; if (timeout.tv_sec != 0 || timeout.tv_usec != 0) { gettimeofday(&start, nullptr); } IOState state; do { state = tryConnect(fastOpen, remote); if (state == IOState::Done) { return; } else if (state == IOState::NeedRead) { int result = waitForData(d_socket, remainingTime.tv_sec, remainingTime.tv_usec); if (result <= 0) { throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result)); } } else if (state == IOState::NeedWrite) { int result = waitForRWData(d_socket, false, remainingTime.tv_sec, remainingTime.tv_usec); if (result <= 0) { throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result)); } } if (timeout.tv_sec != 0 || timeout.tv_usec != 0) { struct timeval now; gettimeofday(&now, nullptr); struct timeval elapsed = now - start; if (now < start || remainingTime < elapsed) { throw runtime_error("Timeout while establishing TLS connection"); } start = now; remainingTime = remainingTime - elapsed; } } while (state != IOState::Done); } void doHandshake() override { int ret = 0; do { ret = gnutls_handshake(d_conn.get()); if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) { if (d_client) { throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret))); } else { throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret))); } } } while (ret != GNUTLS_E_SUCCESS && ret == GNUTLS_E_INTERRUPTED); d_handshakeDone = true; } IOState tryHandshake() override { int ret = 0; do { ret = gnutls_handshake(d_conn.get()); if (ret == GNUTLS_E_SUCCESS) { d_handshakeDone = true; return IOState::Done; } else if (ret == GNUTLS_E_AGAIN) { int direction = gnutls_record_get_direction(d_conn.get()); return direction == 0 ? IOState::NeedRead : IOState::NeedWrite; } else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) { if (d_client) { std::string error; #if HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS if (ret == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR) { gnutls_datum_t out; if (gnutls_certificate_verification_status_print(gnutls_session_get_verify_cert_status(d_conn.get()), gnutls_certificate_type_get(d_conn.get()), &out, 0) == 0) { error = " (" + std::string(reinterpret_cast(out.data)) + ")"; gnutls_free(out.data); } } #endif /* HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS */ throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)) + error); } else { throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret))); } } } while (ret == GNUTLS_E_INTERRUPTED); if (d_client) { throw std::runtime_error("Error establishinging a new connection: " + std::string(gnutls_strerror(ret))); } else { throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret))); } } IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override { if (!d_handshakeDone) { /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us, we need to keep calling gnutls_handshake() until the handshake has been finished. */ auto state = tryHandshake(); if (state != IOState::Done) { return state; } } do { ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast(&buffer.at(pos)), toWrite - pos); if (res == 0) { throw std::runtime_error("Error writing to TLS connection"); } else if (res > 0) { pos += static_cast(res); } else if (res < 0) { if (gnutls_error_is_fatal(res)) { throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res))); } else if (res == GNUTLS_E_AGAIN) { return IOState::NeedWrite; } warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); } } while (pos < toWrite); return IOState::Done; } IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override { if (!d_handshakeDone) { /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us, we need to keep calling gnutls_handshake() until the handshake has been finished. */ auto state = tryHandshake(); if (state != IOState::Done) { return state; } } do { ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast(&buffer.at(pos)), toRead - pos); if (res == 0) { throw std::runtime_error("EOF while reading from TLS connection"); } else if (res > 0) { pos += static_cast(res); if (allowIncomplete) { break; } } else if (res < 0) { if (gnutls_error_is_fatal(res)) { throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res))); } else if (res == GNUTLS_E_AGAIN) { return IOState::NeedRead; } warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); } } while (pos < toRead); return IOState::Done; } size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override { size_t got = 0; struct timeval start{0,0}; struct timeval remainingTime = totalTimeout; if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) { gettimeofday(&start, nullptr); } do { ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast(buffer) + got), bufferSize - got); if (res == 0) { throw std::runtime_error("EOF while reading from TLS connection"); } else if (res > 0) { got += static_cast(res); if (allowIncomplete) { break; } } else if (res < 0) { if (gnutls_error_is_fatal(res)) { throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res))); } else if (res == GNUTLS_E_AGAIN) { int result = waitForData(d_socket, readTimeout.tv_sec, readTimeout.tv_usec); if (result <= 0) { throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result)); } } else { vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res)); } } if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) { struct timeval now; gettimeofday(&now, nullptr); struct timeval elapsed = now - start; if (now < start || remainingTime < elapsed) { throw runtime_error("Timeout while reading data"); } start = now; remainingTime = remainingTime - elapsed; } } while (got < bufferSize); return got; } size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override { size_t got = 0; do { ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast(buffer) + got), bufferSize - got); if (res == 0) { throw std::runtime_error("Error writing to TLS connection"); } else if (res > 0) { got += static_cast(res); } else if (res < 0) { if (gnutls_error_is_fatal(res)) { throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res))); } else if (res == GNUTLS_E_AGAIN) { int result = waitForRWData(d_socket, false, writeTimeout.tv_sec, writeTimeout.tv_usec); if (result <= 0) { throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result)); } } else { vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); } } } while (got < bufferSize); return got; } bool hasBufferedData() const override { if (d_conn) { return gnutls_record_check_pending(d_conn.get()) > 0; } return false; } bool isUsable() const override { if (!d_conn) { return false; } /* as far as I can tell we can't peek so we cannot do better */ return isTCPSocketUsable(d_socket); } std::string getServerNameIndication() const override { if (d_conn) { unsigned int type; size_t name_len = 256; std::string sni; sni.resize(name_len); int res = gnutls_server_name_get(d_conn.get(), const_cast(sni.c_str()), &name_len, &type, 0); if (res == GNUTLS_E_SUCCESS) { sni.resize(name_len); return sni; } } return std::string(); } std::vector getNextProtocol() const override { std::vector result; if (!d_conn) { return result; } gnutls_datum_t next; if (gnutls_alpn_get_selected_protocol(d_conn.get(), &next) != GNUTLS_E_SUCCESS) { return result; } result.insert(result.end(), next.data, next.data + next.size); return result; } LibsslTLSVersion getTLSVersion() const override { auto proto = gnutls_protocol_get_version(d_conn.get()); switch (proto) { case GNUTLS_TLS1_0: return LibsslTLSVersion::TLS10; case GNUTLS_TLS1_1: return LibsslTLSVersion::TLS11; case GNUTLS_TLS1_2: return LibsslTLSVersion::TLS12; #if GNUTLS_VERSION_NUMBER >= 0x030603 case GNUTLS_TLS1_3: return LibsslTLSVersion::TLS13; #endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */ default: return LibsslTLSVersion::Unknown; } } bool hasSessionBeenResumed() const override { if (d_conn) { return gnutls_session_is_resumed(d_conn.get()) != 0; } return false; } std::vector> getSessions() override { return std::move(d_tlsSessions); } void setSession(std::unique_ptr& session) override { auto sess = dynamic_cast(session.get()); if (!sess) { throw std::runtime_error("Unable to convert GnuTLS session"); } auto native = sess->getNative(); auto ret = gnutls_session_set_data(d_conn.get(), native.data, native.size); if (ret != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error setting up GnuTLS session: " + std::string(gnutls_strerror(ret))); } session.reset(); } void close() override { if (d_conn) { gnutls_bye(d_conn.get(), GNUTLS_SHUT_RDWR); } } bool setALPNProtos(const std::vector>& protos) { std::vector values; values.reserve(protos.size()); for (const auto& proto : protos) { gnutls_datum_t value; value.data = const_cast(proto.data()); value.size = proto.size(); values.push_back(value); } unsigned int flags = 0; #if GNUTLS_VERSION_NUMBER >= 0x030500 flags |= GNUTLS_ALPN_MANDATORY; #elif defined(GNUTLS_ALPN_MAND) flags |= GNUTLS_ALPN_MAND; #endif return gnutls_alpn_set_protocols(d_conn.get(), values.data(), values.size(), flags); } private: std::shared_ptr d_creds; std::shared_ptr d_ticketsKey; std::unique_ptr d_conn; std::vector> d_tlsSessions; std::string d_host; bool d_client{false}; bool d_handshakeDone{false}; }; class GnuTLSIOCtx: public TLSCtx { public: /* server side context */ GnuTLSIOCtx(TLSFrontend& fe): d_enableTickets(fe.d_tlsConfig.d_enableTickets) { int rc = 0; d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay; gnutls_certificate_credentials_t creds; rc = gnutls_certificate_allocate_credentials(&creds); if (rc != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); } d_creds = std::shared_ptr(creds, gnutls_certificate_free_credentials); creds = nullptr; for (const auto& pair : fe.d_tlsConfig.d_certKeyPairs) { rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.first.c_str(), pair.second.c_str(), GNUTLS_X509_FMT_PEM); if (rc != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error loading certificate ('" + pair.first + "') and key ('" + pair.second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); } } size_t count = 0; for (const auto& file : fe.d_tlsConfig.d_ocspFiles) { rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count); if (rc != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error loading OCSP response from file '" + file + "' for certificate ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).first + "') and key ('" + fe.d_tlsConfig.d_certKeyPairs.at(count).second + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); } ++count; } #if GNUTLS_VERSION_NUMBER >= 0x030600 rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH); if (rc != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc)); } #endif rc = gnutls_priority_init(&d_priorityCache, fe.d_tlsConfig.d_ciphers.empty() ? "NORMAL" : fe.d_tlsConfig.d_ciphers.c_str(), nullptr); if (rc != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe.d_tlsConfig.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + fe.d_addr.toStringWithPort()); } try { if (fe.d_tlsConfig.d_ticketKeyFile.empty()) { handleTicketsKeyRotation(time(nullptr)); } else { GnuTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile); } } catch(const std::runtime_error& e) { throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what()); } } /* client side context */ GnuTLSIOCtx(const TLSContextParameters& params): d_contextParameters(std::make_unique(params)), d_enableTickets(true), d_validateCerts(params.d_validateCertificates) { int rc = 0; gnutls_certificate_credentials_t creds; rc = gnutls_certificate_allocate_credentials(&creds); if (rc != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc))); } d_creds = std::shared_ptr(creds, gnutls_certificate_free_credentials); creds = nullptr; if (params.d_validateCertificates) { if (params.d_caStore.empty()) { #if GNUTLS_VERSION_NUMBER >= 0x030700 && GNUTLS_VERSION_NUMBER < 0x030703 /* see https://gitlab.com/gnutls/gnutls/-/issues/1277 */ std::cerr<<"Warning: GnuTLS 3.7.0 - 3.7.2 have a memory leak when validating server certificates in some configurations (PKCS11 support enabled, and a default PKCS11 trust store), please consider upgrading GnuTLS, using the OpenSSL provider for outgoing connections, or explicitly setting a CA store"<= 0x030700 && GNUTLS_VERSION_NUMBER < 0x030703 */ rc = gnutls_certificate_set_x509_system_trust(d_creds.get()); if (rc < 0) { throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc))); } } else { rc = gnutls_certificate_set_x509_trust_file(d_creds.get(), params.d_caStore.c_str(), GNUTLS_X509_FMT_PEM); if (rc < 0) { throw std::runtime_error("Error adding '" + params.d_caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc))); } } } rc = gnutls_priority_init(&d_priorityCache, params.d_ciphers.empty() ? "NORMAL" : params.d_ciphers.c_str(), nullptr); if (rc != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error setting up TLS cipher preferences to 'NORMAL' (" + std::string(gnutls_strerror(rc)) + ")"); } } virtual ~GnuTLSIOCtx() override { d_creds.reset(); if (d_priorityCache) { gnutls_priority_deinit(d_priorityCache); } } std::unique_ptr getConnection(int socket, const struct timeval& timeout, time_t now) override { handleTicketsKeyRotation(now); std::shared_ptr ticketsKey; { ticketsKey = *(d_ticketsKey.read_lock()); } auto connection = std::make_unique(socket, timeout, d_creds, d_priorityCache, ticketsKey, d_enableTickets); if (!d_protos.empty()) { connection->setALPNProtos(d_protos); } return connection; } static std::shared_ptr getPerThreadCredentials(bool validate, const std::string& caStore) { static thread_local std::map, std::shared_ptr> t_credentials; auto& entry = t_credentials[{validate, caStore}]; if (!entry) { gnutls_certificate_credentials_t creds; int rc = gnutls_certificate_allocate_credentials(&creds); if (rc != GNUTLS_E_SUCCESS) { throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc))); } entry = std::shared_ptr(creds, gnutls_certificate_free_credentials); creds = nullptr; if (validate) { if (caStore.empty()) { rc = gnutls_certificate_set_x509_system_trust(entry.get()); if (rc < 0) { throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc))); } } else { rc = gnutls_certificate_set_x509_trust_file(entry.get(), caStore.c_str(), GNUTLS_X509_FMT_PEM); if (rc < 0) { throw std::runtime_error("Error adding '" + caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc))); } } } } return entry; } std::unique_ptr getClientConnection(const std::string& host, int socket, const struct timeval& timeout) override { auto creds = getPerThreadCredentials(d_contextParameters->d_validateCertificates, d_contextParameters->d_caStore); auto connection = std::make_unique(host, socket, timeout, creds, d_priorityCache, d_validateCerts); if (!d_protos.empty()) { connection->setALPNProtos(d_protos); } return connection; } void rotateTicketsKey(time_t now) override { if (!d_enableTickets) { return; } auto newKey = std::make_shared(); { *(d_ticketsKey.write_lock()) = newKey; } if (d_ticketsKeyRotationDelay > 0) { d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay; } } void loadTicketsKeys(const std::string& file) override final { if (!d_enableTickets) { return; } auto newKey = std::make_shared(file); { *(d_ticketsKey.write_lock()) = newKey; } if (d_ticketsKeyRotationDelay > 0) { d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay; } } size_t getTicketsKeysCount() override { return *(d_ticketsKey.read_lock()) != nullptr ? 1 : 0; } std::string getName() const override { return "gnutls"; } bool setALPNProtos(const std::vector>& protos) override { #ifdef HAVE_GNUTLS_ALPN_SET_PROTOCOLS d_protos = protos; return true; #else return false; #endif } private: /* client context parameters */ std::unique_ptr d_contextParameters{nullptr}; std::shared_ptr d_creds; std::vector> d_protos; gnutls_priority_t d_priorityCache{nullptr}; SharedLockGuarded> d_ticketsKey{nullptr}; bool d_enableTickets{true}; bool d_validateCerts{true}; }; #endif /* HAVE_GNUTLS */ #endif /* HAVE_DNS_OVER_TLS */ bool setupDoTProtocolNegotiation(std::shared_ptr& ctx) { if (ctx == nullptr) { return false; } /* we want to set the ALPN to dot (RFC7858), if only to mitigate the ALPACA attack */ const std::vector> dotAlpns = {{'d', 'o', 't'}}; ctx->setALPNProtos(dotAlpns); return true; } bool TLSFrontend::setupTLS() { #ifdef HAVE_DNS_OVER_TLS std::shared_ptr newCtx{nullptr}; /* get the "best" available provider */ if (!d_provider.empty()) { #ifdef HAVE_GNUTLS if (d_provider == "gnutls") { newCtx = std::make_shared(*this); setupDoTProtocolNegotiation(newCtx); std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release); return true; } #endif /* HAVE_GNUTLS */ #ifdef HAVE_LIBSSL if (d_provider == "openssl") { newCtx = std::make_shared(*this); setupDoTProtocolNegotiation(newCtx); std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release); return true; } #endif /* HAVE_LIBSSL */ } #ifdef HAVE_LIBSSL newCtx = std::make_shared(*this); #else /* HAVE_LIBSSL */ #ifdef HAVE_GNUTLS newCtx = std::make_shared(*this); #endif /* HAVE_GNUTLS */ #endif /* HAVE_LIBSSL */ setupDoTProtocolNegotiation(newCtx); std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release); #endif /* HAVE_DNS_OVER_TLS */ return true; } std::shared_ptr getTLSContext(const TLSContextParameters& params) { #ifdef HAVE_DNS_OVER_TLS /* get the "best" available provider */ if (!params.d_provider.empty()) { #ifdef HAVE_GNUTLS if (params.d_provider == "gnutls") { return std::make_shared(params); } #endif /* HAVE_GNUTLS */ #ifdef HAVE_LIBSSL if (params.d_provider == "openssl") { return std::make_shared(params); } #endif /* HAVE_LIBSSL */ } #ifdef HAVE_GNUTLS return std::make_shared(params); #else /* HAVE_GNUTLS */ #ifdef HAVE_LIBSSL return std::make_shared(params); #endif /* HAVE_LIBSSL */ #endif /* HAVE_GNUTLS */ #endif /* HAVE_DNS_OVER_TLS */ return nullptr; }