diff options
Diffstat (limited to 'third_party/rust/hyper/src/upgrade.rs')
-rw-r--r-- | third_party/rust/hyper/src/upgrade.rs | 365 |
1 files changed, 365 insertions, 0 deletions
diff --git a/third_party/rust/hyper/src/upgrade.rs b/third_party/rust/hyper/src/upgrade.rs new file mode 100644 index 0000000000..55f390431f --- /dev/null +++ b/third_party/rust/hyper/src/upgrade.rs @@ -0,0 +1,365 @@ +//! HTTP Upgrades +//! +//! See [this example][example] showing how upgrades work with both +//! Clients and Servers. +//! +//! [example]: https://github.com/hyperium/hyper/blob/master/examples/upgrades.rs + +use std::any::TypeId; +use std::error::Error as StdError; +use std::fmt; +use std::io; +use std::marker::Unpin; + +use bytes::{Buf, Bytes}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::oneshot; + +use crate::common::io::Rewind; +use crate::common::{task, Future, Pin, Poll}; + +/// An upgraded HTTP connection. +/// +/// This type holds a trait object internally of the original IO that +/// was used to speak HTTP before the upgrade. It can be used directly +/// as a `Read` or `Write` for convenience. +/// +/// Alternatively, if the exact type is known, this can be deconstructed +/// into its parts. +pub struct Upgraded { + io: Rewind<Box<dyn Io + Send>>, +} + +/// A future for a possible HTTP upgrade. +/// +/// If no upgrade was available, or it doesn't succeed, yields an `Error`. +pub struct OnUpgrade { + rx: Option<oneshot::Receiver<crate::Result<Upgraded>>>, +} + +/// The deconstructed parts of an [`Upgraded`](Upgraded) type. +/// +/// Includes the original IO type, and a read buffer of bytes that the +/// HTTP state machine may have already read before completing an upgrade. +#[derive(Debug)] +pub struct Parts<T> { + /// The original IO object used before the upgrade. + pub io: T, + /// A buffer of bytes that have been read but not processed as HTTP. + /// + /// For instance, if the `Connection` is used for an HTTP upgrade request, + /// it is possible the server sent back the first bytes of the new protocol + /// along with the response upgrade. + /// + /// You will want to check for any existing bytes if you plan to continue + /// communicating on the IO object. + pub read_buf: Bytes, + _inner: (), +} + +pub(crate) struct Pending { + tx: oneshot::Sender<crate::Result<Upgraded>>, +} + +/// Error cause returned when an upgrade was expected but canceled +/// for whatever reason. +/// +/// This likely means the actual `Conn` future wasn't polled and upgraded. +#[derive(Debug)] +struct UpgradeExpected(()); + +pub(crate) fn pending() -> (Pending, OnUpgrade) { + let (tx, rx) = oneshot::channel(); + (Pending { tx }, OnUpgrade { rx: Some(rx) }) +} + +// ===== impl Upgraded ===== + +impl Upgraded { + pub(crate) fn new<T>(io: T, read_buf: Bytes) -> Self + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + Upgraded { + io: Rewind::new_buffered(Box::new(ForwardsWriteBuf(io)), read_buf), + } + } + + /// Tries to downcast the internal trait object to the type passed. + /// + /// On success, returns the downcasted parts. On error, returns the + /// `Upgraded` back. + pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> { + let (io, buf) = self.io.into_inner(); + match io.__hyper_downcast::<ForwardsWriteBuf<T>>() { + Ok(t) => Ok(Parts { + io: t.0, + read_buf: buf, + _inner: (), + }), + Err(io) => Err(Upgraded { + io: Rewind::new_buffered(io, buf), + }), + } + } +} + +impl AsyncRead for Upgraded { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { + self.io.prepare_uninitialized_buffer(buf) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut [u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.io).poll_read(cx, buf) + } +} + +impl AsyncWrite for Upgraded { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.io).poll_write(cx, buf) + } + + fn poll_write_buf<B: Buf>( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut B, + ) -> Poll<io::Result<usize>> { + Pin::new(self.io.get_mut()).poll_write_dyn_buf(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.io).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.io).poll_shutdown(cx) + } +} + +impl fmt::Debug for Upgraded { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Upgraded").finish() + } +} + +// ===== impl OnUpgrade ===== + +impl OnUpgrade { + pub(crate) fn none() -> Self { + OnUpgrade { rx: None } + } + + pub(crate) fn is_none(&self) -> bool { + self.rx.is_none() + } +} + +impl Future for OnUpgrade { + type Output = Result<Upgraded, crate::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match self.rx { + Some(ref mut rx) => Pin::new(rx).poll(cx).map(|res| match res { + Ok(Ok(upgraded)) => Ok(upgraded), + Ok(Err(err)) => Err(err), + Err(_oneshot_canceled) => { + Err(crate::Error::new_canceled().with(UpgradeExpected(()))) + } + }), + None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())), + } + } +} + +impl fmt::Debug for OnUpgrade { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OnUpgrade").finish() + } +} + +// ===== impl Pending ===== + +impl Pending { + pub(crate) fn fulfill(self, upgraded: Upgraded) { + trace!("pending upgrade fulfill"); + let _ = self.tx.send(Ok(upgraded)); + } + + /// Don't fulfill the pending Upgrade, but instead signal that + /// upgrades are handled manually. + pub(crate) fn manual(self) { + trace!("pending upgrade handled manually"); + let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade())); + } +} + +// ===== impl UpgradeExpected ===== + +impl fmt::Display for UpgradeExpected { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "upgrade expected but not completed") + } +} + +impl StdError for UpgradeExpected {} + +// ===== impl Io ===== + +struct ForwardsWriteBuf<T>(T); + +pub(crate) trait Io: AsyncRead + AsyncWrite + Unpin + 'static { + fn poll_write_dyn_buf( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut dyn Buf, + ) -> Poll<io::Result<usize>>; + + fn __hyper_type_id(&self) -> TypeId { + TypeId::of::<Self>() + } +} + +impl dyn Io + Send { + fn __hyper_is<T: Io>(&self) -> bool { + let t = TypeId::of::<T>(); + self.__hyper_type_id() == t + } + + fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> { + if self.__hyper_is::<T>() { + // Taken from `std::error::Error::downcast()`. + unsafe { + let raw: *mut dyn Io = Box::into_raw(self); + Ok(Box::from_raw(raw as *mut T)) + } + } else { + Err(self) + } + } +} + +impl<T: AsyncRead + Unpin> AsyncRead for ForwardsWriteBuf<T> { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { + self.0.prepare_uninitialized_buffer(buf) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut [u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl<T: AsyncWrite + Unpin> AsyncWrite for ForwardsWriteBuf<T> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_write_buf<B: Buf>( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut B, + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.0).poll_write_buf(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.0).poll_shutdown(cx) + } +} + +impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for ForwardsWriteBuf<T> { + fn poll_write_dyn_buf( + &mut self, + cx: &mut task::Context<'_>, + mut buf: &mut dyn Buf, + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.0).poll_write_buf(cx, &mut buf) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::AsyncWriteExt; + + #[test] + fn upgraded_downcast() { + let upgraded = Upgraded::new(Mock, Bytes::new()); + + let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err(); + + upgraded.downcast::<Mock>().unwrap(); + } + + #[tokio::test] + async fn upgraded_forwards_write_buf() { + // sanity check that the underlying IO implements write_buf + Mock.write_buf(&mut "hello".as_bytes()).await.unwrap(); + + let mut upgraded = Upgraded::new(Mock, Bytes::new()); + upgraded.write_buf(&mut "hello".as_bytes()).await.unwrap(); + } + + // TODO: replace with tokio_test::io when it can test write_buf + struct Mock; + + impl AsyncRead for Mock { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + _buf: &mut [u8], + ) -> Poll<io::Result<usize>> { + unreachable!("Mock::poll_read") + } + } + + impl AsyncWrite for Mock { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + _buf: &[u8], + ) -> Poll<io::Result<usize>> { + panic!("poll_write shouldn't be called"); + } + + fn poll_write_buf<B: Buf>( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + buf: &mut B, + ) -> Poll<io::Result<usize>> { + let n = buf.remaining(); + buf.advance(n); + Poll::Ready(Ok(n)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + unreachable!("Mock::poll_flush") + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + ) -> Poll<io::Result<()>> { + unreachable!("Mock::poll_shutdown") + } + } +} |