use std::fmt; use std::io; use std::net::{SocketAddr, TcpListener as StdTcpListener}; use std::time::Duration; use socket2::TcpKeepalive; use tokio::net::TcpListener; use tokio::time::Sleep; use tracing::{debug, error, trace}; use crate::common::{task, Future, Pin, Poll}; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::addr_stream::AddrStream; use super::accept::Accept; #[derive(Default, Debug, Clone, Copy)] struct TcpKeepaliveConfig { time: Option, interval: Option, retries: Option, } impl TcpKeepaliveConfig { /// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration. fn into_socket2(self) -> Option { let mut dirty = false; let mut ka = TcpKeepalive::new(); if let Some(time) = self.time { ka = ka.with_time(time); dirty = true } if let Some(interval) = self.interval { ka = Self::ka_with_interval(ka, interval, &mut dirty) }; if let Some(retries) = self.retries { ka = Self::ka_with_retries(ka, retries, &mut dirty) }; if dirty { Some(ka) } else { None } } #[cfg(any( target_os = "android", target_os = "dragonfly", target_os = "freebsd", target_os = "fuchsia", target_os = "illumos", target_os = "linux", target_os = "netbsd", target_vendor = "apple", windows, ))] fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive { *dirty = true; ka.with_interval(interval) } #[cfg(not(any( target_os = "android", target_os = "dragonfly", target_os = "freebsd", target_os = "fuchsia", target_os = "illumos", target_os = "linux", target_os = "netbsd", target_vendor = "apple", windows, )))] fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive { ka // no-op as keepalive interval is not supported on this platform } #[cfg(any( target_os = "android", target_os = "dragonfly", target_os = "freebsd", target_os = "fuchsia", target_os = "illumos", target_os = "linux", target_os = "netbsd", target_vendor = "apple", ))] fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive { *dirty = true; ka.with_retries(retries) } #[cfg(not(any( target_os = "android", target_os = "dragonfly", target_os = "freebsd", target_os = "fuchsia", target_os = "illumos", target_os = "linux", target_os = "netbsd", target_vendor = "apple", )))] fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive { ka // no-op as keepalive retries is not supported on this platform } } /// A stream of connections from binding to an address. #[must_use = "streams do nothing unless polled"] pub struct AddrIncoming { addr: SocketAddr, listener: TcpListener, sleep_on_errors: bool, tcp_keepalive_config: TcpKeepaliveConfig, tcp_nodelay: bool, timeout: Option>>, } impl AddrIncoming { pub(super) fn new(addr: &SocketAddr) -> crate::Result { let std_listener = StdTcpListener::bind(addr).map_err(crate::Error::new_listen)?; AddrIncoming::from_std(std_listener) } pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result { // TcpListener::from_std doesn't set O_NONBLOCK std_listener .set_nonblocking(true) .map_err(crate::Error::new_listen)?; let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?; AddrIncoming::from_listener(listener) } /// Creates a new `AddrIncoming` binding to provided socket address. pub fn bind(addr: &SocketAddr) -> crate::Result { AddrIncoming::new(addr) } /// Creates a new `AddrIncoming` from an existing `tokio::net::TcpListener`. pub fn from_listener(listener: TcpListener) -> crate::Result { let addr = listener.local_addr().map_err(crate::Error::new_listen)?; Ok(AddrIncoming { listener, addr, sleep_on_errors: true, tcp_keepalive_config: TcpKeepaliveConfig::default(), tcp_nodelay: false, timeout: None, }) } /// Get the local address bound to this listener. pub fn local_addr(&self) -> SocketAddr { self.addr } /// Set the duration to remain idle before sending TCP keepalive probes. /// /// If `None` is specified, keepalive is disabled. pub fn set_keepalive(&mut self, time: Option) -> &mut Self { self.tcp_keepalive_config.time = time; self } /// Set the duration between two successive TCP keepalive retransmissions, /// if acknowledgement to the previous keepalive transmission is not received. pub fn set_keepalive_interval(&mut self, interval: Option) -> &mut Self { self.tcp_keepalive_config.interval = interval; self } /// Set the number of retransmissions to be carried out before declaring that remote end is not available. pub fn set_keepalive_retries(&mut self, retries: Option) -> &mut Self { self.tcp_keepalive_config.retries = retries; self } /// Set the value of `TCP_NODELAY` option for accepted connections. pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self { self.tcp_nodelay = enabled; self } /// Set whether to sleep on accept errors. /// /// A possible scenario is that the process has hit the max open files /// allowed, and so trying to accept a new connection will fail with /// `EMFILE`. In some cases, it's preferable to just wait for some time, if /// the application will likely close some files (or connections), and try /// to accept the connection again. If this option is `true`, the error /// will be logged at the `error` level, since it is still a big deal, /// and then the listener will sleep for 1 second. /// /// In other cases, hitting the max open files should be treat similarly /// to being out-of-memory, and simply error (and shutdown). Setting /// this option to `false` will allow that. /// /// Default is `true`. pub fn set_sleep_on_errors(&mut self, val: bool) { self.sleep_on_errors = val; } fn poll_next_(&mut self, cx: &mut task::Context<'_>) -> Poll> { // Check if a previous timeout is active that was set by IO errors. if let Some(ref mut to) = self.timeout { ready!(Pin::new(to).poll(cx)); } self.timeout = None; loop { match ready!(self.listener.poll_accept(cx)) { Ok((socket, remote_addr)) => { if let Some(tcp_keepalive) = &self.tcp_keepalive_config.into_socket2() { let sock_ref = socket2::SockRef::from(&socket); if let Err(e) = sock_ref.set_tcp_keepalive(tcp_keepalive) { trace!("error trying to set TCP keepalive: {}", e); } } if let Err(e) = socket.set_nodelay(self.tcp_nodelay) { trace!("error trying to set TCP nodelay: {}", e); } let local_addr = socket.local_addr()?; return Poll::Ready(Ok(AddrStream::new(socket, remote_addr, local_addr))); } Err(e) => { // Connection errors can be ignored directly, continue by // accepting the next request. if is_connection_error(&e) { debug!("accepted connection already errored: {}", e); continue; } if self.sleep_on_errors { error!("accept error: {}", e); // Sleep 1s. let mut timeout = Box::pin(tokio::time::sleep(Duration::from_secs(1))); match timeout.as_mut().poll(cx) { Poll::Ready(()) => { // Wow, it's been a second already? Ok then... continue; } Poll::Pending => { self.timeout = Some(timeout); return Poll::Pending; } } } else { return Poll::Ready(Err(e)); } } } } } } impl Accept for AddrIncoming { type Conn = AddrStream; type Error = io::Error; fn poll_accept( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll>> { let result = ready!(self.poll_next_(cx)); Poll::Ready(Some(result)) } } /// This function defines errors that are per-connection. Which basically /// means that if we get this error from `accept()` system call it means /// next connection might be ready to be accepted. /// /// All other errors will incur a timeout before next `accept()` is performed. /// The timeout is useful to handle resource exhaustion errors like ENFILE /// and EMFILE. Otherwise, could enter into tight loop. fn is_connection_error(e: &io::Error) -> bool { matches!( e.kind(), io::ErrorKind::ConnectionRefused | io::ErrorKind::ConnectionAborted | io::ErrorKind::ConnectionReset ) } impl fmt::Debug for AddrIncoming { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("AddrIncoming") .field("addr", &self.addr) .field("sleep_on_errors", &self.sleep_on_errors) .field("tcp_keepalive_config", &self.tcp_keepalive_config) .field("tcp_nodelay", &self.tcp_nodelay) .finish() } } mod addr_stream { use std::io; use std::net::SocketAddr; #[cfg(unix)] use std::os::unix::io::{AsRawFd, RawFd}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::TcpStream; use crate::common::{task, Pin, Poll}; pin_project_lite::pin_project! { /// A transport returned yieled by `AddrIncoming`. #[derive(Debug)] pub struct AddrStream { #[pin] inner: TcpStream, pub(super) remote_addr: SocketAddr, pub(super) local_addr: SocketAddr } } impl AddrStream { pub(super) fn new( tcp: TcpStream, remote_addr: SocketAddr, local_addr: SocketAddr, ) -> AddrStream { AddrStream { inner: tcp, remote_addr, local_addr, } } /// Returns the remote (peer) address of this connection. #[inline] pub fn remote_addr(&self) -> SocketAddr { self.remote_addr } /// Returns the local address of this connection. #[inline] pub fn local_addr(&self) -> SocketAddr { self.local_addr } /// Consumes the AddrStream and returns the underlying IO object #[inline] pub fn into_inner(self) -> TcpStream { self.inner } /// Attempt to receive data on the socket, without removing that data /// from the queue, registering the current task for wakeup if data is /// not yet available. pub fn poll_peek( &mut self, cx: &mut task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { self.inner.poll_peek(cx, buf) } } impl AsyncRead for AddrStream { #[inline] fn poll_read( self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { self.project().inner.poll_read(cx, buf) } } impl AsyncWrite for AddrStream { #[inline] fn poll_write( self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll> { self.project().inner.poll_write(cx, buf) } #[inline] fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut task::Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { self.project().inner.poll_write_vectored(cx, bufs) } #[inline] fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll> { // TCP flush is a noop Poll::Ready(Ok(())) } #[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { self.project().inner.poll_shutdown(cx) } #[inline] fn is_write_vectored(&self) -> bool { // Note that since `self.inner` is a `TcpStream`, this could // *probably* be hard-coded to return `true`...but it seems more // correct to ask it anyway (maybe we're on some platform without // scatter-gather IO?) self.inner.is_write_vectored() } } #[cfg(unix)] impl AsRawFd for AddrStream { fn as_raw_fd(&self) -> RawFd { self.inner.as_raw_fd() } } } #[cfg(test)] mod tests { use std::time::Duration; use crate::server::tcp::TcpKeepaliveConfig; #[test] fn no_tcp_keepalive_config() { assert!(TcpKeepaliveConfig::default().into_socket2().is_none()); } #[test] fn tcp_keepalive_time_config() { let mut kac = TcpKeepaliveConfig::default(); kac.time = Some(Duration::from_secs(60)); if let Some(tcp_keepalive) = kac.into_socket2() { assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)")); } else { panic!("test failed"); } } #[cfg(any( target_os = "android", target_os = "dragonfly", target_os = "freebsd", target_os = "fuchsia", target_os = "illumos", target_os = "linux", target_os = "netbsd", target_vendor = "apple", windows, ))] #[test] fn tcp_keepalive_interval_config() { let mut kac = TcpKeepaliveConfig::default(); kac.interval = Some(Duration::from_secs(1)); if let Some(tcp_keepalive) = kac.into_socket2() { assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)")); } else { panic!("test failed"); } } #[cfg(any( target_os = "android", target_os = "dragonfly", target_os = "freebsd", target_os = "fuchsia", target_os = "illumos", target_os = "linux", target_os = "netbsd", target_vendor = "apple", ))] #[test] fn tcp_keepalive_retries_config() { let mut kac = TcpKeepaliveConfig::default(); kac.retries = Some(3); if let Some(tcp_keepalive) = kac.into_socket2() { assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)")); } else { panic!("test failed"); } } }