diff options
Diffstat (limited to 'src/shrpx_connection_handler.cc')
-rw-r--r-- | src/shrpx_connection_handler.cc | 1321 |
1 files changed, 1321 insertions, 0 deletions
diff --git a/src/shrpx_connection_handler.cc b/src/shrpx_connection_handler.cc new file mode 100644 index 0000000..be6645d --- /dev/null +++ b/src/shrpx_connection_handler.cc @@ -0,0 +1,1321 @@ +/* + * nghttp2 - HTTP/2 C Library + * + * Copyright (c) 2012 Tatsuhiro Tsujikawa + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "shrpx_connection_handler.h" + +#ifdef HAVE_UNISTD_H +# include <unistd.h> +#endif // HAVE_UNISTD_H +#include <sys/types.h> +#include <sys/wait.h> + +#include <cerrno> +#include <thread> +#include <random> + +#include "shrpx_client_handler.h" +#include "shrpx_tls.h" +#include "shrpx_worker.h" +#include "shrpx_config.h" +#include "shrpx_http2_session.h" +#include "shrpx_connect_blocker.h" +#include "shrpx_downstream_connection.h" +#include "shrpx_accept_handler.h" +#include "shrpx_memcached_dispatcher.h" +#include "shrpx_signal.h" +#include "shrpx_log.h" +#include "xsi_strerror.h" +#include "util.h" +#include "template.h" + +using namespace nghttp2; + +namespace shrpx { + +namespace { +void acceptor_disable_cb(struct ev_loop *loop, ev_timer *w, int revent) { + auto h = static_cast<ConnectionHandler *>(w->data); + + // If we are in graceful shutdown period, we must not enable + // acceptors again. + if (h->get_graceful_shutdown()) { + return; + } + + h->enable_acceptor(); +} +} // namespace + +namespace { +void ocsp_cb(struct ev_loop *loop, ev_timer *w, int revent) { + auto h = static_cast<ConnectionHandler *>(w->data); + + // If we are in graceful shutdown period, we won't do ocsp query. + if (h->get_graceful_shutdown()) { + return; + } + + LOG(NOTICE) << "Start ocsp update"; + + h->proceed_next_cert_ocsp(); +} +} // namespace + +namespace { +void ocsp_read_cb(struct ev_loop *loop, ev_io *w, int revent) { + auto h = static_cast<ConnectionHandler *>(w->data); + + h->read_ocsp_chunk(); +} +} // namespace + +namespace { +void ocsp_chld_cb(struct ev_loop *loop, ev_child *w, int revent) { + auto h = static_cast<ConnectionHandler *>(w->data); + + h->handle_ocsp_complete(); +} +} // namespace + +namespace { +void thread_join_async_cb(struct ev_loop *loop, ev_async *w, int revent) { + ev_break(loop); +} +} // namespace + +namespace { +void serial_event_async_cb(struct ev_loop *loop, ev_async *w, int revent) { + auto h = static_cast<ConnectionHandler *>(w->data); + + h->handle_serial_event(); +} +} // namespace + +ConnectionHandler::ConnectionHandler(struct ev_loop *loop, std::mt19937 &gen) + : +#ifdef ENABLE_HTTP3 + quic_ipc_fd_(-1), +#endif // ENABLE_HTTP3 + gen_(gen), + single_worker_(nullptr), + loop_(loop), +#ifdef HAVE_NEVERBLEED + nb_(nullptr), +#endif // HAVE_NEVERBLEED + tls_ticket_key_memcached_get_retry_count_(0), + tls_ticket_key_memcached_fail_count_(0), + worker_round_robin_cnt_(get_config()->api.enabled ? 1 : 0), + graceful_shutdown_(false), + enable_acceptor_on_ocsp_completion_(false) { + ev_timer_init(&disable_acceptor_timer_, acceptor_disable_cb, 0., 0.); + disable_acceptor_timer_.data = this; + + ev_timer_init(&ocsp_timer_, ocsp_cb, 0., 0.); + ocsp_timer_.data = this; + + ev_io_init(&ocsp_.rev, ocsp_read_cb, -1, EV_READ); + ocsp_.rev.data = this; + + ev_async_init(&thread_join_asyncev_, thread_join_async_cb); + + ev_async_init(&serial_event_asyncev_, serial_event_async_cb); + serial_event_asyncev_.data = this; + + ev_async_start(loop_, &serial_event_asyncev_); + + ev_child_init(&ocsp_.chldev, ocsp_chld_cb, 0, 0); + ocsp_.chldev.data = this; + + ocsp_.next = 0; + ocsp_.proc.rfd = -1; + + reset_ocsp(); +} + +ConnectionHandler::~ConnectionHandler() { + ev_child_stop(loop_, &ocsp_.chldev); + ev_async_stop(loop_, &serial_event_asyncev_); + ev_async_stop(loop_, &thread_join_asyncev_); + ev_io_stop(loop_, &ocsp_.rev); + ev_timer_stop(loop_, &ocsp_timer_); + ev_timer_stop(loop_, &disable_acceptor_timer_); + +#ifdef ENABLE_HTTP3 + for (auto ssl_ctx : quic_all_ssl_ctx_) { + if (ssl_ctx == nullptr) { + continue; + } + + auto tls_ctx_data = + static_cast<tls::TLSContextData *>(SSL_CTX_get_app_data(ssl_ctx)); + delete tls_ctx_data; + SSL_CTX_free(ssl_ctx); + } +#endif // ENABLE_HTTP3 + + for (auto ssl_ctx : all_ssl_ctx_) { + auto tls_ctx_data = + static_cast<tls::TLSContextData *>(SSL_CTX_get_app_data(ssl_ctx)); + delete tls_ctx_data; + SSL_CTX_free(ssl_ctx); + } + + // Free workers before destroying ev_loop + workers_.clear(); + + for (auto loop : worker_loops_) { + ev_loop_destroy(loop); + } +} + +void ConnectionHandler::set_ticket_keys_to_worker( + const std::shared_ptr<TicketKeys> &ticket_keys) { + for (auto &worker : workers_) { + worker->set_ticket_keys(ticket_keys); + } +} + +void ConnectionHandler::worker_reopen_log_files() { + for (auto &worker : workers_) { + WorkerEvent wev{}; + + wev.type = WorkerEventType::REOPEN_LOG; + + worker->send(std::move(wev)); + } +} + +void ConnectionHandler::worker_replace_downstream( + std::shared_ptr<DownstreamConfig> downstreamconf) { + for (auto &worker : workers_) { + WorkerEvent wev{}; + + wev.type = WorkerEventType::REPLACE_DOWNSTREAM; + wev.downstreamconf = downstreamconf; + + worker->send(std::move(wev)); + } +} + +int ConnectionHandler::create_single_worker() { + cert_tree_ = tls::create_cert_lookup_tree(); + auto sv_ssl_ctx = tls::setup_server_ssl_context( + all_ssl_ctx_, indexed_ssl_ctx_, cert_tree_.get() +#ifdef HAVE_NEVERBLEED + , + nb_ +#endif // HAVE_NEVERBLEED + ); + +#ifdef ENABLE_HTTP3 + quic_cert_tree_ = tls::create_cert_lookup_tree(); + auto quic_sv_ssl_ctx = tls::setup_quic_server_ssl_context( + quic_all_ssl_ctx_, quic_indexed_ssl_ctx_, quic_cert_tree_.get() +# ifdef HAVE_NEVERBLEED + , + nb_ +# endif // HAVE_NEVERBLEED + ); +#endif // ENABLE_HTTP3 + + auto cl_ssl_ctx = tls::setup_downstream_client_ssl_context( +#ifdef HAVE_NEVERBLEED + nb_ +#endif // HAVE_NEVERBLEED + ); + + if (cl_ssl_ctx) { + all_ssl_ctx_.push_back(cl_ssl_ctx); +#ifdef ENABLE_HTTP3 + quic_all_ssl_ctx_.push_back(nullptr); +#endif // ENABLE_HTTP3 + } + + auto config = get_config(); + auto &tlsconf = config->tls; + + SSL_CTX *session_cache_ssl_ctx = nullptr; + { + auto &memcachedconf = config->tls.session_cache.memcached; + if (memcachedconf.tls) { + session_cache_ssl_ctx = tls::create_ssl_client_context( +#ifdef HAVE_NEVERBLEED + nb_, +#endif // HAVE_NEVERBLEED + tlsconf.cacert, memcachedconf.cert_file, + memcachedconf.private_key_file, nullptr); + all_ssl_ctx_.push_back(session_cache_ssl_ctx); +#ifdef ENABLE_HTTP3 + quic_all_ssl_ctx_.push_back(nullptr); +#endif // ENABLE_HTTP3 + } + } + +#if defined(ENABLE_HTTP3) && defined(HAVE_LIBBPF) + quic_bpf_refs_.resize(config->conn.quic_listener.addrs.size()); +#endif // ENABLE_HTTP3 && HAVE_LIBBPF + +#ifdef ENABLE_HTTP3 + assert(cid_prefixes_.size() == 1); + const auto &cid_prefix = cid_prefixes_[0]; +#endif // ENABLE_HTTP3 + + single_worker_ = std::make_unique<Worker>( + loop_, sv_ssl_ctx, cl_ssl_ctx, session_cache_ssl_ctx, cert_tree_.get(), +#ifdef ENABLE_HTTP3 + quic_sv_ssl_ctx, quic_cert_tree_.get(), cid_prefix.data(), + cid_prefix.size(), +# ifdef HAVE_LIBBPF + /* index = */ 0, +# endif // HAVE_LIBBPF +#endif // ENABLE_HTTP3 + ticket_keys_, this, config->conn.downstream); +#ifdef HAVE_MRUBY + if (single_worker_->create_mruby_context() != 0) { + return -1; + } +#endif // HAVE_MRUBY + +#ifdef ENABLE_HTTP3 + if (single_worker_->setup_quic_server_socket() != 0) { + return -1; + } +#endif // ENABLE_HTTP3 + + return 0; +} + +int ConnectionHandler::create_worker_thread(size_t num) { +#ifndef NOTHREADS + assert(workers_.size() == 0); + + cert_tree_ = tls::create_cert_lookup_tree(); + auto sv_ssl_ctx = tls::setup_server_ssl_context( + all_ssl_ctx_, indexed_ssl_ctx_, cert_tree_.get() +# ifdef HAVE_NEVERBLEED + , + nb_ +# endif // HAVE_NEVERBLEED + ); + +# ifdef ENABLE_HTTP3 + quic_cert_tree_ = tls::create_cert_lookup_tree(); + auto quic_sv_ssl_ctx = tls::setup_quic_server_ssl_context( + quic_all_ssl_ctx_, quic_indexed_ssl_ctx_, quic_cert_tree_.get() +# ifdef HAVE_NEVERBLEED + , + nb_ +# endif // HAVE_NEVERBLEED + ); +# endif // ENABLE_HTTP3 + + auto cl_ssl_ctx = tls::setup_downstream_client_ssl_context( +# ifdef HAVE_NEVERBLEED + nb_ +# endif // HAVE_NEVERBLEED + ); + + if (cl_ssl_ctx) { + all_ssl_ctx_.push_back(cl_ssl_ctx); +# ifdef ENABLE_HTTP3 + quic_all_ssl_ctx_.push_back(nullptr); +# endif // ENABLE_HTTP3 + } + + auto config = get_config(); + auto &tlsconf = config->tls; + auto &apiconf = config->api; + +# if defined(ENABLE_HTTP3) && defined(HAVE_LIBBPF) + quic_bpf_refs_.resize(config->conn.quic_listener.addrs.size()); +# endif // ENABLE_HTTP3 && HAVE_LIBBPF + + // We have dedicated worker for API request processing. + if (apiconf.enabled) { + ++num; + } + + SSL_CTX *session_cache_ssl_ctx = nullptr; + { + auto &memcachedconf = config->tls.session_cache.memcached; + + if (memcachedconf.tls) { + session_cache_ssl_ctx = tls::create_ssl_client_context( +# ifdef HAVE_NEVERBLEED + nb_, +# endif // HAVE_NEVERBLEED + tlsconf.cacert, memcachedconf.cert_file, + memcachedconf.private_key_file, nullptr); + all_ssl_ctx_.push_back(session_cache_ssl_ctx); +# ifdef ENABLE_HTTP3 + quic_all_ssl_ctx_.push_back(nullptr); +# endif // ENABLE_HTTP3 + } + } + +# ifdef ENABLE_HTTP3 + assert(cid_prefixes_.size() == num); +# endif // ENABLE_HTTP3 + + for (size_t i = 0; i < num; ++i) { + auto loop = ev_loop_new(config->ev_loop_flags); + +# ifdef ENABLE_HTTP3 + const auto &cid_prefix = cid_prefixes_[i]; +# endif // ENABLE_HTTP3 + + auto worker = std::make_unique<Worker>( + loop, sv_ssl_ctx, cl_ssl_ctx, session_cache_ssl_ctx, cert_tree_.get(), +# ifdef ENABLE_HTTP3 + quic_sv_ssl_ctx, quic_cert_tree_.get(), cid_prefix.data(), + cid_prefix.size(), +# ifdef HAVE_LIBBPF + i, +# endif // HAVE_LIBBPF +# endif // ENABLE_HTTP3 + ticket_keys_, this, config->conn.downstream); +# ifdef HAVE_MRUBY + if (worker->create_mruby_context() != 0) { + return -1; + } +# endif // HAVE_MRUBY + +# ifdef ENABLE_HTTP3 + if ((!apiconf.enabled || i != 0) && + worker->setup_quic_server_socket() != 0) { + return -1; + } +# endif // ENABLE_HTTP3 + + workers_.push_back(std::move(worker)); + worker_loops_.push_back(loop); + + LLOG(NOTICE, this) << "Created worker thread #" << workers_.size() - 1; + } + + for (auto &worker : workers_) { + worker->run_async(); + } + +#endif // NOTHREADS + + return 0; +} + +void ConnectionHandler::join_worker() { +#ifndef NOTHREADS + int n = 0; + + if (LOG_ENABLED(INFO)) { + LLOG(INFO, this) << "Waiting for worker thread to join: n=" + << workers_.size(); + } + + for (auto &worker : workers_) { + worker->wait(); + if (LOG_ENABLED(INFO)) { + LLOG(INFO, this) << "Thread #" << n << " joined"; + } + ++n; + } +#endif // NOTHREADS +} + +void ConnectionHandler::graceful_shutdown_worker() { + if (single_worker_) { + return; + } + + if (LOG_ENABLED(INFO)) { + LLOG(INFO, this) << "Sending graceful shutdown signal to worker"; + } + + for (auto &worker : workers_) { + WorkerEvent wev{}; + wev.type = WorkerEventType::GRACEFUL_SHUTDOWN; + + worker->send(std::move(wev)); + } + +#ifndef NOTHREADS + ev_async_start(loop_, &thread_join_asyncev_); + + thread_join_fut_ = std::async(std::launch::async, [this]() { + (void)reopen_log_files(get_config()->logging); + join_worker(); + ev_async_send(get_loop(), &thread_join_asyncev_); + delete_log_config(); + }); +#endif // NOTHREADS +} + +int ConnectionHandler::handle_connection(int fd, sockaddr *addr, int addrlen, + const UpstreamAddr *faddr) { + if (LOG_ENABLED(INFO)) { + LLOG(INFO, this) << "Accepted connection from " + << util::numeric_name(addr, addrlen) << ", fd=" << fd; + } + + auto config = get_config(); + + if (single_worker_) { + auto &upstreamconf = config->conn.upstream; + if (single_worker_->get_worker_stat()->num_connections >= + upstreamconf.worker_connections) { + + if (LOG_ENABLED(INFO)) { + LLOG(INFO, this) << "Too many connections >=" + << upstreamconf.worker_connections; + } + + close(fd); + return -1; + } + + auto client = + tls::accept_connection(single_worker_.get(), fd, addr, addrlen, faddr); + if (!client) { + LLOG(ERROR, this) << "ClientHandler creation failed"; + + close(fd); + return -1; + } + + return 0; + } + + Worker *worker; + + if (faddr->alt_mode == UpstreamAltMode::API) { + worker = workers_[0].get(); + + if (LOG_ENABLED(INFO)) { + LOG(INFO) << "Dispatch connection to API worker #0"; + } + } else { + worker = workers_[worker_round_robin_cnt_].get(); + + if (LOG_ENABLED(INFO)) { + LOG(INFO) << "Dispatch connection to worker #" << worker_round_robin_cnt_; + } + + if (++worker_round_robin_cnt_ == workers_.size()) { + auto &apiconf = config->api; + + if (apiconf.enabled) { + worker_round_robin_cnt_ = 1; + } else { + worker_round_robin_cnt_ = 0; + } + } + } + + WorkerEvent wev{}; + wev.type = WorkerEventType::NEW_CONNECTION; + wev.client_fd = fd; + memcpy(&wev.client_addr, addr, addrlen); + wev.client_addrlen = addrlen; + wev.faddr = faddr; + + worker->send(std::move(wev)); + + return 0; +} + +struct ev_loop *ConnectionHandler::get_loop() const { + return loop_; +} + +Worker *ConnectionHandler::get_single_worker() const { + return single_worker_.get(); +} + +void ConnectionHandler::add_acceptor(std::unique_ptr<AcceptHandler> h) { + acceptors_.push_back(std::move(h)); +} + +void ConnectionHandler::delete_acceptor() { acceptors_.clear(); } + +void ConnectionHandler::enable_acceptor() { + for (auto &a : acceptors_) { + a->enable(); + } +} + +void ConnectionHandler::disable_acceptor() { + for (auto &a : acceptors_) { + a->disable(); + } +} + +void ConnectionHandler::sleep_acceptor(ev_tstamp t) { + if (t == 0. || ev_is_active(&disable_acceptor_timer_)) { + return; + } + + disable_acceptor(); + + ev_timer_set(&disable_acceptor_timer_, t, 0.); + ev_timer_start(loop_, &disable_acceptor_timer_); +} + +void ConnectionHandler::accept_pending_connection() { + for (auto &a : acceptors_) { + a->accept_connection(); + } +} + +void ConnectionHandler::set_ticket_keys( + std::shared_ptr<TicketKeys> ticket_keys) { + ticket_keys_ = std::move(ticket_keys); + if (single_worker_) { + single_worker_->set_ticket_keys(ticket_keys_); + } +} + +const std::shared_ptr<TicketKeys> &ConnectionHandler::get_ticket_keys() const { + return ticket_keys_; +} + +void ConnectionHandler::set_graceful_shutdown(bool f) { + graceful_shutdown_ = f; + if (single_worker_) { + single_worker_->set_graceful_shutdown(f); + } +} + +bool ConnectionHandler::get_graceful_shutdown() const { + return graceful_shutdown_; +} + +void ConnectionHandler::cancel_ocsp_update() { + enable_acceptor_on_ocsp_completion_ = false; + ev_timer_stop(loop_, &ocsp_timer_); + + if (ocsp_.proc.pid == 0) { + return; + } + + int rv; + + rv = kill(ocsp_.proc.pid, SIGTERM); + if (rv != 0) { + auto error = errno; + LOG(ERROR) << "Could not send signal to OCSP query process: errno=" + << error; + } + + while ((rv = waitpid(ocsp_.proc.pid, nullptr, 0)) == -1 && errno == EINTR) + ; + if (rv == -1) { + auto error = errno; + LOG(ERROR) << "Error occurred while we were waiting for the completion of " + "OCSP query process: errno=" + << error; + } +} + +// inspired by h2o_read_command function from h2o project: +// https://github.com/h2o/h2o +int ConnectionHandler::start_ocsp_update(const char *cert_file) { + int rv; + + if (LOG_ENABLED(INFO)) { + LOG(INFO) << "Start ocsp update for " << cert_file; + } + + assert(!ev_is_active(&ocsp_.rev)); + assert(!ev_is_active(&ocsp_.chldev)); + + char *const argv[] = { + const_cast<char *>( + get_config()->tls.ocsp.fetch_ocsp_response_file.c_str()), + const_cast<char *>(cert_file), nullptr}; + + Process proc; + rv = exec_read_command(proc, argv); + if (rv != 0) { + return -1; + } + + ocsp_.proc = proc; + + ev_io_set(&ocsp_.rev, ocsp_.proc.rfd, EV_READ); + ev_io_start(loop_, &ocsp_.rev); + + ev_child_set(&ocsp_.chldev, ocsp_.proc.pid, 0); + ev_child_start(loop_, &ocsp_.chldev); + + return 0; +} + +void ConnectionHandler::read_ocsp_chunk() { + std::array<uint8_t, 4_k> buf; + for (;;) { + ssize_t n; + while ((n = read(ocsp_.proc.rfd, buf.data(), buf.size())) == -1 && + errno == EINTR) + ; + + if (n == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; + } + auto error = errno; + LOG(WARN) << "Reading from ocsp query command failed: errno=" << error; + ocsp_.error = error; + + break; + } + + if (n == 0) { + break; + } + + std::copy_n(std::begin(buf), n, std::back_inserter(ocsp_.resp)); + } + + ev_io_stop(loop_, &ocsp_.rev); +} + +void ConnectionHandler::handle_ocsp_complete() { + ev_io_stop(loop_, &ocsp_.rev); + ev_child_stop(loop_, &ocsp_.chldev); + + assert(ocsp_.next < all_ssl_ctx_.size()); +#ifdef ENABLE_HTTP3 + assert(all_ssl_ctx_.size() == quic_all_ssl_ctx_.size()); +#endif // ENABLE_HTTP3 + + auto ssl_ctx = all_ssl_ctx_[ocsp_.next]; + auto tls_ctx_data = + static_cast<tls::TLSContextData *>(SSL_CTX_get_app_data(ssl_ctx)); + + auto rstatus = ocsp_.chldev.rstatus; + auto status = WEXITSTATUS(rstatus); + if (ocsp_.error || !WIFEXITED(rstatus) || status != 0) { + LOG(WARN) << "ocsp query command for " << tls_ctx_data->cert_file + << " failed: error=" << ocsp_.error << ", rstatus=" << log::hex + << rstatus << log::dec << ", status=" << status; + ++ocsp_.next; + proceed_next_cert_ocsp(); + return; + } + + if (LOG_ENABLED(INFO)) { + LOG(INFO) << "ocsp update for " << tls_ctx_data->cert_file + << " finished successfully"; + } + + auto config = get_config(); + auto &tlsconf = config->tls; + + if (tlsconf.ocsp.no_verify || + tls::verify_ocsp_response(ssl_ctx, ocsp_.resp.data(), + ocsp_.resp.size()) == 0) { +#ifdef ENABLE_HTTP3 + // We have list of SSL_CTX with the same certificate in + // quic_all_ssl_ctx_ as well. Some SSL_CTXs are missing there in + // that case we get nullptr. + auto quic_ssl_ctx = quic_all_ssl_ctx_[ocsp_.next]; + if (quic_ssl_ctx) { +# ifndef OPENSSL_IS_BORINGSSL + auto quic_tls_ctx_data = static_cast<tls::TLSContextData *>( + SSL_CTX_get_app_data(quic_ssl_ctx)); +# ifdef HAVE_ATOMIC_STD_SHARED_PTR + std::atomic_store_explicit( + &quic_tls_ctx_data->ocsp_data, + std::make_shared<std::vector<uint8_t>>(ocsp_.resp), + std::memory_order_release); +# else // !HAVE_ATOMIC_STD_SHARED_PTR + std::lock_guard<std::mutex> g(quic_tls_ctx_data->mu); + quic_tls_ctx_data->ocsp_data = + std::make_shared<std::vector<uint8_t>>(ocsp_.resp); +# endif // !HAVE_ATOMIC_STD_SHARED_PTR +# else // OPENSSL_IS_BORINGSSL + SSL_CTX_set_ocsp_response(quic_ssl_ctx, ocsp_.resp.data(), + ocsp_.resp.size()); +# endif // OPENSSL_IS_BORINGSSL + } +#endif // ENABLE_HTTP3 + +#ifndef OPENSSL_IS_BORINGSSL +# ifdef HAVE_ATOMIC_STD_SHARED_PTR + std::atomic_store_explicit( + &tls_ctx_data->ocsp_data, + std::make_shared<std::vector<uint8_t>>(std::move(ocsp_.resp)), + std::memory_order_release); +# else // !HAVE_ATOMIC_STD_SHARED_PTR + std::lock_guard<std::mutex> g(tls_ctx_data->mu); + tls_ctx_data->ocsp_data = + std::make_shared<std::vector<uint8_t>>(std::move(ocsp_.resp)); +# endif // !HAVE_ATOMIC_STD_SHARED_PTR +#else // OPENSSL_IS_BORINGSSL + SSL_CTX_set_ocsp_response(ssl_ctx, ocsp_.resp.data(), ocsp_.resp.size()); +#endif // OPENSSL_IS_BORINGSSL + } + + ++ocsp_.next; + proceed_next_cert_ocsp(); +} + +void ConnectionHandler::reset_ocsp() { + if (ocsp_.proc.rfd != -1) { + close(ocsp_.proc.rfd); + } + + ocsp_.proc.rfd = -1; + ocsp_.proc.pid = 0; + ocsp_.error = 0; + ocsp_.resp = std::vector<uint8_t>(); +} + +void ConnectionHandler::proceed_next_cert_ocsp() { + for (;;) { + reset_ocsp(); + if (ocsp_.next == all_ssl_ctx_.size()) { + ocsp_.next = 0; + // We have updated all ocsp response, and schedule next update. + ev_timer_set(&ocsp_timer_, get_config()->tls.ocsp.update_interval, 0.); + ev_timer_start(loop_, &ocsp_timer_); + + if (enable_acceptor_on_ocsp_completion_) { + enable_acceptor_on_ocsp_completion_ = false; + enable_acceptor(); + } + + return; + } + + auto ssl_ctx = all_ssl_ctx_[ocsp_.next]; + auto tls_ctx_data = + static_cast<tls::TLSContextData *>(SSL_CTX_get_app_data(ssl_ctx)); + + // client SSL_CTX is also included in all_ssl_ctx_, but has no + // tls_ctx_data. + if (!tls_ctx_data) { + ++ocsp_.next; + continue; + } + + auto cert_file = tls_ctx_data->cert_file; + + if (start_ocsp_update(cert_file) != 0) { + ++ocsp_.next; + continue; + } + + break; + } +} + +void ConnectionHandler::set_tls_ticket_key_memcached_dispatcher( + std::unique_ptr<MemcachedDispatcher> dispatcher) { + tls_ticket_key_memcached_dispatcher_ = std::move(dispatcher); +} + +MemcachedDispatcher * +ConnectionHandler::get_tls_ticket_key_memcached_dispatcher() const { + return tls_ticket_key_memcached_dispatcher_.get(); +} + +// Use the similar backoff algorithm described in +// https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md +namespace { +constexpr size_t MAX_BACKOFF_EXP = 10; +constexpr auto MULTIPLIER = 3.2; +constexpr auto JITTER = 0.2; +} // namespace + +void ConnectionHandler::on_tls_ticket_key_network_error(ev_timer *w) { + if (++tls_ticket_key_memcached_get_retry_count_ >= + get_config()->tls.ticket.memcached.max_retry) { + LOG(WARN) << "Memcached: tls ticket get retry all failed " + << tls_ticket_key_memcached_get_retry_count_ << " times."; + + on_tls_ticket_key_not_found(w); + return; + } + + auto base_backoff = util::int_pow( + MULTIPLIER, + std::min(MAX_BACKOFF_EXP, tls_ticket_key_memcached_get_retry_count_)); + auto dist = std::uniform_real_distribution<>(-JITTER * base_backoff, + JITTER * base_backoff); + + auto backoff = base_backoff + dist(gen_); + + LOG(WARN) + << "Memcached: tls ticket get failed due to network error, retrying in " + << backoff << " seconds"; + + ev_timer_set(w, backoff, 0.); + ev_timer_start(loop_, w); +} + +void ConnectionHandler::on_tls_ticket_key_not_found(ev_timer *w) { + tls_ticket_key_memcached_get_retry_count_ = 0; + + if (++tls_ticket_key_memcached_fail_count_ >= + get_config()->tls.ticket.memcached.max_fail) { + LOG(WARN) << "Memcached: could not get tls ticket; disable tls ticket"; + + tls_ticket_key_memcached_fail_count_ = 0; + + set_ticket_keys(nullptr); + set_ticket_keys_to_worker(nullptr); + } + + LOG(WARN) << "Memcached: tls ticket get failed, schedule next"; + schedule_next_tls_ticket_key_memcached_get(w); +} + +void ConnectionHandler::on_tls_ticket_key_get_success( + const std::shared_ptr<TicketKeys> &ticket_keys, ev_timer *w) { + LOG(NOTICE) << "Memcached: tls ticket get success"; + + tls_ticket_key_memcached_get_retry_count_ = 0; + tls_ticket_key_memcached_fail_count_ = 0; + + schedule_next_tls_ticket_key_memcached_get(w); + + if (!ticket_keys || ticket_keys->keys.empty()) { + LOG(WARN) << "Memcached: tls ticket keys are empty; tls ticket disabled"; + set_ticket_keys(nullptr); + set_ticket_keys_to_worker(nullptr); + return; + } + + if (LOG_ENABLED(INFO)) { + LOG(INFO) << "ticket keys get done"; + LOG(INFO) << 0 << " enc+dec: " + << util::format_hex(ticket_keys->keys[0].data.name); + for (size_t i = 1; i < ticket_keys->keys.size(); ++i) { + auto &key = ticket_keys->keys[i]; + LOG(INFO) << i << " dec: " << util::format_hex(key.data.name); + } + } + + set_ticket_keys(ticket_keys); + set_ticket_keys_to_worker(ticket_keys); +} + +void ConnectionHandler::schedule_next_tls_ticket_key_memcached_get( + ev_timer *w) { + ev_timer_set(w, get_config()->tls.ticket.memcached.interval, 0.); + ev_timer_start(loop_, w); +} + +SSL_CTX *ConnectionHandler::create_tls_ticket_key_memcached_ssl_ctx() { + auto config = get_config(); + auto &tlsconf = config->tls; + auto &memcachedconf = config->tls.ticket.memcached; + + auto ssl_ctx = tls::create_ssl_client_context( +#ifdef HAVE_NEVERBLEED + nb_, +#endif // HAVE_NEVERBLEED + tlsconf.cacert, memcachedconf.cert_file, memcachedconf.private_key_file, + nullptr); + + all_ssl_ctx_.push_back(ssl_ctx); +#ifdef ENABLE_HTTP3 + quic_all_ssl_ctx_.push_back(nullptr); +#endif // ENABLE_HTTP3 + + return ssl_ctx; +} + +#ifdef HAVE_NEVERBLEED +void ConnectionHandler::set_neverbleed(neverbleed_t *nb) { nb_ = nb; } +#endif // HAVE_NEVERBLEED + +void ConnectionHandler::handle_serial_event() { + std::vector<SerialEvent> q; + { + std::lock_guard<std::mutex> g(serial_event_mu_); + q.swap(serial_events_); + } + + for (auto &sev : q) { + switch (sev.type) { + case SerialEventType::REPLACE_DOWNSTREAM: + // Mmake sure that none of worker uses + // get_config()->conn.downstream + mod_config()->conn.downstream = sev.downstreamconf; + + if (single_worker_) { + single_worker_->replace_downstream_config(sev.downstreamconf); + + break; + } + + worker_replace_downstream(sev.downstreamconf); + + break; + default: + break; + } + } +} + +void ConnectionHandler::send_replace_downstream( + const std::shared_ptr<DownstreamConfig> &downstreamconf) { + send_serial_event( + SerialEvent(SerialEventType::REPLACE_DOWNSTREAM, downstreamconf)); +} + +void ConnectionHandler::send_serial_event(SerialEvent ev) { + { + std::lock_guard<std::mutex> g(serial_event_mu_); + + serial_events_.push_back(std::move(ev)); + } + + ev_async_send(loop_, &serial_event_asyncev_); +} + +SSL_CTX *ConnectionHandler::get_ssl_ctx(size_t idx) const { + return all_ssl_ctx_[idx]; +} + +const std::vector<SSL_CTX *> & +ConnectionHandler::get_indexed_ssl_ctx(size_t idx) const { + return indexed_ssl_ctx_[idx]; +} + +#ifdef ENABLE_HTTP3 +const std::vector<SSL_CTX *> & +ConnectionHandler::get_quic_indexed_ssl_ctx(size_t idx) const { + return quic_indexed_ssl_ctx_[idx]; +} +#endif // ENABLE_HTTP3 + +void ConnectionHandler::set_enable_acceptor_on_ocsp_completion(bool f) { + enable_acceptor_on_ocsp_completion_ = f; +} + +#ifdef ENABLE_HTTP3 +int ConnectionHandler::forward_quic_packet( + const UpstreamAddr *faddr, const Address &remote_addr, + const Address &local_addr, const ngtcp2_pkt_info &pi, + const uint8_t *cid_prefix, const uint8_t *data, size_t datalen) { + assert(!get_config()->single_thread); + + for (auto &worker : workers_) { + if (!std::equal(cid_prefix, cid_prefix + SHRPX_QUIC_CID_PREFIXLEN, + worker->get_cid_prefix())) { + continue; + } + + WorkerEvent wev{}; + wev.type = WorkerEventType::QUIC_PKT_FORWARD; + wev.quic_pkt = std::make_unique<QUICPacket>(faddr->index, remote_addr, + local_addr, pi, data, datalen); + + worker->send(std::move(wev)); + + return 0; + } + + return -1; +} + +void ConnectionHandler::set_quic_keying_materials( + std::shared_ptr<QUICKeyingMaterials> qkms) { + quic_keying_materials_ = std::move(qkms); +} + +const std::shared_ptr<QUICKeyingMaterials> & +ConnectionHandler::get_quic_keying_materials() const { + return quic_keying_materials_; +} + +void ConnectionHandler::set_cid_prefixes( + const std::vector<std::array<uint8_t, SHRPX_QUIC_CID_PREFIXLEN>> + &cid_prefixes) { + cid_prefixes_ = cid_prefixes; +} + +QUICLingeringWorkerProcess * +ConnectionHandler::match_quic_lingering_worker_process_cid_prefix( + const uint8_t *dcid, size_t dcidlen) { + assert(dcidlen >= SHRPX_QUIC_CID_PREFIXLEN); + + for (auto &lwps : quic_lingering_worker_processes_) { + for (auto &cid_prefix : lwps.cid_prefixes) { + if (std::equal(std::begin(cid_prefix), std::end(cid_prefix), dcid)) { + return &lwps; + } + } + } + + return nullptr; +} + +# ifdef HAVE_LIBBPF +std::vector<BPFRef> &ConnectionHandler::get_quic_bpf_refs() { + return quic_bpf_refs_; +} + +void ConnectionHandler::unload_bpf_objects() { + LOG(NOTICE) << "Unloading BPF objects"; + + for (auto &ref : quic_bpf_refs_) { + if (ref.obj == nullptr) { + continue; + } + + bpf_object__close(ref.obj); + + ref.obj = nullptr; + } +} +# endif // HAVE_LIBBPF + +void ConnectionHandler::set_quic_ipc_fd(int fd) { quic_ipc_fd_ = fd; } + +void ConnectionHandler::set_quic_lingering_worker_processes( + const std::vector<QUICLingeringWorkerProcess> &quic_lwps) { + quic_lingering_worker_processes_ = quic_lwps; +} + +int ConnectionHandler::forward_quic_packet_to_lingering_worker_process( + QUICLingeringWorkerProcess *quic_lwp, const Address &remote_addr, + const Address &local_addr, const ngtcp2_pkt_info &pi, const uint8_t *data, + size_t datalen) { + std::array<uint8_t, 512> header; + + assert(header.size() >= 1 + 1 + 1 + 1 + sizeof(sockaddr_storage) * 2); + assert(remote_addr.len > 0); + assert(local_addr.len > 0); + + auto p = header.data(); + + *p++ = static_cast<uint8_t>(QUICIPCType::DGRAM_FORWARD); + *p++ = static_cast<uint8_t>(remote_addr.len - 1); + p = std::copy_n(reinterpret_cast<const uint8_t *>(&remote_addr.su), + remote_addr.len, p); + *p++ = static_cast<uint8_t>(local_addr.len - 1); + p = std::copy_n(reinterpret_cast<const uint8_t *>(&local_addr.su), + local_addr.len, p); + *p++ = pi.ecn; + + iovec msg_iov[] = { + { + .iov_base = header.data(), + .iov_len = static_cast<size_t>(p - header.data()), + }, + { + .iov_base = const_cast<uint8_t *>(data), + .iov_len = datalen, + }, + }; + + msghdr msg{}; + msg.msg_iov = msg_iov; + msg.msg_iovlen = array_size(msg_iov); + + ssize_t nwrite; + + while ((nwrite = sendmsg(quic_lwp->quic_ipc_fd, &msg, 0)) == -1 && + errno == EINTR) + ; + + if (nwrite == -1) { + std::array<char, STRERROR_BUFSIZE> errbuf; + + auto error = errno; + LOG(ERROR) << "Failed to send QUIC IPC message: " + << xsi_strerror(error, errbuf.data(), errbuf.size()); + + return -1; + } + + return 0; +} + +int ConnectionHandler::quic_ipc_read() { + std::array<uint8_t, 65536> buf; + + ssize_t nread; + + while ((nread = recv(quic_ipc_fd_, buf.data(), buf.size(), 0)) == -1 && + errno == EINTR) + ; + + if (nread == -1) { + std::array<char, STRERROR_BUFSIZE> errbuf; + + auto error = errno; + LOG(ERROR) << "Failed to read data from QUIC IPC channel: " + << xsi_strerror(error, errbuf.data(), errbuf.size()); + + return -1; + } + + if (nread == 0) { + return 0; + } + + size_t len = 1 + 1 + 1 + 1; + + // Wire format: + // TYPE(1) REMOTE_ADDRLEN(1) REMOTE_ADDR(N) LOCAL_ADDRLEN(1) LOCAL_ADDR(N) + // ECN(1) DGRAM_PAYLOAD(N) + // + // When encoding, REMOTE_ADDRLEN and LOCAL_ADDRLEN are decremented + // by 1. + if (static_cast<size_t>(nread) < len) { + return 0; + } + + auto p = buf.data(); + if (*p != static_cast<uint8_t>(QUICIPCType::DGRAM_FORWARD)) { + LOG(ERROR) << "Unknown QUICIPCType: " << static_cast<uint32_t>(*p); + + return -1; + } + + ++p; + + auto pkt = std::make_unique<QUICPacket>(); + + auto remote_addrlen = static_cast<size_t>(*p++) + 1; + if (remote_addrlen > sizeof(sockaddr_storage)) { + LOG(ERROR) << "The length of remote address is too large: " + << remote_addrlen; + + return -1; + } + + len += remote_addrlen; + + if (static_cast<size_t>(nread) < len) { + LOG(ERROR) << "Insufficient QUIC IPC message length"; + + return -1; + } + + pkt->remote_addr.len = remote_addrlen; + memcpy(&pkt->remote_addr.su, p, remote_addrlen); + + p += remote_addrlen; + + auto local_addrlen = static_cast<size_t>(*p++) + 1; + if (local_addrlen > sizeof(sockaddr_storage)) { + LOG(ERROR) << "The length of local address is too large: " << local_addrlen; + + return -1; + } + + len += local_addrlen; + + if (static_cast<size_t>(nread) < len) { + LOG(ERROR) << "Insufficient QUIC IPC message length"; + + return -1; + } + + pkt->local_addr.len = local_addrlen; + memcpy(&pkt->local_addr.su, p, local_addrlen); + + p += local_addrlen; + + pkt->pi.ecn = *p++; + + auto datalen = nread - (p - buf.data()); + + pkt->data.assign(p, p + datalen); + + // At the moment, UpstreamAddr index is unknown. + pkt->upstream_addr_index = static_cast<size_t>(-1); + + ngtcp2_version_cid vc; + + auto rv = ngtcp2_pkt_decode_version_cid(&vc, p, datalen, SHRPX_QUIC_SCIDLEN); + if (rv < 0) { + LOG(ERROR) << "ngtcp2_pkt_decode_version_cid: " << ngtcp2_strerror(rv); + + return -1; + } + + if (vc.dcidlen != SHRPX_QUIC_SCIDLEN) { + LOG(ERROR) << "DCID length is invalid"; + return -1; + } + + if (single_worker_) { + auto faddr = single_worker_->find_quic_upstream_addr(pkt->local_addr); + if (faddr == nullptr) { + LOG(ERROR) << "No suitable upstream address found"; + + return 0; + } + + auto quic_conn_handler = single_worker_->get_quic_connection_handler(); + + // Ignore return value + quic_conn_handler->handle_packet(faddr, pkt->remote_addr, pkt->local_addr, + pkt->pi, pkt->data.data(), + pkt->data.size()); + + return 0; + } + + auto &qkm = quic_keying_materials_->keying_materials.front(); + + std::array<uint8_t, SHRPX_QUIC_DECRYPTED_DCIDLEN> decrypted_dcid; + + if (decrypt_quic_connection_id(decrypted_dcid.data(), + vc.dcid + SHRPX_QUIC_CID_PREFIX_OFFSET, + qkm.cid_encryption_key.data()) != 0) { + return -1; + } + + for (auto &worker : workers_) { + if (!std::equal(std::begin(decrypted_dcid), + std::begin(decrypted_dcid) + SHRPX_QUIC_CID_PREFIXLEN, + worker->get_cid_prefix())) { + continue; + } + + WorkerEvent wev{ + .type = WorkerEventType::QUIC_PKT_FORWARD, + .quic_pkt = std::move(pkt), + }; + worker->send(std::move(wev)); + + return 0; + } + + if (LOG_ENABLED(INFO)) { + LOG(INFO) << "No worker to match CID prefix"; + } + + return 0; +} +#endif // ENABLE_HTTP3 + +} // namespace shrpx |