use std::fmt; use std::fs::File; use std::future::Future; use std::io::{self, BufReader, Cursor, Read}; use std::net::SocketAddr; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use futures_util::ready; use hyper::server::accept::Accept; use hyper::server::conn::{AddrIncoming, AddrStream}; use crate::transport::Transport; use tokio_rustls::rustls::{ server::{AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, NoClientAuth}, Certificate, Error as TlsError, PrivateKey, RootCertStore, ServerConfig, }; /// Represents errors that can occur building the TlsConfig #[derive(Debug)] pub(crate) enum TlsConfigError { Io(io::Error), /// An Error parsing the Certificate CertParseError, /// An Error parsing a Pkcs8 key Pkcs8ParseError, /// An Error parsing a Rsa key RsaParseError, /// An error from an empty key EmptyKey, /// An error from an invalid key InvalidKey(TlsError), } impl fmt::Display for TlsConfigError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { TlsConfigError::Io(err) => err.fmt(f), TlsConfigError::CertParseError => write!(f, "certificate parse error"), TlsConfigError::Pkcs8ParseError => write!(f, "pkcs8 parse error"), TlsConfigError::RsaParseError => write!(f, "rsa parse error"), TlsConfigError::EmptyKey => write!(f, "key contains no private key"), TlsConfigError::InvalidKey(err) => write!(f, "key contains an invalid key, {}", err), } } } impl std::error::Error for TlsConfigError {} /// Tls client authentication configuration. pub(crate) enum TlsClientAuth { /// No client auth. Off, /// Allow any anonymous or authenticated client. Optional(Box), /// Allow any authenticated client. Required(Box), } /// Builder to set the configuration for the Tls server. pub(crate) struct TlsConfigBuilder { cert: Box, key: Box, client_auth: TlsClientAuth, ocsp_resp: Vec, } impl fmt::Debug for TlsConfigBuilder { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("TlsConfigBuilder").finish() } } impl TlsConfigBuilder { /// Create a new TlsConfigBuilder pub(crate) fn new() -> TlsConfigBuilder { TlsConfigBuilder { key: Box::new(io::empty()), cert: Box::new(io::empty()), client_auth: TlsClientAuth::Off, ocsp_resp: Vec::new(), } } /// sets the Tls key via File Path, returns `TlsConfigError::IoError` if the file cannot be open pub(crate) fn key_path(mut self, path: impl AsRef) -> Self { self.key = Box::new(LazyFile { path: path.as_ref().into(), file: None, }); self } /// sets the Tls key via bytes slice pub(crate) fn key(mut self, key: &[u8]) -> Self { self.key = Box::new(Cursor::new(Vec::from(key))); self } /// Specify the file path for the TLS certificate to use. pub(crate) fn cert_path(mut self, path: impl AsRef) -> Self { self.cert = Box::new(LazyFile { path: path.as_ref().into(), file: None, }); self } /// sets the Tls certificate via bytes slice pub(crate) fn cert(mut self, cert: &[u8]) -> Self { self.cert = Box::new(Cursor::new(Vec::from(cert))); self } /// Sets the trust anchor for optional Tls client authentication via file path. /// /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any /// of the `client_auth_` methods, then client authentication is disabled by default. pub(crate) fn client_auth_optional_path(mut self, path: impl AsRef) -> Self { let file = Box::new(LazyFile { path: path.as_ref().into(), file: None, }); self.client_auth = TlsClientAuth::Optional(file); self } /// Sets the trust anchor for optional Tls client authentication via bytes slice. /// /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any /// of the `client_auth_` methods, then client authentication is disabled by default. pub(crate) fn client_auth_optional(mut self, trust_anchor: &[u8]) -> Self { let cursor = Box::new(Cursor::new(Vec::from(trust_anchor))); self.client_auth = TlsClientAuth::Optional(cursor); self } /// Sets the trust anchor for required Tls client authentication via file path. /// /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the /// `client_auth_` methods, then client authentication is disabled by default. pub(crate) fn client_auth_required_path(mut self, path: impl AsRef) -> Self { let file = Box::new(LazyFile { path: path.as_ref().into(), file: None, }); self.client_auth = TlsClientAuth::Required(file); self } /// Sets the trust anchor for required Tls client authentication via bytes slice. /// /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the /// `client_auth_` methods, then client authentication is disabled by default. pub(crate) fn client_auth_required(mut self, trust_anchor: &[u8]) -> Self { let cursor = Box::new(Cursor::new(Vec::from(trust_anchor))); self.client_auth = TlsClientAuth::Required(cursor); self } /// sets the DER-encoded OCSP response pub(crate) fn ocsp_resp(mut self, ocsp_resp: &[u8]) -> Self { self.ocsp_resp = Vec::from(ocsp_resp); self } pub(crate) fn build(mut self) -> Result { let mut cert_rdr = BufReader::new(self.cert); let cert = rustls_pemfile::certs(&mut cert_rdr) .map_err(|_e| TlsConfigError::CertParseError)? .into_iter() .map(Certificate) .collect(); let key = { // convert it to Vec to allow reading it again if key is RSA let mut key_vec = Vec::new(); self.key .read_to_end(&mut key_vec) .map_err(TlsConfigError::Io)?; if key_vec.is_empty() { return Err(TlsConfigError::EmptyKey); } let mut pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut key_vec.as_slice()) .map_err(|_e| TlsConfigError::Pkcs8ParseError)?; if !pkcs8.is_empty() { PrivateKey(pkcs8.remove(0)) } else { let mut rsa = rustls_pemfile::rsa_private_keys(&mut key_vec.as_slice()) .map_err(|_e| TlsConfigError::RsaParseError)?; if !rsa.is_empty() { PrivateKey(rsa.remove(0)) } else { return Err(TlsConfigError::EmptyKey); } } }; fn read_trust_anchor( trust_anchor: Box, ) -> Result { let trust_anchors = { let mut reader = BufReader::new(trust_anchor); rustls_pemfile::certs(&mut reader).map_err(TlsConfigError::Io)? }; let mut store = RootCertStore::empty(); let (added, _skipped) = store.add_parsable_certificates(&trust_anchors); if added == 0 { return Err(TlsConfigError::CertParseError); } Ok(store) } let client_auth = match self.client_auth { TlsClientAuth::Off => NoClientAuth::new(), TlsClientAuth::Optional(trust_anchor) => { AllowAnyAnonymousOrAuthenticatedClient::new(read_trust_anchor(trust_anchor)?) } TlsClientAuth::Required(trust_anchor) => { AllowAnyAuthenticatedClient::new(read_trust_anchor(trust_anchor)?) } }; let mut config = ServerConfig::builder() .with_safe_defaults() .with_client_cert_verifier(client_auth.into()) .with_single_cert_with_ocsp_and_sct(cert, key, self.ocsp_resp, Vec::new()) .map_err(TlsConfigError::InvalidKey)?; config.alpn_protocols = vec!["h2".into(), "http/1.1".into()]; Ok(config) } } struct LazyFile { path: PathBuf, file: Option, } impl LazyFile { fn lazy_read(&mut self, buf: &mut [u8]) -> io::Result { if self.file.is_none() { self.file = Some(File::open(&self.path)?); } self.file.as_mut().unwrap().read(buf) } } impl Read for LazyFile { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.lazy_read(buf).map_err(|err| { let kind = err.kind(); io::Error::new( kind, format!("error reading file ({:?}): {}", self.path.display(), err), ) }) } } impl Transport for TlsStream { fn remote_addr(&self) -> Option { Some(self.remote_addr) } } enum State { Handshaking(tokio_rustls::Accept), Streaming(tokio_rustls::server::TlsStream), } // tokio_rustls::server::TlsStream doesn't expose constructor methods, // so we have to TlsAcceptor::accept and handshake to have access to it // TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first pub(crate) struct TlsStream { state: State, remote_addr: SocketAddr, } impl TlsStream { fn new(stream: AddrStream, config: Arc) -> TlsStream { let remote_addr = stream.remote_addr(); let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); TlsStream { state: State::Handshaking(accept), remote_addr, } } } impl AsyncRead for TlsStream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let pin = self.get_mut(); match pin.state { State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { Ok(mut stream) => { let result = Pin::new(&mut stream).poll_read(cx, buf); pin.state = State::Streaming(stream); result } Err(err) => Poll::Ready(Err(err)), }, State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf), } } } impl AsyncWrite for TlsStream { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let pin = self.get_mut(); match pin.state { State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { Ok(mut stream) => { let result = Pin::new(&mut stream).poll_write(cx, buf); pin.state = State::Streaming(stream); result } Err(err) => Poll::Ready(Err(err)), }, State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.state { State::Handshaking(_) => Poll::Ready(Ok(())), State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx), } } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.state { State::Handshaking(_) => Poll::Ready(Ok(())), State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx), } } } pub(crate) struct TlsAcceptor { config: Arc, incoming: AddrIncoming, } impl TlsAcceptor { pub(crate) fn new(config: ServerConfig, incoming: AddrIncoming) -> TlsAcceptor { TlsAcceptor { config: Arc::new(config), incoming, } } } impl Accept for TlsAcceptor { type Conn = TlsStream; type Error = io::Error; fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { let pin = self.get_mut(); match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))), Some(Err(e)) => Poll::Ready(Some(Err(e))), None => Poll::Ready(None), } } } #[cfg(test)] mod tests { use super::*; #[test] fn file_cert_key() { TlsConfigBuilder::new() .key_path("examples/tls/key.rsa") .cert_path("examples/tls/cert.pem") .build() .unwrap(); } #[test] fn bytes_cert_key() { let key = include_str!("../examples/tls/key.rsa"); let cert = include_str!("../examples/tls/cert.pem"); TlsConfigBuilder::new() .key(key.as_bytes()) .cert(cert.as_bytes()) .build() .unwrap(); } }