//! Test utilities to test your filters. //! //! [`Filter`](../trait.Filter.html)s can be easily tested without starting up an HTTP //! server, by making use of the [`RequestBuilder`](./struct.RequestBuilder.html) in this //! module. //! //! # Testing Filters //! //! It's easy to test filters, especially if smaller filters are used to build //! up your full set. Consider these example filters: //! //! ``` //! use warp::Filter; //! //! fn sum() -> impl Filter + Copy { //! warp::path::param() //! .and(warp::path::param()) //! .map(|x: u32, y: u32| { //! x + y //! }) //! } //! //! fn math() -> impl Filter + Copy { //! warp::post() //! .and(sum()) //! .map(|z: u32| { //! format!("Sum = {}", z) //! }) //! } //! ``` //! //! We can test some requests against the `sum` filter like this: //! //! ``` //! # use warp::Filter; //! #[tokio::test] //! async fn test_sum() { //! # let sum = || warp::any().map(|| 3); //! let filter = sum(); //! //! // Execute `sum` and get the `Extract` back. //! let value = warp::test::request() //! .path("/1/2") //! .filter(&filter) //! .await //! .unwrap(); //! assert_eq!(value, 3); //! //! // Or simply test if a request matches (doesn't reject). //! assert!( //! warp::test::request() //! .path("/1/-5") //! .matches(&filter) //! .await //! ); //! } //! ``` //! //! If the filter returns something that implements `Reply`, and thus can be //! turned into a response sent back to the client, we can test what exact //! response is returned. The `math` filter uses the `sum` filter, but returns //! a `String` that can be turned into a response. //! //! ``` //! # use warp::Filter; //! #[test] //! fn test_math() { //! # let math = || warp::any().map(warp::reply); //! let filter = math(); //! //! let res = warp::test::request() //! .path("/1/2") //! .reply(&filter); //! assert_eq!(res.status(), 405, "GET is not allowed"); //! //! let res = warp::test::request() //! .method("POST") //! .path("/1/2") //! .reply(&filter); //! assert_eq!(res.status(), 200); //! assert_eq!(res.body(), "Sum is 3"); //! } //! ``` use std::convert::TryFrom; use std::error::Error as StdError; use std::fmt; use std::future::Future; use std::net::SocketAddr; #[cfg(feature = "websocket")] use std::pin::Pin; #[cfg(feature = "websocket")] use std::task::Context; #[cfg(feature = "websocket")] use std::task::{self, Poll}; use bytes::Bytes; #[cfg(feature = "websocket")] use futures_channel::mpsc; #[cfg(feature = "websocket")] use futures_util::StreamExt; use futures_util::{future, FutureExt, TryFutureExt}; use http::{ header::{HeaderName, HeaderValue}, Response, }; use serde::Serialize; use serde_json; #[cfg(feature = "websocket")] use tokio::sync::oneshot; use crate::filter::Filter; #[cfg(feature = "websocket")] use crate::filters::ws::Message; use crate::reject::IsReject; use crate::reply::Reply; use crate::route::{self, Route}; use crate::Request; #[cfg(feature = "websocket")] use crate::{Sink, Stream}; use self::inner::OneOrTuple; /// Starts a new test `RequestBuilder`. pub fn request() -> RequestBuilder { RequestBuilder { remote_addr: None, req: Request::default(), } } /// Starts a new test `WsBuilder`. #[cfg(feature = "websocket")] pub fn ws() -> WsBuilder { WsBuilder { req: request() } } /// A request builder for testing filters. /// /// See [module documentation](crate::test) for an overview. #[must_use = "RequestBuilder does nothing on its own"] #[derive(Debug)] pub struct RequestBuilder { remote_addr: Option, req: Request, } /// A Websocket builder for testing filters. /// /// See [module documentation](crate::test) for an overview. #[cfg(feature = "websocket")] #[must_use = "WsBuilder does nothing on its own"] #[derive(Debug)] pub struct WsBuilder { req: RequestBuilder, } /// A test client for Websocket filters. #[cfg(feature = "websocket")] pub struct WsClient { tx: mpsc::UnboundedSender, rx: mpsc::UnboundedReceiver>, } /// An error from Websocket filter tests. #[derive(Debug)] pub struct WsError { cause: Box, } impl RequestBuilder { /// Sets the method of this builder. /// /// The default if not set is `GET`. /// /// # Example /// /// ``` /// let req = warp::test::request() /// .method("POST"); /// ``` /// /// # Panic /// /// This panics if the passed string is not able to be parsed as a valid /// `Method`. pub fn method(mut self, method: &str) -> Self { *self.req.method_mut() = method.parse().expect("valid method"); self } /// Sets the request path of this builder. /// /// The default is not set is `/`. /// /// # Example /// /// ``` /// let req = warp::test::request() /// .path("/todos/33"); /// ``` /// /// # Panic /// /// This panics if the passed string is not able to be parsed as a valid /// `Uri`. pub fn path(mut self, p: &str) -> Self { let uri = p.parse().expect("test request path invalid"); *self.req.uri_mut() = uri; self } /// Set a header for this request. /// /// # Example /// /// ``` /// let req = warp::test::request() /// .header("accept", "application/json"); /// ``` /// /// # Panic /// /// This panics if the passed strings are not able to be parsed as a valid /// `HeaderName` and `HeaderValue`. pub fn header(mut self, key: K, value: V) -> Self where HeaderName: TryFrom, HeaderValue: TryFrom, { let name: HeaderName = TryFrom::try_from(key) .map_err(|_| ()) .expect("invalid header name"); let value = TryFrom::try_from(value) .map_err(|_| ()) .expect("invalid header value"); self.req.headers_mut().insert(name, value); self } /// Set the remote address of this request /// /// Default is no remote address. /// /// # Example /// ``` /// use std::net::{IpAddr, Ipv4Addr, SocketAddr}; /// /// let req = warp::test::request() /// .remote_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)); /// ``` pub fn remote_addr(mut self, addr: SocketAddr) -> Self { self.remote_addr = Some(addr); self } /// Add a type to the request's `http::Extensions`. pub fn extension(mut self, ext: T) -> Self where T: Send + Sync + 'static, { self.req.extensions_mut().insert(ext); self } /// Set the bytes of this request body. /// /// Default is an empty body. /// /// # Example /// /// ``` /// let req = warp::test::request() /// .body("foo=bar&baz=quux"); /// ``` pub fn body(mut self, body: impl AsRef<[u8]>) -> Self { let body = body.as_ref().to_vec(); let len = body.len(); *self.req.body_mut() = body.into(); self.header("content-length", len.to_string()) } /// Set the bytes of this request body by serializing a value into JSON. /// /// # Example /// /// ``` /// let req = warp::test::request() /// .json(&true); /// ``` pub fn json(mut self, val: &impl Serialize) -> Self { let vec = serde_json::to_vec(val).expect("json() must serialize to JSON"); let len = vec.len(); *self.req.body_mut() = vec.into(); self.header("content-length", len.to_string()) .header("content-type", "application/json") } /// Tries to apply the `Filter` on this request. /// /// # Example /// /// ```no_run /// async { /// let param = warp::path::param::(); /// /// let ex = warp::test::request() /// .path("/41") /// .filter(¶m) /// .await /// .unwrap(); /// /// assert_eq!(ex, 41); /// /// assert!( /// warp::test::request() /// .path("/foo") /// .filter(¶m) /// .await /// .is_err() /// ); ///}; /// ``` pub async fn filter(self, f: &F) -> Result<::Output, F::Error> where F: Filter, F::Future: Send + 'static, F::Extract: OneOrTuple + Send + 'static, F::Error: Send + 'static, { self.apply_filter(f).await.map(|ex| ex.one_or_tuple()) } /// Returns whether the `Filter` matches this request, or rejects it. /// /// # Example /// /// ```no_run /// async { /// let get = warp::get(); /// let post = warp::post(); /// /// assert!( /// warp::test::request() /// .method("GET") /// .matches(&get) /// .await /// ); /// /// assert!( /// !warp::test::request() /// .method("GET") /// .matches(&post) /// .await /// ); ///}; /// ``` pub async fn matches(self, f: &F) -> bool where F: Filter, F::Future: Send + 'static, F::Extract: Send + 'static, F::Error: Send + 'static, { self.apply_filter(f).await.is_ok() } /// Returns `Response` provided by applying the `Filter`. /// /// This requires that the supplied `Filter` return a [`Reply`](Reply). pub async fn reply(self, f: &F) -> Response where F: Filter + 'static, F::Extract: Reply + Send, F::Error: IsReject + Send, { // TODO: de-duplicate this and apply_filter() assert!(!route::is_set(), "nested test filter calls"); let route = Route::new(self.req, self.remote_addr); let mut fut = Box::pin( route::set(&route, move || f.filter(crate::filter::Internal)).then(|result| { let res = match result { Ok(rep) => rep.into_response(), Err(rej) => { tracing::debug!("rejected: {:?}", rej); rej.into_response() } }; let (parts, body) = res.into_parts(); hyper::body::to_bytes(body).map_ok(|chunk| Response::from_parts(parts, chunk)) }), ); let fut = future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx))); fut.await.expect("reply shouldn't fail") } fn apply_filter(self, f: &F) -> impl Future> where F: Filter, F::Future: Send + 'static, F::Extract: Send + 'static, F::Error: Send + 'static, { assert!(!route::is_set(), "nested test filter calls"); let route = Route::new(self.req, self.remote_addr); let mut fut = Box::pin(route::set(&route, move || { f.filter(crate::filter::Internal) })); future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx))) } } #[cfg(feature = "websocket")] impl WsBuilder { /// Sets the request path of this builder. /// /// The default is not set is `/`. /// /// # Example /// /// ``` /// let req = warp::test::ws() /// .path("/chat"); /// ``` /// /// # Panic /// /// This panics if the passed string is not able to be parsed as a valid /// `Uri`. pub fn path(self, p: &str) -> Self { WsBuilder { req: self.req.path(p), } } /// Set a header for this request. /// /// # Example /// /// ``` /// let req = warp::test::ws() /// .header("foo", "bar"); /// ``` /// /// # Panic /// /// This panics if the passed strings are not able to be parsed as a valid /// `HeaderName` and `HeaderValue`. pub fn header(self, key: K, value: V) -> Self where HeaderName: TryFrom, HeaderValue: TryFrom, { WsBuilder { req: self.req.header(key, value), } } /// Execute this Websocket request against the provided filter. /// /// If the handshake succeeds, returns a `WsClient`. /// /// # Example /// /// ```no_run /// use futures_util::future; /// use warp::Filter; /// #[tokio::main] /// # async fn main() { /// /// // Some route that accepts websockets (but drops them immediately). /// let route = warp::ws() /// .map(|ws: warp::ws::Ws| { /// ws.on_upgrade(|_| future::ready(())) /// }); /// /// let client = warp::test::ws() /// .handshake(route) /// .await /// .expect("handshake"); /// # } /// ``` pub async fn handshake(self, f: F) -> Result where F: Filter + Clone + Send + Sync + 'static, F::Extract: Reply + Send, F::Error: IsReject + Send, { let (upgraded_tx, upgraded_rx) = oneshot::channel(); let (wr_tx, wr_rx) = mpsc::unbounded(); let (rd_tx, rd_rx) = mpsc::unbounded(); tokio::spawn(async move { use tokio_tungstenite::tungstenite::protocol; let (addr, srv) = crate::serve(f).bind_ephemeral(([127, 0, 0, 1], 0)); let mut req = self .req .header("connection", "upgrade") .header("upgrade", "websocket") .header("sec-websocket-version", "13") .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==") .req; let query_string = match req.uri().query() { Some(q) => format!("?{}", q), None => String::from(""), }; let uri = format!("http://{}{}{}", addr, req.uri().path(), query_string) .parse() .expect("addr + path is valid URI"); *req.uri_mut() = uri; // let mut rt = current_thread::Runtime::new().unwrap(); tokio::spawn(srv); let upgrade = ::hyper::Client::builder() .build(AddrConnect(addr)) .request(req) .and_then(hyper::upgrade::on); let upgraded = match upgrade.await { Ok(up) => { let _ = upgraded_tx.send(Ok(())); up } Err(err) => { let _ = upgraded_tx.send(Err(err)); return; } }; let ws = crate::ws::WebSocket::from_raw_socket( upgraded, protocol::Role::Client, Default::default(), ) .await; let (tx, rx) = ws.split(); let write = wr_rx.map(Ok).forward(tx).map(|_| ()); let read = rx .take_while(|result| match result { Err(_) => future::ready(false), Ok(m) => future::ready(!m.is_close()), }) .for_each(move |item| { rd_tx.unbounded_send(item).expect("ws receive error"); future::ready(()) }); future::join(write, read).await; }); match upgraded_rx.await { Ok(Ok(())) => Ok(WsClient { tx: wr_tx, rx: rd_rx, }), Ok(Err(err)) => Err(WsError::new(err)), Err(_canceled) => panic!("websocket handshake thread panicked"), } } } #[cfg(feature = "websocket")] impl WsClient { /// Send a "text" websocket message to the server. pub async fn send_text(&mut self, text: impl Into) { self.send(crate::ws::Message::text(text)).await; } /// Send a websocket message to the server. pub async fn send(&mut self, msg: crate::ws::Message) { self.tx.unbounded_send(msg).unwrap(); } /// Receive a websocket message from the server. pub async fn recv(&mut self) -> Result { self.rx .next() .await .map(|result| result.map_err(WsError::new)) .unwrap_or_else(|| { // websocket is closed Err(WsError::new("closed")) }) } /// Assert the server has closed the connection. pub async fn recv_closed(&mut self) -> Result<(), WsError> { self.rx .next() .await .map(|result| match result { Ok(msg) => Err(WsError::new(format!("received message: {:?}", msg))), Err(err) => Err(WsError::new(err)), }) .unwrap_or_else(|| { // closed successfully Ok(()) }) } fn pinned_tx(self: Pin<&mut Self>) -> Pin<&mut mpsc::UnboundedSender> { let this = Pin::into_inner(self); Pin::new(&mut this.tx) } } #[cfg(feature = "websocket")] impl fmt::Debug for WsClient { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("WsClient").finish() } } #[cfg(feature = "websocket")] impl Sink for WsClient { type Error = WsError; fn poll_ready( self: Pin<&mut Self>, context: &mut Context<'_>, ) -> Poll> { self.pinned_tx().poll_ready(context).map_err(WsError::new) } fn start_send(self: Pin<&mut Self>, message: Message) -> Result<(), Self::Error> { self.pinned_tx().start_send(message).map_err(WsError::new) } fn poll_flush( self: Pin<&mut Self>, context: &mut Context<'_>, ) -> Poll> { self.pinned_tx().poll_flush(context).map_err(WsError::new) } fn poll_close( self: Pin<&mut Self>, context: &mut Context<'_>, ) -> Poll> { self.pinned_tx().poll_close(context).map_err(WsError::new) } } #[cfg(feature = "websocket")] impl Stream for WsClient { type Item = Result; fn poll_next(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll> { let this = Pin::into_inner(self); let rx = Pin::new(&mut this.rx); match rx.poll_next(context) { Poll::Ready(Some(result)) => Poll::Ready(Some(result.map_err(WsError::new))), Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } } // ===== impl WsError ===== #[cfg(feature = "websocket")] impl WsError { fn new>>(cause: E) -> Self { WsError { cause: cause.into(), } } } impl fmt::Display for WsError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "websocket error: {}", self.cause) } } impl StdError for WsError { fn description(&self) -> &str { "websocket error" } } // ===== impl AddrConnect ===== #[cfg(feature = "websocket")] #[derive(Clone)] struct AddrConnect(SocketAddr); #[cfg(feature = "websocket")] impl tower_service::Service<::http::Uri> for AddrConnect { type Response = ::tokio::net::TcpStream; type Error = ::std::io::Error; type Future = Pin> + Send>>; fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, _: ::http::Uri) -> Self::Future { Box::pin(tokio::net::TcpStream::connect(self.0)) } } mod inner { pub trait OneOrTuple { type Output; fn one_or_tuple(self) -> Self::Output; } impl OneOrTuple for () { type Output = (); fn one_or_tuple(self) -> Self::Output {} } macro_rules! one_or_tuple { ($type1:ident) => { impl<$type1> OneOrTuple for ($type1,) { type Output = $type1; fn one_or_tuple(self) -> Self::Output { self.0 } } }; ($type1:ident, $( $type:ident ),*) => { one_or_tuple!($( $type ),*); impl<$type1, $($type),*> OneOrTuple for ($type1, $($type),*) { type Output = Self; fn one_or_tuple(self) -> Self::Output { self } } } } one_or_tuple! { T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16 } }