#![doc(html_root_url = "https://docs.rs/tokio-native-tls/0.3.0")] #![warn( missing_debug_implementations, missing_docs, rust_2018_idioms, unreachable_pub )] #![deny(rustdoc::broken_intra_doc_links)] #![doc(test( no_crate_inject, attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) ))] //! Async TLS streams //! //! This library is an implementation of TLS streams using the most appropriate //! system library by default for negotiating the connection. That is, on //! Windows this library uses SChannel, on OSX it uses SecureTransport, and on //! other platforms it uses OpenSSL. //! //! Each TLS stream implements the `Read` and `Write` traits to interact and //! interoperate with the rest of the futures I/O ecosystem. Client connections //! initiated from this crate verify hostnames automatically and by default. //! //! This crate primarily exports this ability through two newtypes, //! `TlsConnector` and `TlsAcceptor`. These newtypes augment the //! functionality provided by the `native-tls` crate, on which this crate is //! built. Configuration of TLS parameters is still primarily done through the //! `native-tls` crate. use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::native_tls::{Error, HandshakeError, MidHandshakeTlsStream}; use std::fmt; use std::future::Future; use std::io::{self, Read, Write}; use std::marker::Unpin; #[cfg(unix)] use std::os::unix::io::{AsRawFd, RawFd}; #[cfg(windows)] use std::os::windows::io::{AsRawSocket, RawSocket}; use std::pin::Pin; use std::ptr::null_mut; use std::task::{Context, Poll}; /// An intermediate wrapper for the inner stream `S`. #[derive(Debug)] pub struct AllowStd { inner: S, context: *mut (), } impl AllowStd { /// Returns a shared reference to the inner stream. pub fn get_ref(&self) -> &S { &self.inner } /// Returns a mutable reference to the inner stream. pub fn get_mut(&mut self) -> &mut S { &mut self.inner } } /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. /// /// A `TlsStream` represents a handshake that has been completed successfully /// and both the server and the client are ready for receiving and sending /// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written /// to a `TlsStream` are encrypted when passing through to `S`. #[derive(Debug)] pub struct TlsStream(native_tls::TlsStream>); /// A wrapper around a `native_tls::TlsConnector`, providing an async `connect` /// method. #[derive(Clone)] pub struct TlsConnector(native_tls::TlsConnector); /// A wrapper around a `native_tls::TlsAcceptor`, providing an async `accept` /// method. #[derive(Clone)] pub struct TlsAcceptor(native_tls::TlsAcceptor); struct MidHandshake(Option>>); enum StartedHandshake { Done(TlsStream), Mid(MidHandshakeTlsStream>), } struct StartedHandshakeFuture(Option>); struct StartedHandshakeFutureInner { f: F, stream: S, } struct Guard<'a, S>(&'a mut TlsStream) where AllowStd: Read + Write; impl Drop for Guard<'_, S> where AllowStd: Read + Write, { fn drop(&mut self) { (self.0).0.get_mut().context = null_mut(); } } // *mut () context is neither Send nor Sync unsafe impl Send for AllowStd {} unsafe impl Sync for AllowStd {} impl AllowStd where S: Unpin, { fn with_context(&mut self, f: F) -> io::Result where F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll>, { unsafe { assert!(!self.context.is_null()); let waker = &mut *(self.context as *mut _); match f(waker, Pin::new(&mut self.inner)) { Poll::Ready(r) => r, Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), } } } } impl Read for AllowStd where S: AsyncRead + Unpin, { fn read(&mut self, buf: &mut [u8]) -> io::Result { let mut buf = ReadBuf::new(buf); self.with_context(|ctx, stream| stream.poll_read(ctx, &mut buf))?; Ok(buf.filled().len()) } } impl Write for AllowStd where S: AsyncWrite + Unpin, { fn write(&mut self, buf: &[u8]) -> io::Result { self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) } fn flush(&mut self) -> io::Result<()> { self.with_context(|ctx, stream| stream.poll_flush(ctx)) } } impl TlsStream { fn with_context(&mut self, ctx: &mut Context<'_>, f: F) -> Poll> where F: FnOnce(&mut native_tls::TlsStream>) -> io::Result, AllowStd: Read + Write, { self.0.get_mut().context = ctx as *mut _ as *mut (); let g = Guard(self); match f(&mut (g.0).0) { Ok(v) => Poll::Ready(Ok(v)), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(e) => Poll::Ready(Err(e)), } } /// Returns a shared reference to the inner stream. pub fn get_ref(&self) -> &native_tls::TlsStream> { &self.0 } /// Returns a mutable reference to the inner stream. pub fn get_mut(&mut self) -> &mut native_tls::TlsStream> { &mut self.0 } } impl AsyncRead for TlsStream where S: AsyncRead + AsyncWrite + Unpin, { fn poll_read( mut self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { self.with_context(ctx, |s| { let n = s.read(buf.initialize_unfilled())?; buf.advance(n); Ok(()) }) } } impl AsyncWrite for TlsStream where S: AsyncRead + AsyncWrite + Unpin, { fn poll_write( mut self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &[u8], ) -> Poll> { self.with_context(ctx, |s| s.write(buf)) } fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { self.with_context(ctx, |s| s.flush()) } fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { self.with_context(ctx, |s| s.shutdown()) } } #[cfg(unix)] impl AsRawFd for TlsStream where S: AsRawFd, { fn as_raw_fd(&self) -> RawFd { self.get_ref().get_ref().get_ref().as_raw_fd() } } #[cfg(windows)] impl AsRawSocket for TlsStream where S: AsRawSocket, { fn as_raw_socket(&self) -> RawSocket { self.get_ref().get_ref().get_ref().as_raw_socket() } } async fn handshake(f: F, stream: S) -> Result, Error> where F: FnOnce( AllowStd, ) -> Result>, HandshakeError>> + Unpin, S: AsyncRead + AsyncWrite + Unpin, { let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream })); match start.await { Err(e) => Err(e), Ok(StartedHandshake::Done(s)) => Ok(s), Ok(StartedHandshake::Mid(s)) => MidHandshake(Some(s)).await, } } impl Future for StartedHandshakeFuture where F: FnOnce( AllowStd, ) -> Result>, HandshakeError>> + Unpin, S: Unpin, AllowStd: Read + Write, { type Output = Result, Error>; fn poll( mut self: Pin<&mut Self>, ctx: &mut Context<'_>, ) -> Poll, Error>> { let inner = self.0.take().expect("future polled after completion"); let stream = AllowStd { inner: inner.stream, context: ctx as *mut _ as *mut (), }; match (inner.f)(stream) { Ok(mut s) => { s.get_mut().context = null_mut(); Poll::Ready(Ok(StartedHandshake::Done(TlsStream(s)))) } Err(HandshakeError::WouldBlock(mut s)) => { s.get_mut().context = null_mut(); Poll::Ready(Ok(StartedHandshake::Mid(s))) } Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), } } } impl TlsConnector { /// Connects the provided stream with this connector, assuming the provided /// domain. /// /// This function will internally call `TlsConnector::connect` to connect /// the stream and returns a future representing the resolution of the /// connection operation. The returned future will resolve to either /// `TlsStream` or `Error` depending if it's successful or not. /// /// This is typically used for clients who have already established, for /// example, a TCP connection to a remote server. That stream is then /// provided here to perform the client half of a connection to a /// TLS-powered server. pub async fn connect(&self, domain: &str, stream: S) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, { handshake(move |s| self.0.connect(domain, s), stream).await } } impl fmt::Debug for TlsConnector { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("TlsConnector").finish() } } impl From for TlsConnector { fn from(inner: native_tls::TlsConnector) -> TlsConnector { TlsConnector(inner) } } impl TlsAcceptor { /// Accepts a new client connection with the provided stream. /// /// This function will internally call `TlsAcceptor::accept` to connect /// the stream and returns a future representing the resolution of the /// connection operation. The returned future will resolve to either /// `TlsStream` or `Error` depending if it's successful or not. /// /// This is typically used after a new socket has been accepted from a /// `TcpListener`. That socket is then passed to this function to perform /// the server half of accepting a client connection. pub async fn accept(&self, stream: S) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, { handshake(move |s| self.0.accept(s), stream).await } } impl fmt::Debug for TlsAcceptor { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("TlsAcceptor").finish() } } impl From for TlsAcceptor { fn from(inner: native_tls::TlsAcceptor) -> TlsAcceptor { TlsAcceptor(inner) } } impl Future for MidHandshake { type Output = Result, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut_self = self.get_mut(); let mut s = mut_self.0.take().expect("future polled after completion"); s.get_mut().context = cx as *mut _ as *mut (); match s.handshake() { Ok(mut s) => { s.get_mut().context = null_mut(); Poll::Ready(Ok(TlsStream(s))) } Err(HandshakeError::WouldBlock(mut s)) => { s.get_mut().context = null_mut(); mut_self.0 = Some(s); Poll::Pending } Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), } } } /// re-export native_tls pub mod native_tls { pub use native_tls::*; }