diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/warp/src | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/warp/src')
46 files changed, 10857 insertions, 0 deletions
diff --git a/third_party/rust/warp/src/error.rs b/third_party/rust/warp/src/error.rs new file mode 100644 index 0000000000..64220b633e --- /dev/null +++ b/third_party/rust/warp/src/error.rs @@ -0,0 +1,79 @@ +use std::convert::Infallible; +use std::error::Error as StdError; +use std::fmt; + +type BoxError = Box<dyn std::error::Error + Send + Sync>; + +/// Errors that can happen inside warp. +pub struct Error { + inner: BoxError, +} + +impl Error { + pub(crate) fn new<E: Into<BoxError>>(err: E) -> Error { + Error { inner: err.into() } + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Skip showing worthless `Error { .. }` wrapper. + fmt::Debug::fmt(&self.inner, f) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.inner, f) + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + Some(self.inner.as_ref()) + } +} + +impl From<Infallible> for Error { + fn from(infallible: Infallible) -> Error { + match infallible {} + } +} + +#[test] +fn error_size_of() { + assert_eq!( + ::std::mem::size_of::<Error>(), + ::std::mem::size_of::<usize>() * 2 + ); +} + +#[test] +fn error_source() { + let e = Error::new(std::fmt::Error {}); + assert!(e.source().unwrap().is::<std::fmt::Error>()); +} + +macro_rules! unit_error { + ( + $(#[$docs:meta])* + $pub:vis $typ:ident: $display:literal + ) => ( + $(#[$docs])* + $pub struct $typ { _p: (), } + + impl ::std::fmt::Debug for $typ { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + f.debug_struct(stringify!($typ)).finish() + } + } + + impl ::std::fmt::Display for $typ { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + f.write_str($display) + } + } + + impl ::std::error::Error for $typ {} + ) +} diff --git a/third_party/rust/warp/src/filter/and.rs b/third_party/rust/warp/src/filter/and.rs new file mode 100644 index 0000000000..5edd90fed8 --- /dev/null +++ b/third_party/rust/warp/src/filter/and.rs @@ -0,0 +1,97 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_util::ready; +use pin_project::pin_project; + +use super::{Combine, Filter, FilterBase, Internal, Tuple}; +use crate::generic::CombinedTuples; +use crate::reject::CombineRejection; + +#[derive(Clone, Copy, Debug)] +pub struct And<T, U> { + pub(super) first: T, + pub(super) second: U, +} + +impl<T, U> FilterBase for And<T, U> +where + T: Filter, + T::Extract: Send, + U: Filter + Clone + Send, + <T::Extract as Tuple>::HList: Combine<<U::Extract as Tuple>::HList> + Send, + CombinedTuples<T::Extract, U::Extract>: Send, + U::Error: CombineRejection<T::Error>, +{ + type Extract = CombinedTuples<T::Extract, U::Extract>; + type Error = <U::Error as CombineRejection<T::Error>>::One; + type Future = AndFuture<T, U>; + + fn filter(&self, _: Internal) -> Self::Future { + AndFuture { + state: State::First(self.first.filter(Internal), self.second.clone()), + } + } +} + +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct AndFuture<T: Filter, U: Filter> { + #[pin] + state: State<T::Future, T::Extract, U>, +} + +#[pin_project(project = StateProj)] +enum State<T, TE, U: Filter> { + First(#[pin] T, U), + Second(Option<TE>, #[pin] U::Future), + Done, +} + +impl<T, U> Future for AndFuture<T, U> +where + T: Filter, + U: Filter, + <T::Extract as Tuple>::HList: Combine<<U::Extract as Tuple>::HList> + Send, + U::Error: CombineRejection<T::Error>, +{ + type Output = Result< + CombinedTuples<T::Extract, U::Extract>, + <U::Error as CombineRejection<T::Error>>::One, + >; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.project().state.poll(cx) + } +} + +impl<T, TE, U, E> Future for State<T, TE, U> +where + T: Future<Output = Result<TE, E>>, + U: Filter, + TE: Tuple, + TE::HList: Combine<<U::Extract as Tuple>::HList> + Send, + U::Error: CombineRejection<E>, +{ + type Output = Result<CombinedTuples<TE, U::Extract>, <U::Error as CombineRejection<E>>::One>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + match self.as_mut().project() { + StateProj::First(first, second) => { + let ex1 = ready!(first.poll(cx))?; + let fut2 = second.filter(Internal); + self.set(State::Second(Some(ex1), fut2)); + } + StateProj::Second(ex1, second) => { + let ex2 = ready!(second.poll(cx))?; + let ex3 = ex1.take().unwrap().combine(ex2); + self.set(State::Done); + return Poll::Ready(Ok(ex3)); + } + StateProj::Done => panic!("polled after complete"), + } + } + } +} diff --git a/third_party/rust/warp/src/filter/and_then.rs b/third_party/rust/warp/src/filter/and_then.rs new file mode 100644 index 0000000000..efed5fe8cf --- /dev/null +++ b/third_party/rust/warp/src/filter/and_then.rs @@ -0,0 +1,110 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_util::{ready, TryFuture}; +use pin_project::pin_project; + +use super::{Filter, FilterBase, Func, Internal}; +use crate::reject::CombineRejection; + +#[derive(Clone, Copy, Debug)] +pub struct AndThen<T, F> { + pub(super) filter: T, + pub(super) callback: F, +} + +impl<T, F> FilterBase for AndThen<T, F> +where + T: Filter, + F: Func<T::Extract> + Clone + Send, + F::Output: TryFuture + Send, + <F::Output as TryFuture>::Error: CombineRejection<T::Error>, +{ + type Extract = (<F::Output as TryFuture>::Ok,); + type Error = <<F::Output as TryFuture>::Error as CombineRejection<T::Error>>::One; + type Future = AndThenFuture<T, F>; + #[inline] + fn filter(&self, _: Internal) -> Self::Future { + AndThenFuture { + state: State::First(self.filter.filter(Internal), self.callback.clone()), + } + } +} + +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct AndThenFuture<T, F> +where + T: Filter, + F: Func<T::Extract>, + F::Output: TryFuture + Send, + <F::Output as TryFuture>::Error: CombineRejection<T::Error>, +{ + #[pin] + state: State<T::Future, F>, +} + +#[pin_project(project = StateProj)] +enum State<T, F> +where + T: TryFuture, + F: Func<T::Ok>, + F::Output: TryFuture + Send, + <F::Output as TryFuture>::Error: CombineRejection<T::Error>, +{ + First(#[pin] T, F), + Second(#[pin] F::Output), + Done, +} + +impl<T, F> Future for AndThenFuture<T, F> +where + T: Filter, + F: Func<T::Extract>, + F::Output: TryFuture + Send, + <F::Output as TryFuture>::Error: CombineRejection<T::Error>, +{ + type Output = Result< + (<F::Output as TryFuture>::Ok,), + <<F::Output as TryFuture>::Error as CombineRejection<T::Error>>::One, + >; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.project().state.poll(cx) + } +} + +impl<T, F> Future for State<T, F> +where + T: TryFuture, + F: Func<T::Ok>, + F::Output: TryFuture + Send, + <F::Output as TryFuture>::Error: CombineRejection<T::Error>, +{ + type Output = Result< + (<F::Output as TryFuture>::Ok,), + <<F::Output as TryFuture>::Error as CombineRejection<T::Error>>::One, + >; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + match self.as_mut().project() { + StateProj::First(first, second) => { + let ex1 = ready!(first.try_poll(cx))?; + let fut2 = second.call(ex1); + self.set(State::Second(fut2)); + } + StateProj::Second(second) => { + let ex2 = match ready!(second.try_poll(cx)) { + Ok(item) => Ok((item,)), + Err(err) => Err(From::from(err)), + }; + self.set(State::Done); + return Poll::Ready(ex2); + } + StateProj::Done => panic!("polled after complete"), + } + } + } +} diff --git a/third_party/rust/warp/src/filter/boxed.rs b/third_party/rust/warp/src/filter/boxed.rs new file mode 100644 index 0000000000..1ed4a7d800 --- /dev/null +++ b/third_party/rust/warp/src/filter/boxed.rs @@ -0,0 +1,100 @@ +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use futures_util::TryFutureExt; + +use super::{Filter, FilterBase, Internal, Tuple}; +use crate::reject::Rejection; + +/// A type representing a boxed `Filter` trait object. +/// +/// The filter inside is a dynamic trait object. The purpose of this type is +/// to ease returning `Filter`s from other functions. +/// +/// To create one, call `Filter::boxed` on any filter. +/// +/// # Examples +/// +/// ``` +/// use warp::{Filter, filters::BoxedFilter, Reply}; +/// +/// pub fn assets_filter() -> BoxedFilter<(impl Reply,)> { +/// warp::path("assets") +/// .and(warp::fs::dir("./assets")) +/// .boxed() +/// } +/// ``` +/// +pub struct BoxedFilter<T: Tuple> { + filter: Arc< + dyn Filter< + Extract = T, + Error = Rejection, + Future = Pin<Box<dyn Future<Output = Result<T, Rejection>> + Send>>, + > + Send + + Sync, + >, +} + +impl<T: Tuple + Send> BoxedFilter<T> { + pub(super) fn new<F>(filter: F) -> BoxedFilter<T> + where + F: Filter<Extract = T> + Send + Sync + 'static, + F::Error: Into<Rejection>, + { + BoxedFilter { + filter: Arc::new(BoxingFilter { + filter: filter.map_err(super::Internal, Into::into), + }), + } + } +} + +impl<T: Tuple> Clone for BoxedFilter<T> { + fn clone(&self) -> BoxedFilter<T> { + BoxedFilter { + filter: self.filter.clone(), + } + } +} + +impl<T: Tuple> fmt::Debug for BoxedFilter<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BoxedFilter").finish() + } +} + +fn _assert_send() { + fn _assert<T: Send>() {} + _assert::<BoxedFilter<()>>(); +} + +impl<T: Tuple + Send> FilterBase for BoxedFilter<T> { + type Extract = T; + type Error = Rejection; + type Future = Pin<Box<dyn Future<Output = Result<T, Rejection>> + Send>>; + + fn filter(&self, _: Internal) -> Self::Future { + self.filter.filter(Internal) + } +} + +struct BoxingFilter<F> { + filter: F, +} + +impl<F> FilterBase for BoxingFilter<F> +where + F: Filter, + F::Future: Send + 'static, +{ + type Extract = F::Extract; + type Error = F::Error; + type Future = Pin<Box<dyn Future<Output = Result<Self::Extract, Self::Error>> + Send>>; + + fn filter(&self, _: Internal) -> Self::Future { + Box::pin(self.filter.filter(Internal).into_future()) + } +} diff --git a/third_party/rust/warp/src/filter/map.rs b/third_party/rust/warp/src/filter/map.rs new file mode 100644 index 0000000000..ec3173994e --- /dev/null +++ b/third_party/rust/warp/src/filter/map.rs @@ -0,0 +1,59 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_util::{ready, TryFuture}; +use pin_project::pin_project; + +use super::{Filter, FilterBase, Func, Internal}; + +#[derive(Clone, Copy, Debug)] +pub struct Map<T, F> { + pub(super) filter: T, + pub(super) callback: F, +} + +impl<T, F> FilterBase for Map<T, F> +where + T: Filter, + F: Func<T::Extract> + Clone + Send, +{ + type Extract = (F::Output,); + type Error = T::Error; + type Future = MapFuture<T, F>; + #[inline] + fn filter(&self, _: Internal) -> Self::Future { + MapFuture { + extract: self.filter.filter(Internal), + callback: self.callback.clone(), + } + } +} + +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct MapFuture<T: Filter, F> { + #[pin] + extract: T::Future, + callback: F, +} + +impl<T, F> Future for MapFuture<T, F> +where + T: Filter, + F: Func<T::Extract>, +{ + type Output = Result<(F::Output,), T::Error>; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let pin = self.project(); + match ready!(pin.extract.try_poll(cx)) { + Ok(ex) => { + let ex = (pin.callback.call(ex),); + Poll::Ready(Ok(ex)) + } + Err(err) => Poll::Ready(Err(err)), + } + } +} diff --git a/third_party/rust/warp/src/filter/map_err.rs b/third_party/rust/warp/src/filter/map_err.rs new file mode 100644 index 0000000000..f4659e6ff6 --- /dev/null +++ b/third_party/rust/warp/src/filter/map_err.rs @@ -0,0 +1,58 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_util::TryFuture; +use pin_project::pin_project; + +use super::{Filter, FilterBase, Internal}; +use crate::reject::IsReject; + +#[derive(Clone, Copy, Debug)] +pub struct MapErr<T, F> { + pub(super) filter: T, + pub(super) callback: F, +} + +impl<T, F, E> FilterBase for MapErr<T, F> +where + T: Filter, + F: Fn(T::Error) -> E + Clone + Send, + E: IsReject, +{ + type Extract = T::Extract; + type Error = E; + type Future = MapErrFuture<T, F>; + #[inline] + fn filter(&self, _: Internal) -> Self::Future { + MapErrFuture { + extract: self.filter.filter(Internal), + callback: self.callback.clone(), + } + } +} + +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct MapErrFuture<T: Filter, F> { + #[pin] + extract: T::Future, + callback: F, +} + +impl<T, F, E> Future for MapErrFuture<T, F> +where + T: Filter, + F: Fn(T::Error) -> E, +{ + type Output = Result<T::Extract, E>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.as_mut() + .project() + .extract + .try_poll(cx) + .map_err(|err| (self.callback)(err)) + } +} diff --git a/third_party/rust/warp/src/filter/mod.rs b/third_party/rust/warp/src/filter/mod.rs new file mode 100644 index 0000000000..263650768d --- /dev/null +++ b/third_party/rust/warp/src/filter/mod.rs @@ -0,0 +1,492 @@ +mod and; +mod and_then; +mod boxed; +mod map; +mod map_err; +mod or; +mod or_else; +mod recover; +pub(crate) mod service; +mod then; +mod unify; +mod untuple_one; +mod wrap; + +use std::future::Future; + +use futures_util::{future, TryFuture, TryFutureExt}; + +pub(crate) use crate::generic::{one, Combine, Either, Func, One, Tuple}; +use crate::reject::{CombineRejection, IsReject, Rejection}; +use crate::route::{self, Route}; + +pub(crate) use self::and::And; +use self::and_then::AndThen; +pub use self::boxed::BoxedFilter; +pub(crate) use self::map::Map; +pub(crate) use self::map_err::MapErr; +pub(crate) use self::or::Or; +use self::or_else::OrElse; +use self::recover::Recover; +use self::then::Then; +use self::unify::Unify; +use self::untuple_one::UntupleOne; +pub use self::wrap::wrap_fn; +pub(crate) use self::wrap::{Wrap, WrapSealed}; + +// A crate-private base trait, allowing the actual `filter` method to change +// signatures without it being a breaking change. +pub trait FilterBase { + type Extract: Tuple; // + Send; + type Error: IsReject; + type Future: Future<Output = Result<Self::Extract, Self::Error>> + Send; + + fn filter(&self, internal: Internal) -> Self::Future; + + fn map_err<F, E>(self, _internal: Internal, fun: F) -> MapErr<Self, F> + where + Self: Sized, + F: Fn(Self::Error) -> E + Clone, + E: ::std::fmt::Debug + Send, + { + MapErr { + filter: self, + callback: fun, + } + } +} + +// A crate-private argument to prevent users from calling methods on +// the `FilterBase` trait. +// +// For instance, this innocent user code could otherwise call `filter`: +// +// ``` +// async fn with_filter<F: Filter>(f: F) -> Result<F::Extract, F::Error> { +// f.filter().await +// } +// ``` +#[allow(missing_debug_implementations)] +pub struct Internal; + +/// Composable request filters. +/// +/// A `Filter` can optionally extract some data from a request, combine +/// it with others, mutate it, and return back some value as a reply. The +/// power of `Filter`s come from being able to isolate small subsets, and then +/// chain and reuse them in various parts of your app. +/// +/// # Extracting Tuples +/// +/// You may notice that several of these filters extract some tuple, often +/// times a tuple of just 1 item! Why? +/// +/// If a filter extracts a `(String,)`, that simply means that it +/// extracts a `String`. If you were to `map` the filter, the argument type +/// would be exactly that, just a `String`. +/// +/// What is it? It's just some type magic that allows for automatic combining +/// and flattening of tuples. Without it, combining two filters together with +/// `and`, where one extracted `()`, and another `String`, would mean the +/// `map` would be given a single argument of `((), String,)`, which is just +/// no fun. +pub trait Filter: FilterBase { + /// Composes a new `Filter` that requires both this and the other to filter a request. + /// + /// Additionally, this will join together the extracted values of both + /// filters, so that `map` and `and_then` receive them as separate arguments. + /// + /// If a `Filter` extracts nothing (so, `()`), combining with any other + /// filter will simply discard the `()`. If a `Filter` extracts one or + /// more items, combining will mean it extracts the values of itself + /// combined with the other. + /// + /// # Example + /// + /// ``` + /// use warp::Filter; + /// + /// // Match `/hello/:name`... + /// warp::path("hello") + /// .and(warp::path::param::<String>()); + /// ``` + fn and<F>(self, other: F) -> And<Self, F> + where + Self: Sized, + <Self::Extract as Tuple>::HList: Combine<<F::Extract as Tuple>::HList>, + F: Filter + Clone, + F::Error: CombineRejection<Self::Error>, + { + And { + first: self, + second: other, + } + } + + /// Composes a new `Filter` of either this or the other filter. + /// + /// # Example + /// + /// ``` + /// use std::net::SocketAddr; + /// use warp::Filter; + /// + /// // Match either `/:u32` or `/:socketaddr` + /// warp::path::param::<u32>() + /// .or(warp::path::param::<SocketAddr>()); + /// ``` + fn or<F>(self, other: F) -> Or<Self, F> + where + Self: Filter<Error = Rejection> + Sized, + F: Filter, + F::Error: CombineRejection<Self::Error>, + { + Or { + first: self, + second: other, + } + } + + /// Composes this `Filter` with a function receiving the extracted value. + /// + /// + /// # Example + /// + /// ``` + /// use warp::Filter; + /// + /// // Map `/:id` + /// warp::path::param().map(|id: u64| { + /// format!("Hello #{}", id) + /// }); + /// ``` + /// + /// # `Func` + /// + /// The generic `Func` trait is implemented for any function that receives + /// the same arguments as this `Filter` extracts. In practice, this + /// shouldn't ever bother you, and simply makes things feel more natural. + /// + /// For example, if three `Filter`s were combined together, suppose one + /// extracts nothing (so `()`), and the other two extract two integers, + /// a function that accepts exactly two integer arguments is allowed. + /// Specifically, any `Fn(u32, u32)`. + /// + /// Without `Product` and `Func`, this would be a lot messier. First of + /// all, the `()`s couldn't be discarded, and the tuples would be nested. + /// So, instead, you'd need to pass an `Fn(((), (u32, u32)))`. That's just + /// a single argument. Bleck! + /// + /// Even worse, the tuples would shuffle the types around depending on + /// the exact invocation of `and`s. So, `unit.and(int).and(int)` would + /// result in a different extracted type from `unit.and(int.and(int))`, + /// or from `int.and(unit).and(int)`. If you changed around the order + /// of filters, while still having them be semantically equivalent, you'd + /// need to update all your `map`s as well. + /// + /// `Product`, `HList`, and `Func` do all the heavy work so that none of + /// this is a bother to you. What's more, the types are enforced at + /// compile-time, and tuple flattening is optimized away to nothing by + /// LLVM. + fn map<F>(self, fun: F) -> Map<Self, F> + where + Self: Sized, + F: Func<Self::Extract> + Clone, + { + Map { + filter: self, + callback: fun, + } + } + + /// Composes this `Filter` with an async function receiving + /// the extracted value. + /// + /// The function should return some `Future` type. + /// + /// # Example + /// + /// ``` + /// use warp::Filter; + /// + /// // Map `/:id` + /// warp::path::param().then(|id: u64| async move { + /// format!("Hello #{}", id) + /// }); + /// ``` + fn then<F>(self, fun: F) -> Then<Self, F> + where + Self: Sized, + F: Func<Self::Extract> + Clone, + F::Output: Future + Send, + { + Then { + filter: self, + callback: fun, + } + } + + /// Composes this `Filter` with a fallible async function receiving + /// the extracted value. + /// + /// The function should return some `TryFuture` type. + /// + /// The `Error` type of the return `Future` needs be a `Rejection`, which + /// means most futures will need to have their error mapped into one. + /// + /// Rejections are meant to say "this filter didn't accept the request, + /// maybe another can". So for application-level errors, consider using + /// [`Filter::then`] instead. + /// + /// # Example + /// + /// ``` + /// use warp::Filter; + /// + /// // Validate after `/:id` + /// warp::path::param().and_then(|id: u64| async move { + /// if id != 0 { + /// Ok(format!("Hello #{}", id)) + /// } else { + /// Err(warp::reject::not_found()) + /// } + /// }); + /// ``` + fn and_then<F>(self, fun: F) -> AndThen<Self, F> + where + Self: Sized, + F: Func<Self::Extract> + Clone, + F::Output: TryFuture + Send, + <F::Output as TryFuture>::Error: CombineRejection<Self::Error>, + { + AndThen { + filter: self, + callback: fun, + } + } + + /// Compose this `Filter` with a function receiving an error. + /// + /// The function should return some `TryFuture` type yielding the + /// same item and error types. + fn or_else<F>(self, fun: F) -> OrElse<Self, F> + where + Self: Filter<Error = Rejection> + Sized, + F: Func<Rejection>, + F::Output: TryFuture<Ok = Self::Extract> + Send, + <F::Output as TryFuture>::Error: IsReject, + { + OrElse { + filter: self, + callback: fun, + } + } + + /// Compose this `Filter` with a function receiving an error and + /// returning a *new* type, instead of the *same* type. + /// + /// This is useful for "customizing" rejections into new response types. + /// See also the [rejections example][ex]. + /// + /// [ex]: https://github.com/seanmonstar/warp/blob/master/examples/rejections.rs + fn recover<F>(self, fun: F) -> Recover<Self, F> + where + Self: Filter<Error = Rejection> + Sized, + F: Func<Rejection>, + F::Output: TryFuture + Send, + <F::Output as TryFuture>::Error: IsReject, + { + Recover { + filter: self, + callback: fun, + } + } + + /// Unifies the extracted value of `Filter`s composed with `or`. + /// + /// When a `Filter` extracts some `Either<T, T>`, where both sides + /// are the same type, this combinator can be used to grab the + /// inner value, regardless of which side of `Either` it was. This + /// is useful for values that could be extracted from multiple parts + /// of a request, and the exact place isn't important. + /// + /// # Example + /// + /// ```rust + /// use std::net::SocketAddr; + /// use warp::Filter; + /// + /// let client_ip = warp::header("x-real-ip") + /// .or(warp::header("x-forwarded-for")) + /// .unify() + /// .map(|ip: SocketAddr| { + /// // Get the IP from either header, + /// // and unify into the inner type. + /// }); + /// ``` + fn unify<T>(self) -> Unify<Self> + where + Self: Filter<Extract = (Either<T, T>,)> + Sized, + T: Tuple, + { + Unify { filter: self } + } + + /// Convenience method to remove one layer of tupling. + /// + /// This is useful for when things like `map` don't return a new value, + /// but just `()`, since warp will wrap it up into a `((),)`. + /// + /// # Example + /// + /// ``` + /// use warp::Filter; + /// + /// let route = warp::path::param() + /// .map(|num: u64| { + /// println!("just logging: {}", num); + /// // returning "nothing" + /// }) + /// .untuple_one() + /// .map(|| { + /// println!("the ((),) was removed"); + /// warp::reply() + /// }); + /// ``` + /// + /// ``` + /// use warp::Filter; + /// + /// let route = warp::any() + /// .map(|| { + /// // wanting to return a tuple + /// (true, 33) + /// }) + /// .untuple_one() + /// .map(|is_enabled: bool, count: i32| { + /// println!("untupled: ({}, {})", is_enabled, count); + /// }); + /// ``` + fn untuple_one<T>(self) -> UntupleOne<Self> + where + Self: Filter<Extract = (T,)> + Sized, + T: Tuple, + { + UntupleOne { filter: self } + } + + /// Wraps the current filter with some wrapper. + /// + /// The wrapper may do some preparation work before starting this filter, + /// and may do post-processing after the filter completes. + /// + /// # Example + /// + /// ``` + /// use warp::Filter; + /// + /// let route = warp::any() + /// .map(warp::reply); + /// + /// // Wrap the route with a log wrapper. + /// let route = route.with(warp::log("example")); + /// ``` + fn with<W>(self, wrapper: W) -> W::Wrapped + where + Self: Sized, + W: Wrap<Self>, + { + wrapper.wrap(self) + } + + /// Boxes this filter into a trait object, making it easier to name the type. + /// + /// # Example + /// + /// ``` + /// use warp::Filter; + /// + /// fn impl_reply() -> warp::filters::BoxedFilter<(impl warp::Reply,)> { + /// warp::any() + /// .map(warp::reply) + /// .boxed() + /// } + /// + /// fn named_i32() -> warp::filters::BoxedFilter<(i32,)> { + /// warp::path::param::<i32>() + /// .boxed() + /// } + /// + /// fn named_and() -> warp::filters::BoxedFilter<(i32, String)> { + /// warp::path::param::<i32>() + /// .and(warp::header::<String>("host")) + /// .boxed() + /// } + /// ``` + fn boxed(self) -> BoxedFilter<Self::Extract> + where + Self: Sized + Send + Sync + 'static, + Self::Extract: Send, + Self::Error: Into<Rejection>, + { + BoxedFilter::new(self) + } +} + +impl<T: FilterBase> Filter for T {} + +pub trait FilterClone: Filter + Clone {} + +impl<T: Filter + Clone> FilterClone for T {} + +fn _assert_object_safe() { + fn _assert(_f: &dyn Filter<Extract = (), Error = (), Future = future::Ready<()>>) {} +} + +// ===== FilterFn ===== + +pub(crate) fn filter_fn<F, U>(func: F) -> FilterFn<F> +where + F: Fn(&mut Route) -> U, + U: TryFuture, + U::Ok: Tuple, + U::Error: IsReject, +{ + FilterFn { func } +} + +pub(crate) fn filter_fn_one<F, U>( + func: F, +) -> impl Filter<Extract = (U::Ok,), Error = U::Error> + Copy +where + F: Fn(&mut Route) -> U + Copy, + U: TryFuture + Send + 'static, + U::Ok: Send, + U::Error: IsReject, +{ + filter_fn(move |route| func(route).map_ok(|item| (item,))) +} + +#[derive(Copy, Clone)] +#[allow(missing_debug_implementations)] +pub(crate) struct FilterFn<F> { + // TODO: could include a `debug_str: &'static str` to be used in Debug impl + func: F, +} + +impl<F, U> FilterBase for FilterFn<F> +where + F: Fn(&mut Route) -> U, + U: TryFuture + Send + 'static, + U::Ok: Tuple + Send, + U::Error: IsReject, +{ + type Extract = U::Ok; + type Error = U::Error; + type Future = future::IntoFuture<U>; + + #[inline] + fn filter(&self, _: Internal) -> Self::Future { + route::with(|route| (self.func)(route)).into_future() + } +} diff --git a/third_party/rust/warp/src/filter/or.rs b/third_party/rust/warp/src/filter/or.rs new file mode 100644 index 0000000000..774067fb59 --- /dev/null +++ b/third_party/rust/warp/src/filter/or.rs @@ -0,0 +1,110 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_util::{ready, TryFuture}; +use pin_project::pin_project; + +use super::{Filter, FilterBase, Internal}; +use crate::generic::Either; +use crate::reject::CombineRejection; +use crate::route; + +type Combined<E1, E2> = <E1 as CombineRejection<E2>>::Combined; + +#[derive(Clone, Copy, Debug)] +pub struct Or<T, U> { + pub(super) first: T, + pub(super) second: U, +} + +impl<T, U> FilterBase for Or<T, U> +where + T: Filter, + U: Filter + Clone + Send, + U::Error: CombineRejection<T::Error>, +{ + type Extract = (Either<T::Extract, U::Extract>,); + //type Error = <U::Error as CombineRejection<T::Error>>::Combined; + type Error = Combined<U::Error, T::Error>; + type Future = EitherFuture<T, U>; + + fn filter(&self, _: Internal) -> Self::Future { + let idx = route::with(|route| route.matched_path_index()); + EitherFuture { + state: State::First(self.first.filter(Internal), self.second.clone()), + original_path_index: PathIndex(idx), + } + } +} + +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct EitherFuture<T: Filter, U: Filter> { + #[pin] + state: State<T, U>, + original_path_index: PathIndex, +} + +#[pin_project(project = StateProj)] +enum State<T: Filter, U: Filter> { + First(#[pin] T::Future, U), + Second(Option<T::Error>, #[pin] U::Future), + Done, +} + +#[derive(Copy, Clone)] +struct PathIndex(usize); + +impl PathIndex { + fn reset_path(&self) { + route::with(|route| route.reset_matched_path_index(self.0)); + } +} + +impl<T, U> Future for EitherFuture<T, U> +where + T: Filter, + U: Filter, + U::Error: CombineRejection<T::Error>, +{ + type Output = Result<(Either<T::Extract, U::Extract>,), Combined<U::Error, T::Error>>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + let pin = self.as_mut().project(); + let (err1, fut2) = match pin.state.project() { + StateProj::First(first, second) => match ready!(first.try_poll(cx)) { + Ok(ex1) => { + return Poll::Ready(Ok((Either::A(ex1),))); + } + Err(e) => { + pin.original_path_index.reset_path(); + (e, second.filter(Internal)) + } + }, + StateProj::Second(err1, second) => { + let ex2 = match ready!(second.try_poll(cx)) { + Ok(ex2) => Ok((Either::B(ex2),)), + Err(e) => { + pin.original_path_index.reset_path(); + let err1 = err1.take().expect("polled after complete"); + Err(e.combine(err1)) + } + }; + self.set(EitherFuture { + state: State::Done, + ..*self + }); + return Poll::Ready(ex2); + } + StateProj::Done => panic!("polled after complete"), + }; + + self.set(EitherFuture { + state: State::Second(Some(err1), fut2), + ..*self + }); + } + } +} diff --git a/third_party/rust/warp/src/filter/or_else.rs b/third_party/rust/warp/src/filter/or_else.rs new file mode 100644 index 0000000000..aaf23243ec --- /dev/null +++ b/third_party/rust/warp/src/filter/or_else.rs @@ -0,0 +1,107 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_util::{ready, TryFuture}; +use pin_project::pin_project; + +use super::{Filter, FilterBase, Func, Internal}; +use crate::reject::IsReject; +use crate::route; + +#[derive(Clone, Copy, Debug)] +pub struct OrElse<T, F> { + pub(super) filter: T, + pub(super) callback: F, +} + +impl<T, F> FilterBase for OrElse<T, F> +where + T: Filter, + F: Func<T::Error> + Clone + Send, + F::Output: TryFuture<Ok = T::Extract> + Send, + <F::Output as TryFuture>::Error: IsReject, +{ + type Extract = <F::Output as TryFuture>::Ok; + type Error = <F::Output as TryFuture>::Error; + type Future = OrElseFuture<T, F>; + #[inline] + fn filter(&self, _: Internal) -> Self::Future { + let idx = route::with(|route| route.matched_path_index()); + OrElseFuture { + state: State::First(self.filter.filter(Internal), self.callback.clone()), + original_path_index: PathIndex(idx), + } + } +} + +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct OrElseFuture<T: Filter, F> +where + T: Filter, + F: Func<T::Error>, + F::Output: TryFuture<Ok = T::Extract> + Send, +{ + #[pin] + state: State<T, F>, + original_path_index: PathIndex, +} + +#[pin_project(project = StateProj)] +enum State<T, F> +where + T: Filter, + F: Func<T::Error>, + F::Output: TryFuture<Ok = T::Extract> + Send, +{ + First(#[pin] T::Future, F), + Second(#[pin] F::Output), + Done, +} + +#[derive(Copy, Clone)] +struct PathIndex(usize); + +impl PathIndex { + fn reset_path(&self) { + route::with(|route| route.reset_matched_path_index(self.0)); + } +} + +impl<T, F> Future for OrElseFuture<T, F> +where + T: Filter, + F: Func<T::Error>, + F::Output: TryFuture<Ok = T::Extract> + Send, +{ + type Output = Result<<F::Output as TryFuture>::Ok, <F::Output as TryFuture>::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + let pin = self.as_mut().project(); + let (err, second) = match pin.state.project() { + StateProj::First(first, second) => match ready!(first.try_poll(cx)) { + Ok(ex) => return Poll::Ready(Ok(ex)), + Err(err) => (err, second), + }, + StateProj::Second(second) => { + let ex2 = ready!(second.try_poll(cx)); + self.set(OrElseFuture { + state: State::Done, + ..*self + }); + return Poll::Ready(ex2); + } + StateProj::Done => panic!("polled after complete"), + }; + + pin.original_path_index.reset_path(); + let fut2 = second.call(err); + self.set(OrElseFuture { + state: State::Second(fut2), + ..*self + }); + } + } +} diff --git a/third_party/rust/warp/src/filter/recover.rs b/third_party/rust/warp/src/filter/recover.rs new file mode 100644 index 0000000000..100e9398c2 --- /dev/null +++ b/third_party/rust/warp/src/filter/recover.rs @@ -0,0 +1,117 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_util::{ready, TryFuture}; +use pin_project::pin_project; + +use super::{Filter, FilterBase, Func, Internal}; +use crate::generic::Either; +use crate::reject::IsReject; +use crate::route; + +#[derive(Clone, Copy, Debug)] +pub struct Recover<T, F> { + pub(super) filter: T, + pub(super) callback: F, +} + +impl<T, F> FilterBase for Recover<T, F> +where + T: Filter, + F: Func<T::Error> + Clone + Send, + F::Output: TryFuture + Send, + <F::Output as TryFuture>::Error: IsReject, +{ + type Extract = (Either<T::Extract, (<F::Output as TryFuture>::Ok,)>,); + type Error = <F::Output as TryFuture>::Error; + type Future = RecoverFuture<T, F>; + #[inline] + fn filter(&self, _: Internal) -> Self::Future { + let idx = route::with(|route| route.matched_path_index()); + RecoverFuture { + state: State::First(self.filter.filter(Internal), self.callback.clone()), + original_path_index: PathIndex(idx), + } + } +} + +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct RecoverFuture<T: Filter, F> +where + T: Filter, + F: Func<T::Error>, + F::Output: TryFuture + Send, + <F::Output as TryFuture>::Error: IsReject, +{ + #[pin] + state: State<T, F>, + original_path_index: PathIndex, +} + +#[pin_project(project = StateProj)] +enum State<T, F> +where + T: Filter, + F: Func<T::Error>, + F::Output: TryFuture + Send, + <F::Output as TryFuture>::Error: IsReject, +{ + First(#[pin] T::Future, F), + Second(#[pin] F::Output), + Done, +} + +#[derive(Copy, Clone)] +struct PathIndex(usize); + +impl PathIndex { + fn reset_path(&self) { + route::with(|route| route.reset_matched_path_index(self.0)); + } +} + +impl<T, F> Future for RecoverFuture<T, F> +where + T: Filter, + F: Func<T::Error>, + F::Output: TryFuture + Send, + <F::Output as TryFuture>::Error: IsReject, +{ + type Output = Result< + (Either<T::Extract, (<F::Output as TryFuture>::Ok,)>,), + <F::Output as TryFuture>::Error, + >; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + let pin = self.as_mut().project(); + let (err, second) = match pin.state.project() { + StateProj::First(first, second) => match ready!(first.try_poll(cx)) { + Ok(ex) => return Poll::Ready(Ok((Either::A(ex),))), + Err(err) => (err, second), + }, + StateProj::Second(second) => { + let ex2 = match ready!(second.try_poll(cx)) { + Ok(ex2) => Ok((Either::B((ex2,)),)), + Err(e) => Err(e), + }; + self.set(RecoverFuture { + state: State::Done, + ..*self + }); + return Poll::Ready(ex2); + } + StateProj::Done => panic!("polled after complete"), + }; + + pin.original_path_index.reset_path(); + let fut2 = second.call(err); + self.set(RecoverFuture { + state: State::Second(fut2), + ..*self + }); + } + } +} diff --git a/third_party/rust/warp/src/filter/service.rs b/third_party/rust/warp/src/filter/service.rs new file mode 100644 index 0000000000..3de12a02ed --- /dev/null +++ b/third_party/rust/warp/src/filter/service.rs @@ -0,0 +1,137 @@ +use std::convert::Infallible; +use std::future::Future; +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_util::future::TryFuture; +use hyper::service::Service; +use pin_project::pin_project; + +use crate::reject::IsReject; +use crate::reply::{Reply, Response}; +use crate::route::{self, Route}; +use crate::{Filter, Request}; + +/// Convert a `Filter` into a `Service`. +/// +/// Filters are normally what APIs are built on in warp. However, it can be +/// useful to convert a `Filter` into a [`Service`][Service], such as if +/// further customizing a `hyper::Service`, or if wanting to make use of +/// the greater [Tower][tower] set of middleware. +/// +/// # Example +/// +/// Running a `warp::Filter` on a regular `hyper::Server`: +/// +/// ``` +/// # async fn run() -> Result<(), Box<dyn std::error::Error>> { +/// use std::convert::Infallible; +/// use warp::Filter; +/// +/// // Our Filter... +/// let route = warp::any().map(|| "Hello From Warp!"); +/// +/// // Convert it into a `Service`... +/// let svc = warp::service(route); +/// +/// // Typical hyper setup... +/// let make_svc = hyper::service::make_service_fn(move |_| async move { +/// Ok::<_, Infallible>(svc) +/// }); +/// +/// hyper::Server::bind(&([127, 0, 0, 1], 3030).into()) +/// .serve(make_svc) +/// .await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// [Service]: https://docs.rs/hyper/0.13.*/hyper/service/trait.Service.html +/// [tower]: https://docs.rs/tower +pub fn service<F>(filter: F) -> FilteredService<F> +where + F: Filter, + <F::Future as TryFuture>::Ok: Reply, + <F::Future as TryFuture>::Error: IsReject, +{ + FilteredService { filter } +} + +#[derive(Copy, Clone, Debug)] +pub struct FilteredService<F> { + filter: F, +} + +impl<F> FilteredService<F> +where + F: Filter, + <F::Future as TryFuture>::Ok: Reply, + <F::Future as TryFuture>::Error: IsReject, +{ + #[inline] + pub(crate) fn call_with_addr( + &self, + req: Request, + remote_addr: Option<SocketAddr>, + ) -> FilteredFuture<F::Future> { + debug_assert!(!route::is_set(), "nested route::set calls"); + + let route = Route::new(req, remote_addr); + let fut = route::set(&route, || self.filter.filter(super::Internal)); + FilteredFuture { future: fut, route } + } +} + +impl<F> Service<Request> for FilteredService<F> +where + F: Filter, + <F::Future as TryFuture>::Ok: Reply, + <F::Future as TryFuture>::Error: IsReject, +{ + type Response = Response; + type Error = Infallible; + type Future = FilteredFuture<F::Future>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + #[inline] + fn call(&mut self, req: Request) -> Self::Future { + self.call_with_addr(req, None) + } +} + +#[pin_project] +#[derive(Debug)] +pub struct FilteredFuture<F> { + #[pin] + future: F, + route: ::std::cell::RefCell<Route>, +} + +impl<F> Future for FilteredFuture<F> +where + F: TryFuture, + F::Ok: Reply, + F::Error: IsReject, +{ + type Output = Result<Response, Infallible>; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + debug_assert!(!route::is_set(), "nested route::set calls"); + + let pin = self.project(); + let fut = pin.future; + match route::set(pin.route, || fut.try_poll(cx)) { + Poll::Ready(Ok(ok)) => Poll::Ready(Ok(ok.into_response())), + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => { + tracing::debug!("rejected: {:?}", err); + Poll::Ready(Ok(err.into_response())) + } + } + } +} diff --git a/third_party/rust/warp/src/filter/then.rs b/third_party/rust/warp/src/filter/then.rs new file mode 100644 index 0000000000..543a22669a --- /dev/null +++ b/third_party/rust/warp/src/filter/then.rs @@ -0,0 +1,95 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_util::{ready, TryFuture}; +use pin_project::pin_project; + +use super::{Filter, FilterBase, Func, Internal}; + +#[derive(Clone, Copy, Debug)] +pub struct Then<T, F> { + pub(super) filter: T, + pub(super) callback: F, +} + +impl<T, F> FilterBase for Then<T, F> +where + T: Filter, + F: Func<T::Extract> + Clone + Send, + F::Output: Future + Send, +{ + type Extract = (<F::Output as Future>::Output,); + type Error = T::Error; + type Future = ThenFuture<T, F>; + #[inline] + fn filter(&self, _: Internal) -> Self::Future { + ThenFuture { + state: State::First(self.filter.filter(Internal), self.callback.clone()), + } + } +} + +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct ThenFuture<T, F> +where + T: Filter, + F: Func<T::Extract>, + F::Output: Future + Send, +{ + #[pin] + state: State<T::Future, F>, +} + +#[pin_project(project = StateProj)] +enum State<T, F> +where + T: TryFuture, + F: Func<T::Ok>, + F::Output: Future + Send, +{ + First(#[pin] T, F), + Second(#[pin] F::Output), + Done, +} + +impl<T, F> Future for ThenFuture<T, F> +where + T: Filter, + F: Func<T::Extract>, + F::Output: Future + Send, +{ + type Output = Result<(<F::Output as Future>::Output,), T::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.project().state.poll(cx) + } +} + +impl<T, F> Future for State<T, F> +where + T: TryFuture, + F: Func<T::Ok>, + F::Output: Future + Send, +{ + type Output = Result<(<F::Output as Future>::Output,), T::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + match self.as_mut().project() { + StateProj::First(first, second) => { + let ex1 = ready!(first.try_poll(cx))?; + let fut2 = second.call(ex1); + self.set(State::Second(fut2)); + } + StateProj::Second(second) => { + let ex2 = (ready!(second.poll(cx)),); + self.set(State::Done); + return Poll::Ready(Ok(ex2)); + } + StateProj::Done => panic!("polled after complete"), + } + } + } +} diff --git a/third_party/rust/warp/src/filter/unify.rs b/third_party/rust/warp/src/filter/unify.rs new file mode 100644 index 0000000000..0cf670daa2 --- /dev/null +++ b/third_party/rust/warp/src/filter/unify.rs @@ -0,0 +1,50 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_util::{ready, TryFuture}; +use pin_project::pin_project; + +use super::{Either, Filter, FilterBase, Internal, Tuple}; + +#[derive(Clone, Copy, Debug)] +pub struct Unify<F> { + pub(super) filter: F, +} + +impl<F, T> FilterBase for Unify<F> +where + F: Filter<Extract = (Either<T, T>,)>, + T: Tuple, +{ + type Extract = T; + type Error = F::Error; + type Future = UnifyFuture<F::Future>; + #[inline] + fn filter(&self, _: Internal) -> Self::Future { + UnifyFuture { + inner: self.filter.filter(Internal), + } + } +} + +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct UnifyFuture<F> { + #[pin] + inner: F, +} + +impl<F, T> Future for UnifyFuture<F> +where + F: TryFuture<Ok = (Either<T, T>,)>, +{ + type Output = Result<T, F::Error>; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + Poll::Ready(match ready!(self.project().inner.try_poll(cx))? { + (Either::A(x),) | (Either::B(x),) => Ok(x), + }) + } +} diff --git a/third_party/rust/warp/src/filter/untuple_one.rs b/third_party/rust/warp/src/filter/untuple_one.rs new file mode 100644 index 0000000000..0fb0de6748 --- /dev/null +++ b/third_party/rust/warp/src/filter/untuple_one.rs @@ -0,0 +1,52 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_util::{ready, TryFuture}; +use pin_project::pin_project; + +use super::{Filter, FilterBase, Internal, Tuple}; + +#[derive(Clone, Copy, Debug)] +pub struct UntupleOne<F> { + pub(super) filter: F, +} + +impl<F, T> FilterBase for UntupleOne<F> +where + F: Filter<Extract = (T,)>, + T: Tuple, +{ + type Extract = T; + type Error = F::Error; + type Future = UntupleOneFuture<F>; + #[inline] + fn filter(&self, _: Internal) -> Self::Future { + UntupleOneFuture { + extract: self.filter.filter(Internal), + } + } +} + +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct UntupleOneFuture<F: Filter> { + #[pin] + extract: F::Future, +} + +impl<F, T> Future for UntupleOneFuture<F> +where + F: Filter<Extract = (T,)>, + T: Tuple, +{ + type Output = Result<T, F::Error>; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match ready!(self.project().extract.try_poll(cx)) { + Ok((t,)) => Poll::Ready(Ok(t)), + Err(err) => Poll::Ready(Err(err)), + } + } +} diff --git a/third_party/rust/warp/src/filter/wrap.rs b/third_party/rust/warp/src/filter/wrap.rs new file mode 100644 index 0000000000..faceb705f3 --- /dev/null +++ b/third_party/rust/warp/src/filter/wrap.rs @@ -0,0 +1,67 @@ +use super::Filter; + +pub trait WrapSealed<F: Filter> { + type Wrapped: Filter; + + fn wrap(&self, filter: F) -> Self::Wrapped; +} + +impl<'a, T, F> WrapSealed<F> for &'a T +where + T: WrapSealed<F>, + F: Filter, +{ + type Wrapped = T::Wrapped; + fn wrap(&self, filter: F) -> Self::Wrapped { + (*self).wrap(filter) + } +} + +pub trait Wrap<F: Filter>: WrapSealed<F> {} + +impl<T, F> Wrap<F> for T +where + T: WrapSealed<F>, + F: Filter, +{ +} + +/// Combines received filter with pre and after filters +/// +/// # Example +/// +/// ``` +/// use crate::warp::Filter; +/// +/// let route = warp::any() +/// .map(|| "hello world") +/// .with(warp::wrap_fn(|filter| filter)); +/// ``` +/// +/// You can find the full example in the [usage example](https://github.com/seanmonstar/warp/blob/master/examples/wrapping.rs). +pub fn wrap_fn<F, T, U>(func: F) -> WrapFn<F> +where + F: Fn(T) -> U, + T: Filter, + U: Filter, +{ + WrapFn { func } +} + +#[derive(Debug)] +pub struct WrapFn<F> { + func: F, +} + +impl<F, T, U> WrapSealed<T> for WrapFn<F> +where + F: Fn(T) -> U, + T: Filter, + U: Filter, +{ + type Wrapped = U; + + fn wrap(&self, filter: T) -> Self::Wrapped { + (self.func)(filter) + } +} diff --git a/third_party/rust/warp/src/filters/addr.rs b/third_party/rust/warp/src/filters/addr.rs new file mode 100644 index 0000000000..3d630705a1 --- /dev/null +++ b/third_party/rust/warp/src/filters/addr.rs @@ -0,0 +1,26 @@ +//! Socket Address filters. + +use std::convert::Infallible; +use std::net::SocketAddr; + +use crate::filter::{filter_fn_one, Filter}; + +/// Creates a `Filter` to get the remote address of the connection. +/// +/// If the underlying transport doesn't use socket addresses, this will yield +/// `None`. +/// +/// # Example +/// +/// ``` +/// use std::net::SocketAddr; +/// use warp::Filter; +/// +/// let route = warp::addr::remote() +/// .map(|addr: Option<SocketAddr>| { +/// println!("remote address = {:?}", addr); +/// }); +/// ``` +pub fn remote() -> impl Filter<Extract = (Option<SocketAddr>,), Error = Infallible> + Copy { + filter_fn_one(|route| futures_util::future::ok(route.remote_addr())) +} diff --git a/third_party/rust/warp/src/filters/any.rs b/third_party/rust/warp/src/filters/any.rs new file mode 100644 index 0000000000..707fd82dd4 --- /dev/null +++ b/third_party/rust/warp/src/filters/any.rs @@ -0,0 +1,76 @@ +//! A filter that matches any route. +use std::convert::Infallible; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::filter::{Filter, FilterBase, Internal}; + +/// A filter that matches any route. +/// +/// This can be a useful building block to build new filters from, +/// since [`Filter`](crate::Filter) is otherwise a sealed trait. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let route = warp::any() +/// .map(|| { +/// "I always return this string!" +/// }); +/// ``` +/// +/// This could allow creating a single `impl Filter` returning a specific +/// reply, that can then be used as the end of several different filter +/// chains. +/// +/// Another use case is turning some clone-able resource into a `Filter`, +/// thus allowing to easily `and` it together with others. +/// +/// ``` +/// use std::sync::Arc; +/// use warp::Filter; +/// +/// let state = Arc::new(vec![33, 41]); +/// let with_state = warp::any().map(move || state.clone()); +/// +/// // Now we could `and` with any other filter: +/// +/// let route = warp::path::param() +/// .and(with_state) +/// .map(|param_id: u32, db: Arc<Vec<u32>>| { +/// db.contains(¶m_id) +/// }); +/// ``` +pub fn any() -> impl Filter<Extract = (), Error = Infallible> + Copy { + Any +} + +#[derive(Copy, Clone)] +#[allow(missing_debug_implementations)] +struct Any; + +impl FilterBase for Any { + type Extract = (); + type Error = Infallible; + type Future = AnyFut; + + #[inline] + fn filter(&self, _: Internal) -> Self::Future { + AnyFut + } +} + +#[allow(missing_debug_implementations)] +struct AnyFut; + +impl Future for AnyFut { + type Output = Result<(), Infallible>; + + #[inline] + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> { + Poll::Ready(Ok(())) + } +} diff --git a/third_party/rust/warp/src/filters/body.rs b/third_party/rust/warp/src/filters/body.rs new file mode 100644 index 0000000000..3bb08d2b4c --- /dev/null +++ b/third_party/rust/warp/src/filters/body.rs @@ -0,0 +1,345 @@ +//! Body filters +//! +//! Filters that extract a body for a route. + +use std::error::Error as StdError; +use std::fmt; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::{Buf, Bytes}; +use futures_util::{future, ready, Stream, TryFutureExt}; +use headers::ContentLength; +use http::header::CONTENT_TYPE; +use hyper::Body; +use mime; +use serde::de::DeserializeOwned; +use serde_json; +use serde_urlencoded; + +use crate::filter::{filter_fn, filter_fn_one, Filter, FilterBase}; +use crate::reject::{self, Rejection}; + +type BoxError = Box<dyn StdError + Send + Sync>; + +// Extracts the `Body` Stream from the route. +// +// Does not consume any of it. +pub(crate) fn body() -> impl Filter<Extract = (Body,), Error = Rejection> + Copy { + filter_fn_one(|route| { + future::ready(route.take_body().ok_or_else(|| { + tracing::error!("request body already taken in previous filter"); + reject::known(BodyConsumedMultipleTimes { _p: () }) + })) + }) +} + +/// Require a `content-length` header to have a value no greater than some limit. +/// +/// Rejects if `content-length` header is missing, is invalid, or has a number +/// larger than the limit provided. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// // Limit the upload to 4kb... +/// let upload = warp::body::content_length_limit(4096) +/// .and(warp::body::aggregate()); +/// ``` +pub fn content_length_limit(limit: u64) -> impl Filter<Extract = (), Error = Rejection> + Copy { + crate::filters::header::header2() + .map_err(crate::filter::Internal, |_| { + tracing::debug!("content-length missing"); + reject::length_required() + }) + .and_then(move |ContentLength(length)| { + if length <= limit { + future::ok(()) + } else { + tracing::debug!("content-length: {} is over limit {}", length, limit); + future::err(reject::payload_too_large()) + } + }) + .untuple_one() +} + +/// Create a `Filter` that extracts the request body as a `futures::Stream`. +/// +/// If other filters have already extracted the body, this filter will reject +/// with a `500 Internal Server Error`. +/// +/// # Warning +/// +/// This does not have a default size limit, it would be wise to use one to +/// prevent a overly large request from using too much memory. +pub fn stream( +) -> impl Filter<Extract = (impl Stream<Item = Result<impl Buf, crate::Error>>,), Error = Rejection> + Copy +{ + body().map(|body: Body| BodyStream { body }) +} + +/// Returns a `Filter` that matches any request and extracts a `Future` of a +/// concatenated body. +/// +/// The contents of the body will be flattened into a single contiguous +/// `Bytes`, which may require memory copies. If you don't require a +/// contiguous buffer, using `aggregate` can be give better performance. +/// +/// # Warning +/// +/// This does not have a default size limit, it would be wise to use one to +/// prevent a overly large request from using too much memory. +/// +/// # Example +/// +/// ``` +/// use warp::{Buf, Filter}; +/// +/// let route = warp::body::content_length_limit(1024 * 32) +/// .and(warp::body::bytes()) +/// .map(|bytes: bytes::Bytes| { +/// println!("bytes = {:?}", bytes); +/// }); +/// ``` +pub fn bytes() -> impl Filter<Extract = (Bytes,), Error = Rejection> + Copy { + body().and_then(|body: hyper::Body| { + hyper::body::to_bytes(body).map_err(|err| { + tracing::debug!("to_bytes error: {}", err); + reject::known(BodyReadError(err)) + }) + }) +} + +/// Returns a `Filter` that matches any request and extracts a `Future` of an +/// aggregated body. +/// +/// The `Buf` may contain multiple, non-contiguous buffers. This can be more +/// performant (by reducing copies) when receiving large bodies. +/// +/// # Warning +/// +/// This does not have a default size limit, it would be wise to use one to +/// prevent a overly large request from using too much memory. +/// +/// # Example +/// +/// ``` +/// use warp::{Buf, Filter}; +/// +/// fn full_body(mut body: impl Buf) { +/// // It could have several non-contiguous slices of memory... +/// while body.has_remaining() { +/// println!("slice = {:?}", body.chunk()); +/// let cnt = body.chunk().len(); +/// body.advance(cnt); +/// } +/// } +/// +/// let route = warp::body::content_length_limit(1024 * 32) +/// .and(warp::body::aggregate()) +/// .map(full_body); +/// ``` +pub fn aggregate() -> impl Filter<Extract = (impl Buf,), Error = Rejection> + Copy { + body().and_then(|body: ::hyper::Body| { + hyper::body::aggregate(body).map_err(|err| { + tracing::debug!("aggregate error: {}", err); + reject::known(BodyReadError(err)) + }) + }) +} + +/// Returns a `Filter` that matches any request and extracts a `Future` of a +/// JSON-decoded body. +/// +/// # Warning +/// +/// This does not have a default size limit, it would be wise to use one to +/// prevent a overly large request from using too much memory. +/// +/// # Example +/// +/// ``` +/// use std::collections::HashMap; +/// use warp::Filter; +/// +/// let route = warp::body::content_length_limit(1024 * 32) +/// .and(warp::body::json()) +/// .map(|simple_map: HashMap<String, String>| { +/// "Got a JSON body!" +/// }); +/// ``` +pub fn json<T: DeserializeOwned + Send>() -> impl Filter<Extract = (T,), Error = Rejection> + Copy { + is_content_type::<Json>() + .and(bytes()) + .and_then(|buf| async move { + Json::decode(buf).map_err(|err| { + tracing::debug!("request json body error: {}", err); + reject::known(BodyDeserializeError { cause: err }) + }) + }) +} + +/// Returns a `Filter` that matches any request and extracts a +/// `Future` of a form encoded body. +/// +/// # Note +/// +/// This filter is for the simpler `application/x-www-form-urlencoded` format, +/// not `multipart/form-data`. +/// +/// # Warning +/// +/// This does not have a default size limit, it would be wise to use one to +/// prevent a overly large request from using too much memory. +/// +/// +/// ``` +/// use std::collections::HashMap; +/// use warp::Filter; +/// +/// let route = warp::body::content_length_limit(1024 * 32) +/// .and(warp::body::form()) +/// .map(|simple_map: HashMap<String, String>| { +/// "Got a urlencoded body!" +/// }); +/// ``` +pub fn form<T: DeserializeOwned + Send>() -> impl Filter<Extract = (T,), Error = Rejection> + Copy { + is_content_type::<Form>() + .and(aggregate()) + .and_then(|buf| async move { + Form::decode(buf).map_err(|err| { + tracing::debug!("request form body error: {}", err); + reject::known(BodyDeserializeError { cause: err }) + }) + }) +} + +// ===== Decoders ===== + +trait Decode { + const MIME: (mime::Name<'static>, mime::Name<'static>); + const WITH_NO_CONTENT_TYPE: bool; + + fn decode<B: Buf, T: DeserializeOwned>(buf: B) -> Result<T, BoxError>; +} + +struct Json; + +impl Decode for Json { + const MIME: (mime::Name<'static>, mime::Name<'static>) = (mime::APPLICATION, mime::JSON); + const WITH_NO_CONTENT_TYPE: bool = true; + + fn decode<B: Buf, T: DeserializeOwned>(mut buf: B) -> Result<T, BoxError> { + serde_json::from_slice(&buf.copy_to_bytes(buf.remaining())).map_err(Into::into) + } +} + +struct Form; + +impl Decode for Form { + const MIME: (mime::Name<'static>, mime::Name<'static>) = + (mime::APPLICATION, mime::WWW_FORM_URLENCODED); + const WITH_NO_CONTENT_TYPE: bool = true; + + fn decode<B: Buf, T: DeserializeOwned>(buf: B) -> Result<T, BoxError> { + serde_urlencoded::from_reader(buf.reader()).map_err(Into::into) + } +} + +// Require the `content-type` header to be this type (or, if there's no `content-type` +// header at all, optimistically hope it's the right type). +fn is_content_type<D: Decode>() -> impl Filter<Extract = (), Error = Rejection> + Copy { + filter_fn(move |route| { + let (type_, subtype) = D::MIME; + if let Some(value) = route.headers().get(CONTENT_TYPE) { + tracing::trace!("is_content_type {}/{}? {:?}", type_, subtype, value); + let ct = value + .to_str() + .ok() + .and_then(|s| s.parse::<mime::Mime>().ok()); + if let Some(ct) = ct { + if ct.type_() == type_ && ct.subtype() == subtype { + future::ok(()) + } else { + tracing::debug!( + "content-type {:?} doesn't match {}/{}", + value, + type_, + subtype + ); + future::err(reject::unsupported_media_type()) + } + } else { + tracing::debug!("content-type {:?} couldn't be parsed", value); + future::err(reject::unsupported_media_type()) + } + } else if D::WITH_NO_CONTENT_TYPE { + // Optimistically assume its correct! + tracing::trace!("no content-type header, assuming {}/{}", type_, subtype); + future::ok(()) + } else { + tracing::debug!("no content-type found"); + future::err(reject::unsupported_media_type()) + } + }) +} + +// ===== BodyStream ===== + +struct BodyStream { + body: Body, +} + +impl Stream for BodyStream { + type Item = Result<Bytes, crate::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + let opt_item = ready!(Pin::new(&mut self.get_mut().body).poll_next(cx)); + + match opt_item { + None => Poll::Ready(None), + Some(item) => { + let stream_buf = item.map_err(crate::Error::new); + + Poll::Ready(Some(stream_buf)) + } + } + } +} + +// ===== Rejections ===== + +/// An error used in rejections when deserializing a request body fails. +#[derive(Debug)] +pub struct BodyDeserializeError { + cause: BoxError, +} + +impl fmt::Display for BodyDeserializeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Request body deserialize error: {}", self.cause) + } +} + +impl StdError for BodyDeserializeError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + Some(self.cause.as_ref()) + } +} + +#[derive(Debug)] +pub(crate) struct BodyReadError(::hyper::Error); + +impl fmt::Display for BodyReadError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Request body read error: {}", self.0) + } +} + +impl StdError for BodyReadError {} + +unit_error! { + pub(crate) BodyConsumedMultipleTimes: "Request body consumed multiple times" +} diff --git a/third_party/rust/warp/src/filters/compression.rs b/third_party/rust/warp/src/filters/compression.rs new file mode 100644 index 0000000000..244e768356 --- /dev/null +++ b/third_party/rust/warp/src/filters/compression.rs @@ -0,0 +1,292 @@ +//! Compression Filters +//! +//! Filters that compress the body of a response. + +#[cfg(feature = "compression-brotli")] +use async_compression::tokio::bufread::BrotliEncoder; + +#[cfg(feature = "compression-gzip")] +use async_compression::tokio::bufread::{DeflateEncoder, GzipEncoder}; + +use http::header::HeaderValue; +use hyper::{ + header::{CONTENT_ENCODING, CONTENT_LENGTH}, + Body, +}; +use tokio_util::io::{ReaderStream, StreamReader}; + +use crate::filter::{Filter, WrapSealed}; +use crate::reject::IsReject; +use crate::reply::{Reply, Response}; + +use self::internal::{CompressionProps, WithCompression}; + +enum CompressionAlgo { + #[cfg(feature = "compression-brotli")] + BR, + #[cfg(feature = "compression-gzip")] + DEFLATE, + #[cfg(feature = "compression-gzip")] + GZIP, +} + +impl From<CompressionAlgo> for HeaderValue { + #[inline] + fn from(algo: CompressionAlgo) -> Self { + HeaderValue::from_static(match algo { + #[cfg(feature = "compression-brotli")] + CompressionAlgo::BR => "br", + #[cfg(feature = "compression-gzip")] + CompressionAlgo::DEFLATE => "deflate", + #[cfg(feature = "compression-gzip")] + CompressionAlgo::GZIP => "gzip", + }) + } +} + +/// Compression +#[derive(Clone, Copy, Debug)] +pub struct Compression<F> { + func: F, +} + +// TODO: The implementation of `gzip()`, `deflate()`, and `brotli()` could be replaced with +// generics or a macro + +/// Create a wrapping filter that compresses the Body of a [`Response`](crate::reply::Response) +/// using gzip, adding `content-encoding: gzip` to the Response's [`HeaderMap`](hyper::HeaderMap) +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let route = warp::get() +/// .and(warp::path::end()) +/// .and(warp::fs::file("./README.md")) +/// .with(warp::compression::gzip()); +/// ``` +#[cfg(feature = "compression-gzip")] +pub fn gzip() -> Compression<impl Fn(CompressionProps) -> Response + Copy> { + let func = move |mut props: CompressionProps| { + let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::new(StreamReader::new( + props.body, + )))); + props + .head + .headers + .append(CONTENT_ENCODING, CompressionAlgo::GZIP.into()); + props.head.headers.remove(CONTENT_LENGTH); + Response::from_parts(props.head, body) + }; + Compression { func } +} + +/// Create a wrapping filter that compresses the Body of a [`Response`](crate::reply::Response) +/// using deflate, adding `content-encoding: deflate` to the Response's [`HeaderMap`](hyper::HeaderMap) +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let route = warp::get() +/// .and(warp::path::end()) +/// .and(warp::fs::file("./README.md")) +/// .with(warp::compression::deflate()); +/// ``` +#[cfg(feature = "compression-gzip")] +pub fn deflate() -> Compression<impl Fn(CompressionProps) -> Response + Copy> { + let func = move |mut props: CompressionProps| { + let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(StreamReader::new( + props.body, + )))); + props + .head + .headers + .append(CONTENT_ENCODING, CompressionAlgo::DEFLATE.into()); + props.head.headers.remove(CONTENT_LENGTH); + Response::from_parts(props.head, body) + }; + Compression { func } +} + +/// Create a wrapping filter that compresses the Body of a [`Response`](crate::reply::Response) +/// using brotli, adding `content-encoding: br` to the Response's [`HeaderMap`](hyper::HeaderMap) +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let route = warp::get() +/// .and(warp::path::end()) +/// .and(warp::fs::file("./README.md")) +/// .with(warp::compression::brotli()); +/// ``` +#[cfg(feature = "compression-brotli")] +pub fn brotli() -> Compression<impl Fn(CompressionProps) -> Response + Copy> { + let func = move |mut props: CompressionProps| { + let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(StreamReader::new( + props.body, + )))); + props + .head + .headers + .append(CONTENT_ENCODING, CompressionAlgo::BR.into()); + props.head.headers.remove(CONTENT_LENGTH); + Response::from_parts(props.head, body) + }; + Compression { func } +} + +impl<FN, F> WrapSealed<F> for Compression<FN> +where + FN: Fn(CompressionProps) -> Response + Clone + Send, + F: Filter + Clone + Send, + F::Extract: Reply, + F::Error: IsReject, +{ + type Wrapped = WithCompression<FN, F>; + + fn wrap(&self, filter: F) -> Self::Wrapped { + WithCompression { + filter, + compress: self.clone(), + } + } +} + +mod internal { + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + + use bytes::Bytes; + use futures_util::{ready, Stream, TryFuture}; + use hyper::Body; + use pin_project::pin_project; + + use crate::filter::{Filter, FilterBase, Internal}; + use crate::reject::IsReject; + use crate::reply::{Reply, Response}; + + use super::Compression; + + /// A wrapper around any type that implements [`Stream`](futures::Stream) to be + /// compatible with async_compression's Stream based encoders + #[pin_project] + #[derive(Debug)] + pub struct CompressableBody<S, E> + where + E: std::error::Error, + S: Stream<Item = Result<Bytes, E>>, + { + #[pin] + body: S, + } + + impl<S, E> Stream for CompressableBody<S, E> + where + E: std::error::Error, + S: Stream<Item = Result<Bytes, E>>, + { + type Item = std::io::Result<Bytes>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + use std::io::{Error, ErrorKind}; + + let pin = self.project(); + S::poll_next(pin.body, cx).map_err(|_| Error::from(ErrorKind::InvalidData)) + } + } + + impl From<Body> for CompressableBody<Body, hyper::Error> { + fn from(body: Body) -> Self { + CompressableBody { body } + } + } + + /// Compression Props + #[derive(Debug)] + pub struct CompressionProps { + pub(super) body: CompressableBody<Body, hyper::Error>, + pub(super) head: http::response::Parts, + } + + impl From<http::Response<Body>> for CompressionProps { + fn from(resp: http::Response<Body>) -> Self { + let (head, body) = resp.into_parts(); + CompressionProps { + body: body.into(), + head, + } + } + } + + #[allow(missing_debug_implementations)] + pub struct Compressed(pub(super) Response); + + impl Reply for Compressed { + #[inline] + fn into_response(self) -> Response { + self.0 + } + } + + #[allow(missing_debug_implementations)] + #[derive(Clone, Copy)] + pub struct WithCompression<FN, F> { + pub(super) compress: Compression<FN>, + pub(super) filter: F, + } + + impl<FN, F> FilterBase for WithCompression<FN, F> + where + FN: Fn(CompressionProps) -> Response + Clone + Send, + F: Filter + Clone + Send, + F::Extract: Reply, + F::Error: IsReject, + { + type Extract = (Compressed,); + type Error = F::Error; + type Future = WithCompressionFuture<FN, F::Future>; + + fn filter(&self, _: Internal) -> Self::Future { + WithCompressionFuture { + compress: self.compress.clone(), + future: self.filter.filter(Internal), + } + } + } + + #[allow(missing_debug_implementations)] + #[pin_project] + pub struct WithCompressionFuture<FN, F> { + compress: Compression<FN>, + #[pin] + future: F, + } + + impl<FN, F> Future for WithCompressionFuture<FN, F> + where + FN: Fn(CompressionProps) -> Response, + F: TryFuture, + F::Ok: Reply, + F::Error: IsReject, + { + type Output = Result<(Compressed,), F::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let pin = self.as_mut().project(); + let result = ready!(pin.future.try_poll(cx)); + match result { + Ok(reply) => { + let resp = (self.compress.func)(reply.into_response().into()); + Poll::Ready(Ok((Compressed(resp),))) + } + Err(reject) => Poll::Ready(Err(reject)), + } + } + } +} diff --git a/third_party/rust/warp/src/filters/cookie.rs b/third_party/rust/warp/src/filters/cookie.rs new file mode 100644 index 0000000000..53695e85c8 --- /dev/null +++ b/third_party/rust/warp/src/filters/cookie.rs @@ -0,0 +1,46 @@ +//! Cookie Filters + +use futures_util::future; +use headers::Cookie; + +use super::header; +use crate::filter::{Filter, One}; +use crate::reject::Rejection; +use std::convert::Infallible; +use std::str::FromStr; + +/// Creates a `Filter` that requires a cookie by name. +/// +/// If found, extracts the value of the cookie, otherwise rejects. +pub fn cookie<T>(name: &'static str) -> impl Filter<Extract = One<T>, Error = Rejection> + Copy +where + T: FromStr + Send + 'static, +{ + header::header2().and_then(move |cookie: Cookie| { + let cookie = cookie + .get(name) + .ok_or_else(|| crate::reject::missing_cookie(name)) + .and_then(|s| T::from_str(s).map_err(|_| crate::reject::missing_cookie(name))); + future::ready(cookie) + }) +} + +/// Creates a `Filter` that looks for an optional cookie by name. +/// +/// If found, extracts the value of the cookie, otherwise continues +/// the request, extracting `None`. +pub fn optional<T>( + name: &'static str, +) -> impl Filter<Extract = One<Option<T>>, Error = Infallible> + Copy +where + T: FromStr + Send + 'static, +{ + header::optional2().map(move |opt: Option<Cookie>| { + let cookie = opt.and_then(|cookie| cookie.get(name).map(|x| T::from_str(x))); + match cookie { + Some(Ok(t)) => Some(t), + Some(Err(_)) => None, + None => None, + } + }) +} diff --git a/third_party/rust/warp/src/filters/cors.rs b/third_party/rust/warp/src/filters/cors.rs new file mode 100644 index 0000000000..fa28893c20 --- /dev/null +++ b/third_party/rust/warp/src/filters/cors.rs @@ -0,0 +1,626 @@ +//! CORS Filters + +use std::collections::HashSet; +use std::convert::TryFrom; +use std::error::Error as StdError; +use std::fmt; +use std::sync::Arc; + +use headers::{ + AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders, HeaderMapExt, +}; +use http::{ + self, + header::{self, HeaderName, HeaderValue}, +}; + +use crate::filter::{Filter, WrapSealed}; +use crate::reject::{CombineRejection, Rejection}; +use crate::reply::Reply; + +use self::internal::{CorsFilter, IntoOrigin, Seconds}; + +/// Create a wrapping filter that exposes [CORS][] behavior for a wrapped +/// filter. +/// +/// [CORS]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let cors = warp::cors() +/// .allow_origin("https://hyper.rs") +/// .allow_methods(vec!["GET", "POST", "DELETE"]); +/// +/// let route = warp::any() +/// .map(warp::reply) +/// .with(cors); +/// ``` +/// If you want to allow any route: +/// ``` +/// use warp::Filter; +/// let cors = warp::cors() +/// .allow_any_origin(); +/// ``` +/// You can find more usage examples [here](https://github.com/seanmonstar/warp/blob/7fa54eaecd0fe12687137372791ff22fc7995766/tests/cors.rs). +pub fn cors() -> Builder { + Builder { + credentials: false, + allowed_headers: HashSet::new(), + exposed_headers: HashSet::new(), + max_age: None, + methods: HashSet::new(), + origins: None, + } +} + +/// A wrapping filter constructed via `warp::cors()`. +#[derive(Clone, Debug)] +pub struct Cors { + config: Arc<Configured>, +} + +/// A constructed via `warp::cors()`. +#[derive(Clone, Debug)] +pub struct Builder { + credentials: bool, + allowed_headers: HashSet<HeaderName>, + exposed_headers: HashSet<HeaderName>, + max_age: Option<u64>, + methods: HashSet<http::Method>, + origins: Option<HashSet<HeaderValue>>, +} + +impl Builder { + /// Sets whether to add the `Access-Control-Allow-Credentials` header. + pub fn allow_credentials(mut self, allow: bool) -> Self { + self.credentials = allow; + self + } + + /// Adds a method to the existing list of allowed request methods. + /// + /// # Panics + /// + /// Panics if the provided argument is not a valid `http::Method`. + pub fn allow_method<M>(mut self, method: M) -> Self + where + http::Method: TryFrom<M>, + { + let method = match TryFrom::try_from(method) { + Ok(m) => m, + Err(_) => panic!("illegal Method"), + }; + self.methods.insert(method); + self + } + + /// Adds multiple methods to the existing list of allowed request methods. + /// + /// # Panics + /// + /// Panics if the provided argument is not a valid `http::Method`. + pub fn allow_methods<I>(mut self, methods: I) -> Self + where + I: IntoIterator, + http::Method: TryFrom<I::Item>, + { + let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) { + Ok(m) => m, + Err(_) => panic!("illegal Method"), + }); + self.methods.extend(iter); + self + } + + /// Adds a header to the list of allowed request headers. + /// + /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g. `content-type`. + /// + /// # Panics + /// + /// Panics if the provided argument is not a valid `http::header::HeaderName`. + pub fn allow_header<H>(mut self, header: H) -> Self + where + HeaderName: TryFrom<H>, + { + let header = match TryFrom::try_from(header) { + Ok(m) => m, + Err(_) => panic!("illegal Header"), + }; + self.allowed_headers.insert(header); + self + } + + /// Adds multiple headers to the list of allowed request headers. + /// + /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g.`content-type`. + /// + /// # Panics + /// + /// Panics if any of the headers are not a valid `http::header::HeaderName`. + pub fn allow_headers<I>(mut self, headers: I) -> Self + where + I: IntoIterator, + HeaderName: TryFrom<I::Item>, + { + let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) { + Ok(h) => h, + Err(_) => panic!("illegal Header"), + }); + self.allowed_headers.extend(iter); + self + } + + /// Adds a header to the list of exposed headers. + /// + /// # Panics + /// + /// Panics if the provided argument is not a valid `http::header::HeaderName`. + pub fn expose_header<H>(mut self, header: H) -> Self + where + HeaderName: TryFrom<H>, + { + let header = match TryFrom::try_from(header) { + Ok(m) => m, + Err(_) => panic!("illegal Header"), + }; + self.exposed_headers.insert(header); + self + } + + /// Adds multiple headers to the list of exposed headers. + /// + /// # Panics + /// + /// Panics if any of the headers are not a valid `http::header::HeaderName`. + pub fn expose_headers<I>(mut self, headers: I) -> Self + where + I: IntoIterator, + HeaderName: TryFrom<I::Item>, + { + let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) { + Ok(h) => h, + Err(_) => panic!("illegal Header"), + }); + self.exposed_headers.extend(iter); + self + } + + /// Sets that *any* `Origin` header is allowed. + /// + /// # Warning + /// + /// This can allow websites you didn't intend to access this resource, + /// it is usually better to set an explicit list. + pub fn allow_any_origin(mut self) -> Self { + self.origins = None; + self + } + + /// Add an origin to the existing list of allowed `Origin`s. + /// + /// # Panics + /// + /// Panics if the provided argument is not a valid `Origin`. + pub fn allow_origin(self, origin: impl IntoOrigin) -> Self { + self.allow_origins(Some(origin)) + } + + /// Add multiple origins to the existing list of allowed `Origin`s. + /// + /// # Panics + /// + /// Panics if the provided argument is not a valid `Origin`. + pub fn allow_origins<I>(mut self, origins: I) -> Self + where + I: IntoIterator, + I::Item: IntoOrigin, + { + let iter = origins + .into_iter() + .map(IntoOrigin::into_origin) + .map(|origin| { + origin + .to_string() + .parse() + .expect("Origin is always a valid HeaderValue") + }); + + self.origins.get_or_insert_with(HashSet::new).extend(iter); + + self + } + + /// Sets the `Access-Control-Max-Age` header. + /// + /// # Example + /// + /// + /// ``` + /// use std::time::Duration; + /// use warp::Filter; + /// + /// let cors = warp::cors() + /// .max_age(30) // 30u32 seconds + /// .max_age(Duration::from_secs(30)); // or a Duration + /// ``` + pub fn max_age(mut self, seconds: impl Seconds) -> Self { + self.max_age = Some(seconds.seconds()); + self + } + + /// Builds the `Cors` wrapper from the configured settings. + /// + /// This step isn't *required*, as the `Builder` itself can be passed + /// to `Filter::with`. This just allows constructing once, thus not needing + /// to pay the cost of "building" every time. + pub fn build(self) -> Cors { + let expose_headers_header = if self.exposed_headers.is_empty() { + None + } else { + Some(self.exposed_headers.iter().cloned().collect()) + }; + let allowed_headers_header = self.allowed_headers.iter().cloned().collect(); + let methods_header = self.methods.iter().cloned().collect(); + + let config = Arc::new(Configured { + cors: self, + allowed_headers_header, + expose_headers_header, + methods_header, + }); + + Cors { config } + } +} + +impl<F> WrapSealed<F> for Builder +where + F: Filter + Clone + Send + Sync + 'static, + F::Extract: Reply, + F::Error: CombineRejection<Rejection>, + <F::Error as CombineRejection<Rejection>>::One: CombineRejection<Rejection>, +{ + type Wrapped = CorsFilter<F>; + + fn wrap(&self, inner: F) -> Self::Wrapped { + let Cors { config } = self.clone().build(); + + CorsFilter { config, inner } + } +} + +impl<F> WrapSealed<F> for Cors +where + F: Filter + Clone + Send + Sync + 'static, + F::Extract: Reply, + F::Error: CombineRejection<Rejection>, + <F::Error as CombineRejection<Rejection>>::One: CombineRejection<Rejection>, +{ + type Wrapped = CorsFilter<F>; + + fn wrap(&self, inner: F) -> Self::Wrapped { + let config = self.config.clone(); + + CorsFilter { config, inner } + } +} + +/// An error used to reject requests that are forbidden by a `cors` filter. +pub struct CorsForbidden { + kind: Forbidden, +} + +#[derive(Debug)] +enum Forbidden { + OriginNotAllowed, + MethodNotAllowed, + HeaderNotAllowed, +} + +impl fmt::Debug for CorsForbidden { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("CorsForbidden").field(&self.kind).finish() + } +} + +impl fmt::Display for CorsForbidden { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let detail = match self.kind { + Forbidden::OriginNotAllowed => "origin not allowed", + Forbidden::MethodNotAllowed => "request-method not allowed", + Forbidden::HeaderNotAllowed => "header not allowed", + }; + write!(f, "CORS request forbidden: {}", detail) + } +} + +impl StdError for CorsForbidden {} + +#[derive(Clone, Debug)] +struct Configured { + cors: Builder, + allowed_headers_header: AccessControlAllowHeaders, + expose_headers_header: Option<AccessControlExposeHeaders>, + methods_header: AccessControlAllowMethods, +} + +enum Validated { + Preflight(HeaderValue), + Simple(HeaderValue), + NotCors, +} + +impl Configured { + fn check_request( + &self, + method: &http::Method, + headers: &http::HeaderMap, + ) -> Result<Validated, Forbidden> { + match (headers.get(header::ORIGIN), method) { + (Some(origin), &http::Method::OPTIONS) => { + // OPTIONS requests are preflight CORS requests... + + if !self.is_origin_allowed(origin) { + return Err(Forbidden::OriginNotAllowed); + } + + if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) { + if !self.is_method_allowed(req_method) { + return Err(Forbidden::MethodNotAllowed); + } + } else { + tracing::trace!( + "preflight request missing access-control-request-method header" + ); + return Err(Forbidden::MethodNotAllowed); + } + + if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) { + let headers = req_headers + .to_str() + .map_err(|_| Forbidden::HeaderNotAllowed)?; + for header in headers.split(',') { + if !self.is_header_allowed(header.trim()) { + return Err(Forbidden::HeaderNotAllowed); + } + } + } + + Ok(Validated::Preflight(origin.clone())) + } + (Some(origin), _) => { + // Any other method, simply check for a valid origin... + + tracing::trace!("origin header: {:?}", origin); + if self.is_origin_allowed(origin) { + Ok(Validated::Simple(origin.clone())) + } else { + Err(Forbidden::OriginNotAllowed) + } + } + (None, _) => { + // No `ORIGIN` header means this isn't CORS! + Ok(Validated::NotCors) + } + } + } + + fn is_method_allowed(&self, header: &HeaderValue) -> bool { + http::Method::from_bytes(header.as_bytes()) + .map(|method| self.cors.methods.contains(&method)) + .unwrap_or(false) + } + + fn is_header_allowed(&self, header: &str) -> bool { + HeaderName::from_bytes(header.as_bytes()) + .map(|header| self.cors.allowed_headers.contains(&header)) + .unwrap_or(false) + } + + fn is_origin_allowed(&self, origin: &HeaderValue) -> bool { + if let Some(ref allowed) = self.cors.origins { + allowed.contains(origin) + } else { + true + } + } + + fn append_preflight_headers(&self, headers: &mut http::HeaderMap) { + self.append_common_headers(headers); + + headers.typed_insert(self.allowed_headers_header.clone()); + headers.typed_insert(self.methods_header.clone()); + + if let Some(max_age) = self.cors.max_age { + headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into()); + } + } + + fn append_common_headers(&self, headers: &mut http::HeaderMap) { + if self.cors.credentials { + headers.insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + } + if let Some(expose_headers_header) = &self.expose_headers_header { + headers.typed_insert(expose_headers_header.clone()) + } + } +} + +mod internal { + use std::future::Future; + use std::pin::Pin; + use std::sync::Arc; + use std::task::{Context, Poll}; + + use futures_util::{future, ready, TryFuture}; + use headers::Origin; + use http::header; + use pin_project::pin_project; + + use super::{Configured, CorsForbidden, Validated}; + use crate::filter::{Filter, FilterBase, Internal, One}; + use crate::generic::Either; + use crate::reject::{CombineRejection, Rejection}; + use crate::route; + + #[derive(Clone, Debug)] + pub struct CorsFilter<F> { + pub(super) config: Arc<Configured>, + pub(super) inner: F, + } + + impl<F> FilterBase for CorsFilter<F> + where + F: Filter, + F::Extract: Send, + F::Future: Future, + F::Error: CombineRejection<Rejection>, + { + type Extract = + One<Either<One<Preflight>, One<Either<One<Wrapped<F::Extract>>, F::Extract>>>>; + type Error = <F::Error as CombineRejection<Rejection>>::One; + type Future = future::Either< + future::Ready<Result<Self::Extract, Self::Error>>, + WrappedFuture<F::Future>, + >; + + fn filter(&self, _: Internal) -> Self::Future { + let validated = + route::with(|route| self.config.check_request(route.method(), route.headers())); + + match validated { + Ok(Validated::Preflight(origin)) => { + let preflight = Preflight { + config: self.config.clone(), + origin, + }; + future::Either::Left(future::ok((Either::A((preflight,)),))) + } + Ok(Validated::Simple(origin)) => future::Either::Right(WrappedFuture { + inner: self.inner.filter(Internal), + wrapped: Some((self.config.clone(), origin)), + }), + Ok(Validated::NotCors) => future::Either::Right(WrappedFuture { + inner: self.inner.filter(Internal), + wrapped: None, + }), + Err(err) => { + let rejection = crate::reject::known(CorsForbidden { kind: err }); + future::Either::Left(future::err(rejection.into())) + } + } + } + } + + #[derive(Debug)] + pub struct Preflight { + config: Arc<Configured>, + origin: header::HeaderValue, + } + + impl crate::reply::Reply for Preflight { + fn into_response(self) -> crate::reply::Response { + let mut res = crate::reply::Response::default(); + self.config.append_preflight_headers(res.headers_mut()); + res.headers_mut() + .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, self.origin); + res + } + } + + #[derive(Debug)] + pub struct Wrapped<R> { + config: Arc<Configured>, + inner: R, + origin: header::HeaderValue, + } + + impl<R> crate::reply::Reply for Wrapped<R> + where + R: crate::reply::Reply, + { + fn into_response(self) -> crate::reply::Response { + let mut res = self.inner.into_response(); + self.config.append_common_headers(res.headers_mut()); + res.headers_mut() + .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, self.origin); + res + } + } + + #[pin_project] + #[derive(Debug)] + pub struct WrappedFuture<F> { + #[pin] + inner: F, + wrapped: Option<(Arc<Configured>, header::HeaderValue)>, + } + + impl<F> Future for WrappedFuture<F> + where + F: TryFuture, + F::Error: CombineRejection<Rejection>, + { + type Output = Result< + One<Either<One<Preflight>, One<Either<One<Wrapped<F::Ok>>, F::Ok>>>>, + <F::Error as CombineRejection<Rejection>>::One, + >; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let pin = self.project(); + match ready!(pin.inner.try_poll(cx)) { + Ok(inner) => { + let item = if let Some((config, origin)) = pin.wrapped.take() { + (Either::A((Wrapped { + config, + inner, + origin, + },)),) + } else { + (Either::B(inner),) + }; + let item = (Either::B(item),); + Poll::Ready(Ok(item)) + } + Err(err) => Poll::Ready(Err(err.into())), + } + } + } + + pub trait Seconds { + fn seconds(self) -> u64; + } + + impl Seconds for u32 { + fn seconds(self) -> u64 { + self.into() + } + } + + impl Seconds for ::std::time::Duration { + fn seconds(self) -> u64 { + self.as_secs() + } + } + + pub trait IntoOrigin { + fn into_origin(self) -> Origin; + } + + impl<'a> IntoOrigin for &'a str { + fn into_origin(self) -> Origin { + let mut parts = self.splitn(2, "://"); + let scheme = parts.next().expect("missing scheme"); + let rest = parts.next().expect("missing scheme"); + + Origin::try_from_parts(scheme, rest, None).expect("invalid Origin") + } + } +} diff --git a/third_party/rust/warp/src/filters/ext.rs b/third_party/rust/warp/src/filters/ext.rs new file mode 100644 index 0000000000..985bbfb61c --- /dev/null +++ b/third_party/rust/warp/src/filters/ext.rs @@ -0,0 +1,36 @@ +//! Request Extensions + +use std::convert::Infallible; + +use futures_util::future; + +use crate::filter::{filter_fn_one, Filter}; +use crate::reject::{self, Rejection}; + +/// Get a previously set extension of the current route. +/// +/// If the extension doesn't exist, this rejects with a `MissingExtension`. +pub fn get<T: Clone + Send + Sync + 'static>( +) -> impl Filter<Extract = (T,), Error = Rejection> + Copy { + filter_fn_one(|route| { + let route = route + .extensions() + .get::<T>() + .cloned() + .ok_or_else(|| reject::known(MissingExtension { _p: () })); + future::ready(route) + }) +} + +/// Get a previously set extension of the current route. +/// +/// If the extension doesn't exist, it yields `None`. +pub fn optional<T: Clone + Send + Sync + 'static>( +) -> impl Filter<Extract = (Option<T>,), Error = Infallible> + Copy { + filter_fn_one(|route| future::ok(route.extensions().get::<T>().cloned())) +} + +unit_error! { + /// An error used to reject if `get` cannot find the extension. + pub MissingExtension: "Missing request extension" +} diff --git a/third_party/rust/warp/src/filters/fs.rs b/third_party/rust/warp/src/filters/fs.rs new file mode 100644 index 0000000000..0949b66ecb --- /dev/null +++ b/third_party/rust/warp/src/filters/fs.rs @@ -0,0 +1,539 @@ +//! File System Filters + +use std::cmp; +use std::convert::Infallible; +use std::fs::Metadata; +use std::future::Future; +use std::io; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Poll; + +use bytes::{Bytes, BytesMut}; +use futures_util::future::Either; +use futures_util::{future, ready, stream, FutureExt, Stream, StreamExt, TryFutureExt}; +use headers::{ + AcceptRanges, ContentLength, ContentRange, ContentType, HeaderMapExt, IfModifiedSince, IfRange, + IfUnmodifiedSince, LastModified, Range, +}; +use http::StatusCode; +use hyper::Body; +use mime_guess; +use percent_encoding::percent_decode_str; +use tokio::fs::File as TkFile; +use tokio::io::AsyncSeekExt; +use tokio_util::io::poll_read_buf; + +use crate::filter::{Filter, FilterClone, One}; +use crate::reject::{self, Rejection}; +use crate::reply::{Reply, Response}; + +/// Creates a `Filter` that serves a File at the `path`. +/// +/// Does not filter out based on any information of the request. Always serves +/// the file at the exact `path` provided. Thus, this can be used to serve a +/// single file with `GET`s, but could also be used in combination with other +/// filters, such as after validating in `POST` request, wanting to return a +/// specific file as the body. +/// +/// For serving a directory, see [dir](dir). +/// +/// # Example +/// +/// ``` +/// // Always serves this file from the file system. +/// let route = warp::fs::file("/www/static/app.js"); +/// ``` +pub fn file(path: impl Into<PathBuf>) -> impl FilterClone<Extract = One<File>, Error = Rejection> { + let path = Arc::new(path.into()); + crate::any() + .map(move || { + tracing::trace!("file: {:?}", path); + ArcPath(path.clone()) + }) + .and(conditionals()) + .and_then(file_reply) +} + +/// Creates a `Filter` that serves a directory at the base `path` joined +/// by the request path. +/// +/// This can be used to serve "static files" from a directory. By far the most +/// common pattern of serving static files is for `GET` requests, so this +/// filter automatically includes a `GET` check. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// // Matches requests that start with `/static`, +/// // and then uses the rest of that path to lookup +/// // and serve a file from `/www/static`. +/// let route = warp::path("static") +/// .and(warp::fs::dir("/www/static")); +/// +/// // For example: +/// // - `GET /static/app.js` would serve the file `/www/static/app.js` +/// // - `GET /static/css/app.css` would serve the file `/www/static/css/app.css` +/// ``` +pub fn dir(path: impl Into<PathBuf>) -> impl FilterClone<Extract = One<File>, Error = Rejection> { + let base = Arc::new(path.into()); + crate::get() + .or(crate::head()) + .unify() + .and(path_from_tail(base)) + .and(conditionals()) + .and_then(file_reply) +} + +fn path_from_tail( + base: Arc<PathBuf>, +) -> impl FilterClone<Extract = One<ArcPath>, Error = Rejection> { + crate::path::tail().and_then(move |tail: crate::path::Tail| { + future::ready(sanitize_path(base.as_ref(), tail.as_str())).and_then(|mut buf| async { + let is_dir = tokio::fs::metadata(buf.clone()) + .await + .map(|m| m.is_dir()) + .unwrap_or(false); + + if is_dir { + tracing::debug!("dir: appending index.html to directory path"); + buf.push("index.html"); + } + tracing::trace!("dir: {:?}", buf); + Ok(ArcPath(Arc::new(buf))) + }) + }) +} + +fn sanitize_path(base: impl AsRef<Path>, tail: &str) -> Result<PathBuf, Rejection> { + let mut buf = PathBuf::from(base.as_ref()); + let p = match percent_decode_str(tail).decode_utf8() { + Ok(p) => p, + Err(err) => { + tracing::debug!("dir: failed to decode route={:?}: {:?}", tail, err); + return Err(reject::not_found()); + } + }; + tracing::trace!("dir? base={:?}, route={:?}", base.as_ref(), p); + for seg in p.split('/') { + if seg.starts_with("..") { + tracing::warn!("dir: rejecting segment starting with '..'"); + return Err(reject::not_found()); + } else if seg.contains('\\') { + tracing::warn!("dir: rejecting segment containing backslash (\\)"); + return Err(reject::not_found()); + } else if cfg!(windows) && seg.contains(':') { + tracing::warn!("dir: rejecting segment containing colon (:)"); + return Err(reject::not_found()); + } else { + buf.push(seg); + } + } + Ok(buf) +} + +#[derive(Debug)] +struct Conditionals { + if_modified_since: Option<IfModifiedSince>, + if_unmodified_since: Option<IfUnmodifiedSince>, + if_range: Option<IfRange>, + range: Option<Range>, +} + +enum Cond { + NoBody(Response), + WithBody(Option<Range>), +} + +impl Conditionals { + fn check(self, last_modified: Option<LastModified>) -> Cond { + if let Some(since) = self.if_unmodified_since { + let precondition = last_modified + .map(|time| since.precondition_passes(time.into())) + .unwrap_or(false); + + tracing::trace!( + "if-unmodified-since? {:?} vs {:?} = {}", + since, + last_modified, + precondition + ); + if !precondition { + let mut res = Response::new(Body::empty()); + *res.status_mut() = StatusCode::PRECONDITION_FAILED; + return Cond::NoBody(res); + } + } + + if let Some(since) = self.if_modified_since { + tracing::trace!( + "if-modified-since? header = {:?}, file = {:?}", + since, + last_modified + ); + let unmodified = last_modified + .map(|time| !since.is_modified(time.into())) + // no last_modified means its always modified + .unwrap_or(false); + if unmodified { + let mut res = Response::new(Body::empty()); + *res.status_mut() = StatusCode::NOT_MODIFIED; + return Cond::NoBody(res); + } + } + + if let Some(if_range) = self.if_range { + tracing::trace!("if-range? {:?} vs {:?}", if_range, last_modified); + let can_range = !if_range.is_modified(None, last_modified.as_ref()); + + if !can_range { + return Cond::WithBody(None); + } + } + + Cond::WithBody(self.range) + } +} + +fn conditionals() -> impl Filter<Extract = One<Conditionals>, Error = Infallible> + Copy { + crate::header::optional2() + .and(crate::header::optional2()) + .and(crate::header::optional2()) + .and(crate::header::optional2()) + .map( + |if_modified_since, if_unmodified_since, if_range, range| Conditionals { + if_modified_since, + if_unmodified_since, + if_range, + range, + }, + ) +} + +/// A file response. +#[derive(Debug)] +pub struct File { + resp: Response, + path: ArcPath, +} + +impl File { + /// Extract the `&Path` of the file this `Response` delivers. + /// + /// # Example + /// + /// The example below changes the Content-Type response header for every file called `video.mp4`. + /// + /// ``` + /// use warp::{Filter, reply::Reply}; + /// + /// let route = warp::path("static") + /// .and(warp::fs::dir("/www/static")) + /// .map(|reply: warp::filters::fs::File| { + /// if reply.path().ends_with("video.mp4") { + /// warp::reply::with_header(reply, "Content-Type", "video/mp4").into_response() + /// } else { + /// reply.into_response() + /// } + /// }); + /// ``` + pub fn path(&self) -> &Path { + self.path.as_ref() + } +} + +// Silly wrapper since Arc<PathBuf> doesn't implement AsRef<Path> ;_; +#[derive(Clone, Debug)] +struct ArcPath(Arc<PathBuf>); + +impl AsRef<Path> for ArcPath { + fn as_ref(&self) -> &Path { + (*self.0).as_ref() + } +} + +impl Reply for File { + fn into_response(self) -> Response { + self.resp + } +} + +fn file_reply( + path: ArcPath, + conditionals: Conditionals, +) -> impl Future<Output = Result<File, Rejection>> + Send { + TkFile::open(path.clone()).then(move |res| match res { + Ok(f) => Either::Left(file_conditional(f, path, conditionals)), + Err(err) => { + let rej = match err.kind() { + io::ErrorKind::NotFound => { + tracing::debug!("file not found: {:?}", path.as_ref().display()); + reject::not_found() + } + io::ErrorKind::PermissionDenied => { + tracing::warn!("file permission denied: {:?}", path.as_ref().display()); + reject::known(FilePermissionError { _p: () }) + } + _ => { + tracing::error!( + "file open error (path={:?}): {} ", + path.as_ref().display(), + err + ); + reject::known(FileOpenError { _p: () }) + } + }; + Either::Right(future::err(rej)) + } + }) +} + +async fn file_metadata(f: TkFile) -> Result<(TkFile, Metadata), Rejection> { + match f.metadata().await { + Ok(meta) => Ok((f, meta)), + Err(err) => { + tracing::debug!("file metadata error: {}", err); + Err(reject::not_found()) + } + } +} + +fn file_conditional( + f: TkFile, + path: ArcPath, + conditionals: Conditionals, +) -> impl Future<Output = Result<File, Rejection>> + Send { + file_metadata(f).map_ok(move |(file, meta)| { + let mut len = meta.len(); + let modified = meta.modified().ok().map(LastModified::from); + + let resp = match conditionals.check(modified) { + Cond::NoBody(resp) => resp, + Cond::WithBody(range) => { + bytes_range(range, len) + .map(|(start, end)| { + let sub_len = end - start; + let buf_size = optimal_buf_size(&meta); + let stream = file_stream(file, buf_size, (start, end)); + let body = Body::wrap_stream(stream); + + let mut resp = Response::new(body); + + if sub_len != len { + *resp.status_mut() = StatusCode::PARTIAL_CONTENT; + resp.headers_mut().typed_insert( + ContentRange::bytes(start..end, len).expect("valid ContentRange"), + ); + + len = sub_len; + } + + let mime = mime_guess::from_path(path.as_ref()).first_or_octet_stream(); + + resp.headers_mut().typed_insert(ContentLength(len)); + resp.headers_mut().typed_insert(ContentType::from(mime)); + resp.headers_mut().typed_insert(AcceptRanges::bytes()); + + if let Some(last_modified) = modified { + resp.headers_mut().typed_insert(last_modified); + } + + resp + }) + .unwrap_or_else(|BadRange| { + // bad byte range + let mut resp = Response::new(Body::empty()); + *resp.status_mut() = StatusCode::RANGE_NOT_SATISFIABLE; + resp.headers_mut() + .typed_insert(ContentRange::unsatisfied_bytes(len)); + resp + }) + } + }; + + File { resp, path } + }) +} + +struct BadRange; + +fn bytes_range(range: Option<Range>, max_len: u64) -> Result<(u64, u64), BadRange> { + use std::ops::Bound; + + let range = if let Some(range) = range { + range + } else { + return Ok((0, max_len)); + }; + + let ret = range + .iter() + .map(|(start, end)| { + let start = match start { + Bound::Unbounded => 0, + Bound::Included(s) => s, + Bound::Excluded(s) => s + 1, + }; + + let end = match end { + Bound::Unbounded => max_len, + Bound::Included(s) => { + // For the special case where s == the file size + if s == max_len { + s + } else { + s + 1 + } + } + Bound::Excluded(s) => s, + }; + + if start < end && end <= max_len { + Ok((start, end)) + } else { + tracing::trace!("unsatisfiable byte range: {}-{}/{}", start, end, max_len); + Err(BadRange) + } + }) + .next() + .unwrap_or(Ok((0, max_len))); + ret +} + +fn file_stream( + mut file: TkFile, + buf_size: usize, + (start, end): (u64, u64), +) -> impl Stream<Item = Result<Bytes, io::Error>> + Send { + use std::io::SeekFrom; + + let seek = async move { + if start != 0 { + file.seek(SeekFrom::Start(start)).await?; + } + Ok(file) + }; + + seek.into_stream() + .map(move |result| { + let mut buf = BytesMut::new(); + let mut len = end - start; + let mut f = match result { + Ok(f) => f, + Err(f) => return Either::Left(stream::once(future::err(f))), + }; + + Either::Right(stream::poll_fn(move |cx| { + if len == 0 { + return Poll::Ready(None); + } + reserve_at_least(&mut buf, buf_size); + + let n = match ready!(poll_read_buf(Pin::new(&mut f), cx, &mut buf)) { + Ok(n) => n as u64, + Err(err) => { + tracing::debug!("file read error: {}", err); + return Poll::Ready(Some(Err(err))); + } + }; + + if n == 0 { + tracing::debug!("file read found EOF before expected length"); + return Poll::Ready(None); + } + + let mut chunk = buf.split().freeze(); + if n > len { + chunk = chunk.split_to(len as usize); + len = 0; + } else { + len -= n; + } + + Poll::Ready(Some(Ok(chunk))) + })) + }) + .flatten() +} + +fn reserve_at_least(buf: &mut BytesMut, cap: usize) { + if buf.capacity() - buf.len() < cap { + buf.reserve(cap); + } +} + +const DEFAULT_READ_BUF_SIZE: usize = 8_192; + +fn optimal_buf_size(metadata: &Metadata) -> usize { + let block_size = get_block_size(metadata); + + // If file length is smaller than block size, don't waste space + // reserving a bigger-than-needed buffer. + cmp::min(block_size as u64, metadata.len()) as usize +} + +#[cfg(unix)] +fn get_block_size(metadata: &Metadata) -> usize { + use std::os::unix::fs::MetadataExt; + //TODO: blksize() returns u64, should handle bad cast... + //(really, a block size bigger than 4gb?) + + // Use device blocksize unless it's really small. + cmp::max(metadata.blksize() as usize, DEFAULT_READ_BUF_SIZE) +} + +#[cfg(not(unix))] +fn get_block_size(_metadata: &Metadata) -> usize { + DEFAULT_READ_BUF_SIZE +} + +// ===== Rejections ===== + +unit_error! { + pub(crate) FileOpenError: "file open error" +} + +unit_error! { + pub(crate) FilePermissionError: "file perimission error" +} + +#[cfg(test)] +mod tests { + use super::sanitize_path; + use bytes::BytesMut; + + #[test] + fn test_sanitize_path() { + let base = "/var/www"; + + fn p(s: &str) -> &::std::path::Path { + s.as_ref() + } + + assert_eq!( + sanitize_path(base, "/foo.html").unwrap(), + p("/var/www/foo.html") + ); + + // bad paths + sanitize_path(base, "/../foo.html").expect_err("dot dot"); + + sanitize_path(base, "/C:\\/foo.html").expect_err("C:\\"); + } + + #[test] + fn test_reserve_at_least() { + let mut buf = BytesMut::new(); + let cap = 8_192; + + assert_eq!(buf.len(), 0); + assert_eq!(buf.capacity(), 0); + + super::reserve_at_least(&mut buf, cap); + assert_eq!(buf.len(), 0); + assert_eq!(buf.capacity(), cap); + } +} diff --git a/third_party/rust/warp/src/filters/header.rs b/third_party/rust/warp/src/filters/header.rs new file mode 100644 index 0000000000..0c535a38b5 --- /dev/null +++ b/third_party/rust/warp/src/filters/header.rs @@ -0,0 +1,230 @@ +//! Header Filters +//! +//! These filters are used to interact with the Request HTTP headers. Some +//! of them, like `exact` and `exact_ignore_case`, are just predicates, +//! they don't extract any values. The `header` filter allows parsing +//! a type from any header. +use std::convert::Infallible; +use std::str::FromStr; + +use futures_util::future; +use headers::{Header, HeaderMapExt}; +use http::header::HeaderValue; +use http::HeaderMap; + +use crate::filter::{filter_fn, filter_fn_one, Filter, One}; +use crate::reject::{self, Rejection}; + +/// Create a `Filter` that tries to parse the specified header. +/// +/// This `Filter` will look for a header with supplied name, and try to +/// parse to a `T`, otherwise rejects the request. +/// +/// # Example +/// +/// ``` +/// use std::net::SocketAddr; +/// +/// // Parse `content-length: 100` as a `u64` +/// let content_length = warp::header::<u64>("content-length"); +/// +/// // Parse `host: 127.0.0.1:8080` as a `SocketAddr +/// let local_host = warp::header::<SocketAddr>("host"); +/// +/// // Parse `foo: bar` into a `String` +/// let foo = warp::header::<String>("foo"); +/// ``` +pub fn header<T: FromStr + Send + 'static>( + name: &'static str, +) -> impl Filter<Extract = One<T>, Error = Rejection> + Copy { + filter_fn_one(move |route| { + tracing::trace!("header({:?})", name); + let route = route + .headers() + .get(name) + .ok_or_else(|| reject::missing_header(name)) + .and_then(|value| value.to_str().map_err(|_| reject::invalid_header(name))) + .and_then(|s| T::from_str(s).map_err(|_| reject::invalid_header(name))); + future::ready(route) + }) +} + +pub(crate) fn header2<T: Header + Send + 'static>( +) -> impl Filter<Extract = One<T>, Error = Rejection> + Copy { + filter_fn_one(move |route| { + tracing::trace!("header2({:?})", T::name()); + let route = route + .headers() + .typed_get() + .ok_or_else(|| reject::invalid_header(T::name().as_str())); + future::ready(route) + }) +} + +/// Create a `Filter` that tries to parse the specified header, if it exists. +/// +/// If the header does not exist, it yields `None`. Otherwise, it will try to +/// parse as a `T`, and if it fails, a invalid header rejection is return. If +/// successful, the filter yields `Some(T)`. +/// +/// # Example +/// +/// ``` +/// // Grab the `authorization` header if it exists. +/// let opt_auth = warp::header::optional::<String>("authorization"); +/// ``` +pub fn optional<T>( + name: &'static str, +) -> impl Filter<Extract = One<Option<T>>, Error = Rejection> + Copy +where + T: FromStr + Send + 'static, +{ + filter_fn_one(move |route| { + tracing::trace!("optional({:?})", name); + let result = route.headers().get(name).map(|value| { + value + .to_str() + .map_err(|_| reject::invalid_header(name))? + .parse::<T>() + .map_err(|_| reject::invalid_header(name)) + }); + + match result { + Some(Ok(t)) => future::ok(Some(t)), + Some(Err(e)) => future::err(e), + None => future::ok(None), + } + }) +} + +pub(crate) fn optional2<T>() -> impl Filter<Extract = One<Option<T>>, Error = Infallible> + Copy +where + T: Header + Send + 'static, +{ + filter_fn_one(move |route| future::ready(Ok(route.headers().typed_get()))) +} + +/* TODO +pub fn exact2<T>(header: T) -> impl FilterClone<Extract=(), Error=Rejection> +where + T: Header + PartialEq + Clone + Send, +{ + filter_fn(move |route| { + tracing::trace!("exact2({:?})", T::NAME); + route.headers() + .typed_get::<T>() + .and_then(|val| if val == header { + Some(()) + } else { + None + }) + .ok_or_else(|| reject::bad_request()) + }) +} +*/ + +/// Create a `Filter` that requires a header to match the value exactly. +/// +/// This `Filter` will look for a header with supplied name and the exact +/// value, otherwise rejects the request. +/// +/// # Example +/// +/// ``` +/// // Require `dnt: 1` header to be set. +/// let must_dnt = warp::header::exact("dnt", "1"); +/// ``` +pub fn exact( + name: &'static str, + value: &'static str, +) -> impl Filter<Extract = (), Error = Rejection> + Copy { + filter_fn(move |route| { + tracing::trace!("exact?({:?}, {:?})", name, value); + let route = route + .headers() + .get(name) + .ok_or_else(|| reject::missing_header(name)) + .and_then(|val| { + if val == value { + Ok(()) + } else { + Err(reject::invalid_header(name)) + } + }); + future::ready(route) + }) +} + +/// Create a `Filter` that requires a header to match the value exactly. +/// +/// This `Filter` will look for a header with supplied name and the exact +/// value, ignoring ASCII case, otherwise rejects the request. +/// +/// # Example +/// +/// ``` +/// // Require `connection: keep-alive` header to be set. +/// let keep_alive = warp::header::exact_ignore_case("connection", "keep-alive"); +/// ``` +pub fn exact_ignore_case( + name: &'static str, + value: &'static str, +) -> impl Filter<Extract = (), Error = Rejection> + Copy { + filter_fn(move |route| { + tracing::trace!("exact_ignore_case({:?}, {:?})", name, value); + let route = route + .headers() + .get(name) + .ok_or_else(|| reject::missing_header(name)) + .and_then(|val| { + if val.as_bytes().eq_ignore_ascii_case(value.as_bytes()) { + Ok(()) + } else { + Err(reject::invalid_header(name)) + } + }); + future::ready(route) + }) +} + +/// Create a `Filter` that gets a `HeaderValue` for the name. +/// +/// # Example +/// +/// ``` +/// use warp::{Filter, http::header::HeaderValue}; +/// +/// let filter = warp::header::value("x-token") +/// .map(|value: HeaderValue| { +/// format!("header value bytes: {:?}", value) +/// }); +/// ``` +pub fn value( + name: &'static str, +) -> impl Filter<Extract = One<HeaderValue>, Error = Rejection> + Copy { + filter_fn_one(move |route| { + tracing::trace!("value({:?})", name); + let route = route + .headers() + .get(name) + .cloned() + .ok_or_else(|| reject::missing_header(name)); + future::ready(route) + }) +} + +/// Create a `Filter` that returns a clone of the request's `HeaderMap`. +/// +/// # Example +/// +/// ``` +/// use warp::{Filter, http::HeaderMap}; +/// +/// let headers = warp::header::headers_cloned() +/// .map(|headers: HeaderMap| { +/// format!("header count: {}", headers.len()) +/// }); +/// ``` +pub fn headers_cloned() -> impl Filter<Extract = One<HeaderMap>, Error = Infallible> + Copy { + filter_fn_one(|route| future::ok(route.headers().clone())) +} diff --git a/third_party/rust/warp/src/filters/host.rs b/third_party/rust/warp/src/filters/host.rs new file mode 100644 index 0000000000..9aae039c19 --- /dev/null +++ b/third_party/rust/warp/src/filters/host.rs @@ -0,0 +1,96 @@ +//! Host ("authority") filter +//! +use crate::filter::{filter_fn_one, Filter, One}; +use crate::reject::{self, Rejection}; +use futures_util::future; +pub use http::uri::Authority; +use std::str::FromStr; + +/// Creates a `Filter` that requires a specific authority (target server's +/// host and port) in the request. +/// +/// Authority is specified either in the `Host` header or in the target URI. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let multihost = +/// warp::host::exact("foo.com").map(|| "you've reached foo.com") +/// .or(warp::host::exact("bar.com").map(|| "you've reached bar.com")); +/// ``` +pub fn exact(expected: &str) -> impl Filter<Extract = (), Error = Rejection> + Clone { + let expected = Authority::from_str(expected).expect("invalid host/authority"); + optional() + .and_then(move |option: Option<Authority>| match option { + Some(authority) if authority == expected => future::ok(()), + _ => future::err(reject::not_found()), + }) + .untuple_one() +} + +/// Creates a `Filter` that looks for an authority (target server's host +/// and port) in the request. +/// +/// Authority is specified either in the `Host` header or in the target URI. +/// +/// If found, extracts the `Authority`, otherwise continues the request, +/// extracting `None`. +/// +/// Rejects with `400 Bad Request` if the `Host` header is malformed or if there +/// is a mismatch between the `Host` header and the target URI. +/// +/// # Example +/// +/// ``` +/// use warp::{Filter, host::Authority}; +/// +/// let host = warp::host::optional() +/// .map(|authority: Option<Authority>| { +/// if let Some(a) = authority { +/// format!("{} is currently not at home", a.host()) +/// } else { +/// "please state who you're trying to reach".to_owned() +/// } +/// }); +/// ``` +pub fn optional() -> impl Filter<Extract = One<Option<Authority>>, Error = Rejection> + Copy { + filter_fn_one(move |route| { + // The authority can be sent by clients in various ways: + // + // 1) in the "target URI" + // a) serialized in the start line (HTTP/1.1 proxy requests) + // b) serialized in `:authority` pseudo-header (HTTP/2 generated - "SHOULD") + // 2) in the `Host` header (HTTP/1.1 origin requests, HTTP/2 converted) + // + // Hyper transparently handles 1a/1b, but not 2, so we must look at both. + + let from_uri = route.uri().authority(); + + let name = "host"; + let from_header = route.headers() + .get(name) + .map(|value| + // Header present, parse it + value.to_str().map_err(|_| reject::invalid_header(name)) + .and_then(|value| Authority::from_str(value).map_err(|_| reject::invalid_header(name))) + ); + + future::ready(match (from_uri, from_header) { + // no authority in the request (HTTP/1.0 or non-conforming) + (None, None) => Ok(None), + + // authority specified in either or both matching + (Some(a), None) => Ok(Some(a.clone())), + (None, Some(Ok(a))) => Ok(Some(a)), + (Some(a), Some(Ok(b))) if *a == b => Ok(Some(b)), + + // mismatch + (Some(_), Some(Ok(_))) => Err(reject::invalid_header(name)), + + // parse error + (_, Some(Err(r))) => Err(r), + }) + }) +} diff --git a/third_party/rust/warp/src/filters/log.rs b/third_party/rust/warp/src/filters/log.rs new file mode 100644 index 0000000000..eae5974024 --- /dev/null +++ b/third_party/rust/warp/src/filters/log.rs @@ -0,0 +1,280 @@ +//! Logger Filters + +use std::fmt; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +use http::{self, header, StatusCode}; + +use crate::filter::{Filter, WrapSealed}; +use crate::reject::IsReject; +use crate::reply::Reply; +use crate::route::Route; + +use self::internal::WithLog; + +/// Create a wrapping filter with the specified `name` as the `target`. +/// +/// This uses the default access logging format, and log records produced +/// will have their `target` set to `name`. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// // If using something like `pretty_env_logger`, +/// // view logs by setting `RUST_LOG=example::api`. +/// let log = warp::log("example::api"); +/// let route = warp::any() +/// .map(warp::reply) +/// .with(log); +/// ``` +pub fn log(name: &'static str) -> Log<impl Fn(Info<'_>) + Copy> { + let func = move |info: Info<'_>| { + // TODO? + // - response content length? + log::info!( + target: name, + "{} \"{} {} {:?}\" {} \"{}\" \"{}\" {:?}", + OptFmt(info.route.remote_addr()), + info.method(), + info.path(), + info.route.version(), + info.status().as_u16(), + OptFmt(info.referer()), + OptFmt(info.user_agent()), + info.elapsed(), + ); + }; + Log { func } +} + +/// Create a wrapping filter that receives `warp::log::Info`. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let log = warp::log::custom(|info| { +/// // Use a log macro, or slog, or println, or whatever! +/// eprintln!( +/// "{} {} {}", +/// info.method(), +/// info.path(), +/// info.status(), +/// ); +/// }); +/// let route = warp::any() +/// .map(warp::reply) +/// .with(log); +/// ``` +pub fn custom<F>(func: F) -> Log<F> +where + F: Fn(Info<'_>), +{ + Log { func } +} + +/// Decorates a [`Filter`](crate::Filter) to log requests and responses. +#[derive(Clone, Copy, Debug)] +pub struct Log<F> { + func: F, +} + +/// Information about the request/response that can be used to prepare log lines. +#[allow(missing_debug_implementations)] +pub struct Info<'a> { + route: &'a Route, + start: Instant, + status: StatusCode, +} + +impl<FN, F> WrapSealed<F> for Log<FN> +where + FN: Fn(Info<'_>) + Clone + Send, + F: Filter + Clone + Send, + F::Extract: Reply, + F::Error: IsReject, +{ + type Wrapped = WithLog<FN, F>; + + fn wrap(&self, filter: F) -> Self::Wrapped { + WithLog { + filter, + log: self.clone(), + } + } +} + +impl<'a> Info<'a> { + /// View the remote `SocketAddr` of the request. + pub fn remote_addr(&self) -> Option<SocketAddr> { + self.route.remote_addr() + } + + /// View the `http::Method` of the request. + pub fn method(&self) -> &http::Method { + self.route.method() + } + + /// View the URI path of the request. + pub fn path(&self) -> &str { + self.route.full_path() + } + + /// View the `http::Version` of the request. + pub fn version(&self) -> http::Version { + self.route.version() + } + + /// View the `http::StatusCode` of the response. + pub fn status(&self) -> http::StatusCode { + self.status + } + + /// View the referer of the request. + pub fn referer(&self) -> Option<&str> { + self.route + .headers() + .get(header::REFERER) + .and_then(|v| v.to_str().ok()) + } + + /// View the user agent of the request. + pub fn user_agent(&self) -> Option<&str> { + self.route + .headers() + .get(header::USER_AGENT) + .and_then(|v| v.to_str().ok()) + } + + /// View the `Duration` that elapsed for the request. + pub fn elapsed(&self) -> Duration { + tokio::time::Instant::now().into_std() - self.start + } + + /// View the host of the request + pub fn host(&self) -> Option<&str> { + self.route + .headers() + .get(header::HOST) + .and_then(|v| v.to_str().ok()) + } + + /// Access the full headers of the request + pub fn request_headers(&self) -> &http::HeaderMap { + self.route.headers() + } +} + +struct OptFmt<T>(Option<T>); + +impl<T: fmt::Display> fmt::Display for OptFmt<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(ref t) = self.0 { + fmt::Display::fmt(t, f) + } else { + f.write_str("-") + } + } +} + +mod internal { + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + use std::time::Instant; + + use futures_util::{ready, TryFuture}; + use pin_project::pin_project; + + use super::{Info, Log}; + use crate::filter::{Filter, FilterBase, Internal}; + use crate::reject::IsReject; + use crate::reply::{Reply, Response}; + use crate::route; + + #[allow(missing_debug_implementations)] + pub struct Logged(pub(super) Response); + + impl Reply for Logged { + #[inline] + fn into_response(self) -> Response { + self.0 + } + } + + #[allow(missing_debug_implementations)] + #[derive(Clone, Copy)] + pub struct WithLog<FN, F> { + pub(super) filter: F, + pub(super) log: Log<FN>, + } + + impl<FN, F> FilterBase for WithLog<FN, F> + where + FN: Fn(Info<'_>) + Clone + Send, + F: Filter + Clone + Send, + F::Extract: Reply, + F::Error: IsReject, + { + type Extract = (Logged,); + type Error = F::Error; + type Future = WithLogFuture<FN, F::Future>; + + fn filter(&self, _: Internal) -> Self::Future { + let started = tokio::time::Instant::now().into_std(); + WithLogFuture { + log: self.log.clone(), + future: self.filter.filter(Internal), + started, + } + } + } + + #[allow(missing_debug_implementations)] + #[pin_project] + pub struct WithLogFuture<FN, F> { + log: Log<FN>, + #[pin] + future: F, + started: Instant, + } + + impl<FN, F> Future for WithLogFuture<FN, F> + where + FN: Fn(Info<'_>), + F: TryFuture, + F::Ok: Reply, + F::Error: IsReject, + { + type Output = Result<(Logged,), F::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let pin = self.as_mut().project(); + let (result, status) = match ready!(pin.future.try_poll(cx)) { + Ok(reply) => { + let resp = reply.into_response(); + let status = resp.status(); + (Poll::Ready(Ok((Logged(resp),))), status) + } + Err(reject) => { + let status = reject.status(); + (Poll::Ready(Err(reject)), status) + } + }; + + route::with(|route| { + (self.log.func)(Info { + route, + start: self.started, + status, + }); + }); + + result + } + } +} diff --git a/third_party/rust/warp/src/filters/method.rs b/third_party/rust/warp/src/filters/method.rs new file mode 100644 index 0000000000..c4d7462720 --- /dev/null +++ b/third_party/rust/warp/src/filters/method.rs @@ -0,0 +1,150 @@ +//! HTTP Method filters. +//! +//! The filters deal with the HTTP Method part of a request. Several here will +//! match the request `Method`, and if not matched, will reject the request +//! with a `405 Method Not Allowed`. +//! +//! There is also [`warp::method()`](method), which never rejects +//! a request, and just extracts the method to be used in your filter chains. +use futures_util::future; +use http::Method; + +use crate::filter::{filter_fn, filter_fn_one, Filter, One}; +use crate::reject::Rejection; +use std::convert::Infallible; + +/// Create a `Filter` that requires the request method to be `GET`. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let get_only = warp::get().map(warp::reply); +/// ``` +pub fn get() -> impl Filter<Extract = (), Error = Rejection> + Copy { + method_is(|| &Method::GET) +} + +/// Create a `Filter` that requires the request method to be `POST`. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let post_only = warp::post().map(warp::reply); +/// ``` +pub fn post() -> impl Filter<Extract = (), Error = Rejection> + Copy { + method_is(|| &Method::POST) +} + +/// Create a `Filter` that requires the request method to be `PUT`. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let put_only = warp::put().map(warp::reply); +/// ``` +pub fn put() -> impl Filter<Extract = (), Error = Rejection> + Copy { + method_is(|| &Method::PUT) +} + +/// Create a `Filter` that requires the request method to be `DELETE`. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let delete_only = warp::delete().map(warp::reply); +/// ``` +pub fn delete() -> impl Filter<Extract = (), Error = Rejection> + Copy { + method_is(|| &Method::DELETE) +} + +/// Create a `Filter` that requires the request method to be `HEAD`. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let head_only = warp::head().map(warp::reply); +/// ``` +pub fn head() -> impl Filter<Extract = (), Error = Rejection> + Copy { + method_is(|| &Method::HEAD) +} + +/// Create a `Filter` that requires the request method to be `OPTIONS`. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let options_only = warp::options().map(warp::reply); +/// ``` +pub fn options() -> impl Filter<Extract = (), Error = Rejection> + Copy { + method_is(|| &Method::OPTIONS) +} + +/// Create a `Filter` that requires the request method to be `PATCH`. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let patch_only = warp::patch().map(warp::reply); +/// ``` +pub fn patch() -> impl Filter<Extract = (), Error = Rejection> + Copy { + method_is(|| &Method::PATCH) +} + +/// Extract the `Method` from the request. +/// +/// This never rejects a request. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let route = warp::method() +/// .map(|method| { +/// format!("You sent a {} request!", method) +/// }); +/// ``` +pub fn method() -> impl Filter<Extract = One<Method>, Error = Infallible> + Copy { + filter_fn_one(|route| future::ok::<_, Infallible>(route.method().clone())) +} + +// NOTE: This takes a static function instead of `&'static Method` directly +// so that the `impl Filter` can be zero-sized. Moving it around should be +// cheaper than holding a single static pointer (which would make it 1 word). +fn method_is<F>(func: F) -> impl Filter<Extract = (), Error = Rejection> + Copy +where + F: Fn() -> &'static Method + Copy, +{ + filter_fn(move |route| { + let method = func(); + tracing::trace!("method::{:?}?: {:?}", method, route.method()); + if route.method() == method { + future::ok(()) + } else { + future::err(crate::reject::method_not_allowed()) + } + }) +} + +#[cfg(test)] +mod tests { + #[test] + fn method_size_of() { + // See comment on `method_is` function. + assert_eq!(std::mem::size_of_val(&super::get()), 0,); + } +} diff --git a/third_party/rust/warp/src/filters/mod.rs b/third_party/rust/warp/src/filters/mod.rs new file mode 100644 index 0000000000..bd1c48718c --- /dev/null +++ b/third_party/rust/warp/src/filters/mod.rs @@ -0,0 +1,29 @@ +//! Built-in Filters +//! +//! This module mostly serves as documentation to group together the list of +//! built-in filters. Most of these are available at more convenient paths. + +pub mod addr; +pub mod any; +pub mod body; +#[cfg(any(feature = "compression-brotli", feature = "compression-gzip"))] +pub mod compression; +pub mod cookie; +pub mod cors; +pub mod ext; +pub mod fs; +pub mod header; +pub mod host; +pub mod log; +pub mod method; +#[cfg(feature = "multipart")] +pub mod multipart; +pub mod path; +pub mod query; +pub mod reply; +pub mod sse; +pub mod trace; +#[cfg(feature = "websocket")] +pub mod ws; + +pub use crate::filter::BoxedFilter; diff --git a/third_party/rust/warp/src/filters/multipart.rs b/third_party/rust/warp/src/filters/multipart.rs new file mode 100644 index 0000000000..ef2ec92682 --- /dev/null +++ b/third_party/rust/warp/src/filters/multipart.rs @@ -0,0 +1,190 @@ +//! Multipart body filters +//! +//! Filters that extract a multipart body for a route. + +use std::fmt; +use std::future::Future; +use std::io::{Cursor, Read}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::{Buf, Bytes}; +use futures_util::{future, Stream}; +use headers::ContentType; +use mime::Mime; +use multipart::server::Multipart; + +use crate::filter::{Filter, FilterBase, Internal}; +use crate::reject::{self, Rejection}; + +// If not otherwise configured, default to 2MB. +const DEFAULT_FORM_DATA_MAX_LENGTH: u64 = 1024 * 1024 * 2; + +/// A `Filter` to extract a `multipart/form-data` body from a request. +/// +/// Create with the `warp::multipart::form()` function. +#[derive(Debug, Clone)] +pub struct FormOptions { + max_length: u64, +} + +/// A `Stream` of multipart/form-data `Part`s. +/// +/// Extracted with a `warp::multipart::form` filter. +pub struct FormData { + inner: Multipart<Cursor<::bytes::Bytes>>, +} + +/// A single "part" of a multipart/form-data body. +/// +/// Yielded from the `FormData` stream. +pub struct Part { + name: String, + filename: Option<String>, + content_type: Option<String>, + data: Option<Vec<u8>>, +} + +/// Create a `Filter` to extract a `multipart/form-data` body from a request. +/// +/// The extracted `FormData` type is a `Stream` of `Part`s, and each `Part` +/// in turn is a `Stream` of bytes. +pub fn form() -> FormOptions { + FormOptions { + max_length: DEFAULT_FORM_DATA_MAX_LENGTH, + } +} + +// ===== impl Form ===== + +impl FormOptions { + /// Set the maximum byte length allowed for this body. + /// + /// Defaults to 2MB. + pub fn max_length(mut self, max: u64) -> Self { + self.max_length = max; + self + } +} + +type FormFut = Pin<Box<dyn Future<Output = Result<(FormData,), Rejection>> + Send>>; + +impl FilterBase for FormOptions { + type Extract = (FormData,); + type Error = Rejection; + type Future = FormFut; + + fn filter(&self, _: Internal) -> Self::Future { + let boundary = super::header::header2::<ContentType>().and_then(|ct| { + let mime = Mime::from(ct); + let mime = mime + .get_param("boundary") + .map(|v| v.to_string()) + .ok_or_else(|| reject::invalid_header("content-type")); + future::ready(mime) + }); + + let filt = super::body::content_length_limit(self.max_length) + .and(boundary) + .and(super::body::bytes()) + .map(|boundary, body| FormData { + inner: Multipart::with_body(Cursor::new(body), boundary), + }); + + let fut = filt.filter(Internal); + + Box::pin(fut) + } +} + +// ===== impl FormData ===== + +impl fmt::Debug for FormData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FormData").finish() + } +} + +impl Stream for FormData { + type Item = Result<Part, crate::Error>; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + match (*self).inner.read_entry() { + Ok(Some(mut field)) => { + let mut data = Vec::new(); + field + .data + .read_to_end(&mut data) + .map_err(crate::Error::new)?; + Poll::Ready(Some(Ok(Part { + name: field.headers.name.to_string(), + filename: field.headers.filename, + content_type: field.headers.content_type.map(|m| m.to_string()), + data: Some(data), + }))) + } + Ok(None) => Poll::Ready(None), + Err(e) => Poll::Ready(Some(Err(crate::Error::new(e)))), + } + } +} + +// ===== impl Part ===== + +impl Part { + /// Get the name of this part. + pub fn name(&self) -> &str { + &self.name + } + + /// Get the filename of this part, if present. + pub fn filename(&self) -> Option<&str> { + self.filename.as_deref() + } + + /// Get the content-type of this part, if present. + pub fn content_type(&self) -> Option<&str> { + self.content_type.as_deref() + } + + /// Asynchronously get some of the data for this `Part`. + pub async fn data(&mut self) -> Option<Result<impl Buf, crate::Error>> { + self.take_data() + } + + /// Convert this `Part` into a `Stream` of `Buf`s. + pub fn stream(self) -> impl Stream<Item = Result<impl Buf, crate::Error>> { + PartStream(self) + } + + fn take_data(&mut self) -> Option<Result<Bytes, crate::Error>> { + self.data.take().map(|vec| Ok(vec.into())) + } +} + +impl fmt::Debug for Part { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("Part"); + builder.field("name", &self.name); + + if let Some(ref filename) = self.filename { + builder.field("filename", filename); + } + + if let Some(ref mime) = self.content_type { + builder.field("content_type", mime); + } + + builder.finish() + } +} + +struct PartStream(Part); + +impl Stream for PartStream { + type Item = Result<Bytes, crate::Error>; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + Poll::Ready(self.0.take_data()) + } +} diff --git a/third_party/rust/warp/src/filters/path.rs b/third_party/rust/warp/src/filters/path.rs new file mode 100644 index 0000000000..179a8d1c91 --- /dev/null +++ b/third_party/rust/warp/src/filters/path.rs @@ -0,0 +1,652 @@ +//! Path Filters +//! +//! The filters here work on the "path" of requests. +//! +//! - [`path`](./fn.path.html) matches a specific segment, like `/foo`. +//! - [`param`](./fn.param.html) tries to parse a segment into a type, like `/:u16`. +//! - [`end`](./fn.end.html) matches when the path end is found. +//! - [`path!`](../../macro.path.html) eases combining multiple `path` and `param` filters. +//! +//! # Routing +//! +//! Routing in warp is simple yet powerful. +//! +//! First up, matching a single segment: +//! +//! ``` +//! use warp::Filter; +//! +//! // GET /hi +//! let hi = warp::path("hi").map(|| { +//! "Hello, World!" +//! }); +//! ``` +//! +//! How about multiple segments? It's easiest with the `path!` macro: +//! +//! ``` +//! # use warp::Filter; +//! // GET /hello/from/warp +//! let hello_from_warp = warp::path!("hello" / "from" / "warp").map(|| { +//! "Hello from warp!" +//! }); +//! ``` +//! +//! Neat! But how do I handle **parameters** in paths? +//! +//! ``` +//! # use warp::Filter; +//! // GET /sum/:u32/:u32 +//! let sum = warp::path!("sum" / u32 / u32).map(|a, b| { +//! format!("{} + {} = {}", a, b, a + b) +//! }); +//! ``` +//! +//! In fact, any type that implements `FromStr` can be used, in any order: +//! +//! ``` +//! # use warp::Filter; +//! // GET /:u16/times/:u16 +//! let times = warp::path!(u16 / "times" / u16).map(|a, b| { +//! format!("{} times {} = {}", a, b, a * b) +//! }); +//! ``` +//! +//! Oh shoot, those math routes should be **mounted** at a different path, +//! is that possible? Yep! +//! +//! ``` +//! # use warp::Filter; +//! # let sum = warp::any().map(warp::reply); +//! # let times = sum.clone(); +//! // GET /math/sum/:u32/:u32 +//! // GET /math/:u16/times/:u16 +//! let math = warp::path("math"); +//! let math_sum = math.and(sum); +//! let math_times = math.and(times); +//! ``` +//! +//! What! `and`? What's that do? +//! +//! It combines the filters in a sort of "this and then that" order. In fact, +//! it's exactly what the `path!` macro has been doing internally. +//! +//! ``` +//! # use warp::Filter; +//! // GET /bye/:string +//! let bye = warp::path("bye") +//! .and(warp::path::param()) +//! .map(|name: String| { +//! format!("Good bye, {}!", name) +//! }); +//! ``` +//! +//! Ah, so, can filters do things besides `and`? +//! +//! Why, yes they can! They can also `or`! As you might expect, `or` creates a +//! "this or else that" chain of filters. If the first doesn't succeed, then +//! it tries the other. +//! +//! So, those `math` routes could have been **mounted** all as one, with `or`. +//! +//! +//! ``` +//! # use warp::Filter; +//! # let sum = warp::path("sum"); +//! # let times = warp::path("times"); +//! // GET /math/sum/:u32/:u32 +//! // GET /math/:u16/times/:u16 +//! let math = warp::path("math") +//! .and(sum.or(times)); +//! ``` +//! +//! It turns out, using `or` is how you combine everything together into a +//! single API. +//! +//! ``` +//! # use warp::Filter; +//! # let hi = warp::path("hi"); +//! # let hello_from_warp = hi.clone(); +//! # let bye = hi.clone(); +//! # let math = hi.clone(); +//! // GET /hi +//! // GET /hello/from/warp +//! // GET /bye/:string +//! // GET /math/sum/:u32/:u32 +//! // GET /math/:u16/times/:u16 +//! let routes = hi +//! .or(hello_from_warp) +//! .or(bye) +//! .or(math); +//! ``` +//! +//! Note that you will generally want path filters to come **before** other filters +//! like `body` or `headers`. If a different type of filter comes first, a request +//! with an invalid body for route `/right-path-wrong-body` may try matching against `/wrong-path` +//! and return the error from `/wrong-path` instead of the correct body-related error. + +use std::convert::Infallible; +use std::fmt; +use std::str::FromStr; + +use futures_util::future; +use http::uri::PathAndQuery; + +use self::internal::Opaque; +use crate::filter::{filter_fn, one, Filter, FilterBase, Internal, One, Tuple}; +use crate::reject::{self, Rejection}; +use crate::route::{self, Route}; + +/// Create an exact match path segment `Filter`. +/// +/// This will try to match exactly to the current request path segment. +/// +/// # Note +/// +/// - [`end()`](./fn.end.html) should be used to match the end of a path to avoid having +/// filters for shorter paths like `/math` unintentionally match a longer +/// path such as `/math/sum` +/// - Path-related filters should generally come **before** other types of filters, such +/// as those checking headers or body types. Including those other filters before +/// the path checks may result in strange errors being returned because a given request +/// does not match the parameters for a completely separate route. +/// +/// # Panics +/// +/// Exact path filters cannot be empty, or contain slashes. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// // Matches '/hello' +/// let hello = warp::path("hello") +/// .map(|| "Hello, World!"); +/// ``` +pub fn path<P>(p: P) -> Exact<Opaque<P>> +where + P: AsRef<str>, +{ + let s = p.as_ref(); + assert!(!s.is_empty(), "exact path segments should not be empty"); + assert!( + !s.contains('/'), + "exact path segments should not contain a slash: {:?}", + s + ); + + Exact(Opaque(p)) + /* + segment(move |seg| { + tracing::trace!("{:?}?: {:?}", p, seg); + if seg == p { + Ok(()) + } else { + Err(reject::not_found()) + } + }) + */ +} + +/// A `Filter` matching an exact path segment. +/// +/// Constructed from `path()` or `path!()`. +#[allow(missing_debug_implementations)] +#[derive(Clone, Copy)] +pub struct Exact<P>(P); + +impl<P> FilterBase for Exact<P> +where + P: AsRef<str>, +{ + type Extract = (); + type Error = Rejection; + type Future = future::Ready<Result<Self::Extract, Self::Error>>; + + #[inline] + fn filter(&self, _: Internal) -> Self::Future { + route::with(|route| { + let p = self.0.as_ref(); + future::ready(with_segment(route, |seg| { + tracing::trace!("{:?}?: {:?}", p, seg); + + if seg == p { + Ok(()) + } else { + Err(reject::not_found()) + } + })) + }) + } +} + +/// Matches the end of a route. +/// +/// Note that _not_ including `end()` may result in shorter paths like +/// `/math` unintentionally matching `/math/sum`. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// // Matches '/' +/// let hello = warp::path::end() +/// .map(|| "Hello, World!"); +/// ``` +pub fn end() -> impl Filter<Extract = (), Error = Rejection> + Copy { + filter_fn(move |route| { + if route.path().is_empty() { + future::ok(()) + } else { + future::err(reject::not_found()) + } + }) +} + +/// Extract a parameter from a path segment. +/// +/// This will try to parse a value from the current request path +/// segment, and if successful, the value is returned as the `Filter`'s +/// "extracted" value. +/// +/// If the value could not be parsed, rejects with a `404 Not Found`. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let route = warp::path::param() +/// .map(|id: u32| { +/// format!("You asked for /{}", id) +/// }); +/// ``` +pub fn param<T: FromStr + Send + 'static>( +) -> impl Filter<Extract = One<T>, Error = Rejection> + Copy { + filter_segment(|seg| { + tracing::trace!("param?: {:?}", seg); + if seg.is_empty() { + return Err(reject::not_found()); + } + T::from_str(seg).map(one).map_err(|_| reject::not_found()) + }) +} + +/// Extract the unmatched tail of the path. +/// +/// This will return a `Tail`, which allows access to the rest of the path +/// that previous filters have not already matched. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let route = warp::path("foo") +/// .and(warp::path::tail()) +/// .map(|tail| { +/// // GET /foo/bar/baz would return "bar/baz". +/// format!("The tail after foo is {:?}", tail) +/// }); +/// ``` +pub fn tail() -> impl Filter<Extract = One<Tail>, Error = Infallible> + Copy { + filter_fn(move |route| { + let path = path_and_query(route); + let idx = route.matched_path_index(); + + // Giving the user the full tail means we assume the full path + // has been matched now. + let end = path.path().len() - idx; + route.set_unmatched_path(end); + + future::ok(one(Tail { + path, + start_index: idx, + })) + }) +} + +/// Represents the tail part of a request path, returned by the [`tail()`] filter. +pub struct Tail { + path: PathAndQuery, + start_index: usize, +} + +impl Tail { + /// Get the `&str` representation of the remaining path. + pub fn as_str(&self) -> &str { + &self.path.path()[self.start_index..] + } +} + +impl fmt::Debug for Tail { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self.as_str(), f) + } +} + +/// Peek at the unmatched tail of the path, without affecting the matched path. +/// +/// This will return a `Peek`, which allows access to the rest of the path +/// that previous filters have not already matched. This differs from `tail` +/// in that `peek` will **not** set the entire path as matched. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let route = warp::path("foo") +/// .and(warp::path::peek()) +/// .map(|peek| { +/// // GET /foo/bar/baz would return "bar/baz". +/// format!("The path after foo is {:?}", peek) +/// }); +/// ``` +pub fn peek() -> impl Filter<Extract = One<Peek>, Error = Infallible> + Copy { + filter_fn(move |route| { + let path = path_and_query(route); + let idx = route.matched_path_index(); + + future::ok(one(Peek { + path, + start_index: idx, + })) + }) +} + +/// Represents the tail part of a request path, returned by the [`peek()`] filter. +pub struct Peek { + path: PathAndQuery, + start_index: usize, +} + +impl Peek { + /// Get the `&str` representation of the remaining path. + pub fn as_str(&self) -> &str { + &self.path.path()[self.start_index..] + } + + /// Get an iterator over the segments of the peeked path. + pub fn segments(&self) -> impl Iterator<Item = &str> { + self.as_str().split('/').filter(|seg| !seg.is_empty()) + } +} + +impl fmt::Debug for Peek { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self.as_str(), f) + } +} + +/// Returns the full request path, irrespective of other filters. +/// +/// This will return a `FullPath`, which can be stringified to return the +/// full path of the request. +/// +/// This is more useful in generic pre/post-processing filters, and should +/// probably not be used for request matching/routing. +/// +/// # Example +/// +/// ``` +/// use warp::{Filter, path::FullPath}; +/// use std::{collections::HashMap, sync::{Arc, Mutex}}; +/// +/// let counts = Arc::new(Mutex::new(HashMap::new())); +/// let access_counter = warp::path::full() +/// .map(move |path: FullPath| { +/// let mut counts = counts.lock().unwrap(); +/// +/// *counts.entry(path.as_str().to_string()) +/// .and_modify(|c| *c += 1) +/// .or_insert(0) +/// }); +/// +/// let route = warp::path("foo") +/// .and(warp::path("bar")) +/// .and(access_counter) +/// .map(|count| { +/// format!("This is the {}th visit to this URL!", count) +/// }); +/// ``` +pub fn full() -> impl Filter<Extract = One<FullPath>, Error = Infallible> + Copy { + filter_fn(move |route| future::ok(one(FullPath(path_and_query(route))))) +} + +/// Represents the full request path, returned by the [`full()`] filter. +pub struct FullPath(PathAndQuery); + +impl FullPath { + /// Get the `&str` representation of the request path. + pub fn as_str(&self) -> &str { + self.0.path() + } +} + +impl fmt::Debug for FullPath { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self.as_str(), f) + } +} + +fn filter_segment<F, U>(func: F) -> impl Filter<Extract = U, Error = Rejection> + Copy +where + F: Fn(&str) -> Result<U, Rejection> + Copy, + U: Tuple + Send + 'static, +{ + filter_fn(move |route| future::ready(with_segment(route, func))) +} + +fn with_segment<F, U>(route: &mut Route, func: F) -> Result<U, Rejection> +where + F: Fn(&str) -> Result<U, Rejection>, +{ + let seg = segment(route); + let ret = func(seg); + if ret.is_ok() { + let idx = seg.len(); + route.set_unmatched_path(idx); + } + ret +} + +fn segment(route: &Route) -> &str { + route + .path() + .splitn(2, '/') + .next() + .expect("split always has at least 1") +} + +fn path_and_query(route: &Route) -> PathAndQuery { + route + .uri() + .path_and_query() + .cloned() + .unwrap_or_else(|| PathAndQuery::from_static("")) +} + +/// Convenient way to chain multiple path filters together. +/// +/// Any number of either type identifiers or string expressions can be passed, +/// each separated by a forward slash (`/`). Strings will be used to match +/// path segments exactly, and type identifiers are used just like +/// [`param`](crate::path::param) filters. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// // Match `/sum/:a/:b` +/// let route = warp::path!("sum" / u32 / u32) +/// .map(|a, b| { +/// format!("{} + {} = {}", a, b, a + b) +/// }); +/// ``` +/// +/// The equivalent filter chain without using the `path!` macro looks this: +/// +/// ``` +/// use warp::Filter; +/// +/// let route = warp::path("sum") +/// .and(warp::path::param::<u32>()) +/// .and(warp::path::param::<u32>()) +/// .and(warp::path::end()) +/// .map(|a, b| { +/// format!("{} + {} = {}", a, b, a + b) +/// }); +/// ``` +/// +/// # Path Prefixes +/// +/// The `path!` macro automatically assumes the path should include an `end()` +/// filter. To build up a path filter *prefix*, such that the `end()` isn't +/// included, use the `/ ..` syntax. +/// +/// +/// ``` +/// use warp::Filter; +/// +/// let prefix = warp::path!("math" / "sum" / ..); +/// +/// let sum = warp::path!(u32 / u32) +/// .map(|a, b| { +/// format!("{} + {} = {}", a, b, a + b) +/// }); +/// +/// let help = warp::path::end() +/// .map(|| "This API returns the sum of two u32's"); +/// +/// let api = prefix.and(sum.or(help)); +/// ``` +#[macro_export] +macro_rules! path { + ($($pieces:tt)*) => ({ + $crate::__internal_path!(@start $($pieces)*) + }); +} + +#[doc(hidden)] +#[macro_export] +// not public API +macro_rules! __internal_path { + (@start) => ( + $crate::path::end() + ); + (@start ..) => ({ + compile_error!("'..' cannot be the only segment") + }); + (@start $first:tt $(/ $tail:tt)*) => ({ + $crate::__internal_path!(@munch $crate::any(); [$first] [$(/ $tail)*]) + }); + + (@munch $sum:expr; [$cur:tt] [/ $next:tt $(/ $tail:tt)*]) => ({ + $crate::__internal_path!(@munch $crate::Filter::and($sum, $crate::__internal_path!(@segment $cur)); [$next] [$(/ $tail)*]) + }); + (@munch $sum:expr; [$cur:tt] []) => ({ + $crate::__internal_path!(@last $sum; $cur) + }); + + (@last $sum:expr; ..) => ( + $sum + ); + (@last $sum:expr; $end:tt) => ( + $crate::Filter::and( + $crate::Filter::and($sum, $crate::__internal_path!(@segment $end)), + $crate::path::end() + ) + ); + + (@segment ..) => ( + compile_error!("'..' must be the last segment") + ); + (@segment $param:ty) => ( + $crate::path::param::<$param>() + ); + // Constructs a unique ZST so the &'static str pointer doesn't need to + // be carried around. + (@segment $s:literal) => ({ + #[derive(Clone, Copy)] + struct __StaticPath; + impl ::std::convert::AsRef<str> for __StaticPath { + fn as_ref(&self) -> &str { + static S: &str = $s; + S + } + } + $crate::path(__StaticPath) + }); +} + +// path! compile fail tests + +/// ```compile_fail +/// warp::path!("foo" / .. / "bar"); +/// ``` +/// +/// ```compile_fail +/// warp::path!(.. / "bar"); +/// ``` +/// +/// ```compile_fail +/// warp::path!("foo" ..); +/// ``` +/// +/// ```compile_fail +/// warp::path!("foo" / .. /); +/// ``` +/// +/// ```compile_fail +/// warp::path!(..); +/// ``` +fn _path_macro_compile_fail() {} + +mod internal { + // Used to prevent users from naming this type. + // + // For instance, `Exact<Opaque<String>>` means a user cannot depend + // on it being `Exact<String>`. + #[allow(missing_debug_implementations)] + #[derive(Clone, Copy)] + pub struct Opaque<T>(pub(super) T); + + impl<T: AsRef<str>> AsRef<str> for Opaque<T> { + #[inline] + fn as_ref(&self) -> &str { + self.0.as_ref() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_path_exact_size() { + use std::mem::{size_of, size_of_val}; + + assert_eq!( + size_of_val(&path("hello")), + size_of::<&str>(), + "exact(&str) is size of &str" + ); + + assert_eq!( + size_of_val(&path(String::from("world"))), + size_of::<String>(), + "exact(String) is size of String" + ); + + assert_eq!( + size_of_val(&path!("zst")), + size_of::<()>(), + "path!(&str) is ZST" + ); + } +} diff --git a/third_party/rust/warp/src/filters/query.rs b/third_party/rust/warp/src/filters/query.rs new file mode 100644 index 0000000000..7aee569484 --- /dev/null +++ b/third_party/rust/warp/src/filters/query.rs @@ -0,0 +1,91 @@ +//! Query Filters + +use futures_util::future; +use serde::de::DeserializeOwned; +use serde_urlencoded; + +use crate::filter::{filter_fn_one, Filter, One}; +use crate::reject::{self, Rejection}; + +/// Creates a `Filter` that decodes query parameters to the type `T`. +/// +/// If cannot decode into a `T`, the request is rejected with a `400 Bad Request`. +/// +/// # Example +/// +/// ``` +/// use std::collections::HashMap; +/// use warp::{ +/// http::Response, +/// Filter, +/// }; +/// +/// let route = warp::any() +/// .and(warp::query::<HashMap<String, String>>()) +/// .map(|map: HashMap<String, String>| { +/// let mut response: Vec<String> = Vec::new(); +/// for (key, value) in map.into_iter() { +/// response.push(format!("{}={}", key, value)) +/// } +/// Response::builder().body(response.join(";")) +/// }); +/// ``` +/// +/// You can define your custom query object and deserialize with [Serde][Serde]. Ensure to include +/// the crate in your dependencies before usage. +/// +/// ``` +/// use serde_derive::{Deserialize, Serialize}; +/// use std::collections::HashMap; +/// use warp::{ +/// http::Response, +/// Filter, +/// }; +/// +/// #[derive(Serialize, Deserialize)] +/// struct FooQuery { +/// foo: Option<String>, +/// bar: u8, +/// } +/// +/// let route = warp::any() +/// .and(warp::query::<FooQuery>()) +/// .map(|q: FooQuery| { +/// if let Some(foo) = q.foo { +/// Response::builder().body(format!("foo={}", foo)) +/// } else { +/// Response::builder().body(format!("bar={}", q.bar)) +/// } +/// }); +/// ``` +/// +/// For more examples, please take a look at [examples/query_string.rs](https://github.com/seanmonstar/warp/blob/master/examples/query_string.rs). +/// +/// [Serde]: https://docs.rs/serde +pub fn query<T: DeserializeOwned + Send + 'static>( +) -> impl Filter<Extract = One<T>, Error = Rejection> + Copy { + filter_fn_one(|route| { + let query_string = route.query().unwrap_or_else(|| { + tracing::debug!("route was called without a query string, defaulting to empty"); + "" + }); + + let query_encoded = serde_urlencoded::from_str(query_string).map_err(|e| { + tracing::debug!("failed to decode query string '{}': {:?}", query_string, e); + reject::invalid_query() + }); + future::ready(query_encoded) + }) +} + +/// Creates a `Filter` that returns the raw query string as type String. +pub fn raw() -> impl Filter<Extract = One<String>, Error = Rejection> + Copy { + filter_fn_one(|route| { + let route = route + .query() + .map(|q| q.to_owned()) + .map(Ok) + .unwrap_or_else(|| Err(reject::invalid_query())); + future::ready(route) + }) +} diff --git a/third_party/rust/warp/src/filters/reply.rs b/third_party/rust/warp/src/filters/reply.rs new file mode 100644 index 0000000000..c4c37bbd80 --- /dev/null +++ b/third_party/rust/warp/src/filters/reply.rs @@ -0,0 +1,257 @@ +//! Reply Filters +//! +//! These "filters" behave a little differently than the rest. Instead of +//! being used directly on requests, these filters "wrap" other filters. +//! +//! +//! ## Wrapping a `Filter` (`with`) +//! +//! ``` +//! use warp::Filter; +//! +//! let with_server = warp::reply::with::header("server", "warp"); +//! +//! let route = warp::any() +//! .map(warp::reply) +//! .with(with_server); +//! ``` +//! +//! Wrapping allows adding in conditional logic *before* the request enters +//! the inner filter (though the `with::header` wrapper does not). + +use std::convert::TryFrom; +use std::sync::Arc; + +use http::header::{HeaderMap, HeaderName, HeaderValue}; + +use self::sealed::{WithDefaultHeader_, WithHeader_, WithHeaders_}; +use crate::filter::{Filter, Map, WrapSealed}; +use crate::reply::Reply; + +/// Wrap a [`Filter`](crate::Filter) that adds a header to the reply. +/// +/// # Note +/// +/// This **only** adds a header if the underlying filter is successful, and +/// returns a [`Reply`](Reply). If the underlying filter was rejected, the +/// header is not added. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// // Always set `foo: bar` header. +/// let route = warp::any() +/// .map(warp::reply) +/// .with(warp::reply::with::header("foo", "bar")); +/// ``` +pub fn header<K, V>(name: K, value: V) -> WithHeader +where + HeaderName: TryFrom<K>, + <HeaderName as TryFrom<K>>::Error: Into<http::Error>, + HeaderValue: TryFrom<V>, + <HeaderValue as TryFrom<V>>::Error: Into<http::Error>, +{ + let (name, value) = assert_name_and_value(name, value); + WithHeader { name, value } +} + +/// Wrap a [`Filter`](crate::Filter) that adds multiple headers to the reply. +/// +/// # Note +/// +/// This **only** adds a header if the underlying filter is successful, and +/// returns a [`Reply`](Reply). If the underlying filter was rejected, the +/// header is not added. +/// +/// # Example +/// +/// ``` +/// use warp::http::header::{HeaderMap, HeaderValue}; +/// use warp::Filter; +/// +/// let mut headers = HeaderMap::new(); +/// headers.insert("server", HeaderValue::from_static("wee/0")); +/// headers.insert("foo", HeaderValue::from_static("bar")); +/// +/// // Always set `server: wee/0` and `foo: bar` headers. +/// let route = warp::any() +/// .map(warp::reply) +/// .with(warp::reply::with::headers(headers)); +/// ``` +pub fn headers(headers: HeaderMap) -> WithHeaders { + WithHeaders { + headers: Arc::new(headers), + } +} + +// pub fn headers? + +/// Wrap a [`Filter`](crate::Filter) that adds a header to the reply, if they +/// aren't already set. +/// +/// # Note +/// +/// This **only** adds a header if the underlying filter is successful, and +/// returns a [`Reply`](Reply). If the underlying filter was rejected, the +/// header is not added. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// // Set `server: warp` if not already set. +/// let route = warp::any() +/// .map(warp::reply) +/// .with(warp::reply::with::default_header("server", "warp")); +/// ``` +pub fn default_header<K, V>(name: K, value: V) -> WithDefaultHeader +where + HeaderName: TryFrom<K>, + <HeaderName as TryFrom<K>>::Error: Into<http::Error>, + HeaderValue: TryFrom<V>, + <HeaderValue as TryFrom<V>>::Error: Into<http::Error>, +{ + let (name, value) = assert_name_and_value(name, value); + WithDefaultHeader { name, value } +} + +/// Wrap a `Filter` to always set a header. +#[derive(Clone, Debug)] +pub struct WithHeader { + name: HeaderName, + value: HeaderValue, +} + +impl<F, R> WrapSealed<F> for WithHeader +where + F: Filter<Extract = (R,)>, + R: Reply, +{ + type Wrapped = Map<F, WithHeader_>; + + fn wrap(&self, filter: F) -> Self::Wrapped { + let with = WithHeader_ { with: self.clone() }; + filter.map(with) + } +} + +/// Wrap a `Filter` to always set multiple headers. +#[derive(Clone, Debug)] +pub struct WithHeaders { + headers: Arc<HeaderMap>, +} + +impl<F, R> WrapSealed<F> for WithHeaders +where + F: Filter<Extract = (R,)>, + R: Reply, +{ + type Wrapped = Map<F, WithHeaders_>; + + fn wrap(&self, filter: F) -> Self::Wrapped { + let with = WithHeaders_ { with: self.clone() }; + filter.map(with) + } +} + +/// Wrap a `Filter` to set a header if it is not already set. +#[derive(Clone, Debug)] +pub struct WithDefaultHeader { + name: HeaderName, + value: HeaderValue, +} + +impl<F, R> WrapSealed<F> for WithDefaultHeader +where + F: Filter<Extract = (R,)>, + R: Reply, +{ + type Wrapped = Map<F, WithDefaultHeader_>; + + fn wrap(&self, filter: F) -> Self::Wrapped { + let with = WithDefaultHeader_ { with: self.clone() }; + filter.map(with) + } +} + +fn assert_name_and_value<K, V>(name: K, value: V) -> (HeaderName, HeaderValue) +where + HeaderName: TryFrom<K>, + <HeaderName as TryFrom<K>>::Error: Into<http::Error>, + HeaderValue: TryFrom<V>, + <HeaderValue as TryFrom<V>>::Error: Into<http::Error>, +{ + let name = <HeaderName as TryFrom<K>>::try_from(name) + .map_err(Into::into) + .unwrap_or_else(|_| panic!("invalid header name")); + + let value = <HeaderValue as TryFrom<V>>::try_from(value) + .map_err(Into::into) + .unwrap_or_else(|_| panic!("invalid header value")); + + (name, value) +} + +mod sealed { + use super::{WithDefaultHeader, WithHeader, WithHeaders}; + use crate::generic::{Func, One}; + use crate::reply::{Reply, Reply_}; + + #[derive(Clone)] + #[allow(missing_debug_implementations)] + pub struct WithHeader_ { + pub(super) with: WithHeader, + } + + impl<R: Reply> Func<One<R>> for WithHeader_ { + type Output = Reply_; + + fn call(&self, args: One<R>) -> Self::Output { + let mut resp = args.0.into_response(); + // Use "insert" to replace any set header... + resp.headers_mut() + .insert(&self.with.name, self.with.value.clone()); + Reply_(resp) + } + } + + #[derive(Clone)] + #[allow(missing_debug_implementations)] + pub struct WithHeaders_ { + pub(super) with: WithHeaders, + } + + impl<R: Reply> Func<One<R>> for WithHeaders_ { + type Output = Reply_; + + fn call(&self, args: One<R>) -> Self::Output { + let mut resp = args.0.into_response(); + for (name, value) in &*self.with.headers { + resp.headers_mut().insert(name, value.clone()); + } + Reply_(resp) + } + } + + #[derive(Clone)] + #[allow(missing_debug_implementations)] + pub struct WithDefaultHeader_ { + pub(super) with: WithDefaultHeader, + } + + impl<R: Reply> Func<One<R>> for WithDefaultHeader_ { + type Output = Reply_; + + fn call(&self, args: One<R>) -> Self::Output { + let mut resp = args.0.into_response(); + resp.headers_mut() + .entry(&self.with.name) + .or_insert_with(|| self.with.value.clone()); + + Reply_(resp) + } + } +} diff --git a/third_party/rust/warp/src/filters/sse.rs b/third_party/rust/warp/src/filters/sse.rs new file mode 100644 index 0000000000..2fb9b7ef11 --- /dev/null +++ b/third_party/rust/warp/src/filters/sse.rs @@ -0,0 +1,511 @@ +//! Server-Sent Events (SSE) +//! +//! # Example +//! +//! ``` +//! +//! use std::time::Duration; +//! use std::convert::Infallible; +//! use warp::{Filter, sse::Event}; +//! use futures_util::{stream::iter, Stream}; +//! +//! fn sse_events() -> impl Stream<Item = Result<Event, Infallible>> { +//! iter(vec![ +//! Ok(Event::default().data("unnamed event")), +//! Ok( +//! Event::default().event("chat") +//! .data("chat message") +//! ), +//! Ok( +//! Event::default().id(13.to_string()) +//! .event("chat") +//! .data("other chat message\nwith next line") +//! .retry(Duration::from_millis(5000)) +//! ) +//! ]) +//! } +//! +//! let app = warp::path("push-notifications") +//! .and(warp::get()) +//! .map(|| { +//! warp::sse::reply(warp::sse::keep_alive().stream(sse_events())) +//! }); +//! ``` +//! +//! Each field already is event which can be sent to client. +//! The events with multiple fields can be created by combining fields using tuples. +//! +//! See also the [EventSource](https://developer.mozilla.org/en-US/docs/Web/API/EventSource) API, +//! which specifies the expected behavior of Server Sent Events. +//! + +use serde::Serialize; +use std::borrow::Cow; +use std::error::Error as StdError; +use std::fmt::{self, Write}; +use std::future::Future; +use std::pin::Pin; +use std::str::FromStr; +use std::task::{Context, Poll}; +use std::time::Duration; + +use futures_util::{future, Stream, TryStream, TryStreamExt}; +use http::header::{HeaderValue, CACHE_CONTROL, CONTENT_TYPE}; +use hyper::Body; +use pin_project::pin_project; +use serde_json::{self, Error}; +use tokio::time::{self, Sleep}; + +use self::sealed::SseError; +use super::header; +use crate::filter::One; +use crate::reply::Response; +use crate::{Filter, Rejection, Reply}; + +// Server-sent event data type +#[derive(Debug)] +enum DataType { + Text(String), + Json(String), +} + +/// Server-sent event +#[derive(Default, Debug)] +pub struct Event { + id: Option<String>, + data: Option<DataType>, + event: Option<String>, + comment: Option<String>, + retry: Option<Duration>, +} + +impl Event { + /// Set Server-sent event data + /// data field(s) ("data:<content>") + pub fn data<T: Into<String>>(mut self, data: T) -> Event { + self.data = Some(DataType::Text(data.into())); + self + } + + /// Set Server-sent event data + /// data field(s) ("data:<content>") + pub fn json_data<T: Serialize>(mut self, data: T) -> Result<Event, Error> { + self.data = Some(DataType::Json(serde_json::to_string(&data)?)); + Ok(self) + } + + /// Set Server-sent event comment + /// Comment field (":<comment-text>") + pub fn comment<T: Into<String>>(mut self, comment: T) -> Event { + self.comment = Some(comment.into()); + self + } + + /// Set Server-sent event event + /// Event name field ("event:<event-name>") + pub fn event<T: Into<String>>(mut self, event: T) -> Event { + self.event = Some(event.into()); + self + } + + /// Set Server-sent event retry + /// Retry timeout field ("retry:<timeout>") + pub fn retry(mut self, duration: Duration) -> Event { + self.retry = Some(duration); + self + } + + /// Set Server-sent event id + /// Identifier field ("id:<identifier>") + pub fn id<T: Into<String>>(mut self, id: T) -> Event { + self.id = Some(id.into()); + self + } +} + +impl fmt::Display for Event { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(ref comment) = &self.comment { + ":".fmt(f)?; + comment.fmt(f)?; + f.write_char('\n')?; + } + + if let Some(ref event) = &self.event { + "event:".fmt(f)?; + event.fmt(f)?; + f.write_char('\n')?; + } + + match self.data { + Some(DataType::Text(ref data)) => { + for line in data.split('\n') { + "data:".fmt(f)?; + line.fmt(f)?; + f.write_char('\n')?; + } + } + Some(DataType::Json(ref data)) => { + "data:".fmt(f)?; + data.fmt(f)?; + f.write_char('\n')?; + } + None => {} + } + + if let Some(ref id) = &self.id { + "id:".fmt(f)?; + id.fmt(f)?; + f.write_char('\n')?; + } + + if let Some(ref duration) = &self.retry { + "retry:".fmt(f)?; + + let secs = duration.as_secs(); + let millis = duration.subsec_millis(); + + if secs > 0 { + // format seconds + secs.fmt(f)?; + + // pad milliseconds + if millis < 10 { + f.write_str("00")?; + } else if millis < 100 { + f.write_char('0')?; + } + } + + // format milliseconds + millis.fmt(f)?; + + f.write_char('\n')?; + } + + f.write_char('\n')?; + Ok(()) + } +} + +/// Gets the optional last event id from request. +/// Typically this identifier represented as number or string. +/// +/// ``` +/// let app = warp::sse::last_event_id::<u32>(); +/// +/// // The identifier is present +/// async { +/// assert_eq!( +/// warp::test::request() +/// .header("Last-Event-ID", "12") +/// .filter(&app) +/// .await +/// .unwrap(), +/// Some(12) +/// ); +/// +/// // The identifier is missing +/// assert_eq!( +/// warp::test::request() +/// .filter(&app) +/// .await +/// .unwrap(), +/// None +/// ); +/// +/// // The identifier is not a valid +/// assert!( +/// warp::test::request() +/// .header("Last-Event-ID", "abc") +/// .filter(&app) +/// .await +/// .is_err(), +/// ); +///}; +/// ``` +pub fn last_event_id<T>() -> impl Filter<Extract = One<Option<T>>, Error = Rejection> + Copy +where + T: FromStr + Send + Sync + 'static, +{ + header::optional("last-event-id") +} + +/// Server-sent events reply +/// +/// This function converts stream of server events into a `Reply` with: +/// +/// - Status of `200 OK` +/// - Header `content-type: text/event-stream` +/// - Header `cache-control: no-cache`. +/// +/// # Example +/// +/// ``` +/// +/// use std::time::Duration; +/// use futures_util::Stream; +/// use futures_util::stream::iter; +/// use std::convert::Infallible; +/// use warp::{Filter, sse::Event}; +/// use serde_derive::Serialize; +/// +/// #[derive(Serialize)] +/// struct Msg { +/// from: u32, +/// text: String, +/// } +/// +/// fn event_stream() -> impl Stream<Item = Result<Event, Infallible>> { +/// iter(vec![ +/// // Unnamed event with data only +/// Ok(Event::default().data("payload")), +/// // Named event with ID and retry timeout +/// Ok( +/// Event::default().data("other message\nwith next line") +/// .event("chat") +/// .id(1.to_string()) +/// .retry(Duration::from_millis(15000)) +/// ), +/// // Event with JSON data +/// Ok( +/// Event::default().id(2.to_string()) +/// .json_data(Msg { +/// from: 2, +/// text: "hello".into(), +/// }).unwrap(), +/// ) +/// ]) +/// } +/// +/// async { +/// let app = warp::path("sse").and(warp::get()).map(|| { +/// warp::sse::reply(event_stream()) +/// }); +/// +/// let res = warp::test::request() +/// .method("GET") +/// .header("Connection", "Keep-Alive") +/// .path("/sse") +/// .reply(&app) +/// .await +/// .into_body(); +/// +/// assert_eq!( +/// res, +/// r#"data:payload +/// +/// event:chat +/// data:other message +/// data:with next line +/// id:1 +/// retry:15000 +/// +/// data:{"from":2,"text":"hello"} +/// id:2 +/// +/// "# +/// ); +/// }; +/// ``` +pub fn reply<S>(event_stream: S) -> impl Reply +where + S: TryStream<Ok = Event> + Send + 'static, + S::Error: StdError + Send + Sync + 'static, +{ + SseReply { event_stream } +} + +#[allow(missing_debug_implementations)] +struct SseReply<S> { + event_stream: S, +} + +impl<S> Reply for SseReply<S> +where + S: TryStream<Ok = Event> + Send + 'static, + S::Error: StdError + Send + Sync + 'static, +{ + #[inline] + fn into_response(self) -> Response { + let body_stream = self + .event_stream + .map_err(|error| { + // FIXME: error logging + log::error!("sse stream error: {}", error); + SseError + }) + .into_stream() + .and_then(|event| future::ready(Ok(event.to_string()))); + + let mut res = Response::new(Body::wrap_stream(body_stream)); + // Set appropriate content type + res.headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + // Disable response body caching + res.headers_mut() + .insert(CACHE_CONTROL, HeaderValue::from_static("no-cache")); + res + } +} + +/// Configure the interval between keep-alive messages, the content +/// of each message, and the associated stream. +#[derive(Debug)] +pub struct KeepAlive { + comment_text: Cow<'static, str>, + max_interval: Duration, +} + +impl KeepAlive { + /// Customize the interval between keep-alive messages. + /// + /// Default is 15 seconds. + pub fn interval(mut self, time: Duration) -> Self { + self.max_interval = time; + self + } + + /// Customize the text of the keep-alive message. + /// + /// Default is an empty comment. + pub fn text(mut self, text: impl Into<Cow<'static, str>>) -> Self { + self.comment_text = text.into(); + self + } + + /// Wrap an event stream with keep-alive functionality. + /// + /// See [`keep_alive`](keep_alive) for more. + pub fn stream<S>( + self, + event_stream: S, + ) -> impl TryStream<Ok = Event, Error = impl StdError + Send + Sync + 'static> + Send + 'static + where + S: TryStream<Ok = Event> + Send + 'static, + S::Error: StdError + Send + Sync + 'static, + { + let alive_timer = time::sleep(self.max_interval); + SseKeepAlive { + event_stream, + comment_text: self.comment_text, + max_interval: self.max_interval, + alive_timer, + } + } +} + +#[allow(missing_debug_implementations)] +#[pin_project] +struct SseKeepAlive<S> { + #[pin] + event_stream: S, + comment_text: Cow<'static, str>, + max_interval: Duration, + #[pin] + alive_timer: Sleep, +} + +/// Keeps event source connection alive when no events sent over a some time. +/// +/// Some proxy servers may drop HTTP connection after a some timeout of inactivity. +/// This function helps to prevent such behavior by sending comment events every +/// `keep_interval` of inactivity. +/// +/// By default the comment is `:` (an empty comment) and the time interval between +/// events is 15 seconds. Both may be customized using the builder pattern +/// as shown below. +/// +/// ``` +/// use std::time::Duration; +/// use std::convert::Infallible; +/// use futures_util::StreamExt; +/// use tokio::time::interval; +/// use tokio_stream::wrappers::IntervalStream; +/// use warp::{Filter, Stream, sse::Event}; +/// +/// // create server-sent event +/// fn sse_counter(counter: u64) -> Result<Event, Infallible> { +/// Ok(Event::default().data(counter.to_string())) +/// } +/// +/// fn main() { +/// let routes = warp::path("ticks") +/// .and(warp::get()) +/// .map(|| { +/// let mut counter: u64 = 0; +/// let interval = interval(Duration::from_secs(15)); +/// let stream = IntervalStream::new(interval); +/// let event_stream = stream.map(move |_| { +/// counter += 1; +/// sse_counter(counter) +/// }); +/// // reply using server-sent events +/// let stream = warp::sse::keep_alive() +/// .interval(Duration::from_secs(5)) +/// .text("thump".to_string()) +/// .stream(event_stream); +/// warp::sse::reply(stream) +/// }); +/// } +/// ``` +/// +/// See [notes](https://www.w3.org/TR/2009/WD-eventsource-20090421/#notes). +pub fn keep_alive() -> KeepAlive { + KeepAlive { + comment_text: Cow::Borrowed(""), + max_interval: Duration::from_secs(15), + } +} + +impl<S> Stream for SseKeepAlive<S> +where + S: TryStream<Ok = Event> + Send + 'static, + S::Error: StdError + Send + Sync + 'static, +{ + type Item = Result<Event, SseError>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + let mut pin = self.project(); + match pin.event_stream.try_poll_next(cx) { + Poll::Pending => match Pin::new(&mut pin.alive_timer).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(_) => { + // restart timer + pin.alive_timer + .reset(tokio::time::Instant::now() + *pin.max_interval); + let comment_str = pin.comment_text.clone(); + let event = Event::default().comment(comment_str); + Poll::Ready(Some(Ok(event))) + } + }, + Poll::Ready(Some(Ok(event))) => { + // restart timer + pin.alive_timer + .reset(tokio::time::Instant::now() + *pin.max_interval); + Poll::Ready(Some(Ok(event))) + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Err(error))) => { + log::error!("sse::keep error: {}", error); + Poll::Ready(Some(Err(SseError))) + } + } + } +} + +mod sealed { + use super::*; + + /// SSE error type + #[derive(Debug)] + pub struct SseError; + + impl fmt::Display for SseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "sse error") + } + } + + impl StdError for SseError {} +} diff --git a/third_party/rust/warp/src/filters/trace.rs b/third_party/rust/warp/src/filters/trace.rs new file mode 100644 index 0000000000..5ca4e9df96 --- /dev/null +++ b/third_party/rust/warp/src/filters/trace.rs @@ -0,0 +1,309 @@ +//! [`tracing`] filters. +//! +//! [`tracing`] is a framework for instrumenting Rust programs to +//! collect scoped, structured, and async-aware diagnostics. This module +//! provides a set of filters for instrumenting Warp applications with `tracing` +//! spans. [`Spans`] can be used to associate individual events with a request, +//! and track contexts through the application. +//! +//! [`tracing`]: https://crates.io/crates/tracing +//! [`Spans`]: https://docs.rs/tracing/latest/tracing/#spans +use tracing::Span; + +use std::net::SocketAddr; + +use http::{self, header}; + +use crate::filter::{Filter, WrapSealed}; +use crate::reject::IsReject; +use crate::reply::Reply; +use crate::route::Route; + +use self::internal::WithTrace; + +/// Create a wrapping filter that instruments every request with a `tracing` +/// [`Span`] at the [`INFO`] level, containing a summary of the request. +/// Additionally, if the [`DEBUG`] level is enabled, the span will contain an +/// event recording the request's headers. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let route = warp::any() +/// .map(warp::reply) +/// .with(warp::trace::request()); +/// ``` +/// +/// [`Span`]: https://docs.rs/tracing/latest/tracing/#spans +/// [`INFO`]: https://docs.rs/tracing/0.1.16/tracing/struct.Level.html#associatedconstant.INFO +/// [`DEBUG`]: https://docs.rs/tracing/0.1.16/tracing/struct.Level.html#associatedconstant.DEBUG +pub fn request() -> Trace<impl Fn(Info<'_>) -> Span + Clone> { + use tracing::field::{display, Empty}; + trace(|info: Info<'_>| { + let span = tracing::info_span!( + "request", + remote.addr = Empty, + method = %info.method(), + path = %info.path(), + version = ?info.route.version(), + referer = Empty, + ); + + // Record optional fields. + if let Some(remote_addr) = info.remote_addr() { + span.record("remote.addr", &display(remote_addr)); + } + + if let Some(referer) = info.referer() { + span.record("referer", &display(referer)); + } + + tracing::debug!(parent: &span, "received request"); + + span + }) +} + +/// Create a wrapping filter that instruments every request with a custom +/// `tracing` [`Span`] provided by a function. +/// +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let route = warp::any() +/// .map(warp::reply) +/// .with(warp::trace(|info| { +/// // Create a span using tracing macros +/// tracing::info_span!( +/// "request", +/// method = %info.method(), +/// path = %info.path(), +/// ) +/// })); +/// ``` +/// +/// [`Span`]: https://docs.rs/tracing/latest/tracing/#spans +pub fn trace<F>(func: F) -> Trace<F> +where + F: Fn(Info<'_>) -> Span + Clone, +{ + Trace { func } +} + +/// Create a wrapping filter that instruments every request with a `tracing` +/// [`Span`] at the [`DEBUG`] level representing a named context. +/// +/// This can be used to instrument multiple routes with their own sub-spans in a +/// per-request trace. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let hello = warp::path("hello") +/// .map(warp::reply) +/// .with(warp::trace::named("hello")); +/// +/// let goodbye = warp::path("goodbye") +/// .map(warp::reply) +/// .with(warp::trace::named("goodbye")); +/// +/// let routes = hello.or(goodbye); +/// ``` +/// +/// [`Span`]: https://docs.rs/tracing/latest/tracing/#spans +/// [`DEBUG`]: https://docs.rs/tracing/0.1.16/tracing/struct.Level.html#associatedconstant.DEBUG +pub fn named(name: &'static str) -> Trace<impl Fn(Info<'_>) -> Span + Copy> { + trace(move |_| tracing::debug_span!("context", "{}", name,)) +} + +/// Decorates a [`Filter`](crate::Filter) to create a [`tracing`] [span] for +/// requests and responses. +/// +/// [`tracing`]: https://crates.io/crates/tracing +/// [span]: https://docs.rs/tracing/latest/tracing/#spans +#[derive(Clone, Copy, Debug)] +pub struct Trace<F> { + func: F, +} + +/// Information about the request/response that can be used to prepare log lines. +#[allow(missing_debug_implementations)] +pub struct Info<'a> { + route: &'a Route, +} + +impl<FN, F> WrapSealed<F> for Trace<FN> +where + FN: Fn(Info<'_>) -> Span + Clone + Send, + F: Filter + Clone + Send, + F::Extract: Reply, + F::Error: IsReject, +{ + type Wrapped = WithTrace<FN, F>; + + fn wrap(&self, filter: F) -> Self::Wrapped { + WithTrace { + filter, + trace: self.clone(), + } + } +} + +impl<'a> Info<'a> { + /// View the remote `SocketAddr` of the request. + pub fn remote_addr(&self) -> Option<SocketAddr> { + self.route.remote_addr() + } + + /// View the `http::Method` of the request. + pub fn method(&self) -> &http::Method { + self.route.method() + } + + /// View the URI path of the request. + pub fn path(&self) -> &str { + self.route.full_path() + } + + /// View the `http::Version` of the request. + pub fn version(&self) -> http::Version { + self.route.version() + } + + /// View the referer of the request. + pub fn referer(&self) -> Option<&str> { + self.route + .headers() + .get(header::REFERER) + .and_then(|v| v.to_str().ok()) + } + + /// View the user agent of the request. + pub fn user_agent(&self) -> Option<&str> { + self.route + .headers() + .get(header::USER_AGENT) + .and_then(|v| v.to_str().ok()) + } + + /// View the host of the request + pub fn host(&self) -> Option<&str> { + self.route + .headers() + .get(header::HOST) + .and_then(|v| v.to_str().ok()) + } + + /// View the request headers. + pub fn request_headers(&self) -> &http::HeaderMap { + self.route.headers() + } +} + +mod internal { + use futures_util::{future::Inspect, future::MapOk, FutureExt, TryFutureExt}; + + use super::{Info, Trace}; + use crate::filter::{Filter, FilterBase, Internal}; + use crate::reject::IsReject; + use crate::reply::Reply; + use crate::reply::Response; + use crate::route; + + #[allow(missing_debug_implementations)] + pub struct Traced(pub(super) Response); + + impl Reply for Traced { + #[inline] + fn into_response(self) -> Response { + self.0 + } + } + + #[allow(missing_debug_implementations)] + #[derive(Clone, Copy)] + pub struct WithTrace<FN, F> { + pub(super) filter: F, + pub(super) trace: Trace<FN>, + } + + use tracing::instrument::{Instrument, Instrumented}; + use tracing::Span; + + fn finished_logger<E: IsReject>(reply: &Result<(Traced,), E>) { + let (status, error) = match reply { + Ok((Traced(resp),)) => (resp.status(), None), + Err(error) => (error.status(), Some(error)), + }; + + if status.is_success() { + tracing::info!( + target: "warp::filters::trace", + status = status.as_u16(), + "finished processing with success" + ); + } else if status.is_server_error() { + tracing::error!( + target: "warp::filters::trace", + status = status.as_u16(), + error = ?error, + "unable to process request (internal error)" + ); + } else if status.is_client_error() { + tracing::warn!( + target: "warp::filters::trace", + status = status.as_u16(), + error = ?error, + "unable to serve request (client error)" + ); + } else { + // Either informational or redirect + tracing::info!( + target: "warp::filters::trace", + status = status.as_u16(), + error = ?error, + "finished processing with status" + ); + } + } + + fn convert_reply<R: Reply>(reply: R) -> (Traced,) { + (Traced(reply.into_response()),) + } + + impl<FN, F> FilterBase for WithTrace<FN, F> + where + FN: Fn(Info<'_>) -> Span + Clone + Send, + F: Filter + Clone + Send, + F::Extract: Reply, + F::Error: IsReject, + { + type Extract = (Traced,); + type Error = F::Error; + type Future = Instrumented< + Inspect< + MapOk<F::Future, fn(F::Extract) -> Self::Extract>, + fn(&Result<Self::Extract, F::Error>), + >, + >; + + fn filter(&self, _: Internal) -> Self::Future { + let span = route::with(|route| (self.trace.func)(Info { route })); + let _entered = span.enter(); + + tracing::info!(target: "warp::filters::trace", "processing request"); + self.filter + .filter(Internal) + .map_ok(convert_reply as fn(F::Extract) -> Self::Extract) + .inspect(finished_logger as fn(&Result<Self::Extract, F::Error>)) + .instrument(span.clone()) + } + } +} diff --git a/third_party/rust/warp/src/filters/ws.rs b/third_party/rust/warp/src/filters/ws.rs new file mode 100644 index 0000000000..1e953c7a35 --- /dev/null +++ b/third_party/rust/warp/src/filters/ws.rs @@ -0,0 +1,410 @@ +//! Websockets Filters + +use std::borrow::Cow; +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use super::header; +use crate::filter::{filter_fn_one, Filter, One}; +use crate::reject::Rejection; +use crate::reply::{Reply, Response}; +use futures_util::{future, ready, FutureExt, Sink, Stream, TryFutureExt}; +use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade}; +use http; +use hyper::upgrade::OnUpgrade; +use tokio_tungstenite::{ + tungstenite::protocol::{self, WebSocketConfig}, + WebSocketStream, +}; + +/// Creates a Websocket Filter. +/// +/// The yielded `Ws` is used to finish the websocket upgrade. +/// +/// # Note +/// +/// This filter combines multiple filters internally, so you don't need them: +/// +/// - Method must be `GET` +/// - Header `connection` must be `upgrade` +/// - Header `upgrade` must be `websocket` +/// - Header `sec-websocket-version` must be `13` +/// - Header `sec-websocket-key` must be set. +/// +/// If the filters are met, yields a `Ws`. Calling `Ws::on_upgrade` will +/// return a reply with: +/// +/// - Status of `101 Switching Protocols` +/// - Header `connection: upgrade` +/// - Header `upgrade: websocket` +/// - Header `sec-websocket-accept` with the hash value of the received key. +pub fn ws() -> impl Filter<Extract = One<Ws>, Error = Rejection> + Copy { + let connection_has_upgrade = header::header2() + .and_then(|conn: ::headers::Connection| { + if conn.contains("upgrade") { + future::ok(()) + } else { + future::err(crate::reject::known(MissingConnectionUpgrade)) + } + }) + .untuple_one(); + + crate::get() + .and(connection_has_upgrade) + .and(header::exact_ignore_case("upgrade", "websocket")) + .and(header::exact("sec-websocket-version", "13")) + //.and(header::exact2(Upgrade::websocket())) + //.and(header::exact2(SecWebsocketVersion::V13)) + .and(header::header2::<SecWebsocketKey>()) + .and(on_upgrade()) + .map( + move |key: SecWebsocketKey, on_upgrade: Option<OnUpgrade>| Ws { + config: None, + key, + on_upgrade, + }, + ) +} + +/// Extracted by the [`ws`](ws) filter, and used to finish an upgrade. +pub struct Ws { + config: Option<WebSocketConfig>, + key: SecWebsocketKey, + on_upgrade: Option<OnUpgrade>, +} + +impl Ws { + /// Finish the upgrade, passing a function to handle the `WebSocket`. + /// + /// The passed function must return a `Future`. + pub fn on_upgrade<F, U>(self, func: F) -> impl Reply + where + F: FnOnce(WebSocket) -> U + Send + 'static, + U: Future<Output = ()> + Send + 'static, + { + WsReply { + ws: self, + on_upgrade: func, + } + } + + // config + + /// Set the size of the internal message send queue. + pub fn max_send_queue(mut self, max: usize) -> Self { + self.config + .get_or_insert_with(WebSocketConfig::default) + .max_send_queue = Some(max); + self + } + + /// Set the maximum message size (defaults to 64 megabytes) + pub fn max_message_size(mut self, max: usize) -> Self { + self.config + .get_or_insert_with(WebSocketConfig::default) + .max_message_size = Some(max); + self + } + + /// Set the maximum frame size (defaults to 16 megabytes) + pub fn max_frame_size(mut self, max: usize) -> Self { + self.config + .get_or_insert_with(WebSocketConfig::default) + .max_frame_size = Some(max); + self + } +} + +impl fmt::Debug for Ws { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Ws").finish() + } +} + +#[allow(missing_debug_implementations)] +struct WsReply<F> { + ws: Ws, + on_upgrade: F, +} + +impl<F, U> Reply for WsReply<F> +where + F: FnOnce(WebSocket) -> U + Send + 'static, + U: Future<Output = ()> + Send + 'static, +{ + fn into_response(self) -> Response { + if let Some(on_upgrade) = self.ws.on_upgrade { + let on_upgrade_cb = self.on_upgrade; + let config = self.ws.config; + let fut = on_upgrade + .and_then(move |upgraded| { + tracing::trace!("websocket upgrade complete"); + WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok) + }) + .and_then(move |socket| on_upgrade_cb(socket).map(Ok)) + .map(|result| { + if let Err(err) = result { + tracing::debug!("ws upgrade error: {}", err); + } + }); + ::tokio::task::spawn(fut); + } else { + tracing::debug!("ws couldn't be upgraded since no upgrade state was present"); + } + + let mut res = http::Response::default(); + + *res.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS; + + res.headers_mut().typed_insert(Connection::upgrade()); + res.headers_mut().typed_insert(Upgrade::websocket()); + res.headers_mut() + .typed_insert(SecWebsocketAccept::from(self.ws.key)); + + res + } +} + +// Extracts OnUpgrade state from the route. +fn on_upgrade() -> impl Filter<Extract = (Option<OnUpgrade>,), Error = Rejection> + Copy { + filter_fn_one(|route| future::ready(Ok(route.extensions_mut().remove::<OnUpgrade>()))) +} + +/// A websocket `Stream` and `Sink`, provided to `ws` filters. +/// +/// Ping messages sent from the client will be handled internally by replying with a Pong message. +/// Close messages need to be handled explicitly: usually by closing the `Sink` end of the +/// `WebSocket`. +/// +/// **Note!** +/// Due to rust futures nature, pings won't be handled until read part of `WebSocket` is polled + +pub struct WebSocket { + inner: WebSocketStream<hyper::upgrade::Upgraded>, +} + +impl WebSocket { + pub(crate) async fn from_raw_socket( + upgraded: hyper::upgrade::Upgraded, + role: protocol::Role, + config: Option<protocol::WebSocketConfig>, + ) -> Self { + WebSocketStream::from_raw_socket(upgraded, role, config) + .map(|inner| WebSocket { inner }) + .await + } + + /// Gracefully close this websocket. + pub async fn close(mut self) -> Result<(), crate::Error> { + future::poll_fn(|cx| Pin::new(&mut self).poll_close(cx)).await + } +} + +impl Stream for WebSocket { + type Item = Result<Message, crate::Error>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + match ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(Ok(item)) => Poll::Ready(Some(Ok(Message { inner: item }))), + Some(Err(e)) => { + tracing::debug!("websocket poll error: {}", e); + Poll::Ready(Some(Err(crate::Error::new(e)))) + } + None => { + tracing::trace!("websocket closed"); + Poll::Ready(None) + } + } + } +} + +impl Sink<Message> for WebSocket { + type Error = crate::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + match ready!(Pin::new(&mut self.inner).poll_ready(cx)) { + Ok(()) => Poll::Ready(Ok(())), + Err(e) => Poll::Ready(Err(crate::Error::new(e))), + } + } + + fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + match Pin::new(&mut self.inner).start_send(item.inner) { + Ok(()) => Ok(()), + Err(e) => { + tracing::debug!("websocket start_send error: {}", e); + Err(crate::Error::new(e)) + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + match ready!(Pin::new(&mut self.inner).poll_flush(cx)) { + Ok(()) => Poll::Ready(Ok(())), + Err(e) => Poll::Ready(Err(crate::Error::new(e))), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + match ready!(Pin::new(&mut self.inner).poll_close(cx)) { + Ok(()) => Poll::Ready(Ok(())), + Err(err) => { + tracing::debug!("websocket close error: {}", err); + Poll::Ready(Err(crate::Error::new(err))) + } + } + } +} + +impl fmt::Debug for WebSocket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WebSocket").finish() + } +} + +/// A WebSocket message. +/// +/// This will likely become a `non-exhaustive` enum in the future, once that +/// language feature has stabilized. +#[derive(Eq, PartialEq, Clone)] +pub struct Message { + inner: protocol::Message, +} + +impl Message { + /// Construct a new Text `Message`. + pub fn text<S: Into<String>>(s: S) -> Message { + Message { + inner: protocol::Message::text(s), + } + } + + /// Construct a new Binary `Message`. + pub fn binary<V: Into<Vec<u8>>>(v: V) -> Message { + Message { + inner: protocol::Message::binary(v), + } + } + + /// Construct a new Ping `Message`. + pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message { + Message { + inner: protocol::Message::Ping(v.into()), + } + } + + /// Construct a new Pong `Message`. + /// + /// Note that one rarely needs to manually construct a Pong message because the underlying tungstenite socket + /// automatically responds to the Ping messages it receives. Manual construction might still be useful in some cases + /// like in tests or to send unidirectional heartbeats. + pub fn pong<V: Into<Vec<u8>>>(v: V) -> Message { + Message { + inner: protocol::Message::Pong(v.into()), + } + } + + /// Construct the default Close `Message`. + pub fn close() -> Message { + Message { + inner: protocol::Message::Close(None), + } + } + + /// Construct a Close `Message` with a code and reason. + pub fn close_with(code: impl Into<u16>, reason: impl Into<Cow<'static, str>>) -> Message { + Message { + inner: protocol::Message::Close(Some(protocol::frame::CloseFrame { + code: protocol::frame::coding::CloseCode::from(code.into()), + reason: reason.into(), + })), + } + } + + /// Returns true if this message is a Text message. + pub fn is_text(&self) -> bool { + self.inner.is_text() + } + + /// Returns true if this message is a Binary message. + pub fn is_binary(&self) -> bool { + self.inner.is_binary() + } + + /// Returns true if this message a is a Close message. + pub fn is_close(&self) -> bool { + self.inner.is_close() + } + + /// Returns true if this message is a Ping message. + pub fn is_ping(&self) -> bool { + self.inner.is_ping() + } + + /// Returns true if this message is a Pong message. + pub fn is_pong(&self) -> bool { + self.inner.is_pong() + } + + /// Try to get the close frame (close code and reason) + pub fn close_frame(&self) -> Option<(u16, &str)> { + if let protocol::Message::Close(Some(ref close_frame)) = self.inner { + Some((close_frame.code.into(), close_frame.reason.as_ref())) + } else { + None + } + } + + /// Try to get a reference to the string text, if this is a Text message. + pub fn to_str(&self) -> Result<&str, ()> { + match self.inner { + protocol::Message::Text(ref s) => Ok(s), + _ => Err(()), + } + } + + /// Return the bytes of this message, if the message can contain data. + pub fn as_bytes(&self) -> &[u8] { + match self.inner { + protocol::Message::Text(ref s) => s.as_bytes(), + protocol::Message::Binary(ref v) => v, + protocol::Message::Ping(ref v) => v, + protocol::Message::Pong(ref v) => v, + protocol::Message::Close(_) => &[], + protocol::Message::Frame(ref frame) => frame.payload(), + } + } + + /// Destructure this message into binary data. + pub fn into_bytes(self) -> Vec<u8> { + self.inner.into_data() + } +} + +impl fmt::Debug for Message { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.inner, f) + } +} + +impl From<Message> for Vec<u8> { + fn from(m: Message) -> Self { + m.into_bytes() + } +} + +// ===== Rejections ===== + +/// Connection header did not include 'upgrade' +#[derive(Debug)] +pub struct MissingConnectionUpgrade; + +impl fmt::Display for MissingConnectionUpgrade { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Connection header did not include 'upgrade'") + } +} + +impl ::std::error::Error for MissingConnectionUpgrade {} diff --git a/third_party/rust/warp/src/generic.rs b/third_party/rust/warp/src/generic.rs new file mode 100644 index 0000000000..5350d22156 --- /dev/null +++ b/third_party/rust/warp/src/generic.rs @@ -0,0 +1,250 @@ +#[derive(Debug)] +pub struct Product<H, T: HList>(pub(crate) H, pub(crate) T); + +pub type One<T> = (T,); + +#[inline] +pub(crate) fn one<T>(val: T) -> One<T> { + (val,) +} + +#[derive(Debug)] +pub enum Either<T, U> { + A(T), + B(U), +} + +// Converts Product (and ()) into tuples. +pub trait HList: Sized { + type Tuple: Tuple<HList = Self>; + + fn flatten(self) -> Self::Tuple; +} + +// Typeclass that tuples can be converted into a Product (or unit ()). +pub trait Tuple: Sized { + type HList: HList<Tuple = Self>; + + fn hlist(self) -> Self::HList; + + #[inline] + fn combine<T>(self, other: T) -> CombinedTuples<Self, T> + where + Self: Sized, + T: Tuple, + Self::HList: Combine<T::HList>, + { + self.hlist().combine(other.hlist()).flatten() + } +} + +pub type CombinedTuples<T, U> = + <<<T as Tuple>::HList as Combine<<U as Tuple>::HList>>::Output as HList>::Tuple; + +// Combines Product together. +pub trait Combine<T: HList> { + type Output: HList; + + fn combine(self, other: T) -> Self::Output; +} + +pub trait Func<Args> { + type Output; + + fn call(&self, args: Args) -> Self::Output; +} + +// ===== impl Combine ===== + +impl<T: HList> Combine<T> for () { + type Output = T; + #[inline] + fn combine(self, other: T) -> Self::Output { + other + } +} + +impl<H, T: HList, U: HList> Combine<U> for Product<H, T> +where + T: Combine<U>, + Product<H, <T as Combine<U>>::Output>: HList, +{ + type Output = Product<H, <T as Combine<U>>::Output>; + + #[inline] + fn combine(self, other: U) -> Self::Output { + Product(self.0, self.1.combine(other)) + } +} + +impl HList for () { + type Tuple = (); + #[inline] + fn flatten(self) -> Self::Tuple {} +} + +impl Tuple for () { + type HList = (); + + #[inline] + fn hlist(self) -> Self::HList {} +} + +impl<F, R> Func<()> for F +where + F: Fn() -> R, +{ + type Output = R; + + #[inline] + fn call(&self, _args: ()) -> Self::Output { + (*self)() + } +} + +impl<F, R> Func<crate::Rejection> for F +where + F: Fn(crate::Rejection) -> R, +{ + type Output = R; + + #[inline] + fn call(&self, arg: crate::Rejection) -> Self::Output { + (*self)(arg) + } +} + +macro_rules! product { + ($H:expr) => { Product($H, ()) }; + ($H:expr, $($T:expr),*) => { Product($H, product!($($T),*)) }; +} + +macro_rules! Product { + ($H:ty) => { Product<$H, ()> }; + ($H:ty, $($T:ty),*) => { Product<$H, Product!($($T),*)> }; +} + +macro_rules! product_pat { + ($H:pat) => { Product($H, ()) }; + ($H:pat, $($T:pat),*) => { Product($H, product_pat!($($T),*)) }; +} + +macro_rules! generics { + ($type:ident) => { + impl<$type> HList for Product!($type) { + type Tuple = ($type,); + + #[inline] + fn flatten(self) -> Self::Tuple { + (self.0,) + } + } + + impl<$type> Tuple for ($type,) { + type HList = Product!($type); + #[inline] + fn hlist(self) -> Self::HList { + product!(self.0) + } + } + + impl<F, R, $type> Func<Product!($type)> for F + where + F: Fn($type) -> R, + { + type Output = R; + + #[inline] + fn call(&self, args: Product!($type)) -> Self::Output { + (*self)(args.0) + } + + } + + impl<F, R, $type> Func<($type,)> for F + where + F: Fn($type) -> R, + { + type Output = R; + + #[inline] + fn call(&self, args: ($type,)) -> Self::Output { + (*self)(args.0) + } + } + + }; + + ($type1:ident, $( $type:ident ),*) => { + generics!($( $type ),*); + + impl<$type1, $( $type ),*> HList for Product!($type1, $($type),*) { + type Tuple = ($type1, $( $type ),*); + + #[inline] + fn flatten(self) -> Self::Tuple { + #[allow(non_snake_case)] + let product_pat!($type1, $( $type ),*) = self; + ($type1, $( $type ),*) + } + } + + impl<$type1, $( $type ),*> Tuple for ($type1, $($type),*) { + type HList = Product!($type1, $( $type ),*); + + #[inline] + fn hlist(self) -> Self::HList { + #[allow(non_snake_case)] + let ($type1, $( $type ),*) = self; + product!($type1, $( $type ),*) + } + } + + impl<F, R, $type1, $( $type ),*> Func<Product!($type1, $($type),*)> for F + where + F: Fn($type1, $( $type ),*) -> R, + { + type Output = R; + + #[inline] + fn call(&self, args: Product!($type1, $($type),*)) -> Self::Output { + #[allow(non_snake_case)] + let product_pat!($type1, $( $type ),*) = args; + (*self)($type1, $( $type ),*) + } + } + + impl<F, R, $type1, $( $type ),*> Func<($type1, $($type),*)> for F + where + F: Fn($type1, $( $type ),*) -> R, + { + type Output = R; + + #[inline] + fn call(&self, args: ($type1, $($type),*)) -> Self::Output { + #[allow(non_snake_case)] + let ($type1, $( $type ),*) = args; + (*self)($type1, $( $type ),*) + } + } + }; +} + +generics! { + T1, + T2, + T3, + T4, + T5, + T6, + T7, + T8, + T9, + T10, + T11, + T12, + T13, + T14, + T15, + T16 +} diff --git a/third_party/rust/warp/src/lib.rs b/third_party/rust/warp/src/lib.rs new file mode 100644 index 0000000000..a965b1d1b8 --- /dev/null +++ b/third_party/rust/warp/src/lib.rs @@ -0,0 +1,179 @@ +#![doc(html_root_url = "https://docs.rs/warp/0.3.3")] +#![deny(missing_docs)] +#![deny(missing_debug_implementations)] +#![deny(rust_2018_idioms)] +#![cfg_attr(test, deny(warnings))] + +//! # warp +//! +//! warp is a super-easy, composable, web server framework for warp speeds. +//! +//! Thanks to its [`Filter`][Filter] system, warp provides these out of the box: +//! +//! - Path routing and parameter extraction +//! - Header requirements and extraction +//! - Query string deserialization +//! - JSON and Form bodies +//! - Multipart form data +//! - Static Files and Directories +//! - Websockets +//! - Access logging +//! - Etc +//! +//! Since it builds on top of [hyper](https://hyper.rs), you automatically get: +//! +//! - HTTP/1 +//! - HTTP/2 +//! - Asynchronous +//! - One of the fastest HTTP implementations +//! - Tested and **correct** +//! +//! ## Filters +//! +//! The main concept in warp is the [`Filter`][Filter], which allows composition +//! to describe various endpoints in your web service. Besides this powerful +//! trait, warp comes with several built in [filters](filters/index.html), which +//! can be combined for your specific needs. +//! +//! As a small example, consider an endpoint that has path and header requirements: +//! +//! ``` +//! use warp::Filter; +//! +//! let hi = warp::path("hello") +//! .and(warp::path::param()) +//! .and(warp::header("user-agent")) +//! .map(|param: String, agent: String| { +//! format!("Hello {}, whose agent is {}", param, agent) +//! }); +//! ``` +//! +//! This example composes several [`Filter`s][Filter] together using `and`: +//! +//! - A path prefix of "hello" +//! - A path parameter of a `String` +//! - The `user-agent` header parsed as a `String` +//! +//! These specific filters will [`reject`][reject] requests that don't match +//! their requirements. +//! +//! This ends up matching requests like: +//! +//! ```notrust +//! GET /hello/sean HTTP/1.1 +//! Host: hyper.rs +//! User-Agent: reqwest/v0.8.6 +//! +//! ``` +//! And it returns a response similar to this: +//! +//! ```notrust +//! HTTP/1.1 200 OK +//! Content-Length: 41 +//! Date: ... +//! +//! Hello sean, whose agent is reqwest/v0.8.6 +//! ``` +//! +//! Take a look at the full list of [`filters`](filters/index.html) to see what +//! you can build. +//! +//! ## Testing +//! +//! Testing your web services easily is extremely important, and warp provides +//! a [`test`](self::test) module to help send mocked requests through your service. +//! +//! [Filter]: trait.Filter.html +//! [reject]: reject/index.html + +#[macro_use] +mod error; +mod filter; +pub mod filters; +mod generic; +pub mod redirect; +pub mod reject; +pub mod reply; +mod route; +mod server; +mod service; +pub mod test; +#[cfg(feature = "tls")] +mod tls; +mod transport; + +pub use self::error::Error; +pub use self::filter::Filter; +// This otherwise shows a big dump of re-exports in the doc homepage, +// with zero context, so just hide it from the docs. Doc examples +// on each can show that a convenient import exists. +#[cfg(feature = "compression")] +#[doc(hidden)] +pub use self::filters::compression; +#[cfg(feature = "multipart")] +#[doc(hidden)] +pub use self::filters::multipart; +#[cfg(feature = "websocket")] +#[doc(hidden)] +pub use self::filters::ws; +#[doc(hidden)] +pub use self::filters::{ + addr, + // any() function + any::any, + body, + cookie, + // cookie() function + cookie::cookie, + cors, + // cors() function + cors::cors, + ext, + fs, + header, + // header() function + header::header, + host, + log, + // log() function + log::log, + method::{delete, get, head, method, options, patch, post, put}, + path, + // path() function and macro + path::path, + query, + // query() function + query::query, + sse, + trace, + // trace() function + trace::trace, +}; +// ws() function +pub use self::filter::wrap_fn; +#[cfg(feature = "websocket")] +#[doc(hidden)] +pub use self::filters::ws::ws; +#[doc(hidden)] +pub use self::redirect::redirect; +#[doc(hidden)] +#[allow(deprecated)] +pub use self::reject::{reject, Rejection}; +#[doc(hidden)] +pub use self::reply::{reply, Reply}; +#[cfg(feature = "tls")] +pub use self::server::TlsServer; +pub use self::server::{serve, Server}; +pub use self::service::service; +#[doc(hidden)] +pub use http; +#[doc(hidden)] +pub use hyper; + +#[doc(hidden)] +pub use bytes::Buf; +#[doc(hidden)] +pub use futures_util::{Future, Sink, Stream}; +#[doc(hidden)] + +pub(crate) type Request = http::Request<hyper::Body>; diff --git a/third_party/rust/warp/src/redirect.rs b/third_party/rust/warp/src/redirect.rs new file mode 100644 index 0000000000..ee2c3ad798 --- /dev/null +++ b/third_party/rust/warp/src/redirect.rs @@ -0,0 +1,132 @@ +//! Redirect requests to a new location. +//! +//! The types in this module are helpers that implement [`Reply`](Reply), and easy +//! to use in order to setup redirects. + +use http::{header, StatusCode}; + +pub use self::sealed::AsLocation; +use crate::reply::{self, Reply}; + +/// A simple `301` permanent redirect to a different location. +/// +/// # Example +/// +/// ``` +/// use warp::{http::Uri, Filter}; +/// +/// let route = warp::path("v1") +/// .map(|| { +/// warp::redirect(Uri::from_static("/v2")) +/// }); +/// ``` +pub fn redirect(uri: impl AsLocation) -> impl Reply { + reply::with_header( + StatusCode::MOVED_PERMANENTLY, + header::LOCATION, + uri.header_value(), + ) +} + +/// A simple `302` found redirect to a different location +/// +/// # Example +/// +/// ``` +/// use warp::{http::Uri, Filter}; +/// +/// let route = warp::path("v1") +/// .map(|| { +/// warp::redirect::found(Uri::from_static("/v2")) +/// }); +/// ``` +pub fn found(uri: impl AsLocation) -> impl Reply { + reply::with_header(StatusCode::FOUND, header::LOCATION, uri.header_value()) +} + +/// A simple `303` redirect to a different location. +/// +/// The HTTP method of the request to the new location will always be `GET`. +/// +/// # Example +/// +/// ``` +/// use warp::{http::Uri, Filter}; +/// +/// let route = warp::path("v1") +/// .map(|| { +/// warp::redirect::see_other(Uri::from_static("/v2")) +/// }); +/// ``` +pub fn see_other(uri: impl AsLocation) -> impl Reply { + reply::with_header(StatusCode::SEE_OTHER, header::LOCATION, uri.header_value()) +} + +/// A simple `307` temporary redirect to a different location. +/// +/// This is similar to [`see_other`](fn@see_other) but the HTTP method and the body of the request +/// to the new location will be the same as the method and body of the current request. +/// +/// # Example +/// +/// ``` +/// use warp::{http::Uri, Filter}; +/// +/// let route = warp::path("v1") +/// .map(|| { +/// warp::redirect::temporary(Uri::from_static("/v2")) +/// }); +/// ``` +pub fn temporary(uri: impl AsLocation) -> impl Reply { + reply::with_header( + StatusCode::TEMPORARY_REDIRECT, + header::LOCATION, + uri.header_value(), + ) +} + +/// A simple `308` permanent redirect to a different location. +/// +/// This is similar to [`redirect`](fn@redirect) but the HTTP method of the request to the new +/// location will be the same as the method of the current request. +/// +/// # Example +/// +/// ``` +/// use warp::{http::Uri, Filter}; +/// +/// let route = warp::path("v1") +/// .map(|| { +/// warp::redirect::permanent(Uri::from_static("/v2")) +/// }); +/// ``` +pub fn permanent(uri: impl AsLocation) -> impl Reply { + reply::with_header( + StatusCode::PERMANENT_REDIRECT, + header::LOCATION, + uri.header_value(), + ) +} + +mod sealed { + use bytes::Bytes; + use http::{header::HeaderValue, Uri}; + + /// Trait for redirect locations. Currently only a `Uri` can be used in + /// redirect. + /// This sealed trait exists to allow adding possibly new impls so other + /// arguments could be accepted, like maybe just `warp::redirect("/v2")`. + pub trait AsLocation: Sealed {} + pub trait Sealed { + fn header_value(self) -> HeaderValue; + } + + impl AsLocation for Uri {} + + impl Sealed for Uri { + fn header_value(self) -> HeaderValue { + let bytes = Bytes::from(self.to_string()); + HeaderValue::from_maybe_shared(bytes).expect("Uri is a valid HeaderValue") + } + } +} diff --git a/third_party/rust/warp/src/reject.rs b/third_party/rust/warp/src/reject.rs new file mode 100644 index 0000000000..fd09188412 --- /dev/null +++ b/third_party/rust/warp/src/reject.rs @@ -0,0 +1,844 @@ +//! Rejections +//! +//! Part of the power of the [`Filter`](../trait.Filter.html) system is being able to +//! reject a request from a filter chain. This allows for filters to be +//! combined with `or`, so that if one side of the chain finds that a request +//! doesn't fulfill its requirements, the other side can try to process +//! the request. +//! +//! Many of the built-in [`filters`](../filters) will automatically reject +//! the request with an appropriate rejection. However, you can also build +//! new custom [`Filter`](../trait.Filter.html)s and still want other routes to be +//! matchable in the case a predicate doesn't hold. +//! +//! As a request is processed by a Filter chain, the rejections are accumulated into +//! a list contained by the [`Rejection`](struct.Rejection.html) type. Rejections from +//! filters can be handled using [`Filter::recover`](../trait.Filter.html#method.recover). +//! This is a convenient way to map rejections into a [`Reply`](../reply/trait.Reply.html). +//! +//! For a more complete example see the +//! [Rejection Example](https://github.com/seanmonstar/warp/blob/master/examples/rejections.rs) +//! from the repository. +//! +//! # Example +//! +//! ``` +//! use warp::{reply, Reply, Filter, reject, Rejection, http::StatusCode}; +//! +//! #[derive(Debug)] +//! struct InvalidParameter; +//! +//! impl reject::Reject for InvalidParameter {} +//! +//! // Custom rejection handler that maps rejections into responses. +//! async fn handle_rejection(err: Rejection) -> Result<impl Reply, std::convert::Infallible> { +//! if err.is_not_found() { +//! Ok(reply::with_status("NOT_FOUND", StatusCode::NOT_FOUND)) +//! } else if let Some(e) = err.find::<InvalidParameter>() { +//! Ok(reply::with_status("BAD_REQUEST", StatusCode::BAD_REQUEST)) +//! } else { +//! eprintln!("unhandled rejection: {:?}", err); +//! Ok(reply::with_status("INTERNAL_SERVER_ERROR", StatusCode::INTERNAL_SERVER_ERROR)) +//! } +//! } +//! +//! +//! // Filter on `/:id`, but reject with InvalidParameter if the `id` is `0`. +//! // Recover from this rejection using a custom rejection handler. +//! let route = warp::path::param() +//! .and_then(|id: u32| async move { +//! if id == 0 { +//! Err(warp::reject::custom(InvalidParameter)) +//! } else { +//! Ok("id is valid") +//! } +//! }) +//! .recover(handle_rejection); +//! ``` + +use std::any::Any; +use std::convert::Infallible; +use std::error::Error as StdError; +use std::fmt; + +use http::{ + self, + header::{HeaderValue, CONTENT_TYPE}, + StatusCode, +}; +use hyper::Body; + +pub(crate) use self::sealed::{CombineRejection, IsReject}; + +/// Rejects a request with `404 Not Found`. +#[inline] +pub fn reject() -> Rejection { + not_found() +} + +/// Rejects a request with `404 Not Found`. +#[inline] +pub fn not_found() -> Rejection { + Rejection { + reason: Reason::NotFound, + } +} + +// 400 Bad Request +#[inline] +pub(crate) fn invalid_query() -> Rejection { + known(InvalidQuery { _p: () }) +} + +// 400 Bad Request +#[inline] +pub(crate) fn missing_header(name: &'static str) -> Rejection { + known(MissingHeader { name }) +} + +// 400 Bad Request +#[inline] +pub(crate) fn invalid_header(name: &'static str) -> Rejection { + known(InvalidHeader { name }) +} + +// 400 Bad Request +#[inline] +pub(crate) fn missing_cookie(name: &'static str) -> Rejection { + known(MissingCookie { name }) +} + +// 405 Method Not Allowed +#[inline] +pub(crate) fn method_not_allowed() -> Rejection { + known(MethodNotAllowed { _p: () }) +} + +// 411 Length Required +#[inline] +pub(crate) fn length_required() -> Rejection { + known(LengthRequired { _p: () }) +} + +// 413 Payload Too Large +#[inline] +pub(crate) fn payload_too_large() -> Rejection { + known(PayloadTooLarge { _p: () }) +} + +// 415 Unsupported Media Type +// +// Used by the body filters if the request payload content-type doesn't match +// what can be deserialized. +#[inline] +pub(crate) fn unsupported_media_type() -> Rejection { + known(UnsupportedMediaType { _p: () }) +} + +/// Rejects a request with a custom cause. +/// +/// A [`recover`][] filter should convert this `Rejection` into a `Reply`, +/// or else this will be returned as a `500 Internal Server Error`. +/// +/// [`recover`]: ../trait.Filter.html#method.recover +pub fn custom<T: Reject>(err: T) -> Rejection { + Rejection::custom(Box::new(err)) +} + +/// Protect against re-rejecting a rejection. +/// +/// ```compile_fail +/// fn with(r: warp::Rejection) { +/// let _wat = warp::reject::custom(r); +/// } +/// ``` +fn __reject_custom_compilefail() {} + +/// A marker trait to ensure proper types are used for custom rejections. +/// +/// Can be converted into Rejection. +/// +/// # Example +/// +/// ``` +/// use warp::{Filter, reject::Reject}; +/// +/// #[derive(Debug)] +/// struct RateLimited; +/// +/// impl Reject for RateLimited {} +/// +/// let route = warp::any().and_then(|| async { +/// Err::<(), _>(warp::reject::custom(RateLimited)) +/// }); +/// ``` +// Require `Sized` for now to prevent passing a `Box<dyn Reject>`, since we +// would be double-boxing it, and the downcasting wouldn't work as expected. +pub trait Reject: fmt::Debug + Sized + Send + Sync + 'static {} + +trait Cause: fmt::Debug + Send + Sync + 'static { + fn as_any(&self) -> &dyn Any; +} + +impl<T> Cause for T +where + T: fmt::Debug + Send + Sync + 'static, +{ + fn as_any(&self) -> &dyn Any { + self + } +} + +impl dyn Cause { + fn downcast_ref<T: Any>(&self) -> Option<&T> { + self.as_any().downcast_ref::<T>() + } +} + +pub(crate) fn known<T: Into<Known>>(err: T) -> Rejection { + Rejection::known(err.into()) +} + +/// Rejection of a request by a [`Filter`](crate::Filter). +/// +/// See the [`reject`](module@crate::reject) documentation for more. +pub struct Rejection { + reason: Reason, +} + +enum Reason { + NotFound, + Other(Box<Rejections>), +} + +enum Rejections { + Known(Known), + Custom(Box<dyn Cause>), + Combined(Box<Rejections>, Box<Rejections>), +} + +macro_rules! enum_known { + ($($(#[$attr:meta])* $var:ident($ty:path),)+) => ( + pub(crate) enum Known { + $( + $(#[$attr])* + $var($ty), + )+ + } + + impl Known { + fn inner_as_any(&self) -> &dyn Any { + match *self { + $( + $(#[$attr])* + Known::$var(ref t) => t, + )+ + } + } + } + + impl fmt::Debug for Known { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + $( + $(#[$attr])* + Known::$var(ref t) => t.fmt(f), + )+ + } + } + } + + impl fmt::Display for Known { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + $( + $(#[$attr])* + Known::$var(ref t) => t.fmt(f), + )+ + } + } + } + + $( + #[doc(hidden)] + $(#[$attr])* + impl From<$ty> for Known { + fn from(ty: $ty) -> Known { + Known::$var(ty) + } + } + )+ + ); +} + +enum_known! { + MethodNotAllowed(MethodNotAllowed), + InvalidHeader(InvalidHeader), + MissingHeader(MissingHeader), + MissingCookie(MissingCookie), + InvalidQuery(InvalidQuery), + LengthRequired(LengthRequired), + PayloadTooLarge(PayloadTooLarge), + UnsupportedMediaType(UnsupportedMediaType), + FileOpenError(crate::fs::FileOpenError), + FilePermissionError(crate::fs::FilePermissionError), + BodyReadError(crate::body::BodyReadError), + BodyDeserializeError(crate::body::BodyDeserializeError), + CorsForbidden(crate::cors::CorsForbidden), + #[cfg(feature = "websocket")] + MissingConnectionUpgrade(crate::ws::MissingConnectionUpgrade), + MissingExtension(crate::ext::MissingExtension), + BodyConsumedMultipleTimes(crate::body::BodyConsumedMultipleTimes), +} + +impl Rejection { + fn known(known: Known) -> Self { + Rejection { + reason: Reason::Other(Box::new(Rejections::Known(known))), + } + } + + fn custom(other: Box<dyn Cause>) -> Self { + Rejection { + reason: Reason::Other(Box::new(Rejections::Custom(other))), + } + } + + /// Searches this `Rejection` for a specific cause. + /// + /// A `Rejection` will accumulate causes over a `Filter` chain. This method + /// can search through them and return the first cause of this type. + /// + /// # Example + /// + /// ``` + /// #[derive(Debug)] + /// struct Nope; + /// + /// impl warp::reject::Reject for Nope {} + /// + /// let reject = warp::reject::custom(Nope); + /// + /// if let Some(nope) = reject.find::<Nope>() { + /// println!("found it: {:?}", nope); + /// } + /// ``` + pub fn find<T: 'static>(&self) -> Option<&T> { + if let Reason::Other(ref rejections) = self.reason { + return rejections.find(); + } + None + } + + /// Returns true if this Rejection was made via `warp::reject::not_found`. + /// + /// # Example + /// + /// ``` + /// let rejection = warp::reject(); + /// + /// assert!(rejection.is_not_found()); + /// ``` + pub fn is_not_found(&self) -> bool { + matches!(self.reason, Reason::NotFound) + } +} + +impl<T: Reject> From<T> for Rejection { + #[inline] + fn from(err: T) -> Rejection { + custom(err) + } +} + +impl From<Infallible> for Rejection { + #[inline] + fn from(infallible: Infallible) -> Rejection { + match infallible {} + } +} + +impl IsReject for Infallible { + fn status(&self) -> StatusCode { + match *self {} + } + + fn into_response(&self) -> crate::reply::Response { + match *self {} + } +} + +impl IsReject for Rejection { + fn status(&self) -> StatusCode { + match self.reason { + Reason::NotFound => StatusCode::NOT_FOUND, + Reason::Other(ref other) => other.status(), + } + } + + fn into_response(&self) -> crate::reply::Response { + match self.reason { + Reason::NotFound => { + let mut res = http::Response::default(); + *res.status_mut() = StatusCode::NOT_FOUND; + res + } + Reason::Other(ref other) => other.into_response(), + } + } +} + +impl fmt::Debug for Rejection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Rejection").field(&self.reason).finish() + } +} + +impl fmt::Debug for Reason { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Reason::NotFound => f.write_str("NotFound"), + Reason::Other(ref other) => match **other { + Rejections::Known(ref e) => fmt::Debug::fmt(e, f), + Rejections::Custom(ref e) => fmt::Debug::fmt(e, f), + Rejections::Combined(ref a, ref b) => { + let mut list = f.debug_list(); + a.debug_list(&mut list); + b.debug_list(&mut list); + list.finish() + } + }, + } + } +} + +// ===== Rejections ===== + +impl Rejections { + fn status(&self) -> StatusCode { + match *self { + Rejections::Known(ref k) => match *k { + Known::MethodNotAllowed(_) => StatusCode::METHOD_NOT_ALLOWED, + Known::InvalidHeader(_) + | Known::MissingHeader(_) + | Known::MissingCookie(_) + | Known::InvalidQuery(_) + | Known::BodyReadError(_) + | Known::BodyDeserializeError(_) => StatusCode::BAD_REQUEST, + #[cfg(feature = "websocket")] + Known::MissingConnectionUpgrade(_) => StatusCode::BAD_REQUEST, + Known::LengthRequired(_) => StatusCode::LENGTH_REQUIRED, + Known::PayloadTooLarge(_) => StatusCode::PAYLOAD_TOO_LARGE, + Known::UnsupportedMediaType(_) => StatusCode::UNSUPPORTED_MEDIA_TYPE, + Known::FilePermissionError(_) | Known::CorsForbidden(_) => StatusCode::FORBIDDEN, + Known::FileOpenError(_) + | Known::MissingExtension(_) + | Known::BodyConsumedMultipleTimes(_) => StatusCode::INTERNAL_SERVER_ERROR, + }, + Rejections::Custom(..) => StatusCode::INTERNAL_SERVER_ERROR, + Rejections::Combined(ref a, ref b) => preferred(a, b).status(), + } + } + + fn into_response(&self) -> crate::reply::Response { + match *self { + Rejections::Known(ref e) => { + let mut res = http::Response::new(Body::from(e.to_string())); + *res.status_mut() = self.status(); + res.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("text/plain; charset=utf-8"), + ); + res + } + Rejections::Custom(ref e) => { + tracing::error!( + "unhandled custom rejection, returning 500 response: {:?}", + e + ); + let body = format!("Unhandled rejection: {:?}", e); + let mut res = http::Response::new(Body::from(body)); + *res.status_mut() = self.status(); + res.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("text/plain; charset=utf-8"), + ); + res + } + Rejections::Combined(ref a, ref b) => preferred(a, b).into_response(), + } + } + + fn find<T: 'static>(&self) -> Option<&T> { + match *self { + Rejections::Known(ref e) => e.inner_as_any().downcast_ref(), + Rejections::Custom(ref e) => e.downcast_ref(), + Rejections::Combined(ref a, ref b) => a.find().or_else(|| b.find()), + } + } + + fn debug_list(&self, f: &mut fmt::DebugList<'_, '_>) { + match *self { + Rejections::Known(ref e) => { + f.entry(e); + } + Rejections::Custom(ref e) => { + f.entry(e); + } + Rejections::Combined(ref a, ref b) => { + a.debug_list(f); + b.debug_list(f); + } + } + } +} + +fn preferred<'a>(a: &'a Rejections, b: &'a Rejections) -> &'a Rejections { + // Compare status codes, with this priority: + // - NOT_FOUND is lowest + // - METHOD_NOT_ALLOWED is second + // - if one status code is greater than the other + // - otherwise, prefer A... + match (a.status(), b.status()) { + (_, StatusCode::NOT_FOUND) => a, + (StatusCode::NOT_FOUND, _) => b, + (_, StatusCode::METHOD_NOT_ALLOWED) => a, + (StatusCode::METHOD_NOT_ALLOWED, _) => b, + (sa, sb) if sa < sb => b, + _ => a, + } +} + +unit_error! { + /// Invalid query + pub InvalidQuery: "Invalid query string" +} + +unit_error! { + /// HTTP method not allowed + pub MethodNotAllowed: "HTTP method not allowed" +} + +unit_error! { + /// A content-length header is required + pub LengthRequired: "A content-length header is required" +} + +unit_error! { + /// The request payload is too large + pub PayloadTooLarge: "The request payload is too large" +} + +unit_error! { + /// The request's content-type is not supported + pub UnsupportedMediaType: "The request's content-type is not supported" +} + +/// Missing request header +#[derive(Debug)] +pub struct MissingHeader { + name: &'static str, +} + +impl MissingHeader { + /// Retrieve the name of the header that was missing + pub fn name(&self) -> &str { + self.name + } +} + +impl fmt::Display for MissingHeader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Missing request header {:?}", self.name) + } +} + +impl StdError for MissingHeader {} + +/// Invalid request header +#[derive(Debug)] +pub struct InvalidHeader { + name: &'static str, +} + +impl InvalidHeader { + /// Retrieve the name of the header that was invalid + pub fn name(&self) -> &str { + self.name + } +} + +impl fmt::Display for InvalidHeader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Invalid request header {:?}", self.name) + } +} + +impl StdError for InvalidHeader {} + +/// Missing cookie +#[derive(Debug)] +pub struct MissingCookie { + name: &'static str, +} + +impl MissingCookie { + /// Retrieve the name of the cookie that was missing + pub fn name(&self) -> &str { + self.name + } +} + +impl fmt::Display for MissingCookie { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Missing request cookie {:?}", self.name) + } +} + +impl StdError for MissingCookie {} + +mod sealed { + use super::{Reason, Rejection, Rejections}; + use http::StatusCode; + use std::convert::Infallible; + use std::fmt; + + // This sealed trait exists to allow Filters to return either `Rejection` + // or `!`. There are no other types that make sense, and so it is sealed. + pub trait IsReject: fmt::Debug + Send + Sync { + fn status(&self) -> StatusCode; + fn into_response(&self) -> crate::reply::Response; + } + + fn _assert_object_safe() { + fn _assert(_: &dyn IsReject) {} + } + + // This weird trait is to allow optimizations of propagating when a + // rejection can *never* happen (currently with the `Never` type, + // eventually to be replaced with `!`). + // + // Using this trait means the `Never` gets propagated to chained filters, + // allowing LLVM to eliminate more code paths. Without it, such as just + // requiring that `Rejection::from(Never)` were used in those filters, + // would mean that links later in the chain may assume a rejection *could* + // happen, and no longer eliminate those branches. + pub trait CombineRejection<E>: Send + Sized { + /// The type that should be returned when only 1 of the two + /// "rejections" occurs. + /// + /// # For example: + /// + /// `warp::any().and(warp::path("foo"))` has the following steps: + /// + /// 1. Since this is `and`, only **one** of the rejections will occur, + /// and as soon as it does, it will be returned. + /// 2. `warp::any()` rejects with `Never`. So, it will never return `Never`. + /// 3. `warp::path()` rejects with `Rejection`. It may return `Rejection`. + /// + /// Thus, if the above filter rejects, it will definitely be `Rejection`. + type One: IsReject + From<Self> + From<E> + Into<Rejection>; + + /// The type that should be returned when both rejections occur, + /// and need to be combined. + type Combined: IsReject; + + fn combine(self, other: E) -> Self::Combined; + } + + impl CombineRejection<Rejection> for Rejection { + type One = Rejection; + type Combined = Rejection; + + fn combine(self, other: Rejection) -> Self::Combined { + let reason = match (self.reason, other.reason) { + (Reason::Other(left), Reason::Other(right)) => { + Reason::Other(Box::new(Rejections::Combined(left, right))) + } + (Reason::Other(other), Reason::NotFound) + | (Reason::NotFound, Reason::Other(other)) => { + // ignore the NotFound + Reason::Other(other) + } + (Reason::NotFound, Reason::NotFound) => Reason::NotFound, + }; + + Rejection { reason } + } + } + + impl CombineRejection<Infallible> for Rejection { + type One = Rejection; + type Combined = Infallible; + + fn combine(self, other: Infallible) -> Self::Combined { + match other {} + } + } + + impl CombineRejection<Rejection> for Infallible { + type One = Rejection; + type Combined = Infallible; + + fn combine(self, _: Rejection) -> Self::Combined { + match self {} + } + } + + impl CombineRejection<Infallible> for Infallible { + type One = Infallible; + type Combined = Infallible; + + fn combine(self, _: Infallible) -> Self::Combined { + match self {} + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http::StatusCode; + + #[derive(Debug, PartialEq)] + struct Left; + + #[derive(Debug, PartialEq)] + struct Right; + + impl Reject for Left {} + impl Reject for Right {} + + #[test] + fn rejection_status() { + assert_eq!(not_found().status(), StatusCode::NOT_FOUND); + assert_eq!( + method_not_allowed().status(), + StatusCode::METHOD_NOT_ALLOWED + ); + assert_eq!(length_required().status(), StatusCode::LENGTH_REQUIRED); + assert_eq!(payload_too_large().status(), StatusCode::PAYLOAD_TOO_LARGE); + assert_eq!( + unsupported_media_type().status(), + StatusCode::UNSUPPORTED_MEDIA_TYPE + ); + assert_eq!(custom(Left).status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + #[tokio::test] + async fn combine_rejection_causes_with_some_left_and_none_right() { + let left = custom(Left); + let right = not_found(); + let reject = left.combine(right); + let resp = reject.into_response(); + + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + response_body_string(resp).await, + "Unhandled rejection: Left" + ) + } + + #[tokio::test] + async fn combine_rejection_causes_with_none_left_and_some_right() { + let left = not_found(); + let right = custom(Right); + let reject = left.combine(right); + let resp = reject.into_response(); + + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + response_body_string(resp).await, + "Unhandled rejection: Right" + ) + } + + #[tokio::test] + async fn unhandled_customs() { + let reject = not_found().combine(custom(Right)); + + let resp = reject.into_response(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + response_body_string(resp).await, + "Unhandled rejection: Right" + ); + + // There's no real way to determine which is worse, since both are a 500, + // so pick the first one. + let reject = custom(Left).combine(custom(Right)); + + let resp = reject.into_response(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + response_body_string(resp).await, + "Unhandled rejection: Left" + ); + + // With many rejections, custom still is top priority. + let reject = not_found() + .combine(not_found()) + .combine(not_found()) + .combine(custom(Right)) + .combine(not_found()); + + let resp = reject.into_response(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + response_body_string(resp).await, + "Unhandled rejection: Right" + ); + } + + async fn response_body_string(resp: crate::reply::Response) -> String { + let (_, body) = resp.into_parts(); + let body_bytes = hyper::body::to_bytes(body).await.expect("failed concat"); + String::from_utf8_lossy(&body_bytes).to_string() + } + + #[test] + fn find_cause() { + let rej = custom(Left); + + assert_eq!(rej.find::<Left>(), Some(&Left)); + + let rej = rej.combine(method_not_allowed()); + + assert_eq!(rej.find::<Left>(), Some(&Left)); + assert!(rej.find::<MethodNotAllowed>().is_some(), "MethodNotAllowed"); + } + + #[test] + fn size_of_rejection() { + assert_eq!( + ::std::mem::size_of::<Rejection>(), + ::std::mem::size_of::<usize>(), + ); + } + + #[derive(Debug)] + struct X(u32); + impl Reject for X {} + + fn combine_n<F, R>(n: u32, new_reject: F) -> Rejection + where + F: Fn(u32) -> R, + R: Reject, + { + let mut rej = not_found(); + + for i in 0..n { + rej = rej.combine(custom(new_reject(i))); + } + + rej + } + + #[test] + fn test_debug() { + let rej = combine_n(3, X); + + let s = format!("{:?}", rej); + assert_eq!(s, "Rejection([X(0), X(1), X(2)])"); + } +} diff --git a/third_party/rust/warp/src/reply.rs b/third_party/rust/warp/src/reply.rs new file mode 100644 index 0000000000..74dee278de --- /dev/null +++ b/third_party/rust/warp/src/reply.rs @@ -0,0 +1,584 @@ +//! Reply to requests. +//! +//! A [`Reply`](./trait.Reply.html) is a type that can be converted into an HTTP +//! response to be sent to the client. These are typically the successful +//! counterpart to a [rejection](../reject). +//! +//! The functions in this module are helpers for quickly creating a reply. +//! Besides them, you can return a type that implements [`Reply`](./trait.Reply.html). This +//! could be any of the following: +//! +//! - [`http::Response<impl Into<hyper::Body>>`](https://docs.rs/http) +//! - `String` +//! - `&'static str` +//! - `http::StatusCode` +//! +//! # Example +//! +//! ``` +//! use warp::{Filter, http::Response}; +//! +//! // Returns an empty `200 OK` response. +//! let empty_200 = warp::any().map(warp::reply); +//! +//! // Returns a `200 OK` response with custom header and body. +//! let custom = warp::any().map(|| { +//! Response::builder() +//! .header("my-custom-header", "some-value") +//! .body("and a custom body") +//! }); +//! +//! // GET requests return the empty 200, POST return the custom. +//! let routes = warp::get().and(empty_200) +//! .or(warp::post().and(custom)); +//! ``` + +use std::borrow::Cow; +use std::convert::TryFrom; +use std::error::Error as StdError; +use std::fmt; + +use crate::generic::{Either, One}; +use http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; +use http::StatusCode; +use hyper::Body; +use serde::Serialize; +use serde_json; + +// This re-export just looks weird in docs... +pub(crate) use self::sealed::Reply_; +use self::sealed::{BoxedReply, Internal}; +#[doc(hidden)] +pub use crate::filters::reply as with; + +/// Response type into which types implementing the `Reply` trait are convertable. +pub type Response = ::http::Response<Body>; + +/// Returns an empty `Reply` with status code `200 OK`. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// // GET /just-ok returns an empty `200 OK`. +/// let route = warp::path("just-ok") +/// .map(|| { +/// println!("got a /just-ok request!"); +/// warp::reply() +/// }); +/// ``` +#[inline] +pub fn reply() -> impl Reply { + StatusCode::OK +} + +/// Convert the value into a `Reply` with the value encoded as JSON. +/// +/// The passed value must implement [`Serialize`][ser]. Many +/// collections do, and custom domain types can have `Serialize` derived. +/// +/// [ser]: https://serde.rs +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// // GET /ids returns a `200 OK` with a JSON array of ids: +/// // `[1, 3, 7, 13]` +/// let route = warp::path("ids") +/// .map(|| { +/// let our_ids = vec![1, 3, 7, 13]; +/// warp::reply::json(&our_ids) +/// }); +/// ``` +/// +/// # Note +/// +/// If a type fails to be serialized into JSON, the error is logged at the +/// `error` level, and the returned `impl Reply` will be an empty +/// `500 Internal Server Error` response. +pub fn json<T>(val: &T) -> Json +where + T: Serialize, +{ + Json { + inner: serde_json::to_vec(val).map_err(|err| { + tracing::error!("reply::json error: {}", err); + }), + } +} + +/// A JSON formatted reply. +#[allow(missing_debug_implementations)] +pub struct Json { + inner: Result<Vec<u8>, ()>, +} + +impl Reply for Json { + #[inline] + fn into_response(self) -> Response { + match self.inner { + Ok(body) => { + let mut res = Response::new(body.into()); + res.headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + res + } + Err(()) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), + } + } +} + +#[derive(Debug)] +pub(crate) struct ReplyJsonError; + +impl fmt::Display for ReplyJsonError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("warp::reply::json() failed") + } +} + +impl StdError for ReplyJsonError {} + +/// Reply with a body and `content-type` set to `text/html; charset=utf-8`. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let body = r#" +/// <html> +/// <head> +/// <title>HTML with warp!</title> +/// </head> +/// <body> +/// <h1>warp + HTML = ♥</h1> +/// </body> +/// </html> +/// "#; +/// +/// let route = warp::any() +/// .map(move || { +/// warp::reply::html(body) +/// }); +/// ``` +pub fn html<T>(body: T) -> Html<T> +where + Body: From<T>, + T: Send, +{ + Html { body } +} + +/// An HTML reply. +#[allow(missing_debug_implementations)] +pub struct Html<T> { + body: T, +} + +impl<T> Reply for Html<T> +where + Body: From<T>, + T: Send, +{ + #[inline] + fn into_response(self) -> Response { + let mut res = Response::new(Body::from(self.body)); + res.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("text/html; charset=utf-8"), + ); + res + } +} + +/// Types that can be converted into a `Response`. +/// +/// This trait is implemented for the following: +/// +/// - `http::StatusCode` +/// - `http::Response<impl Into<hyper::Body>>` +/// - `String` +/// - `&'static str` +/// +/// # Example +/// +/// ```rust +/// use warp::{Filter, http::Response}; +/// +/// struct Message { +/// msg: String +/// } +/// +/// impl warp::Reply for Message { +/// fn into_response(self) -> warp::reply::Response { +/// Response::new(format!("message: {}", self.msg).into()) +/// } +/// } +/// +/// fn handler() -> Message { +/// Message { msg: "Hello".to_string() } +/// } +/// +/// let route = warp::any().map(handler); +/// ``` +pub trait Reply: BoxedReply + Send { + /// Converts the given value into a [`Response`]. + /// + /// [`Response`]: type.Response.html + fn into_response(self) -> Response; + + /* + TODO: Currently unsure about having trait methods here, as it + requires returning an exact type, which I'd rather not commit to. + Additionally, it doesn't work great with `Box<Reply>`. + + A possible alternative is to have wrappers, like + + - `WithStatus<R: Reply>(StatusCode, R)` + + + /// Change the status code of this `Reply`. + fn with_status(self, status: StatusCode) -> Reply_ + where + Self: Sized, + { + let mut res = self.into_response(); + *res.status_mut() = status; + Reply_(res) + } + + /// Add a header to this `Reply`. + /// + /// # Example + /// + /// ```rust + /// use warp::Reply; + /// + /// let reply = warp::reply() + /// .with_header("x-foo", "bar"); + /// ``` + fn with_header<K, V>(self, name: K, value: V) -> Reply_ + where + Self: Sized, + HeaderName: TryFrom<K>, + HeaderValue: TryFrom<V>, + { + match <HeaderName as TryFrom<K>>::try_from(name) { + Ok(name) => match <HeaderValue as TryFrom<V>>::try_from(value) { + Ok(value) => { + let mut res = self.into_response(); + res.headers_mut().append(name, value); + Reply_(res) + }, + Err(err) => { + tracing::error!("with_header value error: {}", err.into()); + Reply_(::reject::server_error() + .into_response()) + } + }, + Err(err) => { + tracing::error!("with_header name error: {}", err.into()); + Reply_(::reject::server_error() + .into_response()) + } + } + } + */ +} + +impl<T: Reply + ?Sized> Reply for Box<T> { + fn into_response(self) -> Response { + self.boxed_into_response(Internal) + } +} + +fn _assert_object_safe() { + fn _assert(_: &dyn Reply) {} +} + +/// Wrap an `impl Reply` to change its `StatusCode`. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let route = warp::any() +/// .map(warp::reply) +/// .map(|reply| { +/// warp::reply::with_status(reply, warp::http::StatusCode::CREATED) +/// }); +/// ``` +pub fn with_status<T: Reply>(reply: T, status: StatusCode) -> WithStatus<T> { + WithStatus { reply, status } +} + +/// Wrap an `impl Reply` to change its `StatusCode`. +/// +/// Returned by `warp::reply::with_status`. +#[derive(Debug)] +pub struct WithStatus<T> { + reply: T, + status: StatusCode, +} + +impl<T: Reply> Reply for WithStatus<T> { + fn into_response(self) -> Response { + let mut res = self.reply.into_response(); + *res.status_mut() = self.status; + res + } +} + +/// Wrap an `impl Reply` to add a header when rendering. +/// +/// # Example +/// +/// ``` +/// use warp::Filter; +/// +/// let route = warp::any() +/// .map(warp::reply) +/// .map(|reply| { +/// warp::reply::with_header(reply, "server", "warp") +/// }); +/// ``` +pub fn with_header<T: Reply, K, V>(reply: T, name: K, value: V) -> WithHeader<T> +where + HeaderName: TryFrom<K>, + <HeaderName as TryFrom<K>>::Error: Into<http::Error>, + HeaderValue: TryFrom<V>, + <HeaderValue as TryFrom<V>>::Error: Into<http::Error>, +{ + let header = match <HeaderName as TryFrom<K>>::try_from(name) { + Ok(name) => match <HeaderValue as TryFrom<V>>::try_from(value) { + Ok(value) => Some((name, value)), + Err(err) => { + let err = err.into(); + tracing::error!("with_header value error: {}", err); + None + } + }, + Err(err) => { + let err = err.into(); + tracing::error!("with_header name error: {}", err); + None + } + }; + + WithHeader { header, reply } +} + +/// Wraps an `impl Reply` and adds a header when rendering. +/// +/// Returned by `warp::reply::with_header`. +#[derive(Debug)] +pub struct WithHeader<T> { + header: Option<(HeaderName, HeaderValue)>, + reply: T, +} + +impl<T: Reply> Reply for WithHeader<T> { + fn into_response(self) -> Response { + let mut res = self.reply.into_response(); + if let Some((name, value)) = self.header { + res.headers_mut().insert(name, value); + } + res + } +} + +impl<T: Send> Reply for ::http::Response<T> +where + Body: From<T>, +{ + #[inline] + fn into_response(self) -> Response { + self.map(Body::from) + } +} + +impl Reply for ::http::StatusCode { + #[inline] + fn into_response(self) -> Response { + let mut res = Response::default(); + *res.status_mut() = self; + res + } +} + +impl<T> Reply for Result<T, ::http::Error> +where + T: Reply + Send, +{ + #[inline] + fn into_response(self) -> Response { + match self { + Ok(t) => t.into_response(), + Err(e) => { + tracing::error!("reply error: {:?}", e); + StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + } + } +} + +fn text_plain<T: Into<Body>>(body: T) -> Response { + let mut response = ::http::Response::new(body.into()); + response.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("text/plain; charset=utf-8"), + ); + response +} + +impl Reply for String { + #[inline] + fn into_response(self) -> Response { + text_plain(self) + } +} + +impl Reply for Vec<u8> { + #[inline] + fn into_response(self) -> Response { + ::http::Response::builder() + .header( + CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ) + .body(Body::from(self)) + .unwrap() + } +} + +impl Reply for &'static str { + #[inline] + fn into_response(self) -> Response { + text_plain(self) + } +} + +impl Reply for Cow<'static, str> { + #[inline] + fn into_response(self) -> Response { + match self { + Cow::Borrowed(s) => s.into_response(), + Cow::Owned(s) => s.into_response(), + } + } +} + +impl Reply for &'static [u8] { + #[inline] + fn into_response(self) -> Response { + ::http::Response::builder() + .header( + CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ) + .body(Body::from(self)) + .unwrap() + } +} + +impl<T, U> Reply for Either<T, U> +where + T: Reply, + U: Reply, +{ + #[inline] + fn into_response(self) -> Response { + match self { + Either::A(a) => a.into_response(), + Either::B(b) => b.into_response(), + } + } +} + +impl<T> Reply for One<T> +where + T: Reply, +{ + #[inline] + fn into_response(self) -> Response { + self.0.into_response() + } +} + +impl Reply for std::convert::Infallible { + #[inline(always)] + fn into_response(self) -> Response { + match self {} + } +} + +mod sealed { + use super::{Reply, Response}; + + // An opaque type to return `impl Reply` from trait methods. + #[allow(missing_debug_implementations)] + pub struct Reply_(pub(crate) Response); + + impl Reply for Reply_ { + #[inline] + fn into_response(self) -> Response { + self.0 + } + } + + #[allow(missing_debug_implementations)] + pub struct Internal; + + // Implemented for all types that implement `Reply`. + // + // A user doesn't need to worry about this, it's just trait + // hackery to get `Box<dyn Reply>` working. + pub trait BoxedReply { + fn boxed_into_response(self: Box<Self>, internal: Internal) -> Response; + } + + impl<T: Reply> BoxedReply for T { + fn boxed_into_response(self: Box<Self>, _: Internal) -> Response { + (*self).into_response() + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::*; + + #[test] + fn json_serde_error() { + // a HashMap<Vec, _> cannot be serialized to JSON + let mut map = HashMap::new(); + map.insert(vec![1, 2], 45); + + let res = json(&map).into_response(); + assert_eq!(res.status(), 500); + } + + #[test] + fn response_builder_error() { + let res = ::http::Response::builder() + .status(1337) + .body("woops") + .into_response(); + + assert_eq!(res.status(), 500); + } + + #[test] + fn boxed_reply() { + let r: Box<dyn Reply> = Box::new(reply()); + let resp = r.into_response(); + assert_eq!(resp.status(), 200); + } +} diff --git a/third_party/rust/warp/src/route.rs b/third_party/rust/warp/src/route.rs new file mode 100644 index 0000000000..afbac4d8ba --- /dev/null +++ b/third_party/rust/warp/src/route.rs @@ -0,0 +1,140 @@ +use scoped_tls::scoped_thread_local; +use std::cell::RefCell; +use std::mem; +use std::net::SocketAddr; + +use hyper::Body; + +use crate::Request; + +scoped_thread_local!(static ROUTE: RefCell<Route>); + +pub(crate) fn set<F, U>(r: &RefCell<Route>, func: F) -> U +where + F: FnOnce() -> U, +{ + ROUTE.set(r, func) +} + +pub(crate) fn is_set() -> bool { + ROUTE.is_set() +} + +pub(crate) fn with<F, R>(func: F) -> R +where + F: FnOnce(&mut Route) -> R, +{ + ROUTE.with(move |route| func(&mut *route.borrow_mut())) +} + +#[derive(Debug)] +pub(crate) struct Route { + body: BodyState, + remote_addr: Option<SocketAddr>, + req: Request, + segments_index: usize, +} + +#[derive(Debug)] +enum BodyState { + Ready, + Taken, +} + +impl Route { + pub(crate) fn new(req: Request, remote_addr: Option<SocketAddr>) -> RefCell<Route> { + let segments_index = if req.uri().path().starts_with('/') { + // Skip the beginning slash. + 1 + } else { + 0 + }; + + RefCell::new(Route { + body: BodyState::Ready, + remote_addr, + req, + segments_index, + }) + } + + pub(crate) fn method(&self) -> &http::Method { + self.req.method() + } + + pub(crate) fn headers(&self) -> &http::HeaderMap { + self.req.headers() + } + + pub(crate) fn version(&self) -> http::Version { + self.req.version() + } + + pub(crate) fn extensions(&self) -> &http::Extensions { + self.req.extensions() + } + + #[cfg(feature = "websocket")] + pub(crate) fn extensions_mut(&mut self) -> &mut http::Extensions { + self.req.extensions_mut() + } + + pub(crate) fn uri(&self) -> &http::Uri { + self.req.uri() + } + + pub(crate) fn path(&self) -> &str { + &self.req.uri().path()[self.segments_index..] + } + + pub(crate) fn full_path(&self) -> &str { + self.req.uri().path() + } + + pub(crate) fn set_unmatched_path(&mut self, index: usize) { + let index = self.segments_index + index; + let path = self.req.uri().path(); + if path.is_empty() { + // malformed path + return; + } else if path.len() == index { + self.segments_index = index; + } else { + debug_assert_eq!(path.as_bytes()[index], b'/'); + self.segments_index = index + 1; + } + } + + pub(crate) fn query(&self) -> Option<&str> { + self.req.uri().query() + } + + pub(crate) fn matched_path_index(&self) -> usize { + self.segments_index + } + + pub(crate) fn reset_matched_path_index(&mut self, index: usize) { + debug_assert!( + index <= self.segments_index, + "reset_match_path_index should not be bigger: current={}, arg={}", + self.segments_index, + index, + ); + self.segments_index = index; + } + + pub(crate) fn remote_addr(&self) -> Option<SocketAddr> { + self.remote_addr + } + + pub(crate) fn take_body(&mut self) -> Option<Body> { + match self.body { + BodyState::Ready => { + let body = mem::replace(self.req.body_mut(), Body::empty()); + self.body = BodyState::Taken; + Some(body) + } + BodyState::Taken => None, + } + } +} diff --git a/third_party/rust/warp/src/server.rs b/third_party/rust/warp/src/server.rs new file mode 100644 index 0000000000..f1eb33b952 --- /dev/null +++ b/third_party/rust/warp/src/server.rs @@ -0,0 +1,576 @@ +#[cfg(feature = "tls")] +use crate::tls::TlsConfigBuilder; +use std::convert::Infallible; +use std::error::Error as StdError; +use std::future::Future; +use std::net::SocketAddr; +#[cfg(feature = "tls")] +use std::path::Path; + +use futures_util::{future, FutureExt, TryFuture, TryStream, TryStreamExt}; +use hyper::server::conn::AddrIncoming; +use hyper::service::{make_service_fn, service_fn}; +use hyper::Server as HyperServer; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::Instrument; + +use crate::filter::Filter; +use crate::reject::IsReject; +use crate::reply::Reply; +use crate::transport::Transport; + +/// Create a `Server` with the provided `Filter`. +pub fn serve<F>(filter: F) -> Server<F> +where + F: Filter + Clone + Send + Sync + 'static, + F::Extract: Reply, + F::Error: IsReject, +{ + Server { + pipeline: false, + filter, + } +} + +/// A Warp Server ready to filter requests. +#[derive(Debug)] +pub struct Server<F> { + pipeline: bool, + filter: F, +} + +/// A Warp Server ready to filter requests over TLS. +/// +/// *This type requires the `"tls"` feature.* +#[cfg(feature = "tls")] +pub struct TlsServer<F> { + server: Server<F>, + tls: TlsConfigBuilder, +} + +// Getting all various generic bounds to make this a re-usable method is +// very complicated, so instead this is just a macro. +macro_rules! into_service { + ($into:expr) => {{ + let inner = crate::service($into); + make_service_fn(move |transport| { + let inner = inner.clone(); + let remote_addr = Transport::remote_addr(transport); + future::ok::<_, Infallible>(service_fn(move |req| { + inner.call_with_addr(req, remote_addr) + })) + }) + }}; +} + +macro_rules! addr_incoming { + ($addr:expr) => {{ + let mut incoming = AddrIncoming::bind($addr)?; + incoming.set_nodelay(true); + let addr = incoming.local_addr(); + (addr, incoming) + }}; +} + +macro_rules! bind_inner { + ($this:ident, $addr:expr) => {{ + let service = into_service!($this.filter); + let (addr, incoming) = addr_incoming!($addr); + let srv = HyperServer::builder(incoming) + .http1_pipeline_flush($this.pipeline) + .serve(service); + Ok::<_, hyper::Error>((addr, srv)) + }}; + + (tls: $this:ident, $addr:expr) => {{ + let service = into_service!($this.server.filter); + let (addr, incoming) = addr_incoming!($addr); + let tls = $this.tls.build()?; + let srv = HyperServer::builder(crate::tls::TlsAcceptor::new(tls, incoming)) + .http1_pipeline_flush($this.server.pipeline) + .serve(service); + Ok::<_, Box<dyn std::error::Error + Send + Sync>>((addr, srv)) + }}; +} + +macro_rules! bind { + ($this:ident, $addr:expr) => {{ + let addr = $addr.into(); + (|addr| bind_inner!($this, addr))(&addr).unwrap_or_else(|e| { + panic!("error binding to {}: {}", addr, e); + }) + }}; + + (tls: $this:ident, $addr:expr) => {{ + let addr = $addr.into(); + (|addr| bind_inner!(tls: $this, addr))(&addr).unwrap_or_else(|e| { + panic!("error binding to {}: {}", addr, e); + }) + }}; +} + +macro_rules! try_bind { + ($this:ident, $addr:expr) => {{ + (|addr| bind_inner!($this, addr))($addr) + }}; + + (tls: $this:ident, $addr:expr) => {{ + (|addr| bind_inner!(tls: $this, addr))($addr) + }}; +} + +// ===== impl Server ===== + +impl<F> Server<F> +where + F: Filter + Clone + Send + Sync + 'static, + <F::Future as TryFuture>::Ok: Reply, + <F::Future as TryFuture>::Error: IsReject, +{ + /// Run this `Server` forever on the current thread. + pub async fn run(self, addr: impl Into<SocketAddr>) { + let (addr, fut) = self.bind_ephemeral(addr); + let span = tracing::info_span!("Server::run", ?addr); + tracing::info!(parent: &span, "listening on http://{}", addr); + + fut.instrument(span).await; + } + + /// Run this `Server` forever on the current thread with a specific stream + /// of incoming connections. + /// + /// This can be used for Unix Domain Sockets, or TLS, etc. + pub async fn run_incoming<I>(self, incoming: I) + where + I: TryStream + Send, + I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin, + I::Error: Into<Box<dyn StdError + Send + Sync>>, + { + self.run_incoming2(incoming.map_ok(crate::transport::LiftIo).into_stream()) + .instrument(tracing::info_span!("Server::run_incoming")) + .await; + } + + async fn run_incoming2<I>(self, incoming: I) + where + I: TryStream + Send, + I::Ok: Transport + Send + 'static + Unpin, + I::Error: Into<Box<dyn StdError + Send + Sync>>, + { + let fut = self.serve_incoming2(incoming); + + tracing::info!("listening with custom incoming"); + + fut.await; + } + + /// Bind to a socket address, returning a `Future` that can be + /// executed on the current runtime. + /// + /// # Panics + /// + /// Panics if we are unable to bind to the provided address. + pub fn bind(self, addr: impl Into<SocketAddr> + 'static) -> impl Future<Output = ()> + 'static { + let (_, fut) = self.bind_ephemeral(addr); + fut + } + + /// Bind to a socket address, returning a `Future` that can be + /// executed on any runtime. + /// + /// In case we are unable to bind to the specified address, resolves to an + /// error and logs the reason. + pub async fn try_bind(self, addr: impl Into<SocketAddr>) { + let addr = addr.into(); + let srv = match try_bind!(self, &addr) { + Ok((_, srv)) => srv, + Err(err) => { + tracing::error!("error binding to {}: {}", addr, err); + return; + } + }; + + srv.map(|result| { + if let Err(err) = result { + tracing::error!("server error: {}", err) + } + }) + .await; + } + + /// Bind to a possibly ephemeral socket address. + /// + /// Returns the bound address and a `Future` that can be executed on + /// the current runtime. + /// + /// # Panics + /// + /// Panics if we are unable to bind to the provided address. + pub fn bind_ephemeral( + self, + addr: impl Into<SocketAddr>, + ) -> (SocketAddr, impl Future<Output = ()> + 'static) { + let (addr, srv) = bind!(self, addr); + let srv = srv.map(|result| { + if let Err(err) = result { + tracing::error!("server error: {}", err) + } + }); + + (addr, srv) + } + + /// Tried to bind a possibly ephemeral socket address. + /// + /// Returns a `Result` which fails in case we are unable to bind with the + /// underlying error. + /// + /// Returns the bound address and a `Future` that can be executed on + /// the current runtime. + pub fn try_bind_ephemeral( + self, + addr: impl Into<SocketAddr>, + ) -> Result<(SocketAddr, impl Future<Output = ()> + 'static), crate::Error> { + let addr = addr.into(); + let (addr, srv) = try_bind!(self, &addr).map_err(crate::Error::new)?; + let srv = srv.map(|result| { + if let Err(err) = result { + tracing::error!("server error: {}", err) + } + }); + + Ok((addr, srv)) + } + + /// Create a server with graceful shutdown signal. + /// + /// When the signal completes, the server will start the graceful shutdown + /// process. + /// + /// Returns the bound address and a `Future` that can be executed on + /// the current runtime. + /// + /// # Example + /// + /// ```no_run + /// use warp::Filter; + /// use futures_util::future::TryFutureExt; + /// use tokio::sync::oneshot; + /// + /// # fn main() { + /// let routes = warp::any() + /// .map(|| "Hello, World!"); + /// + /// let (tx, rx) = oneshot::channel(); + /// + /// let (addr, server) = warp::serve(routes) + /// .bind_with_graceful_shutdown(([127, 0, 0, 1], 3030), async { + /// rx.await.ok(); + /// }); + /// + /// // Spawn the server into a runtime + /// tokio::task::spawn(server); + /// + /// // Later, start the shutdown... + /// let _ = tx.send(()); + /// # } + /// ``` + pub fn bind_with_graceful_shutdown( + self, + addr: impl Into<SocketAddr> + 'static, + signal: impl Future<Output = ()> + Send + 'static, + ) -> (SocketAddr, impl Future<Output = ()> + 'static) { + let (addr, srv) = bind!(self, addr); + let fut = srv.with_graceful_shutdown(signal).map(|result| { + if let Err(err) = result { + tracing::error!("server error: {}", err) + } + }); + (addr, fut) + } + + /// Create a server with graceful shutdown signal. + /// + /// When the signal completes, the server will start the graceful shutdown + /// process. + pub fn try_bind_with_graceful_shutdown( + self, + addr: impl Into<SocketAddr> + 'static, + signal: impl Future<Output = ()> + Send + 'static, + ) -> Result<(SocketAddr, impl Future<Output = ()> + 'static), crate::Error> { + let addr = addr.into(); + let (addr, srv) = try_bind!(self, &addr).map_err(crate::Error::new)?; + let srv = srv.with_graceful_shutdown(signal).map(|result| { + if let Err(err) = result { + tracing::error!("server error: {}", err) + } + }); + + Ok((addr, srv)) + } + + /// Setup this `Server` with a specific stream of incoming connections. + /// + /// This can be used for Unix Domain Sockets, or TLS, etc. + /// + /// Returns a `Future` that can be executed on the current runtime. + pub fn serve_incoming<I>(self, incoming: I) -> impl Future<Output = ()> + where + I: TryStream + Send, + I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin, + I::Error: Into<Box<dyn StdError + Send + Sync>>, + { + let incoming = incoming.map_ok(crate::transport::LiftIo); + self.serve_incoming2(incoming) + .instrument(tracing::info_span!("Server::serve_incoming")) + } + + /// Setup this `Server` with a specific stream of incoming connections and a + /// signal to initiate graceful shutdown. + /// + /// This can be used for Unix Domain Sockets, or TLS, etc. + /// + /// When the signal completes, the server will start the graceful shutdown + /// process. + /// + /// Returns a `Future` that can be executed on the current runtime. + pub fn serve_incoming_with_graceful_shutdown<I>( + self, + incoming: I, + signal: impl Future<Output = ()> + Send + 'static, + ) -> impl Future<Output = ()> + where + I: TryStream + Send, + I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin, + I::Error: Into<Box<dyn StdError + Send + Sync>>, + { + let incoming = incoming.map_ok(crate::transport::LiftIo); + let service = into_service!(self.filter); + let pipeline = self.pipeline; + + async move { + let srv = + HyperServer::builder(hyper::server::accept::from_stream(incoming.into_stream())) + .http1_pipeline_flush(pipeline) + .serve(service) + .with_graceful_shutdown(signal) + .await; + + if let Err(err) = srv { + tracing::error!("server error: {}", err); + } + } + .instrument(tracing::info_span!( + "Server::serve_incoming_with_graceful_shutdown" + )) + } + + async fn serve_incoming2<I>(self, incoming: I) + where + I: TryStream + Send, + I::Ok: Transport + Send + 'static + Unpin, + I::Error: Into<Box<dyn StdError + Send + Sync>>, + { + let service = into_service!(self.filter); + + let srv = HyperServer::builder(hyper::server::accept::from_stream(incoming.into_stream())) + .http1_pipeline_flush(self.pipeline) + .serve(service) + .await; + + if let Err(err) = srv { + tracing::error!("server error: {}", err); + } + } + + // Generally shouldn't be used, as it can slow down non-pipelined responses. + // + // It's only real use is to make silly pipeline benchmarks look better. + #[doc(hidden)] + pub fn unstable_pipeline(mut self) -> Self { + self.pipeline = true; + self + } + + /// Configure a server to use TLS. + /// + /// *This function requires the `"tls"` feature.* + #[cfg(feature = "tls")] + pub fn tls(self) -> TlsServer<F> { + TlsServer { + server: self, + tls: TlsConfigBuilder::new(), + } + } +} + +// // ===== impl TlsServer ===== + +#[cfg(feature = "tls")] +impl<F> TlsServer<F> +where + F: Filter + Clone + Send + Sync + 'static, + <F::Future as TryFuture>::Ok: Reply, + <F::Future as TryFuture>::Error: IsReject, +{ + // TLS config methods + + /// Specify the file path to read the private key. + /// + /// *This function requires the `"tls"` feature.* + pub fn key_path(self, path: impl AsRef<Path>) -> Self { + self.with_tls(|tls| tls.key_path(path)) + } + + /// Specify the file path to read the certificate. + /// + /// *This function requires the `"tls"` feature.* + pub fn cert_path(self, path: impl AsRef<Path>) -> Self { + self.with_tls(|tls| tls.cert_path(path)) + } + + /// Specify the file path to read the trust anchor for optional client authentication. + /// + /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any + /// of the `client_auth_` methods, then client authentication is disabled by default. + /// + /// *This function requires the `"tls"` feature.* + pub fn client_auth_optional_path(self, path: impl AsRef<Path>) -> Self { + self.with_tls(|tls| tls.client_auth_optional_path(path)) + } + + /// Specify the file path to read the trust anchor for required client authentication. + /// + /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the + /// `client_auth_` methods, then client authentication is disabled by default. + /// + /// *This function requires the `"tls"` feature.* + pub fn client_auth_required_path(self, path: impl AsRef<Path>) -> Self { + self.with_tls(|tls| tls.client_auth_required_path(path)) + } + + /// Specify the in-memory contents of the private key. + /// + /// *This function requires the `"tls"` feature.* + pub fn key(self, key: impl AsRef<[u8]>) -> Self { + self.with_tls(|tls| tls.key(key.as_ref())) + } + + /// Specify the in-memory contents of the certificate. + /// + /// *This function requires the `"tls"` feature.* + pub fn cert(self, cert: impl AsRef<[u8]>) -> Self { + self.with_tls(|tls| tls.cert(cert.as_ref())) + } + + /// Specify the in-memory contents of the trust anchor for optional client authentication. + /// + /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any + /// of the `client_auth_` methods, then client authentication is disabled by default. + /// + /// *This function requires the `"tls"` feature.* + pub fn client_auth_optional(self, trust_anchor: impl AsRef<[u8]>) -> Self { + self.with_tls(|tls| tls.client_auth_optional(trust_anchor.as_ref())) + } + + /// Specify the in-memory contents of the trust anchor for required client authentication. + /// + /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the + /// `client_auth_` methods, then client authentication is disabled by default. + /// + /// *This function requires the `"tls"` feature.* + pub fn client_auth_required(self, trust_anchor: impl AsRef<[u8]>) -> Self { + self.with_tls(|tls| tls.client_auth_required(trust_anchor.as_ref())) + } + + /// Specify the DER-encoded OCSP response. + /// + /// *This function requires the `"tls"` feature.* + pub fn ocsp_resp(self, resp: impl AsRef<[u8]>) -> Self { + self.with_tls(|tls| tls.ocsp_resp(resp.as_ref())) + } + + fn with_tls<Func>(self, func: Func) -> Self + where + Func: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder, + { + let TlsServer { server, tls } = self; + let tls = func(tls); + TlsServer { server, tls } + } + + // Server run methods + + /// Run this `TlsServer` forever on the current thread. + /// + /// *This function requires the `"tls"` feature.* + pub async fn run(self, addr: impl Into<SocketAddr>) { + let (addr, fut) = self.bind_ephemeral(addr); + let span = tracing::info_span!("TlsServer::run", %addr); + tracing::info!(parent: &span, "listening on https://{}", addr); + + fut.instrument(span).await; + } + + /// Bind to a socket address, returning a `Future` that can be + /// executed on a runtime. + /// + /// *This function requires the `"tls"` feature.* + pub async fn bind(self, addr: impl Into<SocketAddr>) { + let (_, fut) = self.bind_ephemeral(addr); + fut.await; + } + + /// Bind to a possibly ephemeral socket address. + /// + /// Returns the bound address and a `Future` that can be executed on + /// the current runtime. + /// + /// *This function requires the `"tls"` feature.* + pub fn bind_ephemeral( + self, + addr: impl Into<SocketAddr>, + ) -> (SocketAddr, impl Future<Output = ()> + 'static) { + let (addr, srv) = bind!(tls: self, addr); + let srv = srv.map(|result| { + if let Err(err) = result { + tracing::error!("server error: {}", err) + } + }); + + (addr, srv) + } + + /// Create a server with graceful shutdown signal. + /// + /// When the signal completes, the server will start the graceful shutdown + /// process. + /// + /// *This function requires the `"tls"` feature.* + pub fn bind_with_graceful_shutdown( + self, + addr: impl Into<SocketAddr> + 'static, + signal: impl Future<Output = ()> + Send + 'static, + ) -> (SocketAddr, impl Future<Output = ()> + 'static) { + let (addr, srv) = bind!(tls: self, addr); + + let fut = srv.with_graceful_shutdown(signal).map(|result| { + if let Err(err) = result { + tracing::error!("server error: {}", err) + } + }); + (addr, fut) + } +} + +#[cfg(feature = "tls")] +impl<F> ::std::fmt::Debug for TlsServer<F> +where + F: ::std::fmt::Debug, +{ + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + f.debug_struct("TlsServer") + .field("server", &self.server) + .finish() + } +} diff --git a/third_party/rust/warp/src/service.rs b/third_party/rust/warp/src/service.rs new file mode 100644 index 0000000000..4f93809c4e --- /dev/null +++ b/third_party/rust/warp/src/service.rs @@ -0,0 +1,3 @@ +//! Convert `Filter`s into `Service`s + +pub use crate::filter::service::service; diff --git a/third_party/rust/warp/src/test.rs b/third_party/rust/warp/src/test.rs new file mode 100644 index 0000000000..3c67c34d84 --- /dev/null +++ b/third_party/rust/warp/src/test.rs @@ -0,0 +1,764 @@ +//! Test utilities to test your filters. +//! +//! [`Filter`](../trait.Filter.html)s can be easily tested without starting up an HTTP +//! server, by making use of the [`RequestBuilder`](./struct.RequestBuilder.html) in this +//! module. +//! +//! # Testing Filters +//! +//! It's easy to test filters, especially if smaller filters are used to build +//! up your full set. Consider these example filters: +//! +//! ``` +//! use warp::Filter; +//! +//! fn sum() -> impl Filter<Extract = (u32,), Error = warp::Rejection> + Copy { +//! warp::path::param() +//! .and(warp::path::param()) +//! .map(|x: u32, y: u32| { +//! x + y +//! }) +//! } +//! +//! fn math() -> impl Filter<Extract = (String,), Error = warp::Rejection> + Copy { +//! warp::post() +//! .and(sum()) +//! .map(|z: u32| { +//! format!("Sum = {}", z) +//! }) +//! } +//! ``` +//! +//! We can test some requests against the `sum` filter like this: +//! +//! ``` +//! # use warp::Filter; +//! #[tokio::test] +//! async fn test_sum() { +//! # let sum = || warp::any().map(|| 3); +//! let filter = sum(); +//! +//! // Execute `sum` and get the `Extract` back. +//! let value = warp::test::request() +//! .path("/1/2") +//! .filter(&filter) +//! .await +//! .unwrap(); +//! assert_eq!(value, 3); +//! +//! // Or simply test if a request matches (doesn't reject). +//! assert!( +//! warp::test::request() +//! .path("/1/-5") +//! .matches(&filter) +//! .await +//! ); +//! } +//! ``` +//! +//! If the filter returns something that implements `Reply`, and thus can be +//! turned into a response sent back to the client, we can test what exact +//! response is returned. The `math` filter uses the `sum` filter, but returns +//! a `String` that can be turned into a response. +//! +//! ``` +//! # use warp::Filter; +//! #[test] +//! fn test_math() { +//! # let math = || warp::any().map(warp::reply); +//! let filter = math(); +//! +//! let res = warp::test::request() +//! .path("/1/2") +//! .reply(&filter); +//! assert_eq!(res.status(), 405, "GET is not allowed"); +//! +//! let res = warp::test::request() +//! .method("POST") +//! .path("/1/2") +//! .reply(&filter); +//! assert_eq!(res.status(), 200); +//! assert_eq!(res.body(), "Sum is 3"); +//! } +//! ``` +use std::convert::TryFrom; +use std::error::Error as StdError; +use std::fmt; +use std::future::Future; +use std::net::SocketAddr; +#[cfg(feature = "websocket")] +use std::pin::Pin; +#[cfg(feature = "websocket")] +use std::task::Context; +#[cfg(feature = "websocket")] +use std::task::{self, Poll}; + +use bytes::Bytes; +#[cfg(feature = "websocket")] +use futures_channel::mpsc; +#[cfg(feature = "websocket")] +use futures_util::StreamExt; +use futures_util::{future, FutureExt, TryFutureExt}; +use http::{ + header::{HeaderName, HeaderValue}, + Response, +}; +use serde::Serialize; +use serde_json; +#[cfg(feature = "websocket")] +use tokio::sync::oneshot; + +use crate::filter::Filter; +#[cfg(feature = "websocket")] +use crate::filters::ws::Message; +use crate::reject::IsReject; +use crate::reply::Reply; +use crate::route::{self, Route}; +use crate::Request; +#[cfg(feature = "websocket")] +use crate::{Sink, Stream}; + +use self::inner::OneOrTuple; + +/// Starts a new test `RequestBuilder`. +pub fn request() -> RequestBuilder { + RequestBuilder { + remote_addr: None, + req: Request::default(), + } +} + +/// Starts a new test `WsBuilder`. +#[cfg(feature = "websocket")] +pub fn ws() -> WsBuilder { + WsBuilder { req: request() } +} + +/// A request builder for testing filters. +/// +/// See [module documentation](crate::test) for an overview. +#[must_use = "RequestBuilder does nothing on its own"] +#[derive(Debug)] +pub struct RequestBuilder { + remote_addr: Option<SocketAddr>, + req: Request, +} + +/// A Websocket builder for testing filters. +/// +/// See [module documentation](crate::test) for an overview. +#[cfg(feature = "websocket")] +#[must_use = "WsBuilder does nothing on its own"] +#[derive(Debug)] +pub struct WsBuilder { + req: RequestBuilder, +} + +/// A test client for Websocket filters. +#[cfg(feature = "websocket")] +pub struct WsClient { + tx: mpsc::UnboundedSender<crate::ws::Message>, + rx: mpsc::UnboundedReceiver<Result<crate::ws::Message, crate::error::Error>>, +} + +/// An error from Websocket filter tests. +#[derive(Debug)] +pub struct WsError { + cause: Box<dyn StdError + Send + Sync>, +} + +impl RequestBuilder { + /// Sets the method of this builder. + /// + /// The default if not set is `GET`. + /// + /// # Example + /// + /// ``` + /// let req = warp::test::request() + /// .method("POST"); + /// ``` + /// + /// # Panic + /// + /// This panics if the passed string is not able to be parsed as a valid + /// `Method`. + pub fn method(mut self, method: &str) -> Self { + *self.req.method_mut() = method.parse().expect("valid method"); + self + } + + /// Sets the request path of this builder. + /// + /// The default is not set is `/`. + /// + /// # Example + /// + /// ``` + /// let req = warp::test::request() + /// .path("/todos/33"); + /// ``` + /// + /// # Panic + /// + /// This panics if the passed string is not able to be parsed as a valid + /// `Uri`. + pub fn path(mut self, p: &str) -> Self { + let uri = p.parse().expect("test request path invalid"); + *self.req.uri_mut() = uri; + self + } + + /// Set a header for this request. + /// + /// # Example + /// + /// ``` + /// let req = warp::test::request() + /// .header("accept", "application/json"); + /// ``` + /// + /// # Panic + /// + /// This panics if the passed strings are not able to be parsed as a valid + /// `HeaderName` and `HeaderValue`. + pub fn header<K, V>(mut self, key: K, value: V) -> Self + where + HeaderName: TryFrom<K>, + HeaderValue: TryFrom<V>, + { + let name: HeaderName = TryFrom::try_from(key) + .map_err(|_| ()) + .expect("invalid header name"); + let value = TryFrom::try_from(value) + .map_err(|_| ()) + .expect("invalid header value"); + self.req.headers_mut().insert(name, value); + self + } + + /// Set the remote address of this request + /// + /// Default is no remote address. + /// + /// # Example + /// ``` + /// use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + /// + /// let req = warp::test::request() + /// .remote_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)); + /// ``` + pub fn remote_addr(mut self, addr: SocketAddr) -> Self { + self.remote_addr = Some(addr); + self + } + + /// Add a type to the request's `http::Extensions`. + pub fn extension<T>(mut self, ext: T) -> Self + where + T: Send + Sync + 'static, + { + self.req.extensions_mut().insert(ext); + self + } + + /// Set the bytes of this request body. + /// + /// Default is an empty body. + /// + /// # Example + /// + /// ``` + /// let req = warp::test::request() + /// .body("foo=bar&baz=quux"); + /// ``` + pub fn body(mut self, body: impl AsRef<[u8]>) -> Self { + let body = body.as_ref().to_vec(); + let len = body.len(); + *self.req.body_mut() = body.into(); + self.header("content-length", len.to_string()) + } + + /// Set the bytes of this request body by serializing a value into JSON. + /// + /// # Example + /// + /// ``` + /// let req = warp::test::request() + /// .json(&true); + /// ``` + pub fn json(mut self, val: &impl Serialize) -> Self { + let vec = serde_json::to_vec(val).expect("json() must serialize to JSON"); + let len = vec.len(); + *self.req.body_mut() = vec.into(); + self.header("content-length", len.to_string()) + .header("content-type", "application/json") + } + + /// Tries to apply the `Filter` on this request. + /// + /// # Example + /// + /// ```no_run + /// async { + /// let param = warp::path::param::<u32>(); + /// + /// let ex = warp::test::request() + /// .path("/41") + /// .filter(¶m) + /// .await + /// .unwrap(); + /// + /// assert_eq!(ex, 41); + /// + /// assert!( + /// warp::test::request() + /// .path("/foo") + /// .filter(¶m) + /// .await + /// .is_err() + /// ); + ///}; + /// ``` + pub async fn filter<F>(self, f: &F) -> Result<<F::Extract as OneOrTuple>::Output, F::Error> + where + F: Filter, + F::Future: Send + 'static, + F::Extract: OneOrTuple + Send + 'static, + F::Error: Send + 'static, + { + self.apply_filter(f).await.map(|ex| ex.one_or_tuple()) + } + + /// Returns whether the `Filter` matches this request, or rejects it. + /// + /// # Example + /// + /// ```no_run + /// async { + /// let get = warp::get(); + /// let post = warp::post(); + /// + /// assert!( + /// warp::test::request() + /// .method("GET") + /// .matches(&get) + /// .await + /// ); + /// + /// assert!( + /// !warp::test::request() + /// .method("GET") + /// .matches(&post) + /// .await + /// ); + ///}; + /// ``` + pub async fn matches<F>(self, f: &F) -> bool + where + F: Filter, + F::Future: Send + 'static, + F::Extract: Send + 'static, + F::Error: Send + 'static, + { + self.apply_filter(f).await.is_ok() + } + + /// Returns `Response` provided by applying the `Filter`. + /// + /// This requires that the supplied `Filter` return a [`Reply`](Reply). + pub async fn reply<F>(self, f: &F) -> Response<Bytes> + where + F: Filter + 'static, + F::Extract: Reply + Send, + F::Error: IsReject + Send, + { + // TODO: de-duplicate this and apply_filter() + assert!(!route::is_set(), "nested test filter calls"); + + let route = Route::new(self.req, self.remote_addr); + let mut fut = Box::pin( + route::set(&route, move || f.filter(crate::filter::Internal)).then(|result| { + let res = match result { + Ok(rep) => rep.into_response(), + Err(rej) => { + tracing::debug!("rejected: {:?}", rej); + rej.into_response() + } + }; + let (parts, body) = res.into_parts(); + hyper::body::to_bytes(body).map_ok(|chunk| Response::from_parts(parts, chunk)) + }), + ); + + let fut = future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx))); + + fut.await.expect("reply shouldn't fail") + } + + fn apply_filter<F>(self, f: &F) -> impl Future<Output = Result<F::Extract, F::Error>> + where + F: Filter, + F::Future: Send + 'static, + F::Extract: Send + 'static, + F::Error: Send + 'static, + { + assert!(!route::is_set(), "nested test filter calls"); + + let route = Route::new(self.req, self.remote_addr); + let mut fut = Box::pin(route::set(&route, move || { + f.filter(crate::filter::Internal) + })); + future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx))) + } +} + +#[cfg(feature = "websocket")] +impl WsBuilder { + /// Sets the request path of this builder. + /// + /// The default is not set is `/`. + /// + /// # Example + /// + /// ``` + /// let req = warp::test::ws() + /// .path("/chat"); + /// ``` + /// + /// # Panic + /// + /// This panics if the passed string is not able to be parsed as a valid + /// `Uri`. + pub fn path(self, p: &str) -> Self { + WsBuilder { + req: self.req.path(p), + } + } + + /// Set a header for this request. + /// + /// # Example + /// + /// ``` + /// let req = warp::test::ws() + /// .header("foo", "bar"); + /// ``` + /// + /// # Panic + /// + /// This panics if the passed strings are not able to be parsed as a valid + /// `HeaderName` and `HeaderValue`. + pub fn header<K, V>(self, key: K, value: V) -> Self + where + HeaderName: TryFrom<K>, + HeaderValue: TryFrom<V>, + { + WsBuilder { + req: self.req.header(key, value), + } + } + + /// Execute this Websocket request against the provided filter. + /// + /// If the handshake succeeds, returns a `WsClient`. + /// + /// # Example + /// + /// ```no_run + /// use futures_util::future; + /// use warp::Filter; + /// #[tokio::main] + /// # async fn main() { + /// + /// // Some route that accepts websockets (but drops them immediately). + /// let route = warp::ws() + /// .map(|ws: warp::ws::Ws| { + /// ws.on_upgrade(|_| future::ready(())) + /// }); + /// + /// let client = warp::test::ws() + /// .handshake(route) + /// .await + /// .expect("handshake"); + /// # } + /// ``` + pub async fn handshake<F>(self, f: F) -> Result<WsClient, WsError> + where + F: Filter + Clone + Send + Sync + 'static, + F::Extract: Reply + Send, + F::Error: IsReject + Send, + { + let (upgraded_tx, upgraded_rx) = oneshot::channel(); + let (wr_tx, wr_rx) = mpsc::unbounded(); + let (rd_tx, rd_rx) = mpsc::unbounded(); + + tokio::spawn(async move { + use tokio_tungstenite::tungstenite::protocol; + + let (addr, srv) = crate::serve(f).bind_ephemeral(([127, 0, 0, 1], 0)); + + let mut req = self + .req + .header("connection", "upgrade") + .header("upgrade", "websocket") + .header("sec-websocket-version", "13") + .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==") + .req; + + let query_string = match req.uri().query() { + Some(q) => format!("?{}", q), + None => String::from(""), + }; + + let uri = format!("http://{}{}{}", addr, req.uri().path(), query_string) + .parse() + .expect("addr + path is valid URI"); + + *req.uri_mut() = uri; + + // let mut rt = current_thread::Runtime::new().unwrap(); + tokio::spawn(srv); + + let upgrade = ::hyper::Client::builder() + .build(AddrConnect(addr)) + .request(req) + .and_then(hyper::upgrade::on); + + let upgraded = match upgrade.await { + Ok(up) => { + let _ = upgraded_tx.send(Ok(())); + up + } + Err(err) => { + let _ = upgraded_tx.send(Err(err)); + return; + } + }; + let ws = crate::ws::WebSocket::from_raw_socket( + upgraded, + protocol::Role::Client, + Default::default(), + ) + .await; + + let (tx, rx) = ws.split(); + let write = wr_rx.map(Ok).forward(tx).map(|_| ()); + + let read = rx + .take_while(|result| match result { + Err(_) => future::ready(false), + Ok(m) => future::ready(!m.is_close()), + }) + .for_each(move |item| { + rd_tx.unbounded_send(item).expect("ws receive error"); + future::ready(()) + }); + + future::join(write, read).await; + }); + + match upgraded_rx.await { + Ok(Ok(())) => Ok(WsClient { + tx: wr_tx, + rx: rd_rx, + }), + Ok(Err(err)) => Err(WsError::new(err)), + Err(_canceled) => panic!("websocket handshake thread panicked"), + } + } +} + +#[cfg(feature = "websocket")] +impl WsClient { + /// Send a "text" websocket message to the server. + pub async fn send_text(&mut self, text: impl Into<String>) { + self.send(crate::ws::Message::text(text)).await; + } + + /// Send a websocket message to the server. + pub async fn send(&mut self, msg: crate::ws::Message) { + self.tx.unbounded_send(msg).unwrap(); + } + + /// Receive a websocket message from the server. + pub async fn recv(&mut self) -> Result<crate::filters::ws::Message, WsError> { + self.rx + .next() + .await + .map(|result| result.map_err(WsError::new)) + .unwrap_or_else(|| { + // websocket is closed + Err(WsError::new("closed")) + }) + } + + /// Assert the server has closed the connection. + pub async fn recv_closed(&mut self) -> Result<(), WsError> { + self.rx + .next() + .await + .map(|result| match result { + Ok(msg) => Err(WsError::new(format!("received message: {:?}", msg))), + Err(err) => Err(WsError::new(err)), + }) + .unwrap_or_else(|| { + // closed successfully + Ok(()) + }) + } + + fn pinned_tx(self: Pin<&mut Self>) -> Pin<&mut mpsc::UnboundedSender<crate::ws::Message>> { + let this = Pin::into_inner(self); + Pin::new(&mut this.tx) + } +} + +#[cfg(feature = "websocket")] +impl fmt::Debug for WsClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WsClient").finish() + } +} + +#[cfg(feature = "websocket")] +impl Sink<crate::ws::Message> for WsClient { + type Error = WsError; + + fn poll_ready( + self: Pin<&mut Self>, + context: &mut Context<'_>, + ) -> Poll<Result<(), Self::Error>> { + self.pinned_tx().poll_ready(context).map_err(WsError::new) + } + + fn start_send(self: Pin<&mut Self>, message: Message) -> Result<(), Self::Error> { + self.pinned_tx().start_send(message).map_err(WsError::new) + } + + fn poll_flush( + self: Pin<&mut Self>, + context: &mut Context<'_>, + ) -> Poll<Result<(), Self::Error>> { + self.pinned_tx().poll_flush(context).map_err(WsError::new) + } + + fn poll_close( + self: Pin<&mut Self>, + context: &mut Context<'_>, + ) -> Poll<Result<(), Self::Error>> { + self.pinned_tx().poll_close(context).map_err(WsError::new) + } +} + +#[cfg(feature = "websocket")] +impl Stream for WsClient { + type Item = Result<crate::ws::Message, WsError>; + + fn poll_next(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Option<Self::Item>> { + let this = Pin::into_inner(self); + let rx = Pin::new(&mut this.rx); + match rx.poll_next(context) { + Poll::Ready(Some(result)) => Poll::Ready(Some(result.map_err(WsError::new))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +// ===== impl WsError ===== + +#[cfg(feature = "websocket")] +impl WsError { + fn new<E: Into<Box<dyn StdError + Send + Sync>>>(cause: E) -> Self { + WsError { + cause: cause.into(), + } + } +} + +impl fmt::Display for WsError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "websocket error: {}", self.cause) + } +} + +impl StdError for WsError { + fn description(&self) -> &str { + "websocket error" + } +} + +// ===== impl AddrConnect ===== + +#[cfg(feature = "websocket")] +#[derive(Clone)] +struct AddrConnect(SocketAddr); + +#[cfg(feature = "websocket")] +impl tower_service::Service<::http::Uri> for AddrConnect { + type Response = ::tokio::net::TcpStream; + type Error = ::std::io::Error; + type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; + + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: ::http::Uri) -> Self::Future { + Box::pin(tokio::net::TcpStream::connect(self.0)) + } +} + +mod inner { + pub trait OneOrTuple { + type Output; + + fn one_or_tuple(self) -> Self::Output; + } + + impl OneOrTuple for () { + type Output = (); + fn one_or_tuple(self) -> Self::Output {} + } + + macro_rules! one_or_tuple { + ($type1:ident) => { + impl<$type1> OneOrTuple for ($type1,) { + type Output = $type1; + fn one_or_tuple(self) -> Self::Output { + self.0 + } + } + }; + ($type1:ident, $( $type:ident ),*) => { + one_or_tuple!($( $type ),*); + + impl<$type1, $($type),*> OneOrTuple for ($type1, $($type),*) { + type Output = Self; + fn one_or_tuple(self) -> Self::Output { + self + } + } + } + } + + one_or_tuple! { + T1, + T2, + T3, + T4, + T5, + T6, + T7, + T8, + T9, + T10, + T11, + T12, + T13, + T14, + T15, + T16 + } +} diff --git a/third_party/rust/warp/src/tls.rs b/third_party/rust/warp/src/tls.rs new file mode 100644 index 0000000000..1f81a6bd21 --- /dev/null +++ b/third_party/rust/warp/src/tls.rs @@ -0,0 +1,411 @@ +use std::fmt; +use std::fs::File; +use std::future::Future; +use std::io::{self, BufReader, Cursor, Read}; +use std::net::SocketAddr; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use futures_util::ready; +use hyper::server::accept::Accept; +use hyper::server::conn::{AddrIncoming, AddrStream}; + +use crate::transport::Transport; +use tokio_rustls::rustls::{ + server::{AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, NoClientAuth}, + Certificate, Error as TlsError, PrivateKey, RootCertStore, ServerConfig, +}; + +/// Represents errors that can occur building the TlsConfig +#[derive(Debug)] +pub(crate) enum TlsConfigError { + Io(io::Error), + /// An Error parsing the Certificate + CertParseError, + /// An Error parsing a Pkcs8 key + Pkcs8ParseError, + /// An Error parsing a Rsa key + RsaParseError, + /// An error from an empty key + EmptyKey, + /// An error from an invalid key + InvalidKey(TlsError), +} + +impl fmt::Display for TlsConfigError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TlsConfigError::Io(err) => err.fmt(f), + TlsConfigError::CertParseError => write!(f, "certificate parse error"), + TlsConfigError::Pkcs8ParseError => write!(f, "pkcs8 parse error"), + TlsConfigError::RsaParseError => write!(f, "rsa parse error"), + TlsConfigError::EmptyKey => write!(f, "key contains no private key"), + TlsConfigError::InvalidKey(err) => write!(f, "key contains an invalid key, {}", err), + } + } +} + +impl std::error::Error for TlsConfigError {} + +/// Tls client authentication configuration. +pub(crate) enum TlsClientAuth { + /// No client auth. + Off, + /// Allow any anonymous or authenticated client. + Optional(Box<dyn Read + Send + Sync>), + /// Allow any authenticated client. + Required(Box<dyn Read + Send + Sync>), +} + +/// Builder to set the configuration for the Tls server. +pub(crate) struct TlsConfigBuilder { + cert: Box<dyn Read + Send + Sync>, + key: Box<dyn Read + Send + Sync>, + client_auth: TlsClientAuth, + ocsp_resp: Vec<u8>, +} + +impl fmt::Debug for TlsConfigBuilder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TlsConfigBuilder").finish() + } +} + +impl TlsConfigBuilder { + /// Create a new TlsConfigBuilder + pub(crate) fn new() -> TlsConfigBuilder { + TlsConfigBuilder { + key: Box::new(io::empty()), + cert: Box::new(io::empty()), + client_auth: TlsClientAuth::Off, + ocsp_resp: Vec::new(), + } + } + + /// sets the Tls key via File Path, returns `TlsConfigError::IoError` if the file cannot be open + pub(crate) fn key_path(mut self, path: impl AsRef<Path>) -> Self { + self.key = Box::new(LazyFile { + path: path.as_ref().into(), + file: None, + }); + self + } + + /// sets the Tls key via bytes slice + pub(crate) fn key(mut self, key: &[u8]) -> Self { + self.key = Box::new(Cursor::new(Vec::from(key))); + self + } + + /// Specify the file path for the TLS certificate to use. + pub(crate) fn cert_path(mut self, path: impl AsRef<Path>) -> Self { + self.cert = Box::new(LazyFile { + path: path.as_ref().into(), + file: None, + }); + self + } + + /// sets the Tls certificate via bytes slice + pub(crate) fn cert(mut self, cert: &[u8]) -> Self { + self.cert = Box::new(Cursor::new(Vec::from(cert))); + self + } + + /// Sets the trust anchor for optional Tls client authentication via file path. + /// + /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any + /// of the `client_auth_` methods, then client authentication is disabled by default. + pub(crate) fn client_auth_optional_path(mut self, path: impl AsRef<Path>) -> Self { + let file = Box::new(LazyFile { + path: path.as_ref().into(), + file: None, + }); + self.client_auth = TlsClientAuth::Optional(file); + self + } + + /// Sets the trust anchor for optional Tls client authentication via bytes slice. + /// + /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any + /// of the `client_auth_` methods, then client authentication is disabled by default. + pub(crate) fn client_auth_optional(mut self, trust_anchor: &[u8]) -> Self { + let cursor = Box::new(Cursor::new(Vec::from(trust_anchor))); + self.client_auth = TlsClientAuth::Optional(cursor); + self + } + + /// Sets the trust anchor for required Tls client authentication via file path. + /// + /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the + /// `client_auth_` methods, then client authentication is disabled by default. + pub(crate) fn client_auth_required_path(mut self, path: impl AsRef<Path>) -> Self { + let file = Box::new(LazyFile { + path: path.as_ref().into(), + file: None, + }); + self.client_auth = TlsClientAuth::Required(file); + self + } + + /// Sets the trust anchor for required Tls client authentication via bytes slice. + /// + /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the + /// `client_auth_` methods, then client authentication is disabled by default. + pub(crate) fn client_auth_required(mut self, trust_anchor: &[u8]) -> Self { + let cursor = Box::new(Cursor::new(Vec::from(trust_anchor))); + self.client_auth = TlsClientAuth::Required(cursor); + self + } + + /// sets the DER-encoded OCSP response + pub(crate) fn ocsp_resp(mut self, ocsp_resp: &[u8]) -> Self { + self.ocsp_resp = Vec::from(ocsp_resp); + self + } + + pub(crate) fn build(mut self) -> Result<ServerConfig, TlsConfigError> { + let mut cert_rdr = BufReader::new(self.cert); + let cert = rustls_pemfile::certs(&mut cert_rdr) + .map_err(|_e| TlsConfigError::CertParseError)? + .into_iter() + .map(Certificate) + .collect(); + + let key = { + // convert it to Vec<u8> to allow reading it again if key is RSA + let mut key_vec = Vec::new(); + self.key + .read_to_end(&mut key_vec) + .map_err(TlsConfigError::Io)?; + + if key_vec.is_empty() { + return Err(TlsConfigError::EmptyKey); + } + + let mut pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut key_vec.as_slice()) + .map_err(|_e| TlsConfigError::Pkcs8ParseError)?; + + if !pkcs8.is_empty() { + PrivateKey(pkcs8.remove(0)) + } else { + let mut rsa = rustls_pemfile::rsa_private_keys(&mut key_vec.as_slice()) + .map_err(|_e| TlsConfigError::RsaParseError)?; + + if !rsa.is_empty() { + PrivateKey(rsa.remove(0)) + } else { + return Err(TlsConfigError::EmptyKey); + } + } + }; + + fn read_trust_anchor( + trust_anchor: Box<dyn Read + Send + Sync>, + ) -> Result<RootCertStore, TlsConfigError> { + let trust_anchors = { + let mut reader = BufReader::new(trust_anchor); + rustls_pemfile::certs(&mut reader).map_err(TlsConfigError::Io)? + }; + + let mut store = RootCertStore::empty(); + let (added, _skipped) = store.add_parsable_certificates(&trust_anchors); + if added == 0 { + return Err(TlsConfigError::CertParseError); + } + + Ok(store) + } + + let client_auth = match self.client_auth { + TlsClientAuth::Off => NoClientAuth::new(), + TlsClientAuth::Optional(trust_anchor) => { + AllowAnyAnonymousOrAuthenticatedClient::new(read_trust_anchor(trust_anchor)?) + } + TlsClientAuth::Required(trust_anchor) => { + AllowAnyAuthenticatedClient::new(read_trust_anchor(trust_anchor)?) + } + }; + + let mut config = ServerConfig::builder() + .with_safe_defaults() + .with_client_cert_verifier(client_auth.into()) + .with_single_cert_with_ocsp_and_sct(cert, key, self.ocsp_resp, Vec::new()) + .map_err(TlsConfigError::InvalidKey)?; + config.alpn_protocols = vec!["h2".into(), "http/1.1".into()]; + Ok(config) + } +} + +struct LazyFile { + path: PathBuf, + file: Option<File>, +} + +impl LazyFile { + fn lazy_read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + if self.file.is_none() { + self.file = Some(File::open(&self.path)?); + } + + self.file.as_mut().unwrap().read(buf) + } +} + +impl Read for LazyFile { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + self.lazy_read(buf).map_err(|err| { + let kind = err.kind(); + io::Error::new( + kind, + format!("error reading file ({:?}): {}", self.path.display(), err), + ) + }) + } +} + +impl Transport for TlsStream { + fn remote_addr(&self) -> Option<SocketAddr> { + Some(self.remote_addr) + } +} + +enum State { + Handshaking(tokio_rustls::Accept<AddrStream>), + Streaming(tokio_rustls::server::TlsStream<AddrStream>), +} + +// tokio_rustls::server::TlsStream doesn't expose constructor methods, +// so we have to TlsAcceptor::accept and handshake to have access to it +// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first +pub(crate) struct TlsStream { + state: State, + remote_addr: SocketAddr, +} + +impl TlsStream { + fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream { + let remote_addr = stream.remote_addr(); + let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); + TlsStream { + state: State::Handshaking(accept), + remote_addr, + } + } +} + +impl AsyncRead for TlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let pin = self.get_mut(); + match pin.state { + State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { + Ok(mut stream) => { + let result = Pin::new(&mut stream).poll_read(cx, buf); + pin.state = State::Streaming(stream); + result + } + Err(err) => Poll::Ready(Err(err)), + }, + State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + let pin = self.get_mut(); + match pin.state { + State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { + Ok(mut stream) => { + let result = Pin::new(&mut stream).poll_write(cx, buf); + pin.state = State::Streaming(stream); + result + } + Err(err) => Poll::Ready(Err(err)), + }, + State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + match self.state { + State::Handshaking(_) => Poll::Ready(Ok(())), + State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + match self.state { + State::Handshaking(_) => Poll::Ready(Ok(())), + State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx), + } + } +} + +pub(crate) struct TlsAcceptor { + config: Arc<ServerConfig>, + incoming: AddrIncoming, +} + +impl TlsAcceptor { + pub(crate) fn new(config: ServerConfig, incoming: AddrIncoming) -> TlsAcceptor { + TlsAcceptor { + config: Arc::new(config), + incoming, + } + } +} + +impl Accept for TlsAcceptor { + type Conn = TlsStream; + type Error = io::Error; + + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { + let pin = self.get_mut(); + match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { + Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))), + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn file_cert_key() { + TlsConfigBuilder::new() + .key_path("examples/tls/key.rsa") + .cert_path("examples/tls/cert.pem") + .build() + .unwrap(); + } + + #[test] + fn bytes_cert_key() { + let key = include_str!("../examples/tls/key.rsa"); + let cert = include_str!("../examples/tls/cert.pem"); + + TlsConfigBuilder::new() + .key(key.as_bytes()) + .cert(cert.as_bytes()) + .build() + .unwrap(); + } +} diff --git a/third_party/rust/warp/src/transport.rs b/third_party/rust/warp/src/transport.rs new file mode 100644 index 0000000000..be553e706e --- /dev/null +++ b/third_party/rust/warp/src/transport.rs @@ -0,0 +1,53 @@ +use std::io; +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use hyper::server::conn::AddrStream; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +pub trait Transport: AsyncRead + AsyncWrite { + fn remote_addr(&self) -> Option<SocketAddr>; +} + +impl Transport for AddrStream { + fn remote_addr(&self) -> Option<SocketAddr> { + Some(self.remote_addr()) + } +} + +pub(crate) struct LiftIo<T>(pub(crate) T); + +impl<T: AsyncRead + Unpin> AsyncRead for LiftIo<T> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + Pin::new(&mut self.get_mut().0).poll_read(cx, buf) + } +} + +impl<T: AsyncWrite + Unpin> AsyncWrite for LiftIo<T> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.get_mut().0).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Pin::new(&mut self.get_mut().0).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Pin::new(&mut self.get_mut().0).poll_shutdown(cx) + } +} + +impl<T: AsyncRead + AsyncWrite + Unpin> Transport for LiftIo<T> { + fn remote_addr(&self) -> Option<SocketAddr> { + None + } +} |