summaryrefslogtreecommitdiffstats
path: root/third_party/rust/warp/src/server.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/warp/src/server.rs')
-rw-r--r--third_party/rust/warp/src/server.rs576
1 files changed, 576 insertions, 0 deletions
diff --git a/third_party/rust/warp/src/server.rs b/third_party/rust/warp/src/server.rs
new file mode 100644
index 0000000000..f1eb33b952
--- /dev/null
+++ b/third_party/rust/warp/src/server.rs
@@ -0,0 +1,576 @@
+#[cfg(feature = "tls")]
+use crate::tls::TlsConfigBuilder;
+use std::convert::Infallible;
+use std::error::Error as StdError;
+use std::future::Future;
+use std::net::SocketAddr;
+#[cfg(feature = "tls")]
+use std::path::Path;
+
+use futures_util::{future, FutureExt, TryFuture, TryStream, TryStreamExt};
+use hyper::server::conn::AddrIncoming;
+use hyper::service::{make_service_fn, service_fn};
+use hyper::Server as HyperServer;
+use tokio::io::{AsyncRead, AsyncWrite};
+use tracing::Instrument;
+
+use crate::filter::Filter;
+use crate::reject::IsReject;
+use crate::reply::Reply;
+use crate::transport::Transport;
+
+/// Create a `Server` with the provided `Filter`.
+pub fn serve<F>(filter: F) -> Server<F>
+where
+ F: Filter + Clone + Send + Sync + 'static,
+ F::Extract: Reply,
+ F::Error: IsReject,
+{
+ Server {
+ pipeline: false,
+ filter,
+ }
+}
+
+/// A Warp Server ready to filter requests.
+#[derive(Debug)]
+pub struct Server<F> {
+ pipeline: bool,
+ filter: F,
+}
+
+/// A Warp Server ready to filter requests over TLS.
+///
+/// *This type requires the `"tls"` feature.*
+#[cfg(feature = "tls")]
+pub struct TlsServer<F> {
+ server: Server<F>,
+ tls: TlsConfigBuilder,
+}
+
+// Getting all various generic bounds to make this a re-usable method is
+// very complicated, so instead this is just a macro.
+macro_rules! into_service {
+ ($into:expr) => {{
+ let inner = crate::service($into);
+ make_service_fn(move |transport| {
+ let inner = inner.clone();
+ let remote_addr = Transport::remote_addr(transport);
+ future::ok::<_, Infallible>(service_fn(move |req| {
+ inner.call_with_addr(req, remote_addr)
+ }))
+ })
+ }};
+}
+
+macro_rules! addr_incoming {
+ ($addr:expr) => {{
+ let mut incoming = AddrIncoming::bind($addr)?;
+ incoming.set_nodelay(true);
+ let addr = incoming.local_addr();
+ (addr, incoming)
+ }};
+}
+
+macro_rules! bind_inner {
+ ($this:ident, $addr:expr) => {{
+ let service = into_service!($this.filter);
+ let (addr, incoming) = addr_incoming!($addr);
+ let srv = HyperServer::builder(incoming)
+ .http1_pipeline_flush($this.pipeline)
+ .serve(service);
+ Ok::<_, hyper::Error>((addr, srv))
+ }};
+
+ (tls: $this:ident, $addr:expr) => {{
+ let service = into_service!($this.server.filter);
+ let (addr, incoming) = addr_incoming!($addr);
+ let tls = $this.tls.build()?;
+ let srv = HyperServer::builder(crate::tls::TlsAcceptor::new(tls, incoming))
+ .http1_pipeline_flush($this.server.pipeline)
+ .serve(service);
+ Ok::<_, Box<dyn std::error::Error + Send + Sync>>((addr, srv))
+ }};
+}
+
+macro_rules! bind {
+ ($this:ident, $addr:expr) => {{
+ let addr = $addr.into();
+ (|addr| bind_inner!($this, addr))(&addr).unwrap_or_else(|e| {
+ panic!("error binding to {}: {}", addr, e);
+ })
+ }};
+
+ (tls: $this:ident, $addr:expr) => {{
+ let addr = $addr.into();
+ (|addr| bind_inner!(tls: $this, addr))(&addr).unwrap_or_else(|e| {
+ panic!("error binding to {}: {}", addr, e);
+ })
+ }};
+}
+
+macro_rules! try_bind {
+ ($this:ident, $addr:expr) => {{
+ (|addr| bind_inner!($this, addr))($addr)
+ }};
+
+ (tls: $this:ident, $addr:expr) => {{
+ (|addr| bind_inner!(tls: $this, addr))($addr)
+ }};
+}
+
+// ===== impl Server =====
+
+impl<F> Server<F>
+where
+ F: Filter + Clone + Send + Sync + 'static,
+ <F::Future as TryFuture>::Ok: Reply,
+ <F::Future as TryFuture>::Error: IsReject,
+{
+ /// Run this `Server` forever on the current thread.
+ pub async fn run(self, addr: impl Into<SocketAddr>) {
+ let (addr, fut) = self.bind_ephemeral(addr);
+ let span = tracing::info_span!("Server::run", ?addr);
+ tracing::info!(parent: &span, "listening on http://{}", addr);
+
+ fut.instrument(span).await;
+ }
+
+ /// Run this `Server` forever on the current thread with a specific stream
+ /// of incoming connections.
+ ///
+ /// This can be used for Unix Domain Sockets, or TLS, etc.
+ pub async fn run_incoming<I>(self, incoming: I)
+ where
+ I: TryStream + Send,
+ I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin,
+ I::Error: Into<Box<dyn StdError + Send + Sync>>,
+ {
+ self.run_incoming2(incoming.map_ok(crate::transport::LiftIo).into_stream())
+ .instrument(tracing::info_span!("Server::run_incoming"))
+ .await;
+ }
+
+ async fn run_incoming2<I>(self, incoming: I)
+ where
+ I: TryStream + Send,
+ I::Ok: Transport + Send + 'static + Unpin,
+ I::Error: Into<Box<dyn StdError + Send + Sync>>,
+ {
+ let fut = self.serve_incoming2(incoming);
+
+ tracing::info!("listening with custom incoming");
+
+ fut.await;
+ }
+
+ /// Bind to a socket address, returning a `Future` that can be
+ /// executed on the current runtime.
+ ///
+ /// # Panics
+ ///
+ /// Panics if we are unable to bind to the provided address.
+ pub fn bind(self, addr: impl Into<SocketAddr> + 'static) -> impl Future<Output = ()> + 'static {
+ let (_, fut) = self.bind_ephemeral(addr);
+ fut
+ }
+
+ /// Bind to a socket address, returning a `Future` that can be
+ /// executed on any runtime.
+ ///
+ /// In case we are unable to bind to the specified address, resolves to an
+ /// error and logs the reason.
+ pub async fn try_bind(self, addr: impl Into<SocketAddr>) {
+ let addr = addr.into();
+ let srv = match try_bind!(self, &addr) {
+ Ok((_, srv)) => srv,
+ Err(err) => {
+ tracing::error!("error binding to {}: {}", addr, err);
+ return;
+ }
+ };
+
+ srv.map(|result| {
+ if let Err(err) = result {
+ tracing::error!("server error: {}", err)
+ }
+ })
+ .await;
+ }
+
+ /// Bind to a possibly ephemeral socket address.
+ ///
+ /// Returns the bound address and a `Future` that can be executed on
+ /// the current runtime.
+ ///
+ /// # Panics
+ ///
+ /// Panics if we are unable to bind to the provided address.
+ pub fn bind_ephemeral(
+ self,
+ addr: impl Into<SocketAddr>,
+ ) -> (SocketAddr, impl Future<Output = ()> + 'static) {
+ let (addr, srv) = bind!(self, addr);
+ let srv = srv.map(|result| {
+ if let Err(err) = result {
+ tracing::error!("server error: {}", err)
+ }
+ });
+
+ (addr, srv)
+ }
+
+ /// Tried to bind a possibly ephemeral socket address.
+ ///
+ /// Returns a `Result` which fails in case we are unable to bind with the
+ /// underlying error.
+ ///
+ /// Returns the bound address and a `Future` that can be executed on
+ /// the current runtime.
+ pub fn try_bind_ephemeral(
+ self,
+ addr: impl Into<SocketAddr>,
+ ) -> Result<(SocketAddr, impl Future<Output = ()> + 'static), crate::Error> {
+ let addr = addr.into();
+ let (addr, srv) = try_bind!(self, &addr).map_err(crate::Error::new)?;
+ let srv = srv.map(|result| {
+ if let Err(err) = result {
+ tracing::error!("server error: {}", err)
+ }
+ });
+
+ Ok((addr, srv))
+ }
+
+ /// Create a server with graceful shutdown signal.
+ ///
+ /// When the signal completes, the server will start the graceful shutdown
+ /// process.
+ ///
+ /// Returns the bound address and a `Future` that can be executed on
+ /// the current runtime.
+ ///
+ /// # Example
+ ///
+ /// ```no_run
+ /// use warp::Filter;
+ /// use futures_util::future::TryFutureExt;
+ /// use tokio::sync::oneshot;
+ ///
+ /// # fn main() {
+ /// let routes = warp::any()
+ /// .map(|| "Hello, World!");
+ ///
+ /// let (tx, rx) = oneshot::channel();
+ ///
+ /// let (addr, server) = warp::serve(routes)
+ /// .bind_with_graceful_shutdown(([127, 0, 0, 1], 3030), async {
+ /// rx.await.ok();
+ /// });
+ ///
+ /// // Spawn the server into a runtime
+ /// tokio::task::spawn(server);
+ ///
+ /// // Later, start the shutdown...
+ /// let _ = tx.send(());
+ /// # }
+ /// ```
+ pub fn bind_with_graceful_shutdown(
+ self,
+ addr: impl Into<SocketAddr> + 'static,
+ signal: impl Future<Output = ()> + Send + 'static,
+ ) -> (SocketAddr, impl Future<Output = ()> + 'static) {
+ let (addr, srv) = bind!(self, addr);
+ let fut = srv.with_graceful_shutdown(signal).map(|result| {
+ if let Err(err) = result {
+ tracing::error!("server error: {}", err)
+ }
+ });
+ (addr, fut)
+ }
+
+ /// Create a server with graceful shutdown signal.
+ ///
+ /// When the signal completes, the server will start the graceful shutdown
+ /// process.
+ pub fn try_bind_with_graceful_shutdown(
+ self,
+ addr: impl Into<SocketAddr> + 'static,
+ signal: impl Future<Output = ()> + Send + 'static,
+ ) -> Result<(SocketAddr, impl Future<Output = ()> + 'static), crate::Error> {
+ let addr = addr.into();
+ let (addr, srv) = try_bind!(self, &addr).map_err(crate::Error::new)?;
+ let srv = srv.with_graceful_shutdown(signal).map(|result| {
+ if let Err(err) = result {
+ tracing::error!("server error: {}", err)
+ }
+ });
+
+ Ok((addr, srv))
+ }
+
+ /// Setup this `Server` with a specific stream of incoming connections.
+ ///
+ /// This can be used for Unix Domain Sockets, or TLS, etc.
+ ///
+ /// Returns a `Future` that can be executed on the current runtime.
+ pub fn serve_incoming<I>(self, incoming: I) -> impl Future<Output = ()>
+ where
+ I: TryStream + Send,
+ I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin,
+ I::Error: Into<Box<dyn StdError + Send + Sync>>,
+ {
+ let incoming = incoming.map_ok(crate::transport::LiftIo);
+ self.serve_incoming2(incoming)
+ .instrument(tracing::info_span!("Server::serve_incoming"))
+ }
+
+ /// Setup this `Server` with a specific stream of incoming connections and a
+ /// signal to initiate graceful shutdown.
+ ///
+ /// This can be used for Unix Domain Sockets, or TLS, etc.
+ ///
+ /// When the signal completes, the server will start the graceful shutdown
+ /// process.
+ ///
+ /// Returns a `Future` that can be executed on the current runtime.
+ pub fn serve_incoming_with_graceful_shutdown<I>(
+ self,
+ incoming: I,
+ signal: impl Future<Output = ()> + Send + 'static,
+ ) -> impl Future<Output = ()>
+ where
+ I: TryStream + Send,
+ I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin,
+ I::Error: Into<Box<dyn StdError + Send + Sync>>,
+ {
+ let incoming = incoming.map_ok(crate::transport::LiftIo);
+ let service = into_service!(self.filter);
+ let pipeline = self.pipeline;
+
+ async move {
+ let srv =
+ HyperServer::builder(hyper::server::accept::from_stream(incoming.into_stream()))
+ .http1_pipeline_flush(pipeline)
+ .serve(service)
+ .with_graceful_shutdown(signal)
+ .await;
+
+ if let Err(err) = srv {
+ tracing::error!("server error: {}", err);
+ }
+ }
+ .instrument(tracing::info_span!(
+ "Server::serve_incoming_with_graceful_shutdown"
+ ))
+ }
+
+ async fn serve_incoming2<I>(self, incoming: I)
+ where
+ I: TryStream + Send,
+ I::Ok: Transport + Send + 'static + Unpin,
+ I::Error: Into<Box<dyn StdError + Send + Sync>>,
+ {
+ let service = into_service!(self.filter);
+
+ let srv = HyperServer::builder(hyper::server::accept::from_stream(incoming.into_stream()))
+ .http1_pipeline_flush(self.pipeline)
+ .serve(service)
+ .await;
+
+ if let Err(err) = srv {
+ tracing::error!("server error: {}", err);
+ }
+ }
+
+ // Generally shouldn't be used, as it can slow down non-pipelined responses.
+ //
+ // It's only real use is to make silly pipeline benchmarks look better.
+ #[doc(hidden)]
+ pub fn unstable_pipeline(mut self) -> Self {
+ self.pipeline = true;
+ self
+ }
+
+ /// Configure a server to use TLS.
+ ///
+ /// *This function requires the `"tls"` feature.*
+ #[cfg(feature = "tls")]
+ pub fn tls(self) -> TlsServer<F> {
+ TlsServer {
+ server: self,
+ tls: TlsConfigBuilder::new(),
+ }
+ }
+}
+
+// // ===== impl TlsServer =====
+
+#[cfg(feature = "tls")]
+impl<F> TlsServer<F>
+where
+ F: Filter + Clone + Send + Sync + 'static,
+ <F::Future as TryFuture>::Ok: Reply,
+ <F::Future as TryFuture>::Error: IsReject,
+{
+ // TLS config methods
+
+ /// Specify the file path to read the private key.
+ ///
+ /// *This function requires the `"tls"` feature.*
+ pub fn key_path(self, path: impl AsRef<Path>) -> Self {
+ self.with_tls(|tls| tls.key_path(path))
+ }
+
+ /// Specify the file path to read the certificate.
+ ///
+ /// *This function requires the `"tls"` feature.*
+ pub fn cert_path(self, path: impl AsRef<Path>) -> Self {
+ self.with_tls(|tls| tls.cert_path(path))
+ }
+
+ /// Specify the file path to read the trust anchor for optional client authentication.
+ ///
+ /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any
+ /// of the `client_auth_` methods, then client authentication is disabled by default.
+ ///
+ /// *This function requires the `"tls"` feature.*
+ pub fn client_auth_optional_path(self, path: impl AsRef<Path>) -> Self {
+ self.with_tls(|tls| tls.client_auth_optional_path(path))
+ }
+
+ /// Specify the file path to read the trust anchor for required client authentication.
+ ///
+ /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the
+ /// `client_auth_` methods, then client authentication is disabled by default.
+ ///
+ /// *This function requires the `"tls"` feature.*
+ pub fn client_auth_required_path(self, path: impl AsRef<Path>) -> Self {
+ self.with_tls(|tls| tls.client_auth_required_path(path))
+ }
+
+ /// Specify the in-memory contents of the private key.
+ ///
+ /// *This function requires the `"tls"` feature.*
+ pub fn key(self, key: impl AsRef<[u8]>) -> Self {
+ self.with_tls(|tls| tls.key(key.as_ref()))
+ }
+
+ /// Specify the in-memory contents of the certificate.
+ ///
+ /// *This function requires the `"tls"` feature.*
+ pub fn cert(self, cert: impl AsRef<[u8]>) -> Self {
+ self.with_tls(|tls| tls.cert(cert.as_ref()))
+ }
+
+ /// Specify the in-memory contents of the trust anchor for optional client authentication.
+ ///
+ /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any
+ /// of the `client_auth_` methods, then client authentication is disabled by default.
+ ///
+ /// *This function requires the `"tls"` feature.*
+ pub fn client_auth_optional(self, trust_anchor: impl AsRef<[u8]>) -> Self {
+ self.with_tls(|tls| tls.client_auth_optional(trust_anchor.as_ref()))
+ }
+
+ /// Specify the in-memory contents of the trust anchor for required client authentication.
+ ///
+ /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the
+ /// `client_auth_` methods, then client authentication is disabled by default.
+ ///
+ /// *This function requires the `"tls"` feature.*
+ pub fn client_auth_required(self, trust_anchor: impl AsRef<[u8]>) -> Self {
+ self.with_tls(|tls| tls.client_auth_required(trust_anchor.as_ref()))
+ }
+
+ /// Specify the DER-encoded OCSP response.
+ ///
+ /// *This function requires the `"tls"` feature.*
+ pub fn ocsp_resp(self, resp: impl AsRef<[u8]>) -> Self {
+ self.with_tls(|tls| tls.ocsp_resp(resp.as_ref()))
+ }
+
+ fn with_tls<Func>(self, func: Func) -> Self
+ where
+ Func: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder,
+ {
+ let TlsServer { server, tls } = self;
+ let tls = func(tls);
+ TlsServer { server, tls }
+ }
+
+ // Server run methods
+
+ /// Run this `TlsServer` forever on the current thread.
+ ///
+ /// *This function requires the `"tls"` feature.*
+ pub async fn run(self, addr: impl Into<SocketAddr>) {
+ let (addr, fut) = self.bind_ephemeral(addr);
+ let span = tracing::info_span!("TlsServer::run", %addr);
+ tracing::info!(parent: &span, "listening on https://{}", addr);
+
+ fut.instrument(span).await;
+ }
+
+ /// Bind to a socket address, returning a `Future` that can be
+ /// executed on a runtime.
+ ///
+ /// *This function requires the `"tls"` feature.*
+ pub async fn bind(self, addr: impl Into<SocketAddr>) {
+ let (_, fut) = self.bind_ephemeral(addr);
+ fut.await;
+ }
+
+ /// Bind to a possibly ephemeral socket address.
+ ///
+ /// Returns the bound address and a `Future` that can be executed on
+ /// the current runtime.
+ ///
+ /// *This function requires the `"tls"` feature.*
+ pub fn bind_ephemeral(
+ self,
+ addr: impl Into<SocketAddr>,
+ ) -> (SocketAddr, impl Future<Output = ()> + 'static) {
+ let (addr, srv) = bind!(tls: self, addr);
+ let srv = srv.map(|result| {
+ if let Err(err) = result {
+ tracing::error!("server error: {}", err)
+ }
+ });
+
+ (addr, srv)
+ }
+
+ /// Create a server with graceful shutdown signal.
+ ///
+ /// When the signal completes, the server will start the graceful shutdown
+ /// process.
+ ///
+ /// *This function requires the `"tls"` feature.*
+ pub fn bind_with_graceful_shutdown(
+ self,
+ addr: impl Into<SocketAddr> + 'static,
+ signal: impl Future<Output = ()> + Send + 'static,
+ ) -> (SocketAddr, impl Future<Output = ()> + 'static) {
+ let (addr, srv) = bind!(tls: self, addr);
+
+ let fut = srv.with_graceful_shutdown(signal).map(|result| {
+ if let Err(err) = result {
+ tracing::error!("server error: {}", err)
+ }
+ });
+ (addr, fut)
+ }
+}
+
+#[cfg(feature = "tls")]
+impl<F> ::std::fmt::Debug for TlsServer<F>
+where
+ F: ::std::fmt::Debug,
+{
+ fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
+ f.debug_struct("TlsServer")
+ .field("server", &self.server)
+ .finish()
+ }
+}