//! HTTP Upgrades //! //! This module deals with managing [HTTP Upgrades][mdn] in hyper. Since //! several concepts in HTTP allow for first talking HTTP, and then converting //! to a different protocol, this module conflates them into a single API. //! Those include: //! //! - HTTP/1.1 Upgrades //! - HTTP `CONNECT` //! //! You are responsible for any other pre-requisites to establish an upgrade, //! such as sending the appropriate headers, methods, and status codes. You can //! then use [`on`][] to grab a `Future` which will resolve to the upgraded //! connection object, or an error if the upgrade fails. //! //! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism //! //! # Client //! //! Sending an HTTP upgrade from the [`client`](super::client) involves setting //! either the appropriate method, if wanting to `CONNECT`, or headers such as //! `Upgrade` and `Connection`, on the `http::Request`. Once receiving the //! `http::Response` back, you must check for the specific information that the //! upgrade is agreed upon by the server (such as a `101` status code), and then //! get the `Future` from the `Response`. //! //! # Server //! //! Receiving upgrade requests in a server requires you to check the relevant //! headers in a `Request`, and if an upgrade should be done, you then send the //! corresponding headers in a response. To then wait for hyper to finish the //! upgrade, you call `on()` with the `Request`, and then can spawn a task //! awaiting it. //! //! # Example //! //! See [this example][example] showing how upgrades work with both //! Clients and Servers. //! //! [example]: https://github.com/hyperium/hyper/blob/master/examples/upgrades.rs use std::any::TypeId; use std::error::Error as StdError; use std::fmt; use std::io; use std::marker::Unpin; use bytes::Bytes; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::sync::oneshot; #[cfg(any(feature = "http1", feature = "http2"))] use tracing::trace; use crate::common::io::Rewind; use crate::common::{task, Future, Pin, Poll}; /// An upgraded HTTP connection. /// /// This type holds a trait object internally of the original IO that /// was used to speak HTTP before the upgrade. It can be used directly /// as a `Read` or `Write` for convenience. /// /// Alternatively, if the exact type is known, this can be deconstructed /// into its parts. pub struct Upgraded { io: Rewind>, } /// A future for a possible HTTP upgrade. /// /// If no upgrade was available, or it doesn't succeed, yields an `Error`. pub struct OnUpgrade { rx: Option>>, } /// The deconstructed parts of an [`Upgraded`](Upgraded) type. /// /// Includes the original IO type, and a read buffer of bytes that the /// HTTP state machine may have already read before completing an upgrade. #[derive(Debug)] pub struct Parts { /// The original IO object used before the upgrade. pub io: T, /// A buffer of bytes that have been read but not processed as HTTP. /// /// For instance, if the `Connection` is used for an HTTP upgrade request, /// it is possible the server sent back the first bytes of the new protocol /// along with the response upgrade. /// /// You will want to check for any existing bytes if you plan to continue /// communicating on the IO object. pub read_buf: Bytes, _inner: (), } /// Gets a pending HTTP upgrade from this message. /// /// This can be called on the following types: /// /// - `http::Request` /// - `http::Response` /// - `&mut http::Request` /// - `&mut http::Response` pub fn on(msg: T) -> OnUpgrade { msg.on_upgrade() } #[cfg(any(feature = "http1", feature = "http2"))] pub(super) struct Pending { tx: oneshot::Sender>, } #[cfg(any(feature = "http1", feature = "http2"))] pub(super) fn pending() -> (Pending, OnUpgrade) { let (tx, rx) = oneshot::channel(); (Pending { tx }, OnUpgrade { rx: Some(rx) }) } // ===== impl Upgraded ===== impl Upgraded { #[cfg(any(feature = "http1", feature = "http2", test))] pub(super) fn new(io: T, read_buf: Bytes) -> Self where T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { Upgraded { io: Rewind::new_buffered(Box::new(io), read_buf), } } /// Tries to downcast the internal trait object to the type passed. /// /// On success, returns the downcasted parts. On error, returns the /// `Upgraded` back. pub fn downcast(self) -> Result, Self> { let (io, buf) = self.io.into_inner(); match io.__hyper_downcast() { Ok(t) => Ok(Parts { io: *t, read_buf: buf, _inner: (), }), Err(io) => Err(Upgraded { io: Rewind::new_buffered(io, buf), }), } } } impl AsyncRead for Upgraded { fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { Pin::new(&mut self.io).poll_read(cx, buf) } } impl AsyncWrite for Upgraded { fn poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll> { Pin::new(&mut self.io).poll_write(cx, buf) } fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { Pin::new(&mut self.io).poll_write_vectored(cx, bufs) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { Pin::new(&mut self.io).poll_flush(cx) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { Pin::new(&mut self.io).poll_shutdown(cx) } fn is_write_vectored(&self) -> bool { self.io.is_write_vectored() } } impl fmt::Debug for Upgraded { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Upgraded").finish() } } // ===== impl OnUpgrade ===== impl OnUpgrade { pub(super) fn none() -> Self { OnUpgrade { rx: None } } #[cfg(feature = "http1")] pub(super) fn is_none(&self) -> bool { self.rx.is_none() } } impl Future for OnUpgrade { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { match self.rx { Some(ref mut rx) => Pin::new(rx).poll(cx).map(|res| match res { Ok(Ok(upgraded)) => Ok(upgraded), Ok(Err(err)) => Err(err), Err(_oneshot_canceled) => Err(crate::Error::new_canceled().with(UpgradeExpected)), }), None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())), } } } impl fmt::Debug for OnUpgrade { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("OnUpgrade").finish() } } // ===== impl Pending ===== #[cfg(any(feature = "http1", feature = "http2"))] impl Pending { pub(super) fn fulfill(self, upgraded: Upgraded) { trace!("pending upgrade fulfill"); let _ = self.tx.send(Ok(upgraded)); } #[cfg(feature = "http1")] /// Don't fulfill the pending Upgrade, but instead signal that /// upgrades are handled manually. pub(super) fn manual(self) { trace!("pending upgrade handled manually"); let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade())); } } // ===== impl UpgradeExpected ===== /// Error cause returned when an upgrade was expected but canceled /// for whatever reason. /// /// This likely means the actual `Conn` future wasn't polled and upgraded. #[derive(Debug)] struct UpgradeExpected; impl fmt::Display for UpgradeExpected { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("upgrade expected but not completed") } } impl StdError for UpgradeExpected {} // ===== impl Io ===== pub(super) trait Io: AsyncRead + AsyncWrite + Unpin + 'static { fn __hyper_type_id(&self) -> TypeId { TypeId::of::() } } impl Io for T {} impl dyn Io + Send { fn __hyper_is(&self) -> bool { let t = TypeId::of::(); self.__hyper_type_id() == t } fn __hyper_downcast(self: Box) -> Result, Box> { if self.__hyper_is::() { // Taken from `std::error::Error::downcast()`. unsafe { let raw: *mut dyn Io = Box::into_raw(self); Ok(Box::from_raw(raw as *mut T)) } } else { Err(self) } } } mod sealed { use super::OnUpgrade; pub trait CanUpgrade { fn on_upgrade(self) -> OnUpgrade; } impl CanUpgrade for http::Request { fn on_upgrade(mut self) -> OnUpgrade { self.extensions_mut() .remove::() .unwrap_or_else(OnUpgrade::none) } } impl CanUpgrade for &'_ mut http::Request { fn on_upgrade(self) -> OnUpgrade { self.extensions_mut() .remove::() .unwrap_or_else(OnUpgrade::none) } } impl CanUpgrade for http::Response { fn on_upgrade(mut self) -> OnUpgrade { self.extensions_mut() .remove::() .unwrap_or_else(OnUpgrade::none) } } impl CanUpgrade for &'_ mut http::Response { fn on_upgrade(self) -> OnUpgrade { self.extensions_mut() .remove::() .unwrap_or_else(OnUpgrade::none) } } } #[cfg(test)] mod tests { use super::*; #[test] fn upgraded_downcast() { let upgraded = Upgraded::new(Mock, Bytes::new()); let upgraded = upgraded.downcast::>>().unwrap_err(); upgraded.downcast::().unwrap(); } // TODO: replace with tokio_test::io when it can test write_buf struct Mock; impl AsyncRead for Mock { fn poll_read( self: Pin<&mut Self>, _cx: &mut task::Context<'_>, _buf: &mut ReadBuf<'_>, ) -> Poll> { unreachable!("Mock::poll_read") } } impl AsyncWrite for Mock { fn poll_write( self: Pin<&mut Self>, _: &mut task::Context<'_>, buf: &[u8], ) -> Poll> { // panic!("poll_write shouldn't be called"); Poll::Ready(Ok(buf.len())) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll> { unreachable!("Mock::poll_flush") } fn poll_shutdown( self: Pin<&mut Self>, _cx: &mut task::Context<'_>, ) -> Poll> { unreachable!("Mock::poll_shutdown") } } }