use crate::{Body, SizeHint}; use bytes::Buf; use http::HeaderMap; use pin_project_lite::pin_project; use std::error::Error; use std::fmt; use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { /// A length limited body. /// /// This body will return an error if more than the configured number /// of bytes are returned on polling the wrapped body. #[derive(Clone, Copy, Debug)] pub struct Limited { remaining: usize, #[pin] inner: B, } } impl Limited { /// Create a new `Limited`. pub fn new(inner: B, limit: usize) -> Self { Self { remaining: limit, inner, } } } impl Body for Limited where B: Body, B::Error: Into>, { type Data = B::Data; type Error = Box; fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { let this = self.project(); let res = match this.inner.poll_data(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(None) => None, Poll::Ready(Some(Ok(data))) => { if data.remaining() > *this.remaining { *this.remaining = 0; Some(Err(LengthLimitError.into())) } else { *this.remaining -= data.remaining(); Some(Ok(data)) } } Poll::Ready(Some(Err(err))) => Some(Err(err.into())), }; Poll::Ready(res) } fn poll_trailers( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>> { let this = self.project(); let res = match this.inner.poll_trailers(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(data)) => Ok(data), Poll::Ready(Err(err)) => Err(err.into()), }; Poll::Ready(res) } fn is_end_stream(&self) -> bool { self.inner.is_end_stream() } fn size_hint(&self) -> SizeHint { use std::convert::TryFrom; match u64::try_from(self.remaining) { Ok(n) => { let mut hint = self.inner.size_hint(); if hint.lower() >= n { hint.set_exact(n) } else if let Some(max) = hint.upper() { hint.set_upper(n.min(max)) } else { hint.set_upper(n) } hint } Err(_) => self.inner.size_hint(), } } } /// An error returned when body length exceeds the configured limit. #[derive(Debug)] #[non_exhaustive] pub struct LengthLimitError; impl fmt::Display for LengthLimitError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str("length limit exceeded") } } impl Error for LengthLimitError {} #[cfg(test)] mod tests { use super::*; use crate::Full; use bytes::Bytes; use std::convert::Infallible; #[tokio::test] async fn read_for_body_under_limit_returns_data() { const DATA: &[u8] = b"testing"; let inner = Full::new(Bytes::from(DATA)); let body = &mut Limited::new(inner, 8); let mut hint = SizeHint::new(); hint.set_upper(7); assert_eq!(body.size_hint().upper(), hint.upper()); let data = body.data().await.unwrap().unwrap(); assert_eq!(data, DATA); hint.set_upper(0); assert_eq!(body.size_hint().upper(), hint.upper()); assert!(matches!(body.data().await, None)); } #[tokio::test] async fn read_for_body_over_limit_returns_error() { const DATA: &[u8] = b"testing a string that is too long"; let inner = Full::new(Bytes::from(DATA)); let body = &mut Limited::new(inner, 8); let mut hint = SizeHint::new(); hint.set_upper(8); assert_eq!(body.size_hint().upper(), hint.upper()); let error = body.data().await.unwrap().unwrap_err(); assert!(matches!(error.downcast_ref(), Some(LengthLimitError))); } struct Chunky(&'static [&'static [u8]]); impl Body for Chunky { type Data = &'static [u8]; type Error = Infallible; fn poll_data( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll>> { let mut this = self; match this.0.split_first().map(|(&head, tail)| (Ok(head), tail)) { Some((data, new_tail)) => { this.0 = new_tail; Poll::Ready(Some(data)) } None => Poll::Ready(None), } } fn poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll, Self::Error>> { Poll::Ready(Ok(Some(HeaderMap::new()))) } } #[tokio::test] async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk( ) { const DATA: &[&[u8]] = &[b"testing ", b"a string that is too long"]; let inner = Chunky(DATA); let body = &mut Limited::new(inner, 8); let mut hint = SizeHint::new(); hint.set_upper(8); assert_eq!(body.size_hint().upper(), hint.upper()); let data = body.data().await.unwrap().unwrap(); assert_eq!(data, DATA[0]); hint.set_upper(0); assert_eq!(body.size_hint().upper(), hint.upper()); let error = body.data().await.unwrap().unwrap_err(); assert!(matches!(error.downcast_ref(), Some(LengthLimitError))); } #[tokio::test] async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() { const DATA: &[&[u8]] = &[b"testing a string", b" that is too long"]; let inner = Chunky(DATA); let body = &mut Limited::new(inner, 8); let mut hint = SizeHint::new(); hint.set_upper(8); assert_eq!(body.size_hint().upper(), hint.upper()); let error = body.data().await.unwrap().unwrap_err(); assert!(matches!(error.downcast_ref(), Some(LengthLimitError))); } #[tokio::test] async fn read_for_chunked_body_under_limit_is_okay() { const DATA: &[&[u8]] = &[b"test", b"ing!"]; let inner = Chunky(DATA); let body = &mut Limited::new(inner, 8); let mut hint = SizeHint::new(); hint.set_upper(8); assert_eq!(body.size_hint().upper(), hint.upper()); let data = body.data().await.unwrap().unwrap(); assert_eq!(data, DATA[0]); hint.set_upper(4); assert_eq!(body.size_hint().upper(), hint.upper()); let data = body.data().await.unwrap().unwrap(); assert_eq!(data, DATA[1]); hint.set_upper(0); assert_eq!(body.size_hint().upper(), hint.upper()); assert!(matches!(body.data().await, None)); } #[tokio::test] async fn read_for_trailers_propagates_inner_trailers() { const DATA: &[&[u8]] = &[b"test", b"ing!"]; let inner = Chunky(DATA); let body = &mut Limited::new(inner, 8); let trailers = body.trailers().await.unwrap(); assert_eq!(trailers, Some(HeaderMap::new())) } #[derive(Debug)] enum ErrorBodyError { Data, Trailers, } impl fmt::Display for ErrorBodyError { fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { Ok(()) } } impl Error for ErrorBodyError {} struct ErrorBody; impl Body for ErrorBody { type Data = &'static [u8]; type Error = ErrorBodyError; fn poll_data( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll>> { Poll::Ready(Some(Err(ErrorBodyError::Data))) } fn poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll, Self::Error>> { Poll::Ready(Err(ErrorBodyError::Trailers)) } } #[tokio::test] async fn read_for_body_returning_error_propagates_error() { let body = &mut Limited::new(ErrorBody, 8); let error = body.data().await.unwrap().unwrap_err(); assert!(matches!(error.downcast_ref(), Some(ErrorBodyError::Data))); } #[tokio::test] async fn trailers_for_body_returning_error_propagates_error() { let body = &mut Limited::new(ErrorBody, 8); let error = body.trailers().await.unwrap_err(); assert!(matches!( error.downcast_ref(), Some(ErrorBodyError::Trailers) )); } }