use std::io; use std::io::ErrorKind::WouldBlock; #[cfg(any(feature = "ssl", feature = "nativetls"))] use std::mem::replace; use std::net::SocketAddr; use bytes::{Buf, BufMut}; use mio::tcp::TcpStream; #[cfg(feature = "nativetls")] use native_tls::{ HandshakeError, MidHandshakeTlsStream as MidHandshakeSslStream, TlsStream as SslStream, }; #[cfg(feature = "ssl")] use openssl::ssl::{ErrorCode as SslErrorCode, HandshakeError, MidHandshakeSslStream, SslStream}; use result::{Error, Kind, Result}; fn map_non_block(res: io::Result) -> io::Result> { match res { Ok(value) => Ok(Some(value)), Err(err) => { if let WouldBlock = err.kind() { Ok(None) } else { Err(err) } } } } pub trait TryReadBuf: io::Read { fn try_read_buf(&mut self, buf: &mut B) -> io::Result> where Self: Sized, { // Reads the length of the slice supplied by buf.mut_bytes into the buffer // This is not guaranteed to consume an entire datagram or segment. // If your protocol is msg based (instead of continuous stream) you should // ensure that your buffer is large enough to hold an entire segment (1532 bytes if not jumbo // frames) let res = map_non_block(self.read(unsafe { buf.bytes_mut() })); if let Ok(Some(cnt)) = res { unsafe { buf.advance_mut(cnt); } } res } } pub trait TryWriteBuf: io::Write { fn try_write_buf(&mut self, buf: &mut B) -> io::Result> where Self: Sized, { let res = map_non_block(self.write(buf.bytes())); if let Ok(Some(cnt)) = res { buf.advance(cnt); } res } } impl TryReadBuf for T {} impl TryWriteBuf for T {} use self::Stream::*; pub enum Stream { Tcp(TcpStream), #[cfg(any(feature = "ssl", feature = "nativetls"))] Tls(TlsStream), } impl Stream { pub fn tcp(stream: TcpStream) -> Stream { Tcp(stream) } #[cfg(any(feature = "ssl", feature = "nativetls"))] pub fn tls(stream: MidHandshakeSslStream) -> Stream { Tls(TlsStream::Handshake { sock: stream, negotiating: false, }) } #[cfg(any(feature = "ssl", feature = "nativetls"))] pub fn tls_live(stream: SslStream) -> Stream { Tls(TlsStream::Live(stream)) } #[cfg(any(feature = "ssl", feature = "nativetls"))] pub fn is_tls(&self) -> bool { match *self { Tcp(_) => false, Tls(_) => true, } } pub fn evented(&self) -> &TcpStream { match *self { Tcp(ref sock) => sock, #[cfg(any(feature = "ssl", feature = "nativetls"))] Tls(ref inner) => inner.evented(), } } pub fn is_negotiating(&self) -> bool { match *self { Tcp(_) => false, #[cfg(any(feature = "ssl", feature = "nativetls"))] Tls(ref inner) => inner.is_negotiating(), } } pub fn clear_negotiating(&mut self) -> Result<()> { match *self { Tcp(_) => Err(Error::new( Kind::Internal, "Attempted to clear negotiating flag on non ssl connection.", )), #[cfg(any(feature = "ssl", feature = "nativetls"))] Tls(ref mut inner) => inner.clear_negotiating(), } } pub fn peer_addr(&self) -> io::Result { match *self { Tcp(ref sock) => sock.peer_addr(), #[cfg(any(feature = "ssl", feature = "nativetls"))] Tls(ref inner) => inner.peer_addr(), } } pub fn local_addr(&self) -> io::Result { match *self { Tcp(ref sock) => sock.local_addr(), #[cfg(any(feature = "ssl", feature = "nativetls"))] Tls(ref inner) => inner.local_addr(), } } } impl io::Read for Stream { fn read(&mut self, buf: &mut [u8]) -> io::Result { match *self { Tcp(ref mut sock) => sock.read(buf), #[cfg(any(feature = "ssl", feature = "nativetls"))] Tls(TlsStream::Live(ref mut sock)) => sock.read(buf), #[cfg(any(feature = "ssl", feature = "nativetls"))] Tls(ref mut tls_stream) => { trace!("Attempting to read ssl handshake."); match replace(tls_stream, TlsStream::Upgrading) { TlsStream::Live(_) | TlsStream::Upgrading => unreachable!(), TlsStream::Handshake { sock, mut negotiating, } => match sock.handshake() { Ok(mut sock) => { trace!("Completed SSL Handshake"); let res = sock.read(buf); *tls_stream = TlsStream::Live(sock); res } #[cfg(feature = "ssl")] Err(HandshakeError::SetupFailure(err)) => { Err(io::Error::new(io::ErrorKind::Other, err)) } #[cfg(feature = "ssl")] Err(HandshakeError::Failure(mid)) | Err(HandshakeError::WouldBlock(mid)) => { if mid.error().code() == SslErrorCode::WANT_READ { negotiating = true; } let err = if let Some(io_error) = mid.error().io_error() { Err(io::Error::new( io_error.kind(), format!("{:?}", io_error.get_ref()), )) } else { Err(io::Error::new( io::ErrorKind::Other, format!("{}", mid.error()), )) }; *tls_stream = TlsStream::Handshake { sock: mid, negotiating, }; err } #[cfg(feature = "nativetls")] Err(HandshakeError::WouldBlock(mid)) => { negotiating = true; *tls_stream = TlsStream::Handshake { sock: mid, negotiating: negotiating, }; Err(io::Error::new(io::ErrorKind::WouldBlock, "SSL would block")) } #[cfg(feature = "nativetls")] Err(HandshakeError::Failure(err)) => { Err(io::Error::new(io::ErrorKind::Other, format!("{}", err))) } }, } } } } } impl io::Write for Stream { fn write(&mut self, buf: &[u8]) -> io::Result { match *self { Tcp(ref mut sock) => sock.write(buf), #[cfg(any(feature = "ssl", feature = "nativetls"))] Tls(TlsStream::Live(ref mut sock)) => sock.write(buf), #[cfg(any(feature = "ssl", feature = "nativetls"))] Tls(ref mut tls_stream) => { trace!("Attempting to write ssl handshake."); match replace(tls_stream, TlsStream::Upgrading) { TlsStream::Live(_) | TlsStream::Upgrading => unreachable!(), TlsStream::Handshake { sock, mut negotiating, } => match sock.handshake() { Ok(mut sock) => { trace!("Completed SSL Handshake"); let res = sock.write(buf); *tls_stream = TlsStream::Live(sock); res } #[cfg(feature = "ssl")] Err(HandshakeError::SetupFailure(err)) => { Err(io::Error::new(io::ErrorKind::Other, err)) } #[cfg(feature = "ssl")] Err(HandshakeError::Failure(mid)) | Err(HandshakeError::WouldBlock(mid)) => { if mid.error().code() == SslErrorCode::WANT_READ { negotiating = true; } else { negotiating = false; } let err = if let Some(io_error) = mid.error().io_error() { Err(io::Error::new( io_error.kind(), format!("{:?}", io_error.get_ref()), )) } else { Err(io::Error::new( io::ErrorKind::Other, format!("{}", mid.error()), )) }; *tls_stream = TlsStream::Handshake { sock: mid, negotiating, }; err } #[cfg(feature = "nativetls")] Err(HandshakeError::WouldBlock(mid)) => { negotiating = true; *tls_stream = TlsStream::Handshake { sock: mid, negotiating: negotiating, }; Err(io::Error::new(io::ErrorKind::WouldBlock, "SSL would block")) } #[cfg(feature = "nativetls")] Err(HandshakeError::Failure(err)) => { Err(io::Error::new(io::ErrorKind::Other, format!("{}", err))) } }, } } } } fn flush(&mut self) -> io::Result<()> { match *self { Tcp(ref mut sock) => sock.flush(), #[cfg(any(feature = "ssl", feature = "nativetls"))] Tls(TlsStream::Live(ref mut sock)) => sock.flush(), #[cfg(any(feature = "ssl", feature = "nativetls"))] Tls(TlsStream::Handshake { ref mut sock, .. }) => sock.get_mut().flush(), #[cfg(any(feature = "ssl", feature = "nativetls"))] Tls(TlsStream::Upgrading) => panic!("Tried to access actively upgrading TlsStream"), } } } #[cfg(any(feature = "ssl", feature = "nativetls"))] pub enum TlsStream { Live(SslStream), Handshake { sock: MidHandshakeSslStream, negotiating: bool, }, Upgrading, } #[cfg(any(feature = "ssl", feature = "nativetls"))] impl TlsStream { pub fn evented(&self) -> &TcpStream { match *self { TlsStream::Live(ref sock) => sock.get_ref(), TlsStream::Handshake { ref sock, .. } => sock.get_ref(), TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), } } pub fn is_negotiating(&self) -> bool { match *self { TlsStream::Live(_) => false, TlsStream::Handshake { sock: _, negotiating, } => negotiating, TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), } } pub fn clear_negotiating(&mut self) -> Result<()> { match *self { TlsStream::Live(_) => Err(Error::new( Kind::Internal, "Attempted to clear negotiating flag on live ssl connection.", )), TlsStream::Handshake { sock: _, ref mut negotiating, } => Ok(*negotiating = false), TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), } } pub fn peer_addr(&self) -> io::Result { match *self { TlsStream::Live(ref sock) => sock.get_ref().peer_addr(), TlsStream::Handshake { ref sock, .. } => sock.get_ref().peer_addr(), TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), } } pub fn local_addr(&self) -> io::Result { match *self { TlsStream::Live(ref sock) => sock.get_ref().local_addr(), TlsStream::Handshake { ref sock, .. } => sock.get_ref().local_addr(), TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), } } }