diff options
Diffstat (limited to 'third_party/rust/ws/src/stream.rs')
-rw-r--r-- | third_party/rust/ws/src/stream.rs | 358 |
1 files changed, 358 insertions, 0 deletions
diff --git a/third_party/rust/ws/src/stream.rs b/third_party/rust/ws/src/stream.rs new file mode 100644 index 0000000000..3b8d9c441b --- /dev/null +++ b/third_party/rust/ws/src/stream.rs @@ -0,0 +1,358 @@ +use std::io; +use std::io::ErrorKind::WouldBlock; +#[cfg(any(feature = "ssl", feature = "nativetls"))] +use std::mem::replace; +use std::net::SocketAddr; + +use bytes::{Buf, BufMut}; +use mio::tcp::TcpStream; +#[cfg(feature = "nativetls")] +use native_tls::{ + HandshakeError, MidHandshakeTlsStream as MidHandshakeSslStream, TlsStream as SslStream, +}; +#[cfg(feature = "ssl")] +use openssl::ssl::{ErrorCode as SslErrorCode, HandshakeError, MidHandshakeSslStream, SslStream}; + +use result::{Error, Kind, Result}; + +fn map_non_block<T>(res: io::Result<T>) -> io::Result<Option<T>> { + match res { + Ok(value) => Ok(Some(value)), + Err(err) => { + if let WouldBlock = err.kind() { + Ok(None) + } else { + Err(err) + } + } + } +} + +pub trait TryReadBuf: io::Read { + fn try_read_buf<B: BufMut>(&mut self, buf: &mut B) -> io::Result<Option<usize>> + where + Self: Sized, + { + // Reads the length of the slice supplied by buf.mut_bytes into the buffer + // This is not guaranteed to consume an entire datagram or segment. + // If your protocol is msg based (instead of continuous stream) you should + // ensure that your buffer is large enough to hold an entire segment (1532 bytes if not jumbo + // frames) + let res = map_non_block(self.read(unsafe { buf.bytes_mut() })); + + if let Ok(Some(cnt)) = res { + unsafe { + buf.advance_mut(cnt); + } + } + + res + } +} + +pub trait TryWriteBuf: io::Write { + fn try_write_buf<B: Buf>(&mut self, buf: &mut B) -> io::Result<Option<usize>> + where + Self: Sized, + { + let res = map_non_block(self.write(buf.bytes())); + + if let Ok(Some(cnt)) = res { + buf.advance(cnt); + } + + res + } +} + +impl<T: io::Read> TryReadBuf for T {} +impl<T: io::Write> TryWriteBuf for T {} + +use self::Stream::*; +pub enum Stream { + Tcp(TcpStream), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(TlsStream), +} + +impl Stream { + pub fn tcp(stream: TcpStream) -> Stream { + Tcp(stream) + } + + #[cfg(any(feature = "ssl", feature = "nativetls"))] + pub fn tls(stream: MidHandshakeSslStream<TcpStream>) -> Stream { + Tls(TlsStream::Handshake { + sock: stream, + negotiating: false, + }) + } + + #[cfg(any(feature = "ssl", feature = "nativetls"))] + pub fn tls_live(stream: SslStream<TcpStream>) -> Stream { + Tls(TlsStream::Live(stream)) + } + + #[cfg(any(feature = "ssl", feature = "nativetls"))] + pub fn is_tls(&self) -> bool { + match *self { + Tcp(_) => false, + Tls(_) => true, + } + } + + pub fn evented(&self) -> &TcpStream { + match *self { + Tcp(ref sock) => sock, + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(ref inner) => inner.evented(), + } + } + + pub fn is_negotiating(&self) -> bool { + match *self { + Tcp(_) => false, + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(ref inner) => inner.is_negotiating(), + } + } + + pub fn clear_negotiating(&mut self) -> Result<()> { + match *self { + Tcp(_) => Err(Error::new( + Kind::Internal, + "Attempted to clear negotiating flag on non ssl connection.", + )), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(ref mut inner) => inner.clear_negotiating(), + } + } + + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + match *self { + Tcp(ref sock) => sock.peer_addr(), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(ref inner) => inner.peer_addr(), + } + } + + pub fn local_addr(&self) -> io::Result<SocketAddr> { + match *self { + Tcp(ref sock) => sock.local_addr(), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(ref inner) => inner.local_addr(), + } + } +} + +impl io::Read for Stream { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + match *self { + Tcp(ref mut sock) => sock.read(buf), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(TlsStream::Live(ref mut sock)) => sock.read(buf), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(ref mut tls_stream) => { + trace!("Attempting to read ssl handshake."); + match replace(tls_stream, TlsStream::Upgrading) { + TlsStream::Live(_) | TlsStream::Upgrading => unreachable!(), + TlsStream::Handshake { + sock, + mut negotiating, + } => match sock.handshake() { + Ok(mut sock) => { + trace!("Completed SSL Handshake"); + let res = sock.read(buf); + *tls_stream = TlsStream::Live(sock); + res + } + #[cfg(feature = "ssl")] + Err(HandshakeError::SetupFailure(err)) => { + Err(io::Error::new(io::ErrorKind::Other, err)) + } + #[cfg(feature = "ssl")] + Err(HandshakeError::Failure(mid)) + | Err(HandshakeError::WouldBlock(mid)) => { + if mid.error().code() == SslErrorCode::WANT_READ { + negotiating = true; + } + let err = if let Some(io_error) = mid.error().io_error() { + Err(io::Error::new( + io_error.kind(), + format!("{:?}", io_error.get_ref()), + )) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + format!("{}", mid.error()), + )) + }; + *tls_stream = TlsStream::Handshake { + sock: mid, + negotiating, + }; + err + } + #[cfg(feature = "nativetls")] + Err(HandshakeError::WouldBlock(mid)) => { + negotiating = true; + *tls_stream = TlsStream::Handshake { + sock: mid, + negotiating: negotiating, + }; + Err(io::Error::new(io::ErrorKind::WouldBlock, "SSL would block")) + } + #[cfg(feature = "nativetls")] + Err(HandshakeError::Failure(err)) => { + Err(io::Error::new(io::ErrorKind::Other, format!("{}", err))) + } + }, + } + } + } + } +} + +impl io::Write for Stream { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + match *self { + Tcp(ref mut sock) => sock.write(buf), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(TlsStream::Live(ref mut sock)) => sock.write(buf), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(ref mut tls_stream) => { + trace!("Attempting to write ssl handshake."); + match replace(tls_stream, TlsStream::Upgrading) { + TlsStream::Live(_) | TlsStream::Upgrading => unreachable!(), + TlsStream::Handshake { + sock, + mut negotiating, + } => match sock.handshake() { + Ok(mut sock) => { + trace!("Completed SSL Handshake"); + let res = sock.write(buf); + *tls_stream = TlsStream::Live(sock); + res + } + #[cfg(feature = "ssl")] + Err(HandshakeError::SetupFailure(err)) => { + Err(io::Error::new(io::ErrorKind::Other, err)) + } + #[cfg(feature = "ssl")] + Err(HandshakeError::Failure(mid)) + | Err(HandshakeError::WouldBlock(mid)) => { + if mid.error().code() == SslErrorCode::WANT_READ { + negotiating = true; + } else { + negotiating = false; + } + let err = if let Some(io_error) = mid.error().io_error() { + Err(io::Error::new( + io_error.kind(), + format!("{:?}", io_error.get_ref()), + )) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + format!("{}", mid.error()), + )) + }; + *tls_stream = TlsStream::Handshake { + sock: mid, + negotiating, + }; + err + } + #[cfg(feature = "nativetls")] + Err(HandshakeError::WouldBlock(mid)) => { + negotiating = true; + *tls_stream = TlsStream::Handshake { + sock: mid, + negotiating: negotiating, + }; + Err(io::Error::new(io::ErrorKind::WouldBlock, "SSL would block")) + } + #[cfg(feature = "nativetls")] + Err(HandshakeError::Failure(err)) => { + Err(io::Error::new(io::ErrorKind::Other, format!("{}", err))) + } + }, + } + } + } + } + + fn flush(&mut self) -> io::Result<()> { + match *self { + Tcp(ref mut sock) => sock.flush(), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(TlsStream::Live(ref mut sock)) => sock.flush(), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(TlsStream::Handshake { ref mut sock, .. }) => sock.get_mut().flush(), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(TlsStream::Upgrading) => panic!("Tried to access actively upgrading TlsStream"), + } + } +} + +#[cfg(any(feature = "ssl", feature = "nativetls"))] +pub enum TlsStream { + Live(SslStream<TcpStream>), + Handshake { + sock: MidHandshakeSslStream<TcpStream>, + negotiating: bool, + }, + Upgrading, +} + +#[cfg(any(feature = "ssl", feature = "nativetls"))] +impl TlsStream { + pub fn evented(&self) -> &TcpStream { + match *self { + TlsStream::Live(ref sock) => sock.get_ref(), + TlsStream::Handshake { ref sock, .. } => sock.get_ref(), + TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), + } + } + + pub fn is_negotiating(&self) -> bool { + match *self { + TlsStream::Live(_) => false, + TlsStream::Handshake { + sock: _, + negotiating, + } => negotiating, + TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), + } + } + + pub fn clear_negotiating(&mut self) -> Result<()> { + match *self { + TlsStream::Live(_) => Err(Error::new( + Kind::Internal, + "Attempted to clear negotiating flag on live ssl connection.", + )), + TlsStream::Handshake { + sock: _, + ref mut negotiating, + } => Ok(*negotiating = false), + TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), + } + } + + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + match *self { + TlsStream::Live(ref sock) => sock.get_ref().peer_addr(), + TlsStream::Handshake { ref sock, .. } => sock.get_ref().peer_addr(), + TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), + } + } + + pub fn local_addr(&self) -> io::Result<SocketAddr> { + match *self { + TlsStream::Live(ref sock) => sock.get_ref().local_addr(), + TlsStream::Handshake { ref sock, .. } => sock.get_ref().local_addr(), + TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), + } + } +} |