diff options
Diffstat (limited to '')
-rw-r--r-- | third_party/rust/hyper/src/upgrade.rs | 382 |
1 files changed, 382 insertions, 0 deletions
diff --git a/third_party/rust/hyper/src/upgrade.rs b/third_party/rust/hyper/src/upgrade.rs new file mode 100644 index 0000000000..1c7b5b01cd --- /dev/null +++ b/third_party/rust/hyper/src/upgrade.rs @@ -0,0 +1,382 @@ +//! HTTP Upgrades +//! +//! This module deals with managing [HTTP Upgrades][mdn] in hyper. Since +//! several concepts in HTTP allow for first talking HTTP, and then converting +//! to a different protocol, this module conflates them into a single API. +//! Those include: +//! +//! - HTTP/1.1 Upgrades +//! - HTTP `CONNECT` +//! +//! You are responsible for any other pre-requisites to establish an upgrade, +//! such as sending the appropriate headers, methods, and status codes. You can +//! then use [`on`][] to grab a `Future` which will resolve to the upgraded +//! connection object, or an error if the upgrade fails. +//! +//! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism +//! +//! # Client +//! +//! Sending an HTTP upgrade from the [`client`](super::client) involves setting +//! either the appropriate method, if wanting to `CONNECT`, or headers such as +//! `Upgrade` and `Connection`, on the `http::Request`. Once receiving the +//! `http::Response` back, you must check for the specific information that the +//! upgrade is agreed upon by the server (such as a `101` status code), and then +//! get the `Future` from the `Response`. +//! +//! # Server +//! +//! Receiving upgrade requests in a server requires you to check the relevant +//! headers in a `Request`, and if an upgrade should be done, you then send the +//! corresponding headers in a response. To then wait for hyper to finish the +//! upgrade, you call `on()` with the `Request`, and then can spawn a task +//! awaiting it. +//! +//! # Example +//! +//! See [this example][example] showing how upgrades work with both +//! Clients and Servers. +//! +//! [example]: https://github.com/hyperium/hyper/blob/master/examples/upgrades.rs + +use std::any::TypeId; +use std::error::Error as StdError; +use std::fmt; +use std::io; +use std::marker::Unpin; + +use bytes::Bytes; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::sync::oneshot; +#[cfg(any(feature = "http1", feature = "http2"))] +use tracing::trace; + +use crate::common::io::Rewind; +use crate::common::{task, Future, Pin, Poll}; + +/// An upgraded HTTP connection. +/// +/// This type holds a trait object internally of the original IO that +/// was used to speak HTTP before the upgrade. It can be used directly +/// as a `Read` or `Write` for convenience. +/// +/// Alternatively, if the exact type is known, this can be deconstructed +/// into its parts. +pub struct Upgraded { + io: Rewind<Box<dyn Io + Send>>, +} + +/// A future for a possible HTTP upgrade. +/// +/// If no upgrade was available, or it doesn't succeed, yields an `Error`. +pub struct OnUpgrade { + rx: Option<oneshot::Receiver<crate::Result<Upgraded>>>, +} + +/// The deconstructed parts of an [`Upgraded`](Upgraded) type. +/// +/// Includes the original IO type, and a read buffer of bytes that the +/// HTTP state machine may have already read before completing an upgrade. +#[derive(Debug)] +pub struct Parts<T> { + /// The original IO object used before the upgrade. + pub io: T, + /// A buffer of bytes that have been read but not processed as HTTP. + /// + /// For instance, if the `Connection` is used for an HTTP upgrade request, + /// it is possible the server sent back the first bytes of the new protocol + /// along with the response upgrade. + /// + /// You will want to check for any existing bytes if you plan to continue + /// communicating on the IO object. + pub read_buf: Bytes, + _inner: (), +} + +/// Gets a pending HTTP upgrade from this message. +/// +/// This can be called on the following types: +/// +/// - `http::Request<B>` +/// - `http::Response<B>` +/// - `&mut http::Request<B>` +/// - `&mut http::Response<B>` +pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade { + msg.on_upgrade() +} + +#[cfg(any(feature = "http1", feature = "http2"))] +pub(super) struct Pending { + tx: oneshot::Sender<crate::Result<Upgraded>>, +} + +#[cfg(any(feature = "http1", feature = "http2"))] +pub(super) fn pending() -> (Pending, OnUpgrade) { + let (tx, rx) = oneshot::channel(); + (Pending { tx }, OnUpgrade { rx: Some(rx) }) +} + +// ===== impl Upgraded ===== + +impl Upgraded { + #[cfg(any(feature = "http1", feature = "http2", test))] + pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + Upgraded { + io: Rewind::new_buffered(Box::new(io), read_buf), + } + } + + /// Tries to downcast the internal trait object to the type passed. + /// + /// On success, returns the downcasted parts. On error, returns the + /// `Upgraded` back. + pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> { + let (io, buf) = self.io.into_inner(); + match io.__hyper_downcast() { + Ok(t) => Ok(Parts { + io: *t, + read_buf: buf, + _inner: (), + }), + Err(io) => Err(Upgraded { + io: Rewind::new_buffered(io, buf), + }), + } + } +} + +impl AsyncRead for Upgraded { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + Pin::new(&mut self.io).poll_read(cx, buf) + } +} + +impl AsyncWrite for Upgraded { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.io).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.io).poll_write_vectored(cx, bufs) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.io).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.io).poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.io.is_write_vectored() + } +} + +impl fmt::Debug for Upgraded { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Upgraded").finish() + } +} + +// ===== impl OnUpgrade ===== + +impl OnUpgrade { + pub(super) fn none() -> Self { + OnUpgrade { rx: None } + } + + #[cfg(feature = "http1")] + pub(super) fn is_none(&self) -> bool { + self.rx.is_none() + } +} + +impl Future for OnUpgrade { + type Output = Result<Upgraded, crate::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match self.rx { + Some(ref mut rx) => Pin::new(rx).poll(cx).map(|res| match res { + Ok(Ok(upgraded)) => Ok(upgraded), + Ok(Err(err)) => Err(err), + Err(_oneshot_canceled) => Err(crate::Error::new_canceled().with(UpgradeExpected)), + }), + None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())), + } + } +} + +impl fmt::Debug for OnUpgrade { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OnUpgrade").finish() + } +} + +// ===== impl Pending ===== + +#[cfg(any(feature = "http1", feature = "http2"))] +impl Pending { + pub(super) fn fulfill(self, upgraded: Upgraded) { + trace!("pending upgrade fulfill"); + let _ = self.tx.send(Ok(upgraded)); + } + + #[cfg(feature = "http1")] + /// Don't fulfill the pending Upgrade, but instead signal that + /// upgrades are handled manually. + pub(super) fn manual(self) { + trace!("pending upgrade handled manually"); + let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade())); + } +} + +// ===== impl UpgradeExpected ===== + +/// Error cause returned when an upgrade was expected but canceled +/// for whatever reason. +/// +/// This likely means the actual `Conn` future wasn't polled and upgraded. +#[derive(Debug)] +struct UpgradeExpected; + +impl fmt::Display for UpgradeExpected { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("upgrade expected but not completed") + } +} + +impl StdError for UpgradeExpected {} + +// ===== impl Io ===== + +pub(super) trait Io: AsyncRead + AsyncWrite + Unpin + 'static { + fn __hyper_type_id(&self) -> TypeId { + TypeId::of::<Self>() + } +} + +impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {} + +impl dyn Io + Send { + fn __hyper_is<T: Io>(&self) -> bool { + let t = TypeId::of::<T>(); + self.__hyper_type_id() == t + } + + fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> { + if self.__hyper_is::<T>() { + // Taken from `std::error::Error::downcast()`. + unsafe { + let raw: *mut dyn Io = Box::into_raw(self); + Ok(Box::from_raw(raw as *mut T)) + } + } else { + Err(self) + } + } +} + +mod sealed { + use super::OnUpgrade; + + pub trait CanUpgrade { + fn on_upgrade(self) -> OnUpgrade; + } + + impl<B> CanUpgrade for http::Request<B> { + fn on_upgrade(mut self) -> OnUpgrade { + self.extensions_mut() + .remove::<OnUpgrade>() + .unwrap_or_else(OnUpgrade::none) + } + } + + impl<B> CanUpgrade for &'_ mut http::Request<B> { + fn on_upgrade(self) -> OnUpgrade { + self.extensions_mut() + .remove::<OnUpgrade>() + .unwrap_or_else(OnUpgrade::none) + } + } + + impl<B> CanUpgrade for http::Response<B> { + fn on_upgrade(mut self) -> OnUpgrade { + self.extensions_mut() + .remove::<OnUpgrade>() + .unwrap_or_else(OnUpgrade::none) + } + } + + impl<B> CanUpgrade for &'_ mut http::Response<B> { + fn on_upgrade(self) -> OnUpgrade { + self.extensions_mut() + .remove::<OnUpgrade>() + .unwrap_or_else(OnUpgrade::none) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn upgraded_downcast() { + let upgraded = Upgraded::new(Mock, Bytes::new()); + + let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err(); + + upgraded.downcast::<Mock>().unwrap(); + } + + // TODO: replace with tokio_test::io when it can test write_buf + struct Mock; + + impl AsyncRead for Mock { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + unreachable!("Mock::poll_read") + } + } + + impl AsyncWrite for Mock { + fn poll_write( + self: Pin<&mut Self>, + _: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + // panic!("poll_write shouldn't be called"); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + unreachable!("Mock::poll_flush") + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + ) -> Poll<io::Result<()>> { + unreachable!("Mock::poll_shutdown") + } + } +} |