diff options
Diffstat (limited to 'third_party/rust/ws/src')
-rw-r--r-- | third_party/rust/ws/src/communication.rs | 249 | ||||
-rw-r--r-- | third_party/rust/ws/src/connection.rs | 1230 | ||||
-rw-r--r-- | third_party/rust/ws/src/deflate/context.rs | 268 | ||||
-rw-r--r-- | third_party/rust/ws/src/deflate/extension.rs | 565 | ||||
-rw-r--r-- | third_party/rust/ws/src/deflate/mod.rs | 9 | ||||
-rw-r--r-- | third_party/rust/ws/src/factory.rs | 188 | ||||
-rw-r--r-- | third_party/rust/ws/src/frame.rs | 495 | ||||
-rw-r--r-- | third_party/rust/ws/src/handler.rs | 423 | ||||
-rw-r--r-- | third_party/rust/ws/src/handshake.rs | 740 | ||||
-rw-r--r-- | third_party/rust/ws/src/io.rs | 985 | ||||
-rw-r--r-- | third_party/rust/ws/src/lib.rs | 391 | ||||
-rw-r--r-- | third_party/rust/ws/src/message.rs | 173 | ||||
-rw-r--r-- | third_party/rust/ws/src/protocol.rs | 227 | ||||
-rw-r--r-- | third_party/rust/ws/src/result.rs | 204 | ||||
-rw-r--r-- | third_party/rust/ws/src/stream.rs | 358 | ||||
-rw-r--r-- | third_party/rust/ws/src/util.rs | 9 |
16 files changed, 6514 insertions, 0 deletions
diff --git a/third_party/rust/ws/src/communication.rs b/third_party/rust/ws/src/communication.rs new file mode 100644 index 0000000000..2b2822f03c --- /dev/null +++ b/third_party/rust/ws/src/communication.rs @@ -0,0 +1,249 @@ +use std::borrow::Cow; +use std::convert::Into; + +use mio; +use mio::Token; +use mio_extras::timer::Timeout; +use url; + +use io::ALL; +use message; +use protocol::CloseCode; +use result::{Error, Result}; +use std::cmp::PartialEq; +use std::hash::{Hash, Hasher}; +use std::fmt; + +#[derive(Debug, Clone)] +pub enum Signal { + Message(message::Message), + Close(CloseCode, Cow<'static, str>), + Ping(Vec<u8>), + Pong(Vec<u8>), + Connect(url::Url), + Shutdown, + Timeout { delay: u64, token: Token }, + Cancel(Timeout), +} + +#[derive(Debug, Clone)] +pub struct Command { + token: Token, + signal: Signal, + connection_id: u32, +} + +impl Command { + pub fn token(&self) -> Token { + self.token + } + + pub fn into_signal(self) -> Signal { + self.signal + } + + pub fn connection_id(&self) -> u32 { + self.connection_id + } +} + +/// A representation of the output of the WebSocket connection. Use this to send messages to the +/// other endpoint. +#[derive(Clone)] +pub struct Sender { + token: Token, + channel: mio::channel::SyncSender<Command>, + connection_id: u32, +} + +impl fmt::Debug for Sender { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, + "Sender {{ token: {:?}, channel: mio::channel::SyncSender<Command>, connection_id: {:?} }}", + self.token, self.connection_id) + } +} + +impl PartialEq for Sender { + fn eq(&self, other: &Sender) -> bool { + self.token == other.token && self.connection_id == other.connection_id + } +} + +impl Eq for Sender { } + +impl Hash for Sender { + fn hash<H: Hasher>(&self, state: &mut H) { + self.connection_id.hash(state); + self.token.hash(state); + } +} + + +impl Sender { + #[doc(hidden)] + #[inline] + pub fn new( + token: Token, + channel: mio::channel::SyncSender<Command>, + connection_id: u32, + ) -> Sender { + Sender { + token, + channel, + connection_id, + } + } + + /// A Token identifying this sender within the WebSocket. + #[inline] + pub fn token(&self) -> Token { + self.token + } + + /// A connection_id identifying this sender within the WebSocket. + #[inline] + pub fn connection_id(&self) -> u32 { + self.connection_id + } + + /// Send a message over the connection. + #[inline] + pub fn send<M>(&self, msg: M) -> Result<()> + where + M: Into<message::Message>, + { + self.channel + .send(Command { + token: self.token, + signal: Signal::Message(msg.into()), + connection_id: self.connection_id, + }) + .map_err(Error::from) + } + + /// Send a message to the endpoints of all connections. + /// + /// Be careful with this method. It does not discriminate between client and server connections. + /// If your WebSocket is only functioning as a server, then usage is simple, this method will + /// send a copy of the message to each connected client. However, if you have a WebSocket that + /// is listening for connections and is also connected to another WebSocket, this method will + /// broadcast a copy of the message to all the clients connected and to that WebSocket server. + #[inline] + pub fn broadcast<M>(&self, msg: M) -> Result<()> + where + M: Into<message::Message>, + { + self.channel + .send(Command { + token: ALL, + signal: Signal::Message(msg.into()), + connection_id: self.connection_id, + }) + .map_err(Error::from) + } + + /// Send a close code to the other endpoint. + #[inline] + pub fn close(&self, code: CloseCode) -> Result<()> { + self.channel + .send(Command { + token: self.token, + signal: Signal::Close(code, "".into()), + connection_id: self.connection_id, + }) + .map_err(Error::from) + } + + /// Send a close code and provide a descriptive reason for closing. + #[inline] + pub fn close_with_reason<S>(&self, code: CloseCode, reason: S) -> Result<()> + where + S: Into<Cow<'static, str>>, + { + self.channel + .send(Command { + token: self.token, + signal: Signal::Close(code, reason.into()), + connection_id: self.connection_id, + }) + .map_err(Error::from) + } + + /// Send a ping to the other endpoint with the given test data. + #[inline] + pub fn ping(&self, data: Vec<u8>) -> Result<()> { + self.channel + .send(Command { + token: self.token, + signal: Signal::Ping(data), + connection_id: self.connection_id, + }) + .map_err(Error::from) + } + + /// Send a pong to the other endpoint responding with the given test data. + #[inline] + pub fn pong(&self, data: Vec<u8>) -> Result<()> { + self.channel + .send(Command { + token: self.token, + signal: Signal::Pong(data), + connection_id: self.connection_id, + }) + .map_err(Error::from) + } + + /// Queue a new connection on this WebSocket to the specified URL. + #[inline] + pub fn connect(&self, url: url::Url) -> Result<()> { + self.channel + .send(Command { + token: self.token, + signal: Signal::Connect(url), + connection_id: self.connection_id, + }) + .map_err(Error::from) + } + + /// Request that all connections terminate and that the WebSocket stop running. + #[inline] + pub fn shutdown(&self) -> Result<()> { + self.channel + .send(Command { + token: self.token, + signal: Signal::Shutdown, + connection_id: self.connection_id, + }) + .map_err(Error::from) + } + + /// Schedule a `token` to be sent to the WebSocket Handler's `on_timeout` method + /// after `ms` milliseconds + #[inline] + pub fn timeout(&self, ms: u64, token: Token) -> Result<()> { + self.channel + .send(Command { + token: self.token, + signal: Signal::Timeout { delay: ms, token }, + connection_id: self.connection_id, + }) + .map_err(Error::from) + } + + /// Queue the cancellation of a previously scheduled timeout. + /// + /// This method is not guaranteed to prevent the timeout from occurring, because it is + /// possible to call this method after a timeout has already occurred. It is still necessary to + /// handle spurious timeouts. + #[inline] + pub fn cancel(&self, timeout: Timeout) -> Result<()> { + self.channel + .send(Command { + token: self.token, + signal: Signal::Cancel(timeout), + connection_id: self.connection_id, + }) + .map_err(Error::from) + } +} diff --git a/third_party/rust/ws/src/connection.rs b/third_party/rust/ws/src/connection.rs new file mode 100644 index 0000000000..b639695e5c --- /dev/null +++ b/third_party/rust/ws/src/connection.rs @@ -0,0 +1,1230 @@ +use std::borrow::Borrow; +use std::collections::VecDeque; +use std::io::{Cursor, Read, Seek, SeekFrom, Write}; +use std::mem::replace; +use std::net::SocketAddr; +use std::str::from_utf8; + +use mio::tcp::TcpStream; +use mio::{Ready, Token}; +use mio_extras::timer::Timeout; +use url; + +#[cfg(feature = "nativetls")] +use native_tls::HandshakeError; +#[cfg(feature = "ssl")] +use openssl::ssl::HandshakeError; + +use frame::Frame; +use handler::Handler; +use handshake::{Handshake, Request, Response}; +use message::Message; +use protocol::{CloseCode, OpCode}; +use result::{Error, Kind, Result}; +use stream::{Stream, TryReadBuf, TryWriteBuf}; + +use self::Endpoint::*; +use self::State::*; + +use super::Settings; + +#[derive(Debug)] +pub enum State { + // Tcp connection accepted, waiting for handshake to complete + Connecting(Cursor<Vec<u8>>, Cursor<Vec<u8>>), + // Ready to send/receive messages + Open, + AwaitingClose, + RespondingClose, + FinishedClose, +} + +/// A little more semantic than a boolean +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Endpoint { + /// Will mask outgoing frames + Client(url::Url), + /// Won't mask outgoing frames + Server, +} + +impl State { + #[inline] + pub fn is_connecting(&self) -> bool { + match *self { + State::Connecting(..) => true, + _ => false, + } + } + + #[allow(dead_code)] + #[inline] + pub fn is_open(&self) -> bool { + match *self { + State::Open => true, + _ => false, + } + } + + #[inline] + pub fn is_closing(&self) -> bool { + match *self { + State::AwaitingClose | State::FinishedClose => true, + _ => false, + } + } +} + +pub struct Connection<H> +where + H: Handler, +{ + token: Token, + socket: Stream, + state: State, + endpoint: Endpoint, + events: Ready, + + fragments: VecDeque<Frame>, + + in_buffer: Cursor<Vec<u8>>, + out_buffer: Cursor<Vec<u8>>, + + handler: H, + + addresses: Vec<SocketAddr>, + + settings: Settings, + connection_id: u32, +} + +impl<H> Connection<H> +where + H: Handler, +{ + pub fn new( + tok: Token, + sock: TcpStream, + handler: H, + settings: Settings, + connection_id: u32, + ) -> Connection<H> { + Connection { + token: tok, + socket: Stream::tcp(sock), + state: Connecting( + Cursor::new(Vec::with_capacity(2048)), + Cursor::new(Vec::with_capacity(2048)), + ), + endpoint: Endpoint::Server, + events: Ready::empty(), + fragments: VecDeque::with_capacity(settings.fragments_capacity), + in_buffer: Cursor::new(Vec::with_capacity(settings.in_buffer_capacity)), + out_buffer: Cursor::new(Vec::with_capacity(settings.out_buffer_capacity)), + handler, + addresses: Vec::new(), + settings, + connection_id, + } + } + + pub fn as_server(&mut self) -> Result<()> { + self.events.insert(Ready::readable()); + Ok(()) + } + + pub fn as_client(&mut self, url: url::Url, addrs: Vec<SocketAddr>) -> Result<()> { + if let Connecting(ref mut req_buf, _) = self.state { + let req = self.handler.build_request(&url)?; + self.addresses = addrs; + self.events.insert(Ready::writable()); + self.endpoint = Endpoint::Client(url); + req.format(req_buf.get_mut()) + } else { + Err(Error::new( + Kind::Internal, + "Tried to set connection to client while not connecting.", + )) + } + } + + #[cfg(any(feature = "ssl", feature = "nativetls"))] + pub fn encrypt(&mut self) -> Result<()> { + let sock = self.socket().try_clone()?; + let ssl_stream = match self.endpoint { + Server => self.handler.upgrade_ssl_server(sock), + Client(ref url) => self.handler.upgrade_ssl_client(sock, url), + }; + + match ssl_stream { + Ok(stream) => { + self.socket = Stream::tls_live(stream); + Ok(()) + } + #[cfg(feature = "ssl")] + Err(Error { + kind: Kind::SslHandshake(handshake_err), + details, + }) => match handshake_err { + HandshakeError::SetupFailure(_) => { + Err(Error::new(Kind::SslHandshake(handshake_err), details)) + } + HandshakeError::Failure(mid) | HandshakeError::WouldBlock(mid) => { + self.socket = Stream::tls(mid); + Ok(()) + } + }, + #[cfg(feature = "nativetls")] + Err(Error { + kind: Kind::SslHandshake(handshake_err), + details, + }) => match handshake_err { + HandshakeError::Failure(_) => { + Err(Error::new(Kind::SslHandshake(handshake_err), details)) + } + HandshakeError::WouldBlock(mid) => { + self.socket = Stream::tls(mid); + Ok(()) + } + }, + Err(e) => Err(e), + } + } + + pub fn token(&self) -> Token { + self.token + } + + pub fn socket(&self) -> &TcpStream { + self.socket.evented() + } + + pub fn connection_id(&self) -> u32 { + self.connection_id + } + + fn peer_addr(&self) -> String { + if let Ok(addr) = self.socket.peer_addr() { + addr.to_string() + } else { + "UNKNOWN".into() + } + } + + // Resetting may be necessary in order to try all possible addresses for a server + #[cfg(any(feature = "ssl", feature = "nativetls"))] + pub fn reset(&mut self) -> Result<()> { + // if self.is_client() { + if let Client(ref url) = self.endpoint { + if let Connecting(ref mut req, ref mut res) = self.state { + req.set_position(0); + res.set_position(0); + self.events.remove(Ready::readable()); + self.events.insert(Ready::writable()); + + if let Some(ref addr) = self.addresses.pop() { + let sock = TcpStream::connect(addr)?; + if self.socket.is_tls() { + let ssl_stream = self.handler.upgrade_ssl_client(sock, url); + match ssl_stream { + Ok(stream) => { + self.socket = Stream::tls_live(stream); + Ok(()) + } + #[cfg(feature = "ssl")] + Err(Error { + kind: Kind::SslHandshake(handshake_err), + details, + }) => match handshake_err { + HandshakeError::SetupFailure(_) => { + Err(Error::new(Kind::SslHandshake(handshake_err), details)) + } + HandshakeError::Failure(mid) | HandshakeError::WouldBlock(mid) => { + self.socket = Stream::tls(mid); + Ok(()) + } + }, + #[cfg(feature = "nativetls")] + Err(Error { + kind: Kind::SslHandshake(handshake_err), + details, + }) => match handshake_err { + HandshakeError::Failure(_) => { + Err(Error::new(Kind::SslHandshake(handshake_err), details)) + } + HandshakeError::WouldBlock(mid) => { + self.socket = Stream::tls(mid); + Ok(()) + } + }, + Err(e) => Err(e), + } + } else { + self.socket = Stream::tcp(sock); + Ok(()) + } + } else { + if self.settings.panic_on_new_connection { + panic!("Unable to connect to server."); + } + Err(Error::new(Kind::Internal, "Exhausted possible addresses.")) + } + } else { + Err(Error::new( + Kind::Internal, + "Unable to reset client connection because it is active.", + )) + } + } else { + Err(Error::new( + Kind::Internal, + "Server connections cannot be reset.", + )) + } + } + + #[cfg(not(any(feature = "ssl", feature = "nativetls")))] + pub fn reset(&mut self) -> Result<()> { + if self.is_client() { + if let Connecting(ref mut req, ref mut res) = self.state { + req.set_position(0); + res.set_position(0); + self.events.remove(Ready::readable()); + self.events.insert(Ready::writable()); + + if let Some(ref addr) = self.addresses.pop() { + let sock = TcpStream::connect(addr)?; + self.socket = Stream::tcp(sock); + Ok(()) + } else { + if self.settings.panic_on_new_connection { + panic!("Unable to connect to server."); + } + Err(Error::new(Kind::Internal, "Exhausted possible addresses.")) + } + } else { + Err(Error::new( + Kind::Internal, + "Unable to reset client connection because it is active.", + )) + } + } else { + Err(Error::new( + Kind::Internal, + "Server connections cannot be reset.", + )) + } + } + + pub fn events(&self) -> Ready { + self.events + } + + pub fn is_client(&self) -> bool { + match self.endpoint { + Client(_) => true, + Server => false, + } + } + + pub fn is_server(&self) -> bool { + match self.endpoint { + Client(_) => false, + Server => true, + } + } + + pub fn shutdown(&mut self) { + self.handler.on_shutdown(); + if let Err(err) = self.send_close(CloseCode::Away, "Shutting down.") { + self.handler.on_error(err); + self.disconnect() + } + } + + #[inline] + pub fn new_timeout(&mut self, event: Token, timeout: Timeout) -> Result<()> { + self.handler.on_new_timeout(event, timeout) + } + + #[inline] + pub fn timeout_triggered(&mut self, event: Token) -> Result<()> { + self.handler.on_timeout(event) + } + + pub fn error(&mut self, err: Error) { + match self.state { + Connecting(_, ref mut res) => match err.kind { + #[cfg(feature = "ssl")] + Kind::Ssl(_) => { + self.handler.on_error(err); + self.events = Ready::empty(); + } + Kind::Io(_) => { + self.handler.on_error(err); + self.events = Ready::empty(); + } + Kind::Protocol => { + let msg = err.to_string(); + self.handler.on_error(err); + if let Server = self.endpoint { + res.get_mut().clear(); + if let Err(err) = + write!(res.get_mut(), "HTTP/1.1 400 Bad Request\r\n\r\n{}", msg) + { + self.handler.on_error(Error::from(err)); + self.events = Ready::empty(); + } else { + self.events.remove(Ready::readable()); + self.events.insert(Ready::writable()); + } + } else { + self.events = Ready::empty(); + } + } + _ => { + let msg = err.to_string(); + self.handler.on_error(err); + if let Server = self.endpoint { + res.get_mut().clear(); + if let Err(err) = write!( + res.get_mut(), + "HTTP/1.1 500 Internal Server Error\r\n\r\n{}", + msg + ) { + self.handler.on_error(Error::from(err)); + self.events = Ready::empty(); + } else { + self.events.remove(Ready::readable()); + self.events.insert(Ready::writable()); + } + } else { + self.events = Ready::empty(); + } + } + }, + _ => { + match err.kind { + Kind::Internal => { + if self.settings.panic_on_internal { + panic!("Panicking on internal error -- {}", err); + } + let reason = format!("{}", err); + + self.handler.on_error(err); + if let Err(err) = self.send_close(CloseCode::Error, reason) { + self.handler.on_error(err); + self.disconnect() + } + } + Kind::Capacity => { + if self.settings.panic_on_capacity { + panic!("Panicking on capacity error -- {}", err); + } + let reason = format!("{}", err); + + self.handler.on_error(err); + if let Err(err) = self.send_close(CloseCode::Size, reason) { + self.handler.on_error(err); + self.disconnect() + } + } + Kind::Protocol => { + if self.settings.panic_on_protocol { + panic!("Panicking on protocol error -- {}", err); + } + let reason = format!("{}", err); + + self.handler.on_error(err); + if let Err(err) = self.send_close(CloseCode::Protocol, reason) { + self.handler.on_error(err); + self.disconnect() + } + } + Kind::Encoding(_) => { + if self.settings.panic_on_encoding { + panic!("Panicking on encoding error -- {}", err); + } + let reason = format!("{}", err); + + self.handler.on_error(err); + if let Err(err) = self.send_close(CloseCode::Invalid, reason) { + self.handler.on_error(err); + self.disconnect() + } + } + Kind::Http(_) => { + // This may happen if some handler writes a bad response + self.handler.on_error(err); + error!("Disconnecting WebSocket."); + self.disconnect() + } + Kind::Custom(_) => { + self.handler.on_error(err); + } + Kind::Queue(_) => { + if self.settings.panic_on_queue { + panic!("Panicking on queue error -- {}", err); + } + self.handler.on_error(err); + } + _ => { + if self.settings.panic_on_io { + panic!("Panicking on io error -- {}", err); + } + self.handler.on_error(err); + self.disconnect() + } + } + } + } + } + + pub fn disconnect(&mut self) { + match self.state { + RespondingClose | FinishedClose | Connecting(_, _) => (), + _ => { + self.handler.on_close(CloseCode::Abnormal, ""); + } + } + self.events = Ready::empty() + } + + pub fn consume(self) -> H { + self.handler + } + + fn write_handshake(&mut self) -> Result<()> { + if let Connecting(ref mut req, ref mut res) = self.state { + match self.endpoint { + Server => { + let mut done = false; + if self.socket.try_write_buf(res)?.is_some() { + if res.position() as usize == res.get_ref().len() { + done = true + } + } + if !done { + return Ok(()); + } + } + Client(_) => { + if self.socket.try_write_buf(req)?.is_some() { + if req.position() as usize == req.get_ref().len() { + trace!( + "Finished writing handshake request to {}", + self.socket + .peer_addr() + .map(|addr| addr.to_string()) + .unwrap_or_else(|_| "UNKNOWN".into()) + ); + self.events.insert(Ready::readable()); + self.events.remove(Ready::writable()); + } + } + return Ok(()); + } + } + } + + if let Connecting(ref req, ref res) = replace(&mut self.state, Open) { + trace!( + "Finished writing handshake response to {}", + self.peer_addr() + ); + + let request = match Request::parse(req.get_ref()) { + Ok(Some(req)) => req, + _ => { + // An error should already have been sent for the first time it failed to + // parse. We don't call disconnect here because `on_open` hasn't been called yet. + self.state = FinishedClose; + self.events = Ready::empty(); + return Ok(()); + } + }; + + let response = Response::parse(res.get_ref())?.ok_or_else(|| { + Error::new( + Kind::Internal, + "Failed to parse response after handshake is complete.", + ) + })?; + + if response.status() != 101 { + self.events = Ready::empty(); + return Ok(()); + } else { + self.handler.on_open(Handshake { + request, + response, + peer_addr: self.socket.peer_addr().ok(), + local_addr: self.socket.local_addr().ok(), + })?; + debug!("Connection to {} is now open.", self.peer_addr()); + self.events.insert(Ready::readable()); + self.check_events(); + return Ok(()); + } + } else { + Err(Error::new( + Kind::Internal, + "Tried to write WebSocket handshake while not in connecting state!", + )) + } + } + + fn read_handshake(&mut self) -> Result<()> { + if let Connecting(ref mut req, ref mut res) = self.state { + match self.endpoint { + Server => { + if let Some(read) = self.socket.try_read_buf(req.get_mut())? { + if read == 0 { + self.events = Ready::empty(); + return Ok(()); + } + if let Some(ref request) = Request::parse(req.get_ref())? { + trace!("Handshake request received: \n{}", request); + let response = self.handler.on_request(request)?; + response.format(res.get_mut())?; + self.events.remove(Ready::readable()); + self.events.insert(Ready::writable()); + } + } + return Ok(()); + } + Client(_) => { + if self.socket.try_read_buf(res.get_mut())?.is_some() { + // TODO: see if this can be optimized with drain + let end = { + let data = res.get_ref(); + let end = data.iter() + .enumerate() + .take_while(|&(ind, _)| !data[..ind].ends_with(b"\r\n\r\n")) + .count(); + if !data[..end].ends_with(b"\r\n\r\n") { + return Ok(()); + } + self.in_buffer.get_mut().extend(&data[end..]); + end + }; + res.get_mut().truncate(end); + } else { + // NOTE: wait to be polled again; response not ready. + return Ok(()); + } + } + } + } + + if let Connecting(ref req, ref res) = replace(&mut self.state, Open) { + trace!( + "Finished reading handshake response from {}", + self.peer_addr() + ); + + let request = Request::parse(req.get_ref())?.ok_or_else(|| { + Error::new( + Kind::Internal, + "Failed to parse request after handshake is complete.", + ) + })?; + + let response = Response::parse(res.get_ref())?.ok_or_else(|| { + Error::new( + Kind::Internal, + "Failed to parse response after handshake is complete.", + ) + })?; + + trace!("Handshake response received: \n{}", response); + + if response.status() != 101 { + if response.status() != 301 && response.status() != 302 { + return Err(Error::new(Kind::Protocol, "Handshake failed.")); + } else { + return Ok(()); + } + } + + if self.settings.key_strict { + let req_key = request.hashed_key()?; + let res_key = from_utf8(response.key()?)?; + if req_key != res_key { + return Err(Error::new( + Kind::Protocol, + format!( + "Received incorrect WebSocket Accept key: {} vs {}", + req_key, res_key + ), + )); + } + } + + self.handler.on_response(&response)?; + self.handler.on_open(Handshake { + request, + response, + peer_addr: self.socket.peer_addr().ok(), + local_addr: self.socket.local_addr().ok(), + })?; + + // check to see if there is anything to read already + if !self.in_buffer.get_ref().is_empty() { + self.read_frames()?; + } + + self.check_events(); + return Ok(()); + } + Err(Error::new( + Kind::Internal, + "Tried to read WebSocket handshake while not in connecting state!", + )) + } + + pub fn read(&mut self) -> Result<()> { + if self.socket.is_negotiating() { + trace!("Performing TLS negotiation on {}.", self.peer_addr()); + self.socket.clear_negotiating()?; + self.write() + } else { + let res = if self.state.is_connecting() { + trace!("Ready to read handshake from {}.", self.peer_addr()); + self.read_handshake() + } else { + trace!("Ready to read messages from {}.", self.peer_addr()); + while let Some(len) = self.buffer_in()? { + self.read_frames()?; + if len == 0 { + if self.events.is_writable() { + self.events.remove(Ready::readable()); + } else { + self.disconnect() + } + break; + } + } + Ok(()) + }; + + if self.socket.is_negotiating() && res.is_ok() { + self.events.remove(Ready::readable()); + self.events.insert(Ready::writable()); + } + res + } + } + + fn read_frames(&mut self) -> Result<()> { + let max_size = self.settings.max_fragment_size as u64; + while let Some(mut frame) = Frame::parse(&mut self.in_buffer, max_size)? { + match self.state { + // Ignore data received after receiving close frame + RespondingClose | FinishedClose => continue, + _ => (), + } + + if self.settings.masking_strict { + if frame.is_masked() { + if self.is_client() { + return Err(Error::new( + Kind::Protocol, + "Received masked frame from a server endpoint.", + )); + } + } else { + if self.is_server() { + return Err(Error::new( + Kind::Protocol, + "Received unmasked frame from a client endpoint.", + )); + } + } + } + + // This is safe whether or not a frame is masked. + frame.remove_mask(); + + if let Some(frame) = self.handler.on_frame(frame)? { + if frame.is_final() { + match frame.opcode() { + // singleton data frames + OpCode::Text => { + trace!("Received text frame {:?}", frame); + // since we are going to handle this, there can't be an ongoing + // message + if !self.fragments.is_empty() { + return Err(Error::new(Kind::Protocol, "Received unfragmented text frame while processing fragmented message.")); + } + let msg = Message::text(String::from_utf8(frame.into_data()) + .map_err(|err| err.utf8_error())?); + self.handler.on_message(msg)?; + } + OpCode::Binary => { + trace!("Received binary frame {:?}", frame); + // since we are going to handle this, there can't be an ongoing + // message + if !self.fragments.is_empty() { + return Err(Error::new(Kind::Protocol, "Received unfragmented binary frame while processing fragmented message.")); + } + let data = frame.into_data(); + self.handler.on_message(Message::binary(data))?; + } + // control frames + OpCode::Close => { + trace!("Received close frame {:?}", frame); + // Closing handshake + if self.state.is_closing() { + if self.is_server() { + // Finished handshake, disconnect server side + self.events = Ready::empty() + } else { + // We are a client, so we wait for the server to close the + // connection + } + } else { + // Starting handshake, will send the responding close frame + self.state = RespondingClose; + } + + let mut close_code = [0u8; 2]; + let mut data = Cursor::new(frame.into_data()); + if let 2 = data.read(&mut close_code)? { + let raw_code: u16 = + (u16::from(close_code[0]) << 8) | (u16::from(close_code[1])); + trace!( + "Connection to {} received raw close code: {:?}, {:?}", + self.peer_addr(), + raw_code, + close_code + ); + let named = CloseCode::from(raw_code); + if let CloseCode::Other(code) = named { + if code < 1000 || + code >= 5000 || + code == 1004 || + code == 1014 || + code == 1016 || // these below are here to pass the autobahn test suite + code == 1100 || // we shouldn't need them later + code == 2000 + || code == 2999 + { + return Err(Error::new( + Kind::Protocol, + format!( + "Received invalid close code from endpoint: {}", + code + ), + )); + } + } + let has_reason = { + if let Ok(reason) = from_utf8(&data.get_ref()[2..]) { + self.handler.on_close(named, reason); // note reason may be an empty string + true + } else { + self.handler.on_close(named, ""); + false + } + }; + + if let CloseCode::Abnormal = named { + return Err(Error::new( + Kind::Protocol, + "Received abnormal close code from endpoint.", + )); + } else if let CloseCode::Status = named { + return Err(Error::new( + Kind::Protocol, + "Received no status close code from endpoint.", + )); + } else if let CloseCode::Restart = named { + return Err(Error::new( + Kind::Protocol, + "Restart close code is not supported.", + )); + } else if let CloseCode::Again = named { + return Err(Error::new( + Kind::Protocol, + "Try again later close code is not supported.", + )); + } else if let CloseCode::Tls = named { + return Err(Error::new( + Kind::Protocol, + "Received TLS close code outside of TLS handshake.", + )); + } else { + if !self.state.is_closing() { + if has_reason { + self.send_close(named, "")?; // note this drops any extra close data + } else { + self.send_close(CloseCode::Invalid, "")?; + } + } else { + self.state = FinishedClose; + } + } + } else { + // This is not an error. It is allowed behavior in the + // protocol, so we don't trigger an error. + // "If there is no such data in the Close control frame, + // _The WebSocket Connection Close Reason_ is the empty string." + self.handler.on_close(CloseCode::Status, ""); + if !self.state.is_closing() { + self.send_close(CloseCode::Empty, "")?; + } else { + self.state = FinishedClose; + } + } + } + OpCode::Ping => { + trace!("Received ping frame {:?}", frame); + self.send_pong(frame.into_data())?; + } + OpCode::Pong => { + trace!("Received pong frame {:?}", frame); + // no ping validation for now + } + // last fragment + OpCode::Continue => { + trace!("Received final fragment {:?}", frame); + if let Some(first) = self.fragments.pop_front() { + let size = self.fragments.iter().fold( + first.payload().len() + frame.payload().len(), + |len, frame| len + frame.payload().len(), + ); + match first.opcode() { + OpCode::Text => { + trace!("Constructing text message from fragments: {:?} -> {:?} -> {:?}", first, self.fragments.iter().collect::<Vec<&Frame>>(), frame); + let mut data = Vec::with_capacity(size); + data.extend(first.into_data()); + while let Some(frame) = self.fragments.pop_front() { + data.extend(frame.into_data()); + } + data.extend(frame.into_data()); + + let string = String::from_utf8(data) + .map_err(|err| err.utf8_error())?; + + trace!( + "Calling handler with constructed message: {:?}", + string + ); + self.handler.on_message(Message::text(string))?; + } + OpCode::Binary => { + trace!("Constructing binary message from fragments: {:?} -> {:?} -> {:?}", first, self.fragments.iter().collect::<Vec<&Frame>>(), frame); + let mut data = Vec::with_capacity(size); + data.extend(first.into_data()); + + while let Some(frame) = self.fragments.pop_front() { + data.extend(frame.into_data()); + } + + data.extend(frame.into_data()); + + trace!( + "Calling handler with constructed message: {:?}", + data + ); + self.handler.on_message(Message::binary(data))?; + } + _ => { + return Err(Error::new( + Kind::Protocol, + "Encounted fragmented control frame.", + )) + } + } + } else { + return Err(Error::new( + Kind::Protocol, + "Unable to reconstruct fragmented message. No first frame.", + )); + } + } + _ => return Err(Error::new(Kind::Protocol, "Encountered invalid opcode.")), + } + } else { + if frame.is_control() { + return Err(Error::new( + Kind::Protocol, + "Encounted fragmented control frame.", + )); + } else { + trace!("Received non-final fragment frame {:?}", frame); + if !self.settings.fragments_grow + && self.settings.fragments_capacity == self.fragments.len() + { + return Err(Error::new(Kind::Capacity, "Exceeded max fragments.")); + } else { + self.fragments.push_back(frame) + } + } + } + } + } + Ok(()) + } + + pub fn write(&mut self) -> Result<()> { + if self.socket.is_negotiating() { + trace!("Performing TLS negotiation on {}.", self.peer_addr()); + self.socket.clear_negotiating()?; + self.read() + } else { + let res = if self.state.is_connecting() { + trace!("Ready to write handshake to {}.", self.peer_addr()); + self.write_handshake() + } else { + trace!("Ready to write messages to {}.", self.peer_addr()); + + // Start out assuming that this write will clear the whole buffer + self.events.remove(Ready::writable()); + + if let Some(len) = self.socket.try_write_buf(&mut self.out_buffer)? { + trace!("Wrote {} bytes to {}", len, self.peer_addr()); + let finished = len == 0 + || self.out_buffer.position() == self.out_buffer.get_ref().len() as u64; + if finished { + match self.state { + // we are are a server that is closing and just wrote out our confirming + // close frame, let's disconnect + FinishedClose if self.is_server() => { + self.events = Ready::empty(); + return Ok(()); + } + _ => (), + } + } + } + + // Check if there is more to write so that the connection will be rescheduled + self.check_events(); + Ok(()) + }; + + if self.socket.is_negotiating() && res.is_ok() { + self.events.remove(Ready::writable()); + self.events.insert(Ready::readable()); + } + res + } + } + + pub fn send_message(&mut self, msg: Message) -> Result<()> { + if self.state.is_closing() { + trace!( + "Connection is closing. Ignoring request to send message {:?} to {}.", + msg, + self.peer_addr() + ); + return Ok(()); + } + + let opcode = msg.opcode(); + trace!("Message opcode {:?}", opcode); + let data = msg.into_data(); + + if let Some(frame) = self.handler + .on_send_frame(Frame::message(data, opcode, true))? + { + if frame.payload().len() > self.settings.fragment_size { + trace!("Chunking at {:?}.", self.settings.fragment_size); + // note this copies the data, so it's actually somewhat expensive to fragment + let mut chunks = frame + .payload() + .chunks(self.settings.fragment_size) + .peekable(); + let chunk = chunks.next().expect("Unable to get initial chunk!"); + + let mut first = Frame::message(Vec::from(chunk), opcode, false); + + // Match reserved bits from original to keep extension status intact + first.set_rsv1(frame.has_rsv1()); + first.set_rsv2(frame.has_rsv2()); + first.set_rsv3(frame.has_rsv3()); + + self.buffer_frame(first)?; + + while let Some(chunk) = chunks.next() { + if chunks.peek().is_some() { + self.buffer_frame(Frame::message( + Vec::from(chunk), + OpCode::Continue, + false, + ))?; + } else { + self.buffer_frame(Frame::message( + Vec::from(chunk), + OpCode::Continue, + true, + ))?; + } + } + } else { + trace!("Sending unfragmented message frame."); + // true means that the message is done + self.buffer_frame(frame)?; + } + } + self.check_events(); + Ok(()) + } + + #[inline] + pub fn send_ping(&mut self, data: Vec<u8>) -> Result<()> { + if self.state.is_closing() { + trace!( + "Connection is closing. Ignoring request to send ping {:?} to {}.", + data, + self.peer_addr() + ); + return Ok(()); + } + trace!("Sending ping to {}.", self.peer_addr()); + + if let Some(frame) = self.handler.on_send_frame(Frame::ping(data))? { + self.buffer_frame(frame)?; + } + self.check_events(); + Ok(()) + } + + #[inline] + pub fn send_pong(&mut self, data: Vec<u8>) -> Result<()> { + if self.state.is_closing() { + trace!( + "Connection is closing. Ignoring request to send pong {:?} to {}.", + data, + self.peer_addr() + ); + return Ok(()); + } + trace!("Sending pong to {}.", self.peer_addr()); + + if let Some(frame) = self.handler.on_send_frame(Frame::pong(data))? { + self.buffer_frame(frame)?; + } + self.check_events(); + Ok(()) + } + + #[inline] + pub fn send_close<R>(&mut self, code: CloseCode, reason: R) -> Result<()> + where + R: Borrow<str>, + { + match self.state { + // We are responding to a close frame the other endpoint, when this frame goes out, we + // are done. + RespondingClose => self.state = FinishedClose, + // Multiple close frames are being sent from our end, ignore the later frames + AwaitingClose | FinishedClose => { + trace!( + "Connection is already closing. Ignoring close {:?} -- {:?} to {}.", + code, + reason.borrow(), + self.peer_addr() + ); + self.check_events(); + return Ok(()); + } + // We are initiating a closing handshake. + Open => self.state = AwaitingClose, + Connecting(_, _) => { + debug_assert!(false, "Attempted to close connection while not yet open.") + } + } + + trace!( + "Sending close {:?} -- {:?} to {}.", + code, + reason.borrow(), + self.peer_addr() + ); + + if let Some(frame) = self.handler + .on_send_frame(Frame::close(code, reason.borrow()))? + { + self.buffer_frame(frame)?; + } + + trace!("Connection to {} is now closing.", self.peer_addr()); + + self.check_events(); + Ok(()) + } + + fn check_events(&mut self) { + if !self.state.is_connecting() { + self.events.insert(Ready::readable()); + if self.out_buffer.position() < self.out_buffer.get_ref().len() as u64 { + self.events.insert(Ready::writable()); + } + } + } + + fn buffer_frame(&mut self, mut frame: Frame) -> Result<()> { + self.check_buffer_out(&frame)?; + + if self.is_client() { + frame.set_mask(); + } + + trace!("Buffering frame to {}:\n{}", self.peer_addr(), frame); + + let pos = self.out_buffer.position(); + self.out_buffer.seek(SeekFrom::End(0))?; + frame.format(&mut self.out_buffer)?; + self.out_buffer.seek(SeekFrom::Start(pos))?; + Ok(()) + } + + fn check_buffer_out(&mut self, frame: &Frame) -> Result<()> { + if self.out_buffer.get_ref().capacity() <= self.out_buffer.get_ref().len() + frame.len() { + // extend + let mut new = Vec::with_capacity(self.out_buffer.get_ref().capacity()); + new.extend(&self.out_buffer.get_ref()[self.out_buffer.position() as usize..]); + if new.len() == new.capacity() { + if self.settings.out_buffer_grow { + new.reserve(self.settings.out_buffer_capacity) + } else { + return Err(Error::new( + Kind::Capacity, + "Maxed out output buffer for connection.", + )); + } + } + self.out_buffer = Cursor::new(new); + } + Ok(()) + } + + fn buffer_in(&mut self) -> Result<Option<usize>> { + trace!("Reading buffer for connection to {}.", self.peer_addr()); + if let Some(len) = self.socket.try_read_buf(self.in_buffer.get_mut())? { + trace!("Buffered {}.", len); + if self.in_buffer.get_ref().len() == self.in_buffer.get_ref().capacity() { + // extend + let mut new = Vec::with_capacity(self.in_buffer.get_ref().capacity()); + new.extend(&self.in_buffer.get_ref()[self.in_buffer.position() as usize..]); + if new.len() == new.capacity() { + if self.settings.in_buffer_grow { + new.reserve(self.settings.in_buffer_capacity); + } else { + return Err(Error::new( + Kind::Capacity, + "Maxed out input buffer for connection.", + )); + } + } + self.in_buffer = Cursor::new(new); + } + Ok(Some(len)) + } else { + Ok(None) + } + } +} diff --git a/third_party/rust/ws/src/deflate/context.rs b/third_party/rust/ws/src/deflate/context.rs new file mode 100644 index 0000000000..2fa2e23056 --- /dev/null +++ b/third_party/rust/ws/src/deflate/context.rs @@ -0,0 +1,268 @@ +use std::mem; +use std::slice; + +use super::ffi; +use super::libc::{c_char, c_int, c_uint}; + +use result::{Error, Kind, Result}; + +const ZLIB_VERSION: &'static str = "1.2.8\0"; + +trait Context { + fn stream(&mut self) -> &mut ffi::z_stream; + + fn stream_apply<F>(&mut self, input: &[u8], output: &mut Vec<u8>, each: F) -> Result<()> + where + F: Fn(&mut ffi::z_stream) -> Option<Result<()>>, + { + debug_assert!(output.len() == 0, "Output vector is not empty."); + + let stream = self.stream(); + + stream.next_in = input.as_ptr() as *mut _; + stream.avail_in = input.len() as c_uint; + + let mut output_size; + + loop { + output_size = output.len(); + + if output_size == output.capacity() { + output.reserve(input.len()) + } + + let out_slice = unsafe { + slice::from_raw_parts_mut( + output.as_mut_ptr().offset(output_size as isize), + output.capacity() - output_size, + ) + }; + + stream.next_out = out_slice.as_mut_ptr(); + stream.avail_out = out_slice.len() as c_uint; + + let before = stream.total_out; + let cont = each(stream); + + unsafe { + output.set_len((stream.total_out - before) as usize + output_size); + } + + if let Some(result) = cont { + return result; + } + } + } +} + +pub struct Compressor { + // Box the z_stream to ensure it isn't moved. Moving the z_stream + // causes zlib to fail, because it maintains internal pointers. + stream: Box<ffi::z_stream>, +} + +impl Compressor { + pub fn new(window_bits: i8) -> Compressor { + debug_assert!(window_bits >= 9, "Received too small window size."); + debug_assert!(window_bits <= 15, "Received too large window size."); + + unsafe { + let mut stream: Box<ffi::z_stream> = Box::new(mem::zeroed()); + let result = ffi::deflateInit2_( + stream.as_mut(), + 9, + ffi::Z_DEFLATED, + -window_bits as c_int, + 9, + ffi::Z_DEFAULT_STRATEGY, + ZLIB_VERSION.as_ptr() as *const c_char, + mem::size_of::<ffi::z_stream>() as c_int, + ); + assert!(result == ffi::Z_OK, "Failed to initialize compresser."); + Compressor { stream: stream } + } + } + + pub fn compress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<()> { + self.stream_apply(input, output, |stream| unsafe { + match ffi::deflate(stream, ffi::Z_SYNC_FLUSH) { + ffi::Z_OK | ffi::Z_BUF_ERROR => { + if stream.avail_in == 0 && stream.avail_out > 0 { + Some(Ok(())) + } else { + None + } + } + code => Some(Err(Error::new( + Kind::Protocol, + format!("Failed to perform compression: {}", code), + ))), + } + }) + } + + pub fn reset(&mut self) -> Result<()> { + match unsafe { ffi::deflateReset(self.stream.as_mut()) } { + ffi::Z_OK => Ok(()), + code => Err(Error::new( + Kind::Protocol, + format!("Failed to reset compression context: {}", code), + )), + } + } +} + +impl Context for Compressor { + fn stream(&mut self) -> &mut ffi::z_stream { + self.stream.as_mut() + } +} + +impl Drop for Compressor { + fn drop(&mut self) { + match unsafe { ffi::deflateEnd(self.stream.as_mut()) } { + ffi::Z_STREAM_ERROR => error!("Compression stream encountered bad state."), + // Ignore discarded data error because we are raw + ffi::Z_OK | ffi::Z_DATA_ERROR => trace!("Deallocated compression context."), + code => error!("Bad zlib status encountered: {}", code), + } + } +} + +pub struct Decompressor { + stream: Box<ffi::z_stream>, +} + +impl Decompressor { + pub fn new(window_bits: i8) -> Decompressor { + debug_assert!(window_bits >= 8, "Received too small window size."); + debug_assert!(window_bits <= 15, "Received too large window size."); + + unsafe { + let mut stream: Box<ffi::z_stream> = Box::new(mem::zeroed()); + let result = ffi::inflateInit2_( + stream.as_mut(), + -window_bits as c_int, + ZLIB_VERSION.as_ptr() as *const c_char, + mem::size_of::<ffi::z_stream>() as c_int, + ); + assert!(result == ffi::Z_OK, "Failed to initialize decompresser."); + Decompressor { stream: stream } + } + } + + pub fn decompress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<()> { + self.stream_apply(input, output, |stream| unsafe { + match ffi::inflate(stream, ffi::Z_SYNC_FLUSH) { + ffi::Z_OK | ffi::Z_BUF_ERROR => { + if stream.avail_in == 0 && stream.avail_out > 0 { + Some(Ok(())) + } else { + None + } + } + code => Some(Err(Error::new( + Kind::Protocol, + format!("Failed to perform decompression: {}", code), + ))), + } + }) + } + + pub fn reset(&mut self) -> Result<()> { + match unsafe { ffi::inflateReset(self.stream.as_mut()) } { + ffi::Z_OK => Ok(()), + code => Err(Error::new( + Kind::Protocol, + format!("Failed to reset compression context: {}", code), + )), + } + } +} + +impl Context for Decompressor { + fn stream(&mut self) -> &mut ffi::z_stream { + self.stream.as_mut() + } +} + +impl Drop for Decompressor { + fn drop(&mut self) { + match unsafe { ffi::inflateEnd(self.stream.as_mut()) } { + ffi::Z_STREAM_ERROR => error!("Decompression stream encountered bad state."), + ffi::Z_OK => trace!("Deallocated decompression context."), + code => error!("Bad zlib status encountered: {}", code), + } + } +} + +mod test { + #![allow(unused_imports, unused_variables, dead_code)] + use super::*; + + fn as_hex(s: &[u8]) { + for byte in s { + print!("0x{:x} ", byte); + } + print!("\n"); + } + + #[test] + fn round_trip() { + for i in 9..16 { + let data = "HI THERE THIS IS some data. これはデータだよ。".as_bytes(); + let mut compressed = Vec::with_capacity(data.len()); + let mut decompressed = Vec::with_capacity(data.len()); + + let com = Compressor::new(i); + let mut moved_com = com; + + moved_com + .compress(&data, &mut compressed) + .expect("Failed to compress data."); + + let dec = Decompressor::new(i); + let mut moved_dec = dec; + + moved_dec + .decompress(&compressed, &mut decompressed) + .expect("Failed to decompress data."); + + assert_eq!(data, &decompressed[..]); + } + } + + #[test] + fn reset() { + let data1 = "HI THERE 直子さん".as_bytes(); + let data2 = "HI THERE 人太郎".as_bytes(); + let mut compressed1 = Vec::with_capacity(data1.len()); + let mut compressed2 = Vec::with_capacity(data2.len()); + let mut compressed2_ind = Vec::with_capacity(data2.len()); + + let mut decompressed1 = Vec::with_capacity(data1.len()); + let mut decompressed2 = Vec::with_capacity(data2.len()); + let mut decompressed2_ind = Vec::with_capacity(data2.len()); + + let mut com = Compressor::new(9); + + com.compress(&data1, &mut compressed1).unwrap(); + com.compress(&data2, &mut compressed2).unwrap(); + com.reset().unwrap(); + com.compress(&data2, &mut compressed2_ind).unwrap(); + + let mut dec = Decompressor::new(9); + + dec.decompress(&compressed1, &mut decompressed1).unwrap(); + dec.decompress(&compressed2, &mut decompressed2).unwrap(); + dec.reset().unwrap(); + dec.decompress(&compressed2_ind, &mut decompressed2_ind) + .unwrap(); + + assert_eq!(data1, &decompressed1[..]); + assert_eq!(data2, &decompressed2[..]); + assert_eq!(data2, &decompressed2_ind[..]); + assert!(compressed2 != compressed2_ind); + assert!(compressed2.len() < compressed2_ind.len()); + } +} diff --git a/third_party/rust/ws/src/deflate/extension.rs b/third_party/rust/ws/src/deflate/extension.rs new file mode 100644 index 0000000000..712e11fb8e --- /dev/null +++ b/third_party/rust/ws/src/deflate/extension.rs @@ -0,0 +1,565 @@ +use std::mem::replace; + +#[cfg(feature = "ssl")] +use openssl::ssl::SslStream; +#[cfg(feature = "nativetls")] +use native_tls::TlsStream as SslStream; +use url; + +use frame::Frame; +use handler::Handler; +use handshake::{Handshake, Request, Response}; +use message::Message; +use protocol::{CloseCode, OpCode}; +use result::{Error, Kind, Result}; +#[cfg(any(feature = "ssl", feature = "nativetls"))] +use util::TcpStream; +use util::{Timeout, Token}; + +use super::context::{Compressor, Decompressor}; + +/// Deflate Extension Handler Settings +#[derive(Debug, Clone, Copy)] +pub struct DeflateSettings { + /// The max size of the sliding window. If the other endpoint selects a smaller size, that size + /// will be used instead. This must be an integer between 9 and 15 inclusive. + /// Default: 15 + pub max_window_bits: u8, + /// Indicates whether to ask the other endpoint to reset the sliding window for each message. + /// Default: false + pub request_no_context_takeover: bool, + /// Indicates whether this endpoint will agree to reset the sliding window for each message it + /// compresses. If this endpoint won't agree to reset the sliding window, then the handshake + /// will fail if this endpoint is a client and the server requests no context takeover. + /// Default: true + pub accept_no_context_takeover: bool, + /// The number of WebSocket frames to store when defragmenting an incoming fragmented + /// compressed message. + /// This setting may be different from the `fragments_capacity` setting of the WebSocket in order to + /// allow for differences between compressed and uncompressed messages. + /// Default: 10 + pub fragments_capacity: usize, + /// Indicates whether the extension handler will reallocate if the `fragments_capacity` is + /// exceeded. If this is not true, a capacity error will be triggered instead. + /// Default: true + pub fragments_grow: bool, +} + +impl Default for DeflateSettings { + fn default() -> DeflateSettings { + DeflateSettings { + max_window_bits: 15, + request_no_context_takeover: false, + accept_no_context_takeover: true, + fragments_capacity: 10, + fragments_grow: true, + } + } +} + +/// Utility for applying the permessage-deflate extension to a handler with particular deflate +/// settings. +#[derive(Debug, Clone, Copy)] +pub struct DeflateBuilder { + settings: DeflateSettings, +} + +impl DeflateBuilder { + /// Create a new DeflateBuilder with the default settings. + pub fn new() -> DeflateBuilder { + DeflateBuilder { + settings: DeflateSettings::default(), + } + } + + /// Configure the DeflateBuilder with the given deflate settings. + pub fn with_settings(&mut self, settings: DeflateSettings) -> &mut DeflateBuilder { + self.settings = settings; + self + } + + /// Wrap another handler in with a deflate handler as configured. + pub fn build<H: Handler>(&self, handler: H) -> DeflateHandler<H> { + DeflateHandler { + com: Compressor::new(self.settings.max_window_bits as i8), + dec: Decompressor::new(self.settings.max_window_bits as i8), + fragments: Vec::with_capacity(self.settings.fragments_capacity), + compress_reset: false, + decompress_reset: false, + pass: false, + settings: self.settings, + inner: handler, + } + } +} + +/// A WebSocket handler that implements the permessage-deflate extension. +/// +/// This handler wraps a child handler and proxies all handler methods to it. The handler will +/// decompress incoming WebSocket message frames in their reserved bits match the +/// permessage-deflate specification and pass them to the child handler. Message frames sent from +/// the child handler will be compressed and sent to the other endpoint using deflate compression. +pub struct DeflateHandler<H: Handler> { + com: Compressor, + dec: Decompressor, + fragments: Vec<Frame>, + compress_reset: bool, + decompress_reset: bool, + pass: bool, + settings: DeflateSettings, + inner: H, +} + +impl<H: Handler> DeflateHandler<H> { + /// Wrap a child handler to provide the permessage-deflate extension. + pub fn new(handler: H) -> DeflateHandler<H> { + trace!("Using permessage-deflate handler."); + let settings = DeflateSettings::default(); + DeflateHandler { + com: Compressor::new(settings.max_window_bits as i8), + dec: Decompressor::new(settings.max_window_bits as i8), + fragments: Vec::with_capacity(settings.fragments_capacity), + compress_reset: false, + decompress_reset: false, + pass: false, + settings: settings, + inner: handler, + } + } + + #[doc(hidden)] + #[inline] + fn decline(&mut self, mut res: Response) -> Result<Response> { + trace!("Declined permessage-deflate offer"); + self.pass = true; + res.remove_extension("permessage-deflate"); + Ok(res) + } +} + +impl<H: Handler> Handler for DeflateHandler<H> { + fn build_request(&mut self, url: &url::Url) -> Result<Request> { + let mut req = self.inner.build_request(url)?; + let mut req_ext = String::with_capacity(100); + req_ext.push_str("permessage-deflate"); + if self.settings.max_window_bits < 15 { + req_ext.push_str(&format!( + "; client_max_window_bits={}; server_max_window_bits={}", + self.settings.max_window_bits, self.settings.max_window_bits + )) + } else { + req_ext.push_str("; client_max_window_bits") + } + if self.settings.request_no_context_takeover { + req_ext.push_str("; server_no_context_takeover") + } + req.add_extension(&req_ext); + Ok(req) + } + + fn on_request(&mut self, req: &Request) -> Result<Response> { + let mut res = self.inner.on_request(req)?; + + 'ext: for req_ext in req.extensions()? + .iter() + .filter(|&&ext| ext.contains("permessage-deflate")) + { + let mut res_ext = String::with_capacity(req_ext.len()); + let mut s_takeover = false; + let mut c_takeover = false; + let mut s_max = false; + let mut c_max = false; + + for param in req_ext.split(';') { + match param.trim() { + "permessage-deflate" => res_ext.push_str("permessage-deflate"), + "server_no_context_takeover" => { + if s_takeover { + return self.decline(res); + } else { + s_takeover = true; + if self.settings.accept_no_context_takeover { + self.compress_reset = true; + res_ext.push_str("; server_no_context_takeover"); + } else { + continue 'ext; + } + } + } + "client_no_context_takeover" => { + if c_takeover { + return self.decline(res); + } else { + c_takeover = true; + self.decompress_reset = true; + res_ext.push_str("; client_no_context_takeover"); + } + } + param if param.starts_with("server_max_window_bits") => { + if s_max { + return self.decline(res); + } else { + s_max = true; + let mut param_iter = param.split('='); + param_iter.next(); // we already know the name + if let Some(window_bits_str) = param_iter.next() { + if let Ok(window_bits) = window_bits_str.trim().parse() { + if window_bits >= 9 && window_bits <= 15 { + if window_bits < self.settings.max_window_bits as i8 { + self.com = Compressor::new(window_bits); + res_ext.push_str("; "); + res_ext.push_str(param) + } + } else { + return self.decline(res); + } + } else { + return self.decline(res); + } + } + } + } + param if param.starts_with("client_max_window_bits") => { + if c_max { + return self.decline(res); + } else { + c_max = true; + let mut param_iter = param.split('='); + param_iter.next(); // we already know the name + if let Some(window_bits_str) = param_iter.next() { + if let Ok(window_bits) = window_bits_str.trim().parse() { + if window_bits >= 9 && window_bits <= 15 { + if window_bits < self.settings.max_window_bits as i8 { + self.dec = Decompressor::new(window_bits); + res_ext.push_str("; "); + res_ext.push_str(param); + continue; + } + } else { + return self.decline(res); + } + } else { + return self.decline(res); + } + } + res_ext.push_str("; "); + res_ext.push_str(&format!( + "client_max_window_bits={}", + self.settings.max_window_bits + )) + } + } + _ => { + // decline all extension offers because we got a bad parameter + return self.decline(res); + } + } + } + + if !res_ext.contains("client_no_context_takeover") + && self.settings.request_no_context_takeover + { + self.decompress_reset = true; + res_ext.push_str("; client_no_context_takeover"); + } + + if !res_ext.contains("server_max_window_bits") { + res_ext.push_str("; "); + res_ext.push_str(&format!( + "server_max_window_bits={}", + self.settings.max_window_bits + )) + } + + if !res_ext.contains("client_max_window_bits") && self.settings.max_window_bits < 15 { + continue; + } + + res.add_extension(&res_ext); + return Ok(res); + } + self.decline(res) + } + + fn on_response(&mut self, res: &Response) -> Result<()> { + if let Some(res_ext) = res.extensions()? + .iter() + .find(|&&ext| ext.contains("permessage-deflate")) + { + let mut name = false; + let mut s_takeover = false; + let mut c_takeover = false; + let mut s_max = false; + let mut c_max = false; + + for param in res_ext.split(';') { + match param.trim() { + "permessage-deflate" => { + if name { + return Err(Error::new( + Kind::Protocol, + format!("Duplicate extension name permessage-deflate"), + )); + } else { + name = true; + } + } + "server_no_context_takeover" => { + if s_takeover { + return Err(Error::new( + Kind::Protocol, + format!("Duplicate extension parameter server_no_context_takeover"), + )); + } else { + s_takeover = true; + self.decompress_reset = true; + } + } + "client_no_context_takeover" => { + if c_takeover { + return Err(Error::new( + Kind::Protocol, + format!("Duplicate extension parameter client_no_context_takeover"), + )); + } else { + c_takeover = true; + if self.settings.accept_no_context_takeover { + self.compress_reset = true; + } else { + return Err(Error::new( + Kind::Protocol, + format!("The client requires context takeover."), + )); + } + } + } + param if param.starts_with("server_max_window_bits") => { + if s_max { + return Err(Error::new( + Kind::Protocol, + format!("Duplicate extension parameter server_max_window_bits"), + )); + } else { + s_max = true; + let mut param_iter = param.split('='); + param_iter.next(); // we already know the name + if let Some(window_bits_str) = param_iter.next() { + if let Ok(window_bits) = window_bits_str.trim().parse() { + if window_bits >= 9 && window_bits <= 15 { + if window_bits as u8 != self.settings.max_window_bits { + self.dec = Decompressor::new(window_bits); + } + } else { + return Err(Error::new( + Kind::Protocol, + format!( + "Invalid server_max_window_bits parameter: {}", + window_bits + ), + )); + } + } else { + return Err(Error::new( + Kind::Protocol, + format!( + "Invalid server_max_window_bits parameter: {}", + window_bits_str + ), + )); + } + } + } + } + param if param.starts_with("client_max_window_bits") => { + if c_max { + return Err(Error::new( + Kind::Protocol, + format!("Duplicate extension parameter client_max_window_bits"), + )); + } else { + c_max = true; + let mut param_iter = param.split('='); + param_iter.next(); // we already know the name + if let Some(window_bits_str) = param_iter.next() { + if let Ok(window_bits) = window_bits_str.trim().parse() { + if window_bits >= 9 && window_bits <= 15 { + if window_bits as u8 != self.settings.max_window_bits { + self.com = Compressor::new(window_bits); + } + } else { + return Err(Error::new( + Kind::Protocol, + format!( + "Invalid client_max_window_bits parameter: {}", + window_bits + ), + )); + } + } else { + return Err(Error::new( + Kind::Protocol, + format!( + "Invalid client_max_window_bits parameter: {}", + window_bits_str + ), + )); + } + } + } + } + param => { + // fail the connection because we got a bad parameter + return Err(Error::new( + Kind::Protocol, + format!("Bad extension parameter: {}", param), + )); + } + } + } + } else { + self.pass = true + } + + Ok(()) + } + + fn on_frame(&mut self, mut frame: Frame) -> Result<Option<Frame>> { + if !self.pass && !frame.is_control() { + if !self.fragments.is_empty() || frame.has_rsv1() { + frame.set_rsv1(false); + + if !frame.is_final() { + self.fragments.push(frame); + return Ok(None); + } else { + if frame.opcode() == OpCode::Continue { + if self.fragments.is_empty() { + return Err(Error::new( + Kind::Protocol, + "Unable to reconstruct fragmented message. No first frame.", + )); + } else { + if !self.settings.fragments_grow + && self.settings.fragments_capacity == self.fragments.len() + { + return Err(Error::new(Kind::Capacity, "Exceeded max fragments.")); + } else { + self.fragments.push(frame); + } + + // it's safe to unwrap because of the above check for empty + let opcode = self.fragments.first().unwrap().opcode(); + let size = self.fragments + .iter() + .fold(0, |len, frame| len + frame.payload().len()); + let mut compressed = Vec::with_capacity(size); + let mut decompressed = Vec::with_capacity(size * 2); + for frag in replace( + &mut self.fragments, + Vec::with_capacity(self.settings.fragments_capacity), + ) { + compressed.extend(frag.into_data()) + } + + compressed.extend(&[0, 0, 255, 255]); + self.dec.decompress(&compressed, &mut decompressed)?; + frame = Frame::message(decompressed, opcode, true); + } + } else { + let mut decompressed = Vec::with_capacity(frame.payload().len() * 2); + frame.payload_mut().extend(&[0, 0, 255, 255]); + + self.dec.decompress(frame.payload(), &mut decompressed)?; + + *frame.payload_mut() = decompressed; + } + + if self.decompress_reset { + self.dec.reset()? + } + } + } + } + self.inner.on_frame(frame) + } + + fn on_send_frame(&mut self, frame: Frame) -> Result<Option<Frame>> { + if let Some(mut frame) = self.inner.on_send_frame(frame)? { + if !self.pass && !frame.is_control() { + debug_assert!( + frame.is_final(), + "Received non-final frame from upstream handler!" + ); + debug_assert!( + frame.opcode() != OpCode::Continue, + "Received continue frame from upstream handler!" + ); + + frame.set_rsv1(true); + let mut compressed = Vec::with_capacity(frame.payload().len()); + self.com.compress(frame.payload(), &mut compressed)?; + let len = compressed.len(); + compressed.truncate(len - 4); + *frame.payload_mut() = compressed; + + if self.compress_reset { + self.com.reset()? + } + } + Ok(Some(frame)) + } else { + Ok(None) + } + } + + #[inline] + fn on_shutdown(&mut self) { + self.inner.on_shutdown() + } + + #[inline] + fn on_open(&mut self, shake: Handshake) -> Result<()> { + self.inner.on_open(shake) + } + + #[inline] + fn on_message(&mut self, msg: Message) -> Result<()> { + self.inner.on_message(msg) + } + + #[inline] + fn on_close(&mut self, code: CloseCode, reason: &str) { + self.inner.on_close(code, reason) + } + + #[inline] + fn on_error(&mut self, err: Error) { + self.inner.on_error(err) + } + + #[inline] + fn on_timeout(&mut self, event: Token) -> Result<()> { + self.inner.on_timeout(event) + } + + #[inline] + fn on_new_timeout(&mut self, tok: Token, timeout: Timeout) -> Result<()> { + self.inner.on_new_timeout(tok, timeout) + } + + #[inline] + #[cfg(any(feature = "ssl", feature = "nativetls"))] + fn upgrade_ssl_client( + &mut self, + stream: TcpStream, + url: &url::Url, + ) -> Result<SslStream<TcpStream>> { + self.inner.upgrade_ssl_client(stream, url) + } + + #[inline] + #[cfg(any(feature = "ssl", feature = "nativetls"))] + fn upgrade_ssl_server(&mut self, stream: TcpStream) -> Result<SslStream<TcpStream>> { + self.inner.upgrade_ssl_server(stream) + } +} diff --git a/third_party/rust/ws/src/deflate/mod.rs b/third_party/rust/ws/src/deflate/mod.rs new file mode 100644 index 0000000000..8d79012e73 --- /dev/null +++ b/third_party/rust/ws/src/deflate/mod.rs @@ -0,0 +1,9 @@ +//! The deflate module provides tools for applying the permessage-deflate extension. + +extern crate libc; +extern crate libz_sys as ffi; + +mod context; +mod extension; + +pub use self::extension::{DeflateBuilder, DeflateHandler, DeflateSettings}; diff --git a/third_party/rust/ws/src/factory.rs b/third_party/rust/ws/src/factory.rs new file mode 100644 index 0000000000..048ac78110 --- /dev/null +++ b/third_party/rust/ws/src/factory.rs @@ -0,0 +1,188 @@ +use communication::Sender; +use handler::Handler; + +/// A trait for creating new WebSocket handlers. +pub trait Factory { + type Handler: Handler; + + /// Called when a TCP connection is made. + fn connection_made(&mut self, _: Sender) -> Self::Handler; + + /// Called when the WebSocket is shutting down. + #[inline] + fn on_shutdown(&mut self) { + debug!("Factory received WebSocket shutdown request."); + } + + /// Called when a new connection is established for a client endpoint. + /// This method can be used to differentiate a client aspect for a handler. + /// + /// ``` + /// use ws::{Sender, Factory, Handler}; + /// + /// struct MyHandler { + /// ws: Sender, + /// is_client: bool, + /// } + /// + /// impl Handler for MyHandler {} + /// + /// struct MyFactory; + /// + /// impl Factory for MyFactory { + /// type Handler = MyHandler; + /// + /// fn connection_made(&mut self, ws: Sender) -> MyHandler { + /// MyHandler { + /// ws: ws, + /// // default to server + /// is_client: false, + /// } + /// } + /// + /// fn client_connected(&mut self, ws: Sender) -> MyHandler { + /// MyHandler { + /// ws: ws, + /// is_client: true, + /// } + /// } + /// } + /// ``` + #[inline] + fn client_connected(&mut self, ws: Sender) -> Self::Handler { + self.connection_made(ws) + } + + /// Called when a new connection is established for a server endpoint. + /// This method can be used to differentiate a server aspect for a handler. + /// + /// ``` + /// use ws::{Sender, Factory, Handler}; + /// + /// struct MyHandler { + /// ws: Sender, + /// is_server: bool, + /// } + /// + /// impl Handler for MyHandler {} + /// + /// struct MyFactory; + /// + /// impl Factory for MyFactory { + /// type Handler = MyHandler; + /// + /// fn connection_made(&mut self, ws: Sender) -> MyHandler { + /// MyHandler { + /// ws: ws, + /// // default to client + /// is_server: false, + /// } + /// } + /// + /// fn server_connected(&mut self, ws: Sender) -> MyHandler { + /// MyHandler { + /// ws: ws, + /// is_server: true, + /// } + /// } + /// } + #[inline] + fn server_connected(&mut self, ws: Sender) -> Self::Handler { + self.connection_made(ws) + } + + /// Called when a TCP connection is lost with the handler that was + /// setup for that connection. + /// + /// The default implementation is a noop that simply drops the handler. + /// You can use this to track connections being destroyed or to finalize + /// state that was not internally tracked by the handler. + #[inline] + fn connection_lost(&mut self, _: Self::Handler) {} +} + +impl<F, H> Factory for F +where + H: Handler, + F: FnMut(Sender) -> H, +{ + type Handler = H; + + fn connection_made(&mut self, out: Sender) -> H { + self(out) + } +} + +mod test { + #![allow(unused_imports, unused_variables, dead_code)] + use super::*; + use communication::{Command, Sender}; + use frame; + use handler::Handler; + use handshake::{Handshake, Request, Response}; + use message; + use mio; + use protocol::CloseCode; + use result::Result; + + #[derive(Debug, Eq, PartialEq)] + struct M; + impl Handler for M { + fn on_message(&mut self, _: message::Message) -> Result<()> { + println!("test"); + Ok(()) + } + + fn on_frame(&mut self, f: frame::Frame) -> Result<Option<frame::Frame>> { + Ok(None) + } + } + + #[test] + fn impl_factory() { + struct X; + + impl Factory for X { + type Handler = M; + fn connection_made(&mut self, _: Sender) -> M { + M + } + } + + let (chn, _) = mio::channel::sync_channel(42); + + let mut x = X; + let m = x.connection_made(Sender::new(mio::Token(0), chn, 0)); + assert_eq!(m, M); + } + + #[test] + fn closure_factory() { + let (chn, _) = mio::channel::sync_channel(42); + + let mut factory = |_| |_| Ok(()); + + factory.connection_made(Sender::new(mio::Token(0), chn, 0)); + } + + #[test] + fn connection_lost() { + struct X; + + impl Factory for X { + type Handler = M; + fn connection_made(&mut self, _: Sender) -> M { + M + } + fn connection_lost(&mut self, handler: M) { + assert_eq!(handler, M); + } + } + + let (chn, _) = mio::channel::sync_channel(42); + + let mut x = X; + let m = x.connection_made(Sender::new(mio::Token(0), chn, 0)); + x.connection_lost(m); + } +} diff --git a/third_party/rust/ws/src/frame.rs b/third_party/rust/ws/src/frame.rs new file mode 100644 index 0000000000..154816c7ad --- /dev/null +++ b/third_party/rust/ws/src/frame.rs @@ -0,0 +1,495 @@ +use std::default::Default; +use std::fmt; +use std::io::{Cursor, ErrorKind, Read, Write}; + +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use rand; + +use protocol::{CloseCode, OpCode}; +use result::{Error, Kind, Result}; +use stream::TryReadBuf; + +fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) { + let iter = buf.iter_mut().zip(mask.iter().cycle()); + for (byte, &key) in iter { + *byte ^= key + } +} + +/// A struct representing a WebSocket frame. +#[derive(Debug, Clone)] +pub struct Frame { + finished: bool, + rsv1: bool, + rsv2: bool, + rsv3: bool, + opcode: OpCode, + + mask: Option<[u8; 4]>, + + payload: Vec<u8>, +} + +impl Frame { + /// Get the length of the frame. + /// This is the length of the header + the length of the payload. + #[inline] + pub fn len(&self) -> usize { + let mut header_length = 2; + let payload_len = self.payload().len(); + if payload_len > 125 { + if payload_len <= u16::max_value() as usize { + header_length += 2; + } else { + header_length += 8; + } + } + + if self.is_masked() { + header_length += 4; + } + + header_length + payload_len + } + + /// Return `false`: a frame is never empty since it has a header. + #[inline] + pub fn is_empty(&self) -> bool { + false + } + + /// Test whether the frame is a final frame. + #[inline] + pub fn is_final(&self) -> bool { + self.finished + } + + /// Test whether the first reserved bit is set. + #[inline] + pub fn has_rsv1(&self) -> bool { + self.rsv1 + } + + /// Test whether the second reserved bit is set. + #[inline] + pub fn has_rsv2(&self) -> bool { + self.rsv2 + } + + /// Test whether the third reserved bit is set. + #[inline] + pub fn has_rsv3(&self) -> bool { + self.rsv3 + } + + /// Get the OpCode of the frame. + #[inline] + pub fn opcode(&self) -> OpCode { + self.opcode + } + + /// Test whether this is a control frame. + #[inline] + pub fn is_control(&self) -> bool { + self.opcode.is_control() + } + + /// Get a reference to the frame's payload. + #[inline] + pub fn payload(&self) -> &Vec<u8> { + &self.payload + } + + // Test whether the frame is masked. + #[doc(hidden)] + #[inline] + pub fn is_masked(&self) -> bool { + self.mask.is_some() + } + + // Get an optional reference to the frame's mask. + #[doc(hidden)] + #[allow(dead_code)] + #[inline] + pub fn mask(&self) -> Option<&[u8; 4]> { + self.mask.as_ref() + } + + /// Make this frame a final frame. + #[allow(dead_code)] + #[inline] + pub fn set_final(&mut self, is_final: bool) -> &mut Frame { + self.finished = is_final; + self + } + + /// Set the first reserved bit. + #[inline] + pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame { + self.rsv1 = has_rsv1; + self + } + + /// Set the second reserved bit. + #[inline] + pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame { + self.rsv2 = has_rsv2; + self + } + + /// Set the third reserved bit. + #[inline] + pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame { + self.rsv3 = has_rsv3; + self + } + + /// Set the OpCode. + #[allow(dead_code)] + #[inline] + pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame { + self.opcode = opcode; + self + } + + /// Edit the frame's payload. + #[allow(dead_code)] + #[inline] + pub fn payload_mut(&mut self) -> &mut Vec<u8> { + &mut self.payload + } + + // Generate a new mask for this frame. + // + // This method simply generates and stores the mask. It does not change the payload data. + // Instead, the payload data will be masked with the generated mask when the frame is sent + // to the other endpoint. + #[doc(hidden)] + #[inline] + pub fn set_mask(&mut self) -> &mut Frame { + self.mask = Some(rand::random()); + self + } + + // This method unmasks the payload and should only be called on frames that are actually + // masked. In other words, those frames that have just been received from a client endpoint. + #[doc(hidden)] + #[inline] + pub fn remove_mask(&mut self) -> &mut Frame { + self.mask + .take() + .map(|mask| apply_mask(&mut self.payload, &mask)); + self + } + + /// Consume the frame into its payload. + pub fn into_data(self) -> Vec<u8> { + self.payload + } + + /// Create a new data frame. + #[inline] + pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame { + debug_assert!( + match code { + OpCode::Text | OpCode::Binary | OpCode::Continue => true, + _ => false, + }, + "Invalid opcode for data frame." + ); + + Frame { + finished, + opcode: code, + payload: data, + ..Frame::default() + } + } + + /// Create a new Pong control frame. + #[inline] + pub fn pong(data: Vec<u8>) -> Frame { + Frame { + opcode: OpCode::Pong, + payload: data, + ..Frame::default() + } + } + + /// Create a new Ping control frame. + #[inline] + pub fn ping(data: Vec<u8>) -> Frame { + Frame { + opcode: OpCode::Ping, + payload: data, + ..Frame::default() + } + } + + /// Create a new Close control frame. + #[inline] + pub fn close(code: CloseCode, reason: &str) -> Frame { + let payload = if let CloseCode::Empty = code { + Vec::new() + } else { + let u: u16 = code.into(); + let raw = [(u >> 8) as u8, u as u8]; + [&raw, reason.as_bytes()].concat() + }; + + Frame { + payload, + ..Frame::default() + } + } + + /// Parse the input stream into a frame. + pub fn parse(cursor: &mut Cursor<Vec<u8>>, max_payload_length: u64) -> Result<Option<Frame>> { + let size = cursor.get_ref().len() as u64 - cursor.position(); + let initial = cursor.position(); + trace!("Position in buffer {}", initial); + + let mut head = [0u8; 2]; + if cursor.read(&mut head)? != 2 { + cursor.set_position(initial); + return Ok(None); + } + + trace!("Parsed headers {:?}", head); + + let first = head[0]; + let second = head[1]; + trace!("First: {:b}", first); + trace!("Second: {:b}", second); + + let finished = first & 0x80 != 0; + + let rsv1 = first & 0x40 != 0; + let rsv2 = first & 0x20 != 0; + let rsv3 = first & 0x10 != 0; + + let opcode = OpCode::from(first & 0x0F); + trace!("Opcode: {:?}", opcode); + + let masked = second & 0x80 != 0; + trace!("Masked: {:?}", masked); + + let mut header_length = 2; + + let mut length = u64::from(second & 0x7F); + + if let Some(length_nbytes) = match length { + 126 => Some(2), + 127 => Some(8), + _ => None, + } { + match cursor.read_uint::<BigEndian>(length_nbytes) { + Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => { + cursor.set_position(initial); + return Ok(None); + } + Err(err) => { + return Err(Error::from(err)); + } + Ok(read) => { + length = read; + } + }; + header_length += length_nbytes as u64; + } + trace!("Payload length: {}", length); + + if length > max_payload_length { + return Err(Error::new( + Kind::Protocol, + format!( + "Rejected frame with payload length exceeding defined max: {}.", + max_payload_length + ), + )); + } + + let mask = if masked { + let mut mask_bytes = [0u8; 4]; + if cursor.read(&mut mask_bytes)? != 4 { + cursor.set_position(initial); + return Ok(None); + } else { + header_length += 4; + Some(mask_bytes) + } + } else { + None + }; + + match length.checked_add(header_length) { + Some(l) if size < l => { + cursor.set_position(initial); + return Ok(None); + } + Some(_) => (), + None => return Ok(None), + }; + + let mut data = Vec::with_capacity(length as usize); + if length > 0 { + if let Some(read) = cursor.try_read_buf(&mut data)? { + debug_assert!(read == length as usize, "Read incorrect payload length!"); + } + } + + // Disallow bad opcode + if let OpCode::Bad = opcode { + return Err(Error::new( + Kind::Protocol, + format!("Encountered invalid opcode: {}", first & 0x0F), + )); + } + + // control frames must have length <= 125 + match opcode { + OpCode::Ping | OpCode::Pong if length > 125 => { + return Err(Error::new( + Kind::Protocol, + format!( + "Rejected WebSocket handshake.Received control frame with length: {}.", + length + ), + )) + } + OpCode::Close if length > 125 => { + debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); + return Ok(Some(Frame::close( + CloseCode::Protocol, + "Received close frame with payload length exceeding 125.", + ))); + } + _ => (), + } + + let frame = Frame { + finished, + rsv1, + rsv2, + rsv3, + opcode, + mask, + payload: data, + }; + + Ok(Some(frame)) + } + + /// Write a frame out to a buffer + pub fn format<W>(&mut self, w: &mut W) -> Result<()> + where + W: Write, + { + let mut one = 0u8; + let code: u8 = self.opcode.into(); + if self.is_final() { + one |= 0x80; + } + if self.has_rsv1() { + one |= 0x40; + } + if self.has_rsv2() { + one |= 0x20; + } + if self.has_rsv3() { + one |= 0x10; + } + one |= code; + + let mut two = 0u8; + if self.is_masked() { + two |= 0x80; + } + + match self.payload.len() { + len if len < 126 => { + two |= len as u8; + } + len if len <= 65535 => { + two |= 126; + } + _ => { + two |= 127; + } + } + w.write_all(&[one, two])?; + + if let Some(length_bytes) = match self.payload.len() { + len if len < 126 => None, + len if len <= 65535 => Some(2), + _ => Some(8), + } { + w.write_uint::<BigEndian>(self.payload.len() as u64, length_bytes)?; + } + + if self.is_masked() { + let mask = self.mask.take().unwrap(); + apply_mask(&mut self.payload, &mask); + w.write_all(&mask)?; + } + + w.write_all(&self.payload)?; + Ok(()) + } +} + +impl Default for Frame { + fn default() -> Frame { + Frame { + finished: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: OpCode::Close, + mask: None, + payload: Vec::new(), + } + } +} + +impl fmt::Display for Frame { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + " +<FRAME> +final: {} +reserved: {} {} {} +opcode: {} +length: {} +payload length: {} +payload: 0x{} + ", + self.finished, + self.rsv1, + self.rsv2, + self.rsv3, + self.opcode, + // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), + self.len(), + self.payload.len(), + self.payload + .iter() + .map(|byte| format!("{:x}", byte)) + .collect::<String>() + ) + } +} + +mod test { + #![allow(unused_imports, unused_variables, dead_code)] + use super::*; + use protocol::OpCode; + + #[test] + fn display_frame() { + let f = Frame::message("hi there".into(), OpCode::Text, true); + let view = format!("{}", f); + view.contains("payload:"); + } +} diff --git a/third_party/rust/ws/src/handler.rs b/third_party/rust/ws/src/handler.rs new file mode 100644 index 0000000000..7c1ef21b4b --- /dev/null +++ b/third_party/rust/ws/src/handler.rs @@ -0,0 +1,423 @@ +use log::Level::Error as ErrorLevel; +#[cfg(feature = "nativetls")] +use native_tls::{TlsConnector, TlsStream as SslStream}; +#[cfg(feature = "ssl")] +use openssl::ssl::{SslConnector, SslMethod, SslStream}; +use url; + +use frame::Frame; +use handshake::{Handshake, Request, Response}; +use message::Message; +use protocol::CloseCode; +use result::{Error, Kind, Result}; +use util::{Timeout, Token}; + +#[cfg(any(feature = "ssl", feature = "nativetls"))] +use util::TcpStream; + +/// The core trait of this library. +/// Implementing this trait provides the business logic of the WebSocket application. +pub trait Handler { + // general + + /// Called when a request to shutdown all connections has been received. + #[inline] + fn on_shutdown(&mut self) { + debug!("Handler received WebSocket shutdown request."); + } + + // WebSocket events + + /// Called when the WebSocket handshake is successful and the connection is open for sending + /// and receiving messages. + fn on_open(&mut self, shake: Handshake) -> Result<()> { + if let Some(addr) = shake.remote_addr()? { + debug!("Connection with {} now open", addr); + } + Ok(()) + } + + /// Called on incoming messages. + fn on_message(&mut self, msg: Message) -> Result<()> { + debug!("Received message {:?}", msg); + Ok(()) + } + + /// Called any time this endpoint receives a close control frame. + /// This may be because the other endpoint is initiating a closing handshake, + /// or it may be the other endpoint confirming the handshake initiated by this endpoint. + fn on_close(&mut self, code: CloseCode, reason: &str) { + debug!("Connection closing due to ({:?}) {}", code, reason); + } + + /// Called when an error occurs on the WebSocket. + fn on_error(&mut self, err: Error) { + // Ignore connection reset errors by default, but allow library clients to see them by + // overriding this method if they want + if let Kind::Io(ref err) = err.kind { + if let Some(104) = err.raw_os_error() { + return; + } + } + + error!("{:?}", err); + if !log_enabled!(ErrorLevel) { + println!( + "Encountered an error: {}\nEnable a logger to see more information.", + err + ); + } + } + + // handshake events + + /// A method for handling the low-level workings of the request portion of the WebSocket + /// handshake. + /// + /// Implementors should select a WebSocket protocol and extensions where they are supported. + /// + /// Implementors can inspect the Request and must return a Response or an error + /// indicating that the handshake failed. The default implementation provides conformance with + /// the WebSocket protocol, and implementors should use the `Response::from_request` method and + /// then modify the resulting response as necessary in order to maintain conformance. + /// + /// This method will not be called when the handler represents a client endpoint. Use + /// `build_request` to provide an initial handshake request. + /// + /// # Examples + /// + /// ```ignore + /// let mut res = try!(Response::from_request(req)); + /// if try!(req.extensions()).iter().find(|&&ext| ext.contains("myextension-name")).is_some() { + /// res.add_extension("myextension-name") + /// } + /// Ok(res) + /// ``` + #[inline] + fn on_request(&mut self, req: &Request) -> Result<Response> { + debug!("Handler received request:\n{}", req); + Response::from_request(req) + } + + /// A method for handling the low-level workings of the response portion of the WebSocket + /// handshake. + /// + /// Implementors can inspect the Response and choose to fail the connection by + /// returning an error. This method will not be called when the handler represents a server + /// endpoint. The response should indicate which WebSocket protocol and extensions the server + /// has agreed to if any. + #[inline] + fn on_response(&mut self, res: &Response) -> Result<()> { + debug!("Handler received response:\n{}", res); + Ok(()) + } + + // timeout events + + /// Called when a timeout is triggered. + /// + /// This method will be called when the eventloop encounters a timeout on the specified + /// token. To schedule a timeout with your specific token use the `Sender::timeout` method. + /// + /// # Examples + /// + /// ```ignore + /// const GRATI: Token = Token(1); + /// + /// ... Handler + /// + /// fn on_open(&mut self, _: Handshake) -> Result<()> { + /// // schedule a timeout to send a gratuitous pong every 5 seconds + /// self.ws.timeout(5_000, GRATI) + /// } + /// + /// fn on_timeout(&mut self, event: Token) -> Result<()> { + /// if event == GRATI { + /// // send gratuitous pong + /// try!(self.ws.pong(vec![])) + /// // reschedule the timeout + /// self.ws.timeout(5_000, GRATI) + /// } else { + /// Err(Error::new(ErrorKind::Internal, "Invalid timeout token encountered!")) + /// } + /// } + /// ``` + #[inline] + fn on_timeout(&mut self, event: Token) -> Result<()> { + debug!("Handler received timeout token: {:?}", event); + Ok(()) + } + + /// Called when a timeout has been scheduled on the eventloop. + /// + /// This method is the hook for obtaining a Timeout object that may be used to cancel a + /// timeout. This is a noop by default. + /// + /// # Examples + /// + /// ```ignore + /// const PING: Token = Token(1); + /// const EXPIRE: Token = Token(2); + /// + /// ... Handler + /// + /// fn on_open(&mut self, _: Handshake) -> Result<()> { + /// // schedule a timeout to send a ping every 5 seconds + /// try!(self.ws.timeout(5_000, PING)); + /// // schedule a timeout to close the connection if there is no activity for 30 seconds + /// self.ws.timeout(30_000, EXPIRE) + /// } + /// + /// fn on_timeout(&mut self, event: Token) -> Result<()> { + /// match event { + /// PING => { + /// self.ws.ping(vec![]); + /// self.ws.timeout(5_000, PING) + /// } + /// EXPIRE => self.ws.close(CloseCode::Away), + /// _ => Err(Error::new(ErrorKind::Internal, "Invalid timeout token encountered!")), + /// } + /// } + /// + /// fn on_new_timeout(&mut self, event: Token, timeout: Timeout) -> Result<()> { + /// if event == EXPIRE { + /// if let Some(t) = self.timeout.take() { + /// try!(self.ws.cancel(t)) + /// } + /// self.timeout = Some(timeout) + /// } + /// Ok(()) + /// } + /// + /// fn on_frame(&mut self, frame: Frame) -> Result<Option<Frame>> { + /// // some activity has occurred, let's reset the expiration + /// try!(self.ws.timeout(30_000, EXPIRE)); + /// Ok(Some(frame)) + /// } + /// ``` + #[inline] + fn on_new_timeout(&mut self, _: Token, _: Timeout) -> Result<()> { + // default implementation discards the timeout handle + Ok(()) + } + + // frame events + + /// A method for handling incoming frames. + /// + /// This method provides very low-level access to the details of the WebSocket protocol. It may + /// be necessary to implement this method in order to provide a particular extension, but + /// incorrect implementation may cause the other endpoint to fail the connection. + /// + /// Returning `Ok(None)` will cause the connection to forget about a particular frame. This is + /// useful if you want ot filter out a frame or if you don't want any of the default handler + /// methods to run. + /// + /// By default this method simply ensures that no reserved bits are set. + #[inline] + fn on_frame(&mut self, frame: Frame) -> Result<Option<Frame>> { + debug!("Handler received: {}", frame); + // default implementation doesn't allow for reserved bits to be set + if frame.has_rsv1() || frame.has_rsv2() || frame.has_rsv3() { + Err(Error::new( + Kind::Protocol, + "Encountered frame with reserved bits set.", + )) + } else { + Ok(Some(frame)) + } + } + + /// A method for handling outgoing frames. + /// + /// This method provides very low-level access to the details of the WebSocket protocol. It may + /// be necessary to implement this method in order to provide a particular extension, but + /// incorrect implementation may cause the other endpoint to fail the connection. + /// + /// Returning `Ok(None)` will cause the connection to forget about a particular frame, meaning + /// that it will not be sent. You can use this approach to merge multiple frames into a single + /// frame before sending the message. + /// + /// For messages, this method will be called with a single complete, final frame before any + /// fragmentation is performed. Automatic fragmentation will be performed on the returned + /// frame, if any, based on the `fragment_size` setting. + /// + /// By default this method simply ensures that no reserved bits are set. + #[inline] + fn on_send_frame(&mut self, frame: Frame) -> Result<Option<Frame>> { + trace!("Handler will send: {}", frame); + // default implementation doesn't allow for reserved bits to be set + if frame.has_rsv1() || frame.has_rsv2() || frame.has_rsv3() { + Err(Error::new( + Kind::Protocol, + "Encountered frame with reserved bits set.", + )) + } else { + Ok(Some(frame)) + } + } + + // constructors + + /// A method for creating the initial handshake request for WebSocket clients. + /// + /// The default implementation provides conformance with the WebSocket protocol, but this + /// method may be overridden. In order to facilitate conformance, + /// implementors should use the `Request::from_url` method and then modify the resulting + /// request as necessary. + /// + /// Implementors should indicate any available WebSocket extensions here. + /// + /// # Examples + /// ```ignore + /// let mut req = try!(Request::from_url(url)); + /// req.add_extension("permessage-deflate; client_max_window_bits"); + /// Ok(req) + /// ``` + #[inline] + fn build_request(&mut self, url: &url::Url) -> Result<Request> { + trace!("Handler is building request to {}.", url); + Request::from_url(url) + } + + /// A method for wrapping a client TcpStream with Ssl Authentication machinery + /// + /// Override this method to customize how the connection is encrypted. By default + /// this will use the Server Name Indication extension in conformance with RFC6455. + #[inline] + #[cfg(feature = "ssl")] + fn upgrade_ssl_client( + &mut self, + stream: TcpStream, + url: &url::Url, + ) -> Result<SslStream<TcpStream>> { + let domain = url.domain().ok_or(Error::new( + Kind::Protocol, + format!("Unable to parse domain from {}. Needed for SSL.", url), + ))?; + let connector = SslConnector::builder(SslMethod::tls()) + .map_err(|e| { + Error::new( + Kind::Internal, + format!("Failed to upgrade client to SSL: {}", e), + ) + })? + .build(); + connector.connect(domain, stream).map_err(Error::from) + } + + #[inline] + #[cfg(feature = "nativetls")] + fn upgrade_ssl_client( + &mut self, + stream: TcpStream, + url: &url::Url, + ) -> Result<SslStream<TcpStream>> { + let domain = url.domain().ok_or(Error::new( + Kind::Protocol, + format!("Unable to parse domain from {}. Needed for SSL.", url), + ))?; + + let connector = TlsConnector::new().map_err(|e| { + Error::new( + Kind::Internal, + format!("Failed to upgrade client to SSL: {}", e), + ) + })?; + + connector.connect(domain, stream).map_err(Error::from) + } + /// A method for wrapping a server TcpStream with Ssl Authentication machinery + /// + /// Override this method to customize how the connection is encrypted. By default + /// this method is not implemented. + #[inline] + #[cfg(any(feature = "ssl", feature = "nativetls"))] + fn upgrade_ssl_server(&mut self, _: TcpStream) -> Result<SslStream<TcpStream>> { + unimplemented!() + } +} + +impl<F> Handler for F +where + F: Fn(Message) -> Result<()>, +{ + fn on_message(&mut self, msg: Message) -> Result<()> { + self(msg) + } +} + +mod test { + #![allow(unused_imports, unused_variables, dead_code)] + use super::*; + use frame; + use handshake::{Handshake, Request, Response}; + use message; + use mio; + use protocol::CloseCode; + use result::Result; + use url; + + #[derive(Debug, Eq, PartialEq)] + struct M; + impl Handler for M { + fn on_message(&mut self, _: message::Message) -> Result<()> { + println!("test"); + Ok(()) + } + + fn on_frame(&mut self, f: frame::Frame) -> Result<Option<frame::Frame>> { + Ok(None) + } + } + + #[test] + fn handler() { + struct H; + + impl Handler for H { + fn on_open(&mut self, shake: Handshake) -> Result<()> { + assert!(shake.request.key().is_ok()); + assert!(shake.response.key().is_ok()); + Ok(()) + } + + fn on_message(&mut self, msg: message::Message) -> Result<()> { + Ok(assert_eq!( + msg, + message::Message::Text(String::from("testme")) + )) + } + + fn on_close(&mut self, code: CloseCode, _: &str) { + assert_eq!(code, CloseCode::Normal) + } + } + + let mut h = H; + let url = url::Url::parse("wss://127.0.0.1:3012").unwrap(); + let req = Request::from_url(&url).unwrap(); + let res = Response::from_request(&req).unwrap(); + h.on_open(Handshake { + request: req, + response: res, + peer_addr: None, + local_addr: None, + }).unwrap(); + h.on_message(message::Message::Text("testme".to_owned())) + .unwrap(); + h.on_close(CloseCode::Normal, ""); + } + + #[test] + fn closure_handler() { + let mut close = |msg| { + assert_eq!(msg, message::Message::Binary(vec![1, 2, 3])); + Ok(()) + }; + + close + .on_message(message::Message::Binary(vec![1, 2, 3])) + .unwrap(); + } +} diff --git a/third_party/rust/ws/src/handshake.rs b/third_party/rust/ws/src/handshake.rs new file mode 100644 index 0000000000..b7520bde56 --- /dev/null +++ b/third_party/rust/ws/src/handshake.rs @@ -0,0 +1,740 @@ +use std::fmt; +use std::io::Write; +use std::net::SocketAddr; +use std::str::from_utf8; + +use httparse; +use rand; +use sha1::{self, Digest}; +use url; + +use result::{Error, Kind, Result}; + +static WS_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; +static BASE64: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +const MAX_HEADERS: usize = 124; + +fn generate_key() -> String { + let key: [u8; 16] = rand::random(); + encode_base64(&key) +} + +pub fn hash_key(key: &[u8]) -> String { + let mut hasher = sha1::Sha1::new(); + + hasher.input(key); + hasher.input(WS_GUID.as_bytes()); + + encode_base64(&hasher.result()) +} + +// This code is based on rustc_serialize base64 STANDARD +fn encode_base64(data: &[u8]) -> String { + let len = data.len(); + let mod_len = len % 3; + + let mut encoded = vec![b'='; (len + 2) / 3 * 4]; + { + let mut in_iter = data[..len - mod_len].iter().map(|&c| u32::from(c)); + let mut out_iter = encoded.iter_mut(); + + let enc = |val| BASE64[val as usize]; + let mut write = |val| *out_iter.next().unwrap() = val; + + while let (Some(one), Some(two), Some(three)) = + (in_iter.next(), in_iter.next(), in_iter.next()) + { + let g24 = one << 16 | two << 8 | three; + write(enc((g24 >> 18) & 63)); + write(enc((g24 >> 12) & 63)); + write(enc((g24 >> 6) & 63)); + write(enc(g24 & 63)); + } + + match mod_len { + 1 => { + let pad = (u32::from(data[len - 1])) << 16; + write(enc((pad >> 18) & 63)); + write(enc((pad >> 12) & 63)); + } + 2 => { + let pad = (u32::from(data[len - 2])) << 16 | (u32::from(data[len - 1])) << 8; + write(enc((pad >> 18) & 63)); + write(enc((pad >> 12) & 63)); + write(enc((pad >> 6) & 63)); + } + _ => (), + } + } + + String::from_utf8(encoded).unwrap() +} + +/// A struct representing the two halves of the WebSocket handshake. +#[derive(Debug)] +pub struct Handshake { + /// The HTTP request sent to begin the handshake. + pub request: Request, + /// The HTTP response from the server confirming the handshake. + pub response: Response, + /// The socket address of the other endpoint. This address may + /// be an intermediary such as a proxy server. + pub peer_addr: Option<SocketAddr>, + /// The socket address of this endpoint. + pub local_addr: Option<SocketAddr>, +} + +impl Handshake { + /// Get the IP address of the remote connection. + /// + /// This is the preferred method of obtaining the client's IP address. + /// It will attempt to retrieve the most likely IP address based on request + /// headers, falling back to the address of the peer. + /// + /// # Note + /// This assumes that the peer is a client. If you are implementing a + /// WebSocket client and want to obtain the address of the server, use + /// `Handshake::peer_addr` instead. + /// + /// This method does not ensure that the address is a valid IP address. + #[allow(dead_code)] + pub fn remote_addr(&self) -> Result<Option<String>> { + Ok(self.request.client_addr()?.map(String::from).or_else(|| { + if let Some(addr) = self.peer_addr { + Some(addr.ip().to_string()) + } else { + None + } + })) + } +} + +/// The handshake request. +#[derive(Debug)] +pub struct Request { + path: String, + method: String, + headers: Vec<(String, Vec<u8>)>, +} + +impl Request { + /// Get the value of the first instance of an HTTP header. + pub fn header(&self, header: &str) -> Option<&Vec<u8>> { + self.headers + .iter() + .find(|&&(ref key, _)| key.to_lowercase() == header.to_lowercase()) + .map(|&(_, ref val)| val) + } + + /// Edit the value of the first instance of an HTTP header. + pub fn header_mut(&mut self, header: &str) -> Option<&mut Vec<u8>> { + self.headers + .iter_mut() + .find(|&&mut (ref key, _)| key.to_lowercase() == header.to_lowercase()) + .map(|&mut (_, ref mut val)| val) + } + + /// Access the request headers. + #[allow(dead_code)] + #[inline] + pub fn headers(&self) -> &Vec<(String, Vec<u8>)> { + &self.headers + } + + /// Edit the request headers. + #[allow(dead_code)] + #[inline] + pub fn headers_mut(&mut self) -> &mut Vec<(String, Vec<u8>)> { + &mut self.headers + } + + /// Get the origin of the request if it comes from a browser. + #[allow(dead_code)] + pub fn origin(&self) -> Result<Option<&str>> { + if let Some(origin) = self.header("origin") { + Ok(Some(from_utf8(origin)?)) + } else { + Ok(None) + } + } + + /// Get the unhashed WebSocket key sent in the request. + pub fn key(&self) -> Result<&Vec<u8>> { + self.header("sec-websocket-key") + .ok_or_else(|| Error::new(Kind::Protocol, "Unable to parse WebSocket key.")) + } + + /// Get the hashed WebSocket key from this request. + pub fn hashed_key(&self) -> Result<String> { + Ok(hash_key(self.key()?)) + } + + /// Get the WebSocket protocol version from the request (should be 13). + #[allow(dead_code)] + pub fn version(&self) -> Result<&str> { + if let Some(version) = self.header("sec-websocket-version") { + from_utf8(version).map_err(Error::from) + } else { + Err(Error::new( + Kind::Protocol, + "The Sec-WebSocket-Version header is missing.", + )) + } + } + + /// Get the request method. + #[inline] + pub fn method(&self) -> &str { + &self.method + } + + /// Get the path of the request. + #[allow(dead_code)] + #[inline] + pub fn resource(&self) -> &str { + &self.path + } + + /// Get the possible protocols for the WebSocket connection. + #[allow(dead_code)] + pub fn protocols(&self) -> Result<Vec<&str>> { + if let Some(protos) = self.header("sec-websocket-protocol") { + Ok(from_utf8(protos)? + .split(',') + .map(|proto| proto.trim()) + .collect()) + } else { + Ok(Vec::new()) + } + } + + /// Add a possible protocol to this request. + /// This may result in duplicate protocols listed. + #[allow(dead_code)] + pub fn add_protocol(&mut self, protocol: &str) { + if let Some(protos) = self.header_mut("sec-websocket-protocol") { + protos.push(b","[0]); + protos.extend(protocol.as_bytes()); + return; + } + self.headers_mut() + .push(("Sec-WebSocket-Protocol".into(), protocol.into())) + } + + /// Remove a possible protocol from this request. + #[allow(dead_code)] + pub fn remove_protocol(&mut self, protocol: &str) { + if let Some(protos) = self.header_mut("sec-websocket-protocol") { + let mut new_protos = Vec::with_capacity(protos.len()); + + if let Ok(protos_str) = from_utf8(protos) { + new_protos = protos_str + .split(',') + .filter(|proto| proto.trim() == protocol) + .collect::<Vec<&str>>() + .join(",") + .into(); + } + if new_protos.len() < protos.len() { + *protos = new_protos + } + } + } + + /// Get the possible extensions for the WebSocket connection. + #[allow(dead_code)] + pub fn extensions(&self) -> Result<Vec<&str>> { + if let Some(exts) = self.header("sec-websocket-extensions") { + Ok(from_utf8(exts)?.split(',').map(|ext| ext.trim()).collect()) + } else { + Ok(Vec::new()) + } + } + + /// Add a possible extension to this request. + /// This may result in duplicate extensions listed. Also, the order of extensions + /// indicates preference, so if the preference matters, consider using the + /// `Sec-WebSocket-Protocol` header directly. + #[allow(dead_code)] + pub fn add_extension(&mut self, ext: &str) { + if let Some(exts) = self.header_mut("sec-websocket-extensions") { + exts.push(b","[0]); + exts.extend(ext.as_bytes()); + return; + } + self.headers_mut() + .push(("Sec-WebSocket-Extensions".into(), ext.into())) + } + + /// Remove a possible extension from this request. + /// This will remove all configurations of the extension. + #[allow(dead_code)] + pub fn remove_extension(&mut self, ext: &str) { + if let Some(exts) = self.header_mut("sec-websocket-extensions") { + let mut new_exts = Vec::with_capacity(exts.len()); + + if let Ok(exts_str) = from_utf8(exts) { + new_exts = exts_str + .split(',') + .filter(|e| e.trim().starts_with(ext)) + .collect::<Vec<&str>>() + .join(",") + .into(); + } + if new_exts.len() < exts.len() { + *exts = new_exts + } + } + } + + /// Get the IP address of the client. + /// + /// This method will attempt to retrieve the most likely IP address of the requester + /// in the following manner: + /// + /// If the `X-Forwarded-For` header exists, this method will return the left most + /// address in the list. + /// + /// If the [Forwarded HTTP Header Field](https://tools.ietf.org/html/rfc7239) exits, + /// this method will return the left most address indicated by the `for` parameter, + /// if it exists. + /// + /// # Note + /// This method does not ensure that the address is a valid IP address. + #[allow(dead_code)] + pub fn client_addr(&self) -> Result<Option<&str>> { + if let Some(x_forward) = self.header("x-forwarded-for") { + return Ok(from_utf8(x_forward)?.split(',').next()); + } + + // We only care about the first forwarded header, so header is ok + if let Some(forward) = self.header("forwarded") { + if let Some(_for) = from_utf8(forward)? + .split(';') + .find(|f| f.trim().starts_with("for")) + { + if let Some(_for_eq) = _for.trim().split(',').next() { + let mut it = _for_eq.split('='); + it.next(); + return Ok(it.next()); + } + } + } + Ok(None) + } + + /// Attempt to parse an HTTP request from a buffer. If the buffer does not contain a complete + /// request, this will return `Ok(None)`. + pub fn parse(buf: &[u8]) -> Result<Option<Request>> { + let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; + let mut req = httparse::Request::new(&mut headers); + let parsed = req.parse(buf)?; + if !parsed.is_partial() { + Ok(Some(Request { + path: req.path.unwrap().into(), + method: req.method.unwrap().into(), + headers: req.headers + .iter() + .map(|h| (h.name.into(), h.value.into())) + .collect(), + })) + } else { + Ok(None) + } + } + + /// Construct a new WebSocket handshake HTTP request from a url. + pub fn from_url(url: &url::Url) -> Result<Request> { + let query = if let Some(q) = url.query() { + format!("?{}", q) + } else { + "".into() + }; + + let mut headers = vec![ + ("Connection".into(), "Upgrade".into()), + ( + "Host".into(), + format!( + "{}:{}", + url.host_str().ok_or_else(|| Error::new( + Kind::Internal, + "No host passed for WebSocket connection.", + ))?, + url.port_or_known_default().unwrap_or(80) + ).into(), + ), + ("Sec-WebSocket-Version".into(), "13".into()), + ("Sec-WebSocket-Key".into(), generate_key().into()), + ("Upgrade".into(), "websocket".into()), + ]; + + if url.password().is_some() || url.username() != "" { + let basic = encode_base64(format!("{}:{}", url.username(), url.password().unwrap_or("")).as_bytes()); + headers.push(("Authorization".into(), format!("Basic {}", basic).into())) + } + + let req = Request { + path: format!("{}{}", url.path(), query), + method: "GET".to_owned(), + headers: headers, + }; + + debug!("Built request from URL:\n{}", req); + + Ok(req) + } + + /// Write a request out to a buffer + pub fn format<W>(&self, w: &mut W) -> Result<()> + where + W: Write, + { + write!(w, "{} {} HTTP/1.1\r\n", self.method, self.path)?; + for &(ref key, ref val) in &self.headers { + write!(w, "{}: ", key)?; + w.write_all(val)?; + write!(w, "\r\n")?; + } + write!(w, "\r\n")?; + Ok(()) + } +} + +impl fmt::Display for Request { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut s = Vec::with_capacity(2048); + self.format(&mut s).map_err(|err| { + error!("{:?}", err); + fmt::Error + })?; + write!( + f, + "{}", + from_utf8(&s).map_err(|err| { + error!("Unable to format request as utf8: {:?}", err); + fmt::Error + })? + ) + } +} + +/// The handshake response. +#[derive(Debug)] +pub struct Response { + status: u16, + reason: String, + headers: Vec<(String, Vec<u8>)>, + body: Vec<u8>, +} + +impl Response { + // TODO: resolve the overlap with Request + + /// Construct a generic HTTP response with a body. + pub fn new<R>(status: u16, reason: R, body: Vec<u8>) -> Response + where + R: Into<String>, + { + Response { + status, + reason: reason.into(), + headers: vec![("Content-Length".into(), body.len().to_string().into())], + body, + } + } + + /// Get the response body. + #[inline] + pub fn body(&self) -> &[u8] { + &self.body + } + + /// Get the value of the first instance of an HTTP header. + fn header(&self, header: &str) -> Option<&Vec<u8>> { + self.headers + .iter() + .find(|&&(ref key, _)| key.to_lowercase() == header.to_lowercase()) + .map(|&(_, ref val)| val) + } + /// Edit the value of the first instance of an HTTP header. + pub fn header_mut(&mut self, header: &str) -> Option<&mut Vec<u8>> { + self.headers + .iter_mut() + .find(|&&mut (ref key, _)| key.to_lowercase() == header.to_lowercase()) + .map(|&mut (_, ref mut val)| val) + } + + /// Access the request headers. + #[allow(dead_code)] + #[inline] + pub fn headers(&self) -> &Vec<(String, Vec<u8>)> { + &self.headers + } + + /// Edit the request headers. + #[allow(dead_code)] + #[inline] + pub fn headers_mut(&mut self) -> &mut Vec<(String, Vec<u8>)> { + &mut self.headers + } + + /// Get the HTTP status code. + #[allow(dead_code)] + #[inline] + pub fn status(&self) -> u16 { + self.status + } + + /// Set the HTTP status code. + #[allow(dead_code)] + #[inline] + pub fn set_status(&mut self, status: u16) { + self.status = status + } + + /// Get the HTTP status reason. + #[allow(dead_code)] + #[inline] + pub fn reason(&self) -> &str { + &self.reason + } + + /// Set the HTTP status reason. + #[allow(dead_code)] + #[inline] + pub fn set_reason<R>(&mut self, reason: R) + where + R: Into<String>, + { + self.reason = reason.into() + } + + /// Get the hashed WebSocket key. + pub fn key(&self) -> Result<&Vec<u8>> { + self.header("sec-websocket-accept") + .ok_or_else(|| Error::new(Kind::Protocol, "Unable to parse WebSocket key.")) + } + + /// Get the protocol that the server has decided to use. + #[allow(dead_code)] + pub fn protocol(&self) -> Result<Option<&str>> { + if let Some(proto) = self.header("sec-websocket-protocol") { + Ok(Some(from_utf8(proto)?)) + } else { + Ok(None) + } + } + + /// Set the protocol that the server has decided to use. + #[allow(dead_code)] + pub fn set_protocol(&mut self, protocol: &str) { + if let Some(proto) = self.header_mut("sec-websocket-protocol") { + *proto = protocol.into(); + return; + } + self.headers_mut() + .push(("Sec-WebSocket-Protocol".into(), protocol.into())) + } + + /// Get the extensions that the server has decided to use. If these are unacceptable, it is + /// appropriate to send an Extension close code. + #[allow(dead_code)] + pub fn extensions(&self) -> Result<Vec<&str>> { + if let Some(exts) = self.header("sec-websocket-extensions") { + Ok(from_utf8(exts)? + .split(',') + .map(|proto| proto.trim()) + .collect()) + } else { + Ok(Vec::new()) + } + } + + /// Add an accepted extension to this response. + /// This may result in duplicate extensions listed. + #[allow(dead_code)] + pub fn add_extension(&mut self, ext: &str) { + if let Some(exts) = self.header_mut("sec-websocket-extensions") { + exts.push(b","[0]); + exts.extend(ext.as_bytes()); + return; + } + self.headers_mut() + .push(("Sec-WebSocket-Extensions".into(), ext.into())) + } + + /// Remove an accepted extension from this response. + /// This will remove all configurations of the extension. + #[allow(dead_code)] + pub fn remove_extension(&mut self, ext: &str) { + if let Some(exts) = self.header_mut("sec-websocket-extensions") { + let mut new_exts = Vec::with_capacity(exts.len()); + + if let Ok(exts_str) = from_utf8(exts) { + new_exts = exts_str + .split(',') + .filter(|e| e.trim().starts_with(ext)) + .collect::<Vec<&str>>() + .join(",") + .into(); + } + if new_exts.len() < exts.len() { + *exts = new_exts + } + } + } + + /// Attempt to parse an HTTP response from a buffer. If the buffer does not contain a complete + /// response, thiw will return `Ok(None)`. + pub fn parse(buf: &[u8]) -> Result<Option<Response>> { + let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; + let mut res = httparse::Response::new(&mut headers); + + let parsed = res.parse(buf)?; + if !parsed.is_partial() { + Ok(Some(Response { + status: res.code.unwrap(), + reason: res.reason.unwrap().into(), + headers: res.headers + .iter() + .map(|h| (h.name.into(), h.value.into())) + .collect(), + body: Vec::new(), + })) + } else { + Ok(None) + } + } + + /// Construct a new WebSocket handshake HTTP response from a request. + /// This will create a response that ignores protocols and extensions. Edit this response to + /// accept a protocol and extensions as necessary. + pub fn from_request(req: &Request) -> Result<Response> { + let res = Response { + status: 101, + reason: "Switching Protocols".into(), + headers: vec![ + ("Connection".into(), "Upgrade".into()), + ("Sec-WebSocket-Accept".into(), req.hashed_key()?.into()), + ("Upgrade".into(), "websocket".into()), + ], + body: Vec::new(), + }; + + debug!("Built response from request:\n{}", res); + Ok(res) + } + + /// Write a response out to a buffer + pub fn format<W>(&self, w: &mut W) -> Result<()> + where + W: Write, + { + write!(w, "HTTP/1.1 {} {}\r\n", self.status, self.reason)?; + for &(ref key, ref val) in &self.headers { + write!(w, "{}: ", key)?; + w.write_all(val)?; + write!(w, "\r\n")?; + } + write!(w, "\r\n")?; + w.write_all(&self.body)?; + Ok(()) + } +} + +impl fmt::Display for Response { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut s = Vec::with_capacity(2048); + self.format(&mut s).map_err(|err| { + error!("{:?}", err); + fmt::Error + })?; + write!( + f, + "{}", + from_utf8(&s).map_err(|err| { + error!("Unable to format response as utf8: {:?}", err); + fmt::Error + })? + ) + } +} + +mod test { + #![allow(unused_imports, unused_variables, dead_code)] + use super::*; + use std::io::Write; + use std::net::SocketAddr; + use std::str::FromStr; + + #[test] + fn remote_addr() { + let mut buf = Vec::with_capacity(2048); + write!( + &mut buf, + "GET / HTTP/1.1\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Version: 13\r\n\ + Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n" + ).unwrap(); + + let req = Request::parse(&buf).unwrap().unwrap(); + let res = Response::from_request(&req).unwrap(); + let shake = Handshake { + request: req, + response: res, + peer_addr: Some(SocketAddr::from_str("127.0.0.1:8888").unwrap()), + local_addr: None, + }; + assert_eq!(shake.remote_addr().unwrap().unwrap(), "127.0.0.1"); + } + + #[test] + fn remote_addr_x_forwarded_for() { + let mut buf = Vec::with_capacity(2048); + write!( + &mut buf, + "GET / HTTP/1.1\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + X-Forwarded-For: 192.168.1.1, 192.168.1.2, 192.168.1.3\r\n\ + Sec-WebSocket-Version: 13\r\n\ + Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n" + ).unwrap(); + + let req = Request::parse(&buf).unwrap().unwrap(); + let res = Response::from_request(&req).unwrap(); + let shake = Handshake { + request: req, + response: res, + peer_addr: None, + local_addr: None, + }; + assert_eq!(shake.remote_addr().unwrap().unwrap(), "192.168.1.1"); + } + + #[test] + fn remote_addr_forwarded() { + let mut buf = Vec::with_capacity(2048); + write!( + &mut buf, + "GET / HTTP/1.1\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + Forwarded: by=192.168.1.1; for=192.0.2.43, for=\"[2001:db8:cafe::17]\", for=unknown\r\n\ + Sec-WebSocket-Version: 13\r\n\ + Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n") + .unwrap(); + let req = Request::parse(&buf).unwrap().unwrap(); + let res = Response::from_request(&req).unwrap(); + let shake = Handshake { + request: req, + response: res, + peer_addr: None, + local_addr: None, + }; + assert_eq!(shake.remote_addr().unwrap().unwrap(), "192.0.2.43"); + } +} diff --git a/third_party/rust/ws/src/io.rs b/third_party/rust/ws/src/io.rs new file mode 100644 index 0000000000..7739cdc472 --- /dev/null +++ b/third_party/rust/ws/src/io.rs @@ -0,0 +1,985 @@ +use std::borrow::Borrow; +use std::io::{Error as IoError, ErrorKind}; +use std::net::{SocketAddr, ToSocketAddrs}; +use std::time::Duration; +use std::usize; + +use mio; +use mio::tcp::{TcpListener, TcpStream}; +use mio::{Poll, PollOpt, Ready, Token}; +use mio_extras; + +use url::Url; + +#[cfg(feature = "native_tls")] +use native_tls::Error as SslError; + +use super::Settings; +use communication::{Command, Sender, Signal}; +use connection::Connection; +use factory::Factory; +use slab::Slab; +use result::{Error, Kind, Result}; + + +const QUEUE: Token = Token(usize::MAX - 3); +const TIMER: Token = Token(usize::MAX - 4); +pub const ALL: Token = Token(usize::MAX - 5); +const SYSTEM: Token = Token(usize::MAX - 6); + +type Conn<F> = Connection<<F as Factory>::Handler>; + +const MAX_EVENTS: usize = 1024; +const MESSAGES_PER_TICK: usize = 256; +const TIMER_TICK_MILLIS: u64 = 100; +const TIMER_WHEEL_SIZE: usize = 1024; +const TIMER_CAPACITY: usize = 65_536; + +#[cfg(not(windows))] +const CONNECTION_REFUSED: i32 = 111; +#[cfg(windows)] +const CONNECTION_REFUSED: i32 = 61; + +fn url_to_addrs(url: &Url) -> Result<Vec<SocketAddr>> { + let host = url.host_str(); + if host.is_none() || (url.scheme() != "ws" && url.scheme() != "wss") { + return Err(Error::new( + Kind::Internal, + format!("Not a valid websocket url: {}", url), + )); + } + let host = host.unwrap(); + + let port = url.port_or_known_default().unwrap_or(80); + let mut addrs = (&host[..], port) + .to_socket_addrs()? + .collect::<Vec<SocketAddr>>(); + addrs.dedup(); + Ok(addrs) +} + +enum State { + Active, + Inactive, +} + +impl State { + fn is_active(&self) -> bool { + match *self { + State::Active => true, + State::Inactive => false, + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct Timeout { + connection: Token, + event: Token, +} + +pub struct Handler<F> +where + F: Factory, +{ + listener: Option<TcpListener>, + connections: Slab<Conn<F>>, + factory: F, + settings: Settings, + state: State, + queue_tx: mio::channel::SyncSender<Command>, + queue_rx: mio::channel::Receiver<Command>, + timer: mio_extras::timer::Timer<Timeout>, + next_connection_id: u32, +} + +impl<F> Handler<F> +where + F: Factory, +{ + pub fn new(factory: F, settings: Settings) -> Handler<F> { + let (tx, rx) = mio::channel::sync_channel(settings.max_connections * settings.queue_size); + let timer = mio_extras::timer::Builder::default() + .tick_duration(Duration::from_millis(TIMER_TICK_MILLIS)) + .num_slots(TIMER_WHEEL_SIZE) + .capacity(TIMER_CAPACITY) + .build(); + Handler { + listener: None, + connections: Slab::with_capacity(settings.max_connections), + factory, + settings, + state: State::Inactive, + queue_tx: tx, + queue_rx: rx, + timer, + next_connection_id: 0, + } + } + + pub fn sender(&self) -> Sender { + Sender::new(ALL, self.queue_tx.clone(), 0) + } + + pub fn listen(&mut self, poll: &mut Poll, addr: &SocketAddr) -> Result<&mut Handler<F>> { + debug_assert!( + self.listener.is_none(), + "Attempted to listen for connections from two addresses on the same websocket." + ); + + let tcp = TcpListener::bind(addr)?; + // TODO: consider net2 in order to set reuse_addr + poll.register(&tcp, ALL, Ready::readable(), PollOpt::level())?; + self.listener = Some(tcp); + Ok(self) + } + + pub fn local_addr(&self) -> ::std::io::Result<SocketAddr> { + if let Some(ref listener) = self.listener { + listener.local_addr() + } else { + Err(IoError::new(ErrorKind::NotFound, "Not a listening socket")) + } + } + + #[cfg(any(feature = "ssl", feature = "nativetls"))] + pub fn connect(&mut self, poll: &mut Poll, url: Url) -> Result<()> { + let settings = self.settings; + + let (tok, addresses) = { + let (tok, entry, connection_id, handler) = + if self.connections.len() < settings.max_connections { + let entry = self.connections.vacant_entry(); + let tok = Token(entry.key()); + let connection_id = self.next_connection_id; + self.next_connection_id = self.next_connection_id.wrapping_add(1); + ( + tok, + entry, + connection_id, + self.factory.client_connected(Sender::new( + tok, + self.queue_tx.clone(), + connection_id, + )), + ) + } else { + return Err(Error::new( + Kind::Capacity, + "Unable to add another connection to the event loop.", + )); + }; + + let mut addresses = match url_to_addrs(&url) { + Ok(addresses) => addresses, + Err(err) => { + self.factory.connection_lost(handler); + return Err(err); + } + }; + + loop { + if let Some(addr) = addresses.pop() { + if let Ok(sock) = TcpStream::connect(&addr) { + if settings.tcp_nodelay { + sock.set_nodelay(true)? + } + addresses.push(addr); // Replace the first addr in case ssl fails and we fallback + entry.insert(Connection::new(tok, sock, handler, settings, connection_id)); + break; + } + } else { + self.factory.connection_lost(handler); + return Err(Error::new( + Kind::Internal, + format!("Unable to obtain any socket address for {}", url), + )); + } + } + + (tok, addresses) + }; + + let will_encrypt = url.scheme() == "wss"; + + if let Err(error) = self.connections[tok.into()].as_client(url, addresses) { + let handler = self.connections.remove(tok.into()).consume(); + self.factory.connection_lost(handler); + return Err(error); + } + + if will_encrypt { + while let Err(ssl_error) = self.connections[tok.into()].encrypt() { + match ssl_error.kind { + #[cfg(feature = "ssl")] + Kind::Ssl(ref inner_ssl_error) => { + if let Some(io_error) = inner_ssl_error.io_error() { + if let Some(errno) = io_error.raw_os_error() { + if errno == CONNECTION_REFUSED { + if let Err(reset_error) = self.connections[tok.into()].reset() { + trace!( + "Encountered error while trying to reset connection: {:?}", + reset_error + ); + } else { + continue; + } + } + } + } + } + #[cfg(feature = "nativetls")] + Kind::Ssl(_) => { + if let Err(reset_error) = self.connections[tok.into()].reset() { + trace!( + "Encountered error while trying to reset connection: {:?}", + reset_error + ); + } else { + continue; + } + } + _ => (), + } + self.connections[tok.into()].error(ssl_error); + // Allow socket to be registered anyway to await hangup + break; + } + } + + poll.register( + self.connections[tok.into()].socket(), + self.connections[tok.into()].token(), + self.connections[tok.into()].events(), + PollOpt::edge() | PollOpt::oneshot(), + ).map_err(Error::from) + .or_else(|err| { + error!( + "Encountered error while trying to build WebSocket connection: {}", + err + ); + let handler = self.connections.remove(tok.into()).consume(); + self.factory.connection_lost(handler); + Err(err) + }) + } + + #[cfg(not(any(feature = "ssl", feature = "nativetls")))] + pub fn connect(&mut self, poll: &mut Poll, url: Url) -> Result<()> { + let settings = self.settings; + + let (tok, addresses) = { + let (tok, entry, connection_id, handler) = + if self.connections.len() < settings.max_connections { + let entry = self.connections.vacant_entry(); + let tok = Token(entry.key()); + let connection_id = self.next_connection_id; + self.next_connection_id = self.next_connection_id.wrapping_add(1); + ( + tok, + entry, + connection_id, + self.factory.client_connected(Sender::new( + tok, + self.queue_tx.clone(), + connection_id, + )), + ) + } else { + return Err(Error::new( + Kind::Capacity, + "Unable to add another connection to the event loop.", + )); + }; + + let mut addresses = match url_to_addrs(&url) { + Ok(addresses) => addresses, + Err(err) => { + self.factory.connection_lost(handler); + return Err(err); + } + }; + + loop { + if let Some(addr) = addresses.pop() { + if let Ok(sock) = TcpStream::connect(&addr) { + if settings.tcp_nodelay { + sock.set_nodelay(true)? + } + entry.insert(Connection::new(tok, sock, handler, settings, connection_id)); + break; + } + } else { + self.factory.connection_lost(handler); + return Err(Error::new( + Kind::Internal, + format!("Unable to obtain any socket address for {}", url), + )); + } + } + + (tok, addresses) + }; + + if url.scheme() == "wss" { + let error = Error::new( + Kind::Protocol, + "The ssl feature is not enabled. Please enable it to use wss urls.", + ); + let handler = self.connections.remove(tok.into()).consume(); + self.factory.connection_lost(handler); + return Err(error); + } + + if let Err(error) = self.connections[tok.into()].as_client(url, addresses) { + let handler = self.connections.remove(tok.into()).consume(); + self.factory.connection_lost(handler); + return Err(error); + } + + poll.register( + self.connections[tok.into()].socket(), + self.connections[tok.into()].token(), + self.connections[tok.into()].events(), + PollOpt::edge() | PollOpt::oneshot(), + ).map_err(Error::from) + .or_else(|err| { + error!( + "Encountered error while trying to build WebSocket connection: {}", + err + ); + let handler = self.connections.remove(tok.into()).consume(); + self.factory.connection_lost(handler); + Err(err) + }) + } + + #[cfg(any(feature = "ssl", feature = "nativetls"))] + pub fn accept(&mut self, poll: &mut Poll, sock: TcpStream) -> Result<()> { + let factory = &mut self.factory; + let settings = self.settings; + + if settings.tcp_nodelay { + sock.set_nodelay(true)? + } + + let tok = { + if self.connections.len() < settings.max_connections { + let entry = self.connections.vacant_entry(); + let tok = Token(entry.key()); + let connection_id = self.next_connection_id; + self.next_connection_id = self.next_connection_id.wrapping_add(1); + let handler = factory.server_connected(Sender::new( + tok, + self.queue_tx.clone(), + connection_id, + )); + entry.insert(Connection::new(tok, sock, handler, settings, connection_id)); + tok + } else { + return Err(Error::new( + Kind::Capacity, + "Unable to add another connection to the event loop.", + )); + } + }; + + let conn = &mut self.connections[tok.into()]; + + conn.as_server()?; + if settings.encrypt_server { + conn.encrypt()? + } + + poll.register( + conn.socket(), + conn.token(), + conn.events(), + PollOpt::edge() | PollOpt::oneshot(), + ).map_err(Error::from) + .or_else(|err| { + error!( + "Encountered error while trying to build WebSocket connection: {}", + err + ); + conn.error(err); + if settings.panic_on_new_connection { + panic!("Encountered error while trying to build WebSocket connection."); + } + Ok(()) + }) + } + + #[cfg(not(any(feature = "ssl", feature = "nativetls")))] + pub fn accept(&mut self, poll: &mut Poll, sock: TcpStream) -> Result<()> { + let factory = &mut self.factory; + let settings = self.settings; + + if settings.tcp_nodelay { + sock.set_nodelay(true)? + } + + let tok = { + if self.connections.len() < settings.max_connections { + let entry = self.connections.vacant_entry(); + let tok = Token(entry.key()); + let connection_id = self.next_connection_id; + self.next_connection_id = self.next_connection_id.wrapping_add(1); + let handler = factory.server_connected(Sender::new( + tok, + self.queue_tx.clone(), + connection_id, + )); + entry.insert(Connection::new(tok, sock, handler, settings, connection_id)); + tok + } else { + return Err(Error::new( + Kind::Capacity, + "Unable to add another connection to the event loop.", + )); + } + }; + + let conn = &mut self.connections[tok.into()]; + + conn.as_server()?; + if settings.encrypt_server { + return Err(Error::new( + Kind::Protocol, + "The ssl feature is not enabled. Please enable it to use wss urls.", + )); + } + + poll.register( + conn.socket(), + conn.token(), + conn.events(), + PollOpt::edge() | PollOpt::oneshot(), + ).map_err(Error::from) + .or_else(|err| { + error!( + "Encountered error while trying to build WebSocket connection: {}", + err + ); + conn.error(err); + if settings.panic_on_new_connection { + panic!("Encountered error while trying to build WebSocket connection."); + } + Ok(()) + }) + } + + pub fn run(&mut self, poll: &mut Poll) -> Result<()> { + trace!("Running event loop"); + poll.register( + &self.queue_rx, + QUEUE, + Ready::readable(), + PollOpt::edge() | PollOpt::oneshot(), + )?; + poll.register(&self.timer, TIMER, Ready::readable(), PollOpt::edge())?; + + self.state = State::Active; + let result = self.event_loop(poll); + self.state = State::Inactive; + + result + .and(poll.deregister(&self.timer).map_err(Error::from)) + .and(poll.deregister(&self.queue_rx).map_err(Error::from)) + } + + #[inline] + fn event_loop(&mut self, poll: &mut Poll) -> Result<()> { + let mut events = mio::Events::with_capacity(MAX_EVENTS); + while self.state.is_active() { + trace!("Waiting for event"); + let nevents = match poll.poll(&mut events, None) { + Ok(nevents) => nevents, + Err(err) => { + if err.kind() == ErrorKind::Interrupted { + if self.settings.shutdown_on_interrupt { + error!("Websocket shutting down for interrupt."); + self.state = State::Inactive; + } else { + error!("Websocket received interrupt."); + } + 0 + } else { + return Err(Error::from(err)); + } + } + }; + trace!("Processing {} events", nevents); + + for i in 0..nevents { + let evt = events.get(i).unwrap(); + self.handle_event(poll, evt.token(), evt.kind()); + } + + self.check_count(); + } + Ok(()) + } + + #[inline] + fn schedule(&self, poll: &mut Poll, conn: &Conn<F>) -> Result<()> { + trace!( + "Scheduling connection to {} as {:?}", + conn.socket() + .peer_addr() + .map(|addr| addr.to_string()) + .unwrap_or_else(|_| "UNKNOWN".into()), + conn.events() + ); + poll.reregister( + conn.socket(), + conn.token(), + conn.events(), + PollOpt::edge() | PollOpt::oneshot(), + )?; + Ok(()) + } + + fn shutdown(&mut self) { + debug!("Received shutdown signal. WebSocket is attempting to shut down."); + for (_, conn) in self.connections.iter_mut() { + conn.shutdown(); + } + self.factory.on_shutdown(); + self.state = State::Inactive; + if self.settings.panic_on_shutdown { + panic!("Panicking on shutdown as per setting.") + } + } + + #[inline] + fn check_active(&mut self, poll: &mut Poll, active: bool, token: Token) { + // NOTE: Closing state only applies after a ws connection was successfully + // established. It's possible that we may go inactive while in a connecting + // state if the handshake fails. + if !active { + if let Ok(addr) = self.connections[token.into()].socket().peer_addr() { + debug!("WebSocket connection to {} disconnected.", addr); + } else { + trace!("WebSocket connection to token={:?} disconnected.", token); + } + let handler = self.connections.remove(token.into()).consume(); + self.factory.connection_lost(handler); + } else { + self.schedule(poll, &self.connections[token.into()]) + .or_else(|err| { + // This will be an io error, so disconnect will already be called + self.connections[token.into()].error(err); + let handler = self.connections.remove(token.into()).consume(); + self.factory.connection_lost(handler); + Ok::<(), Error>(()) + }) + .unwrap() + } + } + + #[inline] + fn is_client(&self) -> bool { + self.listener.is_none() + } + + #[inline] + fn check_count(&mut self) { + trace!("Active connections {:?}", self.connections.len()); + if self.connections.is_empty() { + if !self.state.is_active() { + debug!("Shutting down websocket server."); + } else if self.is_client() { + debug!("Shutting down websocket client."); + self.factory.on_shutdown(); + self.state = State::Inactive; + } + } + } + + fn handle_event(&mut self, poll: &mut Poll, token: Token, events: Ready) { + match token { + SYSTEM => { + debug_assert!(false, "System token used for io event. This is a bug!"); + error!("System token used for io event. This is a bug!"); + } + ALL => { + if events.is_readable() { + match self.listener + .as_ref() + .expect("No listener provided for server websocket connections") + .accept() + { + Ok((sock, addr)) => { + info!("Accepted a new tcp connection from {}.", addr); + if let Err(err) = self.accept(poll, sock) { + error!("Unable to build WebSocket connection {:?}", err); + if self.settings.panic_on_new_connection { + panic!("Unable to build WebSocket connection {:?}", err); + } + } + } + Err(err) => error!( + "Encountered an error {:?} while accepting tcp connection.", + err + ), + } + } + } + TIMER => while let Some(t) = self.timer.poll() { + self.handle_timeout(poll, t); + }, + QUEUE => { + for _ in 0..MESSAGES_PER_TICK { + match self.queue_rx.try_recv() { + Ok(cmd) => self.handle_queue(poll, cmd), + _ => break, + } + } + let _ = poll.reregister( + &self.queue_rx, + QUEUE, + Ready::readable(), + PollOpt::edge() | PollOpt::oneshot(), + ); + } + _ => { + let active = { + let conn_events = self.connections[token.into()].events(); + + if (events & conn_events).is_readable() { + if let Err(err) = self.connections[token.into()].read() { + trace!("Encountered error while reading: {}", err); + if let Kind::Io(ref err) = err.kind { + if let Some(errno) = err.raw_os_error() { + if errno == CONNECTION_REFUSED { + match self.connections[token.into()].reset() { + Ok(_) => { + poll.register( + self.connections[token.into()].socket(), + self.connections[token.into()].token(), + self.connections[token.into()].events(), + PollOpt::edge() | PollOpt::oneshot(), + ).or_else(|err| { + self.connections[token.into()] + .error(Error::from(err)); + let handler = self.connections + .remove(token.into()) + .consume(); + self.factory.connection_lost(handler); + Ok::<(), Error>(()) + }) + .unwrap(); + return; + } + Err(err) => { + trace!("Encountered error while trying to reset connection: {:?}", err); + } + } + } + } + } + // This will trigger disconnect if the connection is open + self.connections[token.into()].error(err) + } + } + + let conn_events = self.connections[token.into()].events(); + + if (events & conn_events).is_writable() { + if let Err(err) = self.connections[token.into()].write() { + trace!("Encountered error while writing: {}", err); + if let Kind::Io(ref err) = err.kind { + if let Some(errno) = err.raw_os_error() { + if errno == CONNECTION_REFUSED { + match self.connections[token.into()].reset() { + Ok(_) => { + poll.register( + self.connections[token.into()].socket(), + self.connections[token.into()].token(), + self.connections[token.into()].events(), + PollOpt::edge() | PollOpt::oneshot(), + ).or_else(|err| { + self.connections[token.into()] + .error(Error::from(err)); + let handler = self.connections + .remove(token.into()) + .consume(); + self.factory.connection_lost(handler); + Ok::<(), Error>(()) + }) + .unwrap(); + return; + } + Err(err) => { + trace!("Encountered error while trying to reset connection: {:?}", err); + } + } + } + } + } + // This will trigger disconnect if the connection is open + self.connections[token.into()].error(err) + } + } + + // connection events may have changed + self.connections[token.into()].events().is_readable() + || self.connections[token.into()].events().is_writable() + }; + + self.check_active(poll, active, token) + } + } + } + + fn handle_queue(&mut self, poll: &mut Poll, cmd: Command) { + match cmd.token() { + SYSTEM => { + // Scaffolding for system events such as internal timeouts + } + ALL => { + let mut dead = Vec::with_capacity(self.connections.len()); + + match cmd.into_signal() { + Signal::Message(msg) => { + trace!("Broadcasting message: {:?}", msg); + for (_, conn) in self.connections.iter_mut() { + if let Err(err) = conn.send_message(msg.clone()) { + dead.push((conn.token(), err)) + } + } + } + Signal::Close(code, reason) => { + trace!("Broadcasting close: {:?} - {}", code, reason); + for (_, conn) in self.connections.iter_mut() { + if let Err(err) = conn.send_close(code, reason.borrow()) { + dead.push((conn.token(), err)) + } + } + } + Signal::Ping(data) => { + trace!("Broadcasting ping"); + for (_, conn) in self.connections.iter_mut() { + if let Err(err) = conn.send_ping(data.clone()) { + dead.push((conn.token(), err)) + } + } + } + Signal::Pong(data) => { + trace!("Broadcasting pong"); + for (_, conn) in self.connections.iter_mut() { + if let Err(err) = conn.send_pong(data.clone()) { + dead.push((conn.token(), err)) + } + } + } + Signal::Connect(url) => { + if let Err(err) = self.connect(poll, url.clone()) { + if self.settings.panic_on_new_connection { + panic!("Unable to establish connection to {}: {:?}", url, err); + } + error!("Unable to establish connection to {}: {:?}", url, err); + } + return; + } + Signal::Shutdown => self.shutdown(), + Signal::Timeout { + delay, + token: event, + } => { + let timeout = self.timer.set_timeout( + Duration::from_millis(delay), + Timeout { + connection: ALL, + event, + }, + ); + for (_, conn) in self.connections.iter_mut() { + if let Err(err) = conn.new_timeout(event, timeout.clone()) { + conn.error(err); + } + } + return; + } + Signal::Cancel(timeout) => { + self.timer.cancel_timeout(&timeout); + return; + } + } + + for (_, conn) in self.connections.iter() { + if let Err(err) = self.schedule(poll, conn) { + dead.push((conn.token(), err)) + } + } + for (token, err) in dead { + // note the same connection may be called twice + self.connections[token.into()].error(err) + } + } + token => { + let connection_id = cmd.connection_id(); + match cmd.into_signal() { + Signal::Message(msg) => { + if let Some(conn) = self.connections.get_mut(token.into()) { + if conn.connection_id() == connection_id { + if let Err(err) = conn.send_message(msg) { + conn.error(err) + } + } else { + trace!("Connection disconnected while a message was waiting in the queue.") + } + } else { + trace!( + "Connection disconnected while a message was waiting in the queue." + ) + } + } + Signal::Close(code, reason) => { + if let Some(conn) = self.connections.get_mut(token.into()) { + if conn.connection_id() == connection_id { + if let Err(err) = conn.send_close(code, reason) { + conn.error(err) + } + } else { + trace!("Connection disconnected while close signal was waiting in the queue.") + } + } else { + trace!("Connection disconnected while close signal was waiting in the queue.") + } + } + Signal::Ping(data) => { + if let Some(conn) = self.connections.get_mut(token.into()) { + if conn.connection_id() == connection_id { + if let Err(err) = conn.send_ping(data) { + conn.error(err) + } + } else { + trace!("Connection disconnected while ping signal was waiting in the queue.") + } + } else { + trace!("Connection disconnected while ping signal was waiting in the queue.") + } + } + Signal::Pong(data) => { + if let Some(conn) = self.connections.get_mut(token.into()) { + if conn.connection_id() == connection_id { + if let Err(err) = conn.send_pong(data) { + conn.error(err) + } + } else { + trace!("Connection disconnected while pong signal was waiting in the queue.") + } + } else { + trace!("Connection disconnected while pong signal was waiting in the queue.") + } + } + Signal::Connect(url) => { + if let Err(err) = self.connect(poll, url.clone()) { + if let Some(conn) = self.connections.get_mut(token.into()) { + conn.error(err) + } else { + if self.settings.panic_on_new_connection { + panic!("Unable to establish connection to {}: {:?}", url, err); + } + error!("Unable to establish connection to {}: {:?}", url, err); + } + } + return; + } + Signal::Shutdown => self.shutdown(), + Signal::Timeout { + delay, + token: event, + } => { + let timeout = self.timer.set_timeout( + Duration::from_millis(delay), + Timeout { + connection: token, + event, + }, + ); + if let Some(conn) = self.connections.get_mut(token.into()) { + if let Err(err) = conn.new_timeout(event, timeout) { + conn.error(err) + } + } else { + trace!("Connection disconnected while pong signal was waiting in the queue.") + } + return; + } + Signal::Cancel(timeout) => { + self.timer.cancel_timeout(&timeout); + return; + } + } + + if self.connections.get(token.into()).is_some() { + if let Err(err) = self.schedule(poll, &self.connections[token.into()]) { + self.connections[token.into()].error(err) + } + } + } + } + } + + fn handle_timeout(&mut self, poll: &mut Poll, Timeout { connection, event }: Timeout) { + let active = { + if let Some(conn) = self.connections.get_mut(connection.into()) { + if let Err(err) = conn.timeout_triggered(event) { + conn.error(err) + } + + conn.events().is_readable() || conn.events().is_writable() + } else { + trace!("Connection disconnected while timeout was waiting."); + return; + } + }; + self.check_active(poll, active, connection); + } +} + +mod test { + #![allow(unused_imports, unused_variables, dead_code)] + use std::str::FromStr; + + use url::Url; + + use super::url_to_addrs; + use super::*; + use result::{Error, Kind}; + + #[test] + fn test_url_to_addrs() { + let ws_url = Url::from_str("ws://example.com?query=me").unwrap(); + let wss_url = Url::from_str("wss://example.com/suburl#fragment").unwrap(); + let bad_url = Url::from_str("http://howdy.bad.com").unwrap(); + let no_resolve = Url::from_str("ws://bad.elucitrans.com").unwrap(); + + assert!(url_to_addrs(&ws_url).is_ok()); + assert!(url_to_addrs(&ws_url).unwrap().len() > 0); + assert!(url_to_addrs(&wss_url).is_ok()); + assert!(url_to_addrs(&wss_url).unwrap().len() > 0); + + match url_to_addrs(&bad_url) { + Ok(_) => panic!("url_to_addrs accepts http urls."), + Err(Error { + kind: Kind::Internal, + details: _, + }) => (), // pass + err => panic!("{:?}", err), + } + + match url_to_addrs(&no_resolve) { + Ok(_) => panic!("url_to_addrs creates addresses for non-existent domains."), + Err(Error { + kind: Kind::Io(_), + details: _, + }) => (), // pass + err => panic!("{:?}", err), + } + } + +} diff --git a/third_party/rust/ws/src/lib.rs b/third_party/rust/ws/src/lib.rs new file mode 100644 index 0000000000..ea9f1a54e4 --- /dev/null +++ b/third_party/rust/ws/src/lib.rs @@ -0,0 +1,391 @@ +//! Lightweight, event-driven WebSockets for Rust. +#![allow(deprecated)] +#![deny(missing_copy_implementations, trivial_casts, trivial_numeric_casts, unstable_features, + unused_import_braces)] + +extern crate byteorder; +extern crate bytes; +extern crate httparse; +extern crate mio; +extern crate mio_extras; +#[cfg(feature = "ssl")] +extern crate openssl; +#[cfg(feature = "nativetls")] +extern crate native_tls; +extern crate rand; +extern crate sha1; +extern crate slab; +extern crate url; +#[macro_use] +extern crate log; + +mod communication; +mod connection; +mod factory; +mod frame; +mod handler; +mod handshake; +mod io; +mod message; +mod protocol; +mod result; +mod stream; + +#[cfg(feature = "permessage-deflate")] +pub mod deflate; + +pub mod util; + +pub use factory::Factory; +pub use handler::Handler; + +pub use communication::Sender; +pub use frame::Frame; +pub use handshake::{Handshake, Request, Response}; +pub use message::Message; +pub use protocol::{CloseCode, OpCode}; +pub use result::Kind as ErrorKind; +pub use result::{Error, Result}; + +use std::borrow::Borrow; +use std::default::Default; +use std::fmt; +use std::net::{SocketAddr, ToSocketAddrs}; + +use mio::Poll; + +/// A utility function for setting up a WebSocket server. +/// +/// # Safety +/// +/// This function blocks until the event loop finishes running. Avoid calling this method within +/// another WebSocket handler. +/// +/// # Examples +/// +/// ```no_run +/// use ws::listen; +/// +/// listen("127.0.0.1:3012", |out| { +/// move |msg| { +/// out.send(msg) +/// } +/// }).unwrap() +/// ``` +/// +pub fn listen<A, F, H>(addr: A, factory: F) -> Result<()> +where + A: ToSocketAddrs + fmt::Debug, + F: FnMut(Sender) -> H, + H: Handler, +{ + let ws = WebSocket::new(factory)?; + ws.listen(addr)?; + Ok(()) +} + +/// A utility function for setting up a WebSocket client. +/// +/// # Safety +/// +/// This function blocks until the event loop finishes running. Avoid calling this method within +/// another WebSocket handler. If you need to establish a connection from inside of a handler, +/// use the `connect` method on the Sender. +/// +/// # Examples +/// +/// ```no_run +/// use ws::{connect, CloseCode}; +/// +/// connect("ws://127.0.0.1:3012", |out| { +/// out.send("Hello WebSocket").unwrap(); +/// +/// move |msg| { +/// println!("Got message: {}", msg); +/// out.close(CloseCode::Normal) +/// } +/// }).unwrap() +/// ``` +/// +pub fn connect<U, F, H>(url: U, factory: F) -> Result<()> +where + U: Borrow<str>, + F: FnMut(Sender) -> H, + H: Handler, +{ + let mut ws = WebSocket::new(factory)?; + let parsed = url::Url::parse(url.borrow()).map_err(|err| { + Error::new( + ErrorKind::Internal, + format!("Unable to parse {} as url due to {:?}", url.borrow(), err), + ) + })?; + ws.connect(parsed)?; + ws.run()?; + Ok(()) +} + +/// WebSocket settings +#[derive(Debug, Clone, Copy)] +pub struct Settings { + /// The maximum number of connections that this WebSocket will support. + /// The default setting is low and should be increased when expecting more + /// connections because this is a hard limit and no new connections beyond + /// this limit can be made until an old connection is dropped. + /// Default: 100 + pub max_connections: usize, + /// The number of events anticipated per connection. The event loop queue size will + /// be `queue_size` * `max_connections`. In order to avoid an overflow error, + /// `queue_size` * `max_connections` must be less than or equal to `usize::max_value()`. + /// The queue is shared between connections, which means that a connection may schedule + /// more events than `queue_size` provided that another connection is using less than + /// `queue_size`. However, if the queue is maxed out a Queue error will occur. + /// Default: 5 + pub queue_size: usize, + /// Whether to panic when unable to establish a new TCP connection. + /// Default: false + pub panic_on_new_connection: bool, + /// Whether to panic when a shutdown of the WebSocket is requested. + /// Default: false + pub panic_on_shutdown: bool, + /// The maximum number of fragments the connection can handle without reallocating. + /// Default: 10 + pub fragments_capacity: usize, + /// Whether to reallocate when `fragments_capacity` is reached. If this is false, + /// a Capacity error will be triggered instead. + /// Default: true + pub fragments_grow: bool, + /// The maximum length of outgoing frames. Messages longer than this will be fragmented. + /// Default: 65,535 + pub fragment_size: usize, + /// The maximum length of acceptable incoming frames. Messages longer than this will be rejected. + /// Default: unlimited + pub max_fragment_size: usize, + /// The size of the incoming buffer. A larger buffer uses more memory but will allow for fewer + /// reallocations. + /// Default: 2048 + pub in_buffer_capacity: usize, + /// Whether to reallocate the incoming buffer when `in_buffer_capacity` is reached. If this is + /// false, a Capacity error will be triggered instead. + /// Default: true + pub in_buffer_grow: bool, + /// The size of the outgoing buffer. A larger buffer uses more memory but will allow for fewer + /// reallocations. + /// Default: 2048 + pub out_buffer_capacity: usize, + /// Whether to reallocate the incoming buffer when `out_buffer_capacity` is reached. If this is + /// false, a Capacity error will be triggered instead. + /// Default: true + pub out_buffer_grow: bool, + /// Whether to panic when an Internal error is encountered. Internal errors should generally + /// not occur, so this setting defaults to true as a debug measure, whereas production + /// applications should consider setting it to false. + /// Default: true + pub panic_on_internal: bool, + /// Whether to panic when a Capacity error is encountered. + /// Default: false + pub panic_on_capacity: bool, + /// Whether to panic when a Protocol error is encountered. + /// Default: false + pub panic_on_protocol: bool, + /// Whether to panic when an Encoding error is encountered. + /// Default: false + pub panic_on_encoding: bool, + /// Whether to panic when a Queue error is encountered. + /// Default: false + pub panic_on_queue: bool, + /// Whether to panic when an Io error is encountered. + /// Default: false + pub panic_on_io: bool, + /// Whether to panic when a Timer error is encountered. + /// Default: false + pub panic_on_timeout: bool, + /// Whether to shutdown the eventloop when an interrupt is received. + /// Default: true + pub shutdown_on_interrupt: bool, + /// The WebSocket protocol requires frames sent from client endpoints to be masked as a + /// security and sanity precaution. Enforcing this requirement, which may be removed at some + /// point may cause incompatibilities. If you need the extra security, set this to true. + /// Default: false + pub masking_strict: bool, + /// The WebSocket protocol requires clients to verify the key returned by a server to ensure + /// that the server and all intermediaries can perform the protocol. Verifying the key will + /// consume processing time and other resources with the benefit that we can fail the + /// connection early. The default in WS-RS is to accept any key from the server and instead + /// fail late if a protocol error occurs. Change this setting to enable key verification. + /// Default: false + pub key_strict: bool, + /// The WebSocket protocol requires clients to perform an opening handshake using the HTTP + /// GET method for the request. However, since only WebSockets are supported on the connection, + /// verifying the method of handshake requests is not always necessary. To enforce the + /// requirement that handshakes begin with a GET method, set this to true. + /// Default: false + pub method_strict: bool, + /// Indicate whether server connections should use ssl encryption when accepting connections. + /// Setting this to true means that clients should use the `wss` scheme to connect to this + /// server. Note that using this flag will in general necessitate overriding the + /// `Handler::upgrade_ssl_server` method in order to provide the details of the ssl context. It may be + /// simpler for most users to use a reverse proxy such as nginx to provide server side + /// encryption. + /// + /// Default: false + pub encrypt_server: bool, + /// Disables Nagle's algorithm. + /// Usually tcp socket tries to accumulate packets to send them all together (every 200ms). + /// When enabled socket will try to send packet as fast as possible. + /// + /// Default: false + pub tcp_nodelay: bool, +} + +impl Default for Settings { + fn default() -> Settings { + Settings { + max_connections: 100, + queue_size: 5, + panic_on_new_connection: false, + panic_on_shutdown: false, + fragments_capacity: 10, + fragments_grow: true, + fragment_size: u16::max_value() as usize, + max_fragment_size: usize::max_value(), + in_buffer_capacity: 2048, + in_buffer_grow: true, + out_buffer_capacity: 2048, + out_buffer_grow: true, + panic_on_internal: true, + panic_on_capacity: false, + panic_on_protocol: false, + panic_on_encoding: false, + panic_on_queue: false, + panic_on_io: false, + panic_on_timeout: false, + shutdown_on_interrupt: true, + masking_strict: false, + key_strict: false, + method_strict: false, + encrypt_server: false, + tcp_nodelay: false, + } + } +} + +/// The WebSocket struct. A WebSocket can support multiple incoming and outgoing connections. +pub struct WebSocket<F> +where + F: Factory, +{ + poll: Poll, + handler: io::Handler<F>, +} + +impl<F> WebSocket<F> +where + F: Factory, +{ + /// Create a new WebSocket using the given Factory to create handlers. + pub fn new(factory: F) -> Result<WebSocket<F>> { + Builder::new().build(factory) + } + + /// Consume the WebSocket and bind to the specified address. + /// If the `addr_spec` yields multiple addresses this will return after the + /// first successful bind. `local_addr` can be called to determine which + /// address it ended up binding to. + /// After the server is successfully bound you should start it using `run`. + pub fn bind<A>(mut self, addr_spec: A) -> Result<WebSocket<F>> + where + A: ToSocketAddrs, + { + let mut last_error = Error::new(ErrorKind::Internal, "No address given"); + + for addr in addr_spec.to_socket_addrs()? { + if let Err(e) = self.handler.listen(&mut self.poll, &addr) { + error!("Unable to listen on {}", addr); + last_error = e; + } else { + let actual_addr = self.handler.local_addr().unwrap_or(addr); + info!("Listening for new connections on {}.", actual_addr); + return Ok(self); + } + } + + Err(last_error) + } + + /// Consume the WebSocket and listen for new connections on the specified address. + /// + /// # Safety + /// + /// This method will block until the event loop finishes running. + pub fn listen<A>(self, addr_spec: A) -> Result<WebSocket<F>> + where + A: ToSocketAddrs, + { + self.bind(addr_spec).and_then(|server| server.run()) + } + + /// Queue an outgoing connection on this WebSocket. This method may be called multiple times, + /// but the actual connections will not be established until `run` is called. + pub fn connect(&mut self, url: url::Url) -> Result<&mut WebSocket<F>> { + let sender = self.handler.sender(); + info!("Queuing connection to {}", url); + sender.connect(url)?; + Ok(self) + } + + /// Run the WebSocket. This will run the encapsulated event loop blocking the calling thread until + /// the WebSocket is shutdown. + pub fn run(mut self) -> Result<WebSocket<F>> { + self.handler.run(&mut self.poll)?; + Ok(self) + } + + /// Get a Sender that can be used to send messages on all connections. + /// Calling `send` on this Sender is equivalent to calling `broadcast`. + /// Calling `shutdown` on this Sender will shutdown the WebSocket even if no connections have + /// been established. + #[inline] + pub fn broadcaster(&self) -> Sender { + self.handler.sender() + } + + /// Get the local socket address this socket is bound to. Will return an error + /// if the backend returns an error. Will return a `NotFound` error if + /// this WebSocket is not a listening socket. + pub fn local_addr(&self) -> ::std::io::Result<SocketAddr> { + self.handler.local_addr() + } +} + +/// Utility for constructing a WebSocket from various settings. +#[derive(Debug, Default, Clone, Copy)] +pub struct Builder { + settings: Settings, +} + +// TODO: add convenience methods for each setting +impl Builder { + /// Create a new Builder with default settings. + pub fn new() -> Builder { + Builder::default() + } + + /// Build a WebSocket using this builder and a factory. + /// It is possible to use the same builder to create multiple WebSockets. + pub fn build<F>(&self, factory: F) -> Result<WebSocket<F>> + where + F: Factory, + { + Ok(WebSocket { + poll: Poll::new()?, + handler: io::Handler::new(factory, self.settings), + }) + } + + /// Set the WebSocket settings to use. + pub fn with_settings(&mut self, settings: Settings) -> &mut Builder { + self.settings = settings; + self + } +} diff --git a/third_party/rust/ws/src/message.rs b/third_party/rust/ws/src/message.rs new file mode 100644 index 0000000000..08509a5e1d --- /dev/null +++ b/third_party/rust/ws/src/message.rs @@ -0,0 +1,173 @@ +use std::convert::{From, Into}; +use std::fmt; +use std::result::Result as StdResult; +use std::str::from_utf8; + +use protocol::OpCode; +use result::Result; + +use self::Message::*; + +/// An enum representing the various forms of a WebSocket message. +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Message { + /// A text WebSocket message + Text(String), + /// A binary WebSocket message + Binary(Vec<u8>), +} + +impl Message { + /// Create a new text WebSocket message from a stringable. + pub fn text<S>(string: S) -> Message + where + S: Into<String>, + { + Message::Text(string.into()) + } + + /// Create a new binary WebSocket message by converting to Vec<u8>. + pub fn binary<B>(bin: B) -> Message + where + B: Into<Vec<u8>>, + { + Message::Binary(bin.into()) + } + + /// Indicates whether a message is a text message. + pub fn is_text(&self) -> bool { + match *self { + Text(_) => true, + Binary(_) => false, + } + } + + /// Indicates whether a message is a binary message. + pub fn is_binary(&self) -> bool { + match *self { + Text(_) => false, + Binary(_) => true, + } + } + + /// Get the length of the WebSocket message. + pub fn len(&self) -> usize { + match *self { + Text(ref string) => string.len(), + Binary(ref data) => data.len(), + } + } + + /// Returns true if the WebSocket message has no content. + /// For example, if the other side of the connection sent an empty string. + pub fn is_empty(&self) -> bool { + match *self { + Text(ref string) => string.is_empty(), + Binary(ref data) => data.is_empty(), + } + } + + #[doc(hidden)] + pub fn opcode(&self) -> OpCode { + match *self { + Text(_) => OpCode::Text, + Binary(_) => OpCode::Binary, + } + } + + /// Consume the WebSocket and return it as binary data. + pub fn into_data(self) -> Vec<u8> { + match self { + Text(string) => string.into_bytes(), + Binary(data) => data, + } + } + + /// Attempt to consume the WebSocket message and convert it to a String. + pub fn into_text(self) -> Result<String> { + match self { + Text(string) => Ok(string), + Binary(data) => Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?), + } + } + + /// Attempt to get a &str from the WebSocket message, + /// this will try to convert binary data to utf8. + pub fn as_text(&self) -> Result<&str> { + match *self { + Text(ref string) => Ok(string), + Binary(ref data) => Ok(from_utf8(data)?), + } + } +} + +impl From<String> for Message { + fn from(string: String) -> Message { + Message::text(string) + } +} + +impl<'s> From<&'s str> for Message { + fn from(string: &'s str) -> Message { + Message::text(string) + } +} + +impl<'b> From<&'b [u8]> for Message { + fn from(data: &'b [u8]) -> Message { + Message::binary(data) + } +} + +impl From<Vec<u8>> for Message { + fn from(data: Vec<u8>) -> Message { + Message::binary(data) + } +} + +impl fmt::Display for Message { + fn fmt(&self, f: &mut fmt::Formatter) -> StdResult<(), fmt::Error> { + if let Ok(string) = self.as_text() { + write!(f, "{}", string) + } else { + write!(f, "Binary Data<length={}>", self.len()) + } + } +} + +mod test { + #![allow(unused_imports, unused_variables, dead_code)] + use super::*; + + #[test] + fn display() { + let t = Message::text(format!("test")); + assert_eq!(t.to_string(), "test".to_owned()); + + let bin = Message::binary(vec![0, 1, 3, 4, 241]); + assert_eq!(bin.to_string(), "Binary Data<length=5>".to_owned()); + } + + #[test] + fn binary_convert() { + let bin = [6u8, 7, 8, 9, 10, 241]; + let msg = Message::from(&bin[..]); + assert!(msg.is_binary()); + assert!(msg.into_text().is_err()); + } + + #[test] + fn binary_convert_vec() { + let bin = vec![6u8, 7, 8, 9, 10, 241]; + let msg = Message::from(bin); + assert!(msg.is_binary()); + assert!(msg.into_text().is_err()); + } + + #[test] + fn text_convert() { + let s = "kiwotsukete"; + let msg = Message::from(s); + assert!(msg.is_text()); + } +} diff --git a/third_party/rust/ws/src/protocol.rs b/third_party/rust/ws/src/protocol.rs new file mode 100644 index 0000000000..d4c1f2e326 --- /dev/null +++ b/third_party/rust/ws/src/protocol.rs @@ -0,0 +1,227 @@ +use std::convert::{From, Into}; +use std::fmt; + +use self::OpCode::*; +/// Operation codes as part of rfc6455. +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub enum OpCode { + /// Indicates a continuation frame of a fragmented message. + Continue, + /// Indicates a text data frame. + Text, + /// Indicates a binary data frame. + Binary, + /// Indicates a close control frame. + Close, + /// Indicates a ping control frame. + Ping, + /// Indicates a pong control frame. + Pong, + /// Indicates an invalid opcode was received. + Bad, +} + +impl OpCode { + /// Test whether the opcode indicates a control frame. + pub fn is_control(&self) -> bool { + match *self { + Text | Binary | Continue => false, + _ => true, + } + } +} + +impl fmt::Display for OpCode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Continue => write!(f, "CONTINUE"), + Text => write!(f, "TEXT"), + Binary => write!(f, "BINARY"), + Close => write!(f, "CLOSE"), + Ping => write!(f, "PING"), + Pong => write!(f, "PONG"), + Bad => write!(f, "BAD"), + } + } +} + +impl Into<u8> for OpCode { + fn into(self) -> u8 { + match self { + Continue => 0, + Text => 1, + Binary => 2, + Close => 8, + Ping => 9, + Pong => 10, + Bad => { + debug_assert!( + false, + "Attempted to convert invalid opcode to u8. This is a bug." + ); + 8 // if this somehow happens, a close frame will help us tear down quickly + } + } + } +} + +impl From<u8> for OpCode { + fn from(byte: u8) -> OpCode { + match byte { + 0 => Continue, + 1 => Text, + 2 => Binary, + 8 => Close, + 9 => Ping, + 10 => Pong, + _ => Bad, + } + } +} + +use self::CloseCode::*; +/// Status code used to indicate why an endpoint is closing the WebSocket connection. +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub enum CloseCode { + /// Indicates a normal closure, meaning that the purpose for + /// which the connection was established has been fulfilled. + Normal, + /// Indicates that an endpoint is "going away", such as a server + /// going down or a browser having navigated away from a page. + Away, + /// Indicates that an endpoint is terminating the connection due + /// to a protocol error. + Protocol, + /// Indicates that an endpoint is terminating the connection + /// because it has received a type of data it cannot accept (e.g., an + /// endpoint that understands only text data MAY send this if it + /// receives a binary message). + Unsupported, + /// Indicates that no status code was included in a closing frame. This + /// close code makes it possible to use a single method, `on_close` to + /// handle even cases where no close code was provided. + Status, + /// Indicates an abnormal closure. If the abnormal closure was due to an + /// error, this close code will not be used. Instead, the `on_error` method + /// of the handler will be called with the error. However, if the connection + /// is simply dropped, without an error, this close code will be sent to the + /// handler. + Abnormal, + /// Indicates that an endpoint is terminating the connection + /// because it has received data within a message that was not + /// consistent with the type of the message (e.g., non-UTF-8 [RFC3629] + /// data within a text message). + Invalid, + /// Indicates that an endpoint is terminating the connection + /// because it has received a message that violates its policy. This + /// is a generic status code that can be returned when there is no + /// other more suitable status code (e.g., Unsupported or Size) or if there + /// is a need to hide specific details about the policy. + Policy, + /// Indicates that an endpoint is terminating the connection + /// because it has received a message that is too big for it to + /// process. + Size, + /// Indicates that an endpoint (client) is terminating the + /// connection because it has expected the server to negotiate one or + /// more extension, but the server didn't return them in the response + /// message of the WebSocket handshake. The list of extensions that + /// are needed should be given as the reason for closing. + /// Note that this status code is not used by the server, because it + /// can fail the WebSocket handshake instead. + Extension, + /// Indicates that a server is terminating the connection because + /// it encountered an unexpected condition that prevented it from + /// fulfilling the request. + Error, + /// Indicates that the server is restarting. A client may choose to reconnect, + /// and if it does, it should use a randomized delay of 5-30 seconds between attempts. + Restart, + /// Indicates that the server is overloaded and the client should either connect + /// to a different IP (when multiple targets exist), or reconnect to the same IP + /// when a user has performed an action. + Again, + #[doc(hidden)] + Tls, + #[doc(hidden)] + Empty, + #[doc(hidden)] + Other(u16), +} + +impl Into<u16> for CloseCode { + fn into(self) -> u16 { + match self { + Normal => 1000, + Away => 1001, + Protocol => 1002, + Unsupported => 1003, + Status => 1005, + Abnormal => 1006, + Invalid => 1007, + Policy => 1008, + Size => 1009, + Extension => 1010, + Error => 1011, + Restart => 1012, + Again => 1013, + Tls => 1015, + Empty => 0, + Other(code) => code, + } + } +} + +impl From<u16> for CloseCode { + fn from(code: u16) -> CloseCode { + match code { + 1000 => Normal, + 1001 => Away, + 1002 => Protocol, + 1003 => Unsupported, + 1005 => Status, + 1006 => Abnormal, + 1007 => Invalid, + 1008 => Policy, + 1009 => Size, + 1010 => Extension, + 1011 => Error, + 1012 => Restart, + 1013 => Again, + 1015 => Tls, + 0 => Empty, + _ => Other(code), + } + } +} + +mod test { + #![allow(unused_imports, unused_variables, dead_code)] + use super::*; + + #[test] + fn opcode_from_u8() { + let byte = 2u8; + assert_eq!(OpCode::from(byte), OpCode::Binary); + } + + #[test] + fn opcode_into_u8() { + let text = OpCode::Text; + let byte: u8 = text.into(); + assert_eq!(byte, 1u8); + } + + #[test] + fn closecode_from_u16() { + let byte = 1008u16; + assert_eq!(CloseCode::from(byte), CloseCode::Policy); + } + + #[test] + fn closecode_into_u16() { + let text = CloseCode::Away; + let byte: u16 = text.into(); + assert_eq!(byte, 1001u16); + } +} diff --git a/third_party/rust/ws/src/result.rs b/third_party/rust/ws/src/result.rs new file mode 100644 index 0000000000..eb3c151813 --- /dev/null +++ b/third_party/rust/ws/src/result.rs @@ -0,0 +1,204 @@ +use std::borrow::Cow; +use std::convert::{From, Into}; +use std::error::Error as StdError; +use std::fmt; +use std::io; +use std::result::Result as StdResult; +use std::str::Utf8Error; + +use httparse; +use mio; +#[cfg(feature = "ssl")] +use openssl::ssl::{Error as SslError, HandshakeError as SslHandshakeError}; +#[cfg(feature = "nativetls")] +use native_tls::{Error as SslError, HandshakeError as SslHandshakeError}; +#[cfg(any(feature = "ssl", feature = "nativetls"))] +type HandshakeError = SslHandshakeError<mio::tcp::TcpStream>; + +use communication::Command; + +pub type Result<T> = StdResult<T, Error>; + +/// The type of an error, which may indicate other kinds of errors as the underlying cause. +#[derive(Debug)] +pub enum Kind { + /// Indicates an internal application error. + /// If panic_on_internal is true, which is the default, then the application will panic. + /// Otherwise the WebSocket will automatically attempt to send an Error (1011) close code. + Internal, + /// Indicates a state where some size limit has been exceeded, such as an inability to accept + /// any more new connections. + /// If a Connection is active, the WebSocket will automatically attempt to send + /// a Size (1009) close code. + Capacity, + /// Indicates a violation of the WebSocket protocol. + /// The WebSocket will automatically attempt to send a Protocol (1002) close code, or if + /// this error occurs during a handshake, an HTTP 400 response will be generated. + Protocol, + /// Indicates that the WebSocket received data that should be utf8 encoded but was not. + /// The WebSocket will automatically attempt to send a Invalid Frame Payload Data (1007) close + /// code. + Encoding(Utf8Error), + /// Indicates an underlying IO Error. + /// This kind of error will result in a WebSocket Connection disconnecting. + Io(io::Error), + /// Indicates a failure to parse an HTTP message. + /// This kind of error should only occur during a WebSocket Handshake, and a HTTP 500 response + /// will be generated. + Http(httparse::Error), + /// Indicates a failure to send a signal on the internal EventLoop channel. This means that + /// the WebSocket is overloaded. In order to avoid this error, it is important to set + /// `Settings::max_connections` and `Settings:queue_size` high enough to handle the load. + /// If encountered, retuning from a handler method and waiting for the EventLoop to consume + /// the queue may relieve the situation. + Queue(mio::channel::SendError<Command>), + /// Indicates a failure to perform SSL encryption. + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Ssl(SslError), + /// Indicates a failure to perform SSL encryption. + #[cfg(any(feature = "ssl", feature = "nativetls"))] + SslHandshake(HandshakeError), + /// A custom error kind for use by applications. This error kind involves extra overhead + /// because it will allocate the memory on the heap. The WebSocket ignores such errors by + /// default, simply passing them to the Connection Handler. + Custom(Box<dyn StdError + Send + Sync>), +} + +/// A struct indicating the kind of error that has occurred and any precise details of that error. +pub struct Error { + pub kind: Kind, + pub details: Cow<'static, str>, +} + +impl Error { + pub fn new<I>(kind: Kind, details: I) -> Error + where + I: Into<Cow<'static, str>>, + { + Error { + kind, + details: details.into(), + } + } + + pub fn into_box(self) -> Box<dyn StdError> { + match self.kind { + Kind::Custom(err) => err, + _ => Box::new(self), + } + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self.details.len() > 0 { + write!(f, "WS Error <{:?}>: {}", self.kind, self.details) + } else { + write!(f, "WS Error <{:?}>", self.kind) + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self.details.len() > 0 { + write!(f, "{}: {}", self.description(), self.details) + } else { + write!(f, "{}", self.description()) + } + } +} + +impl StdError for Error { + fn description(&self) -> &str { + match self.kind { + Kind::Internal => "Internal Application Error", + Kind::Capacity => "WebSocket at Capacity", + Kind::Protocol => "WebSocket Protocol Error", + Kind::Encoding(ref err) => err.description(), + Kind::Io(ref err) => err.description(), + Kind::Http(_) => "Unable to parse HTTP", + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Kind::Ssl(ref err) => err.description(), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Kind::SslHandshake(ref err) => err.description(), + Kind::Queue(_) => "Unable to send signal on event loop", + Kind::Custom(ref err) => err.description(), + } + } + + fn cause(&self) -> Option<&dyn StdError> { + match self.kind { + Kind::Encoding(ref err) => Some(err), + Kind::Io(ref err) => Some(err), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Kind::Ssl(ref err) => Some(err), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Kind::SslHandshake(ref err) => err.cause(), + Kind::Custom(ref err) => Some(err.as_ref()), + _ => None, + } + } +} + +impl From<io::Error> for Error { + fn from(err: io::Error) -> Error { + Error::new(Kind::Io(err), "") + } +} + +impl From<httparse::Error> for Error { + fn from(err: httparse::Error) -> Error { + let details = match err { + httparse::Error::HeaderName => "Invalid byte in header name.", + httparse::Error::HeaderValue => "Invalid byte in header value.", + httparse::Error::NewLine => "Invalid byte in new line.", + httparse::Error::Status => "Invalid byte in Response status.", + httparse::Error::Token => "Invalid byte where token is required.", + httparse::Error::TooManyHeaders => { + "Parsed more headers than provided buffer can contain." + } + httparse::Error::Version => "Invalid byte in HTTP version.", + }; + + Error::new(Kind::Http(err), details) + } +} + +impl From<mio::channel::SendError<Command>> for Error { + fn from(err: mio::channel::SendError<Command>) -> Error { + match err { + mio::channel::SendError::Io(err) => Error::from(err), + _ => Error::new(Kind::Queue(err), ""), + } + } +} + +impl From<Utf8Error> for Error { + fn from(err: Utf8Error) -> Error { + Error::new(Kind::Encoding(err), "") + } +} + +#[cfg(any(feature = "ssl", feature = "nativetls"))] +impl From<SslError> for Error { + fn from(err: SslError) -> Error { + Error::new(Kind::Ssl(err), "") + } +} + +#[cfg(any(feature = "ssl", feature = "nativetls"))] +impl From<HandshakeError> for Error { + fn from(err: HandshakeError) -> Error { + Error::new(Kind::SslHandshake(err), "") + } +} + +impl<B> From<Box<B>> for Error +where + B: StdError + Send + Sync + 'static, +{ + fn from(err: Box<B>) -> Error { + Error::new(Kind::Custom(err), "") + } +} diff --git a/third_party/rust/ws/src/stream.rs b/third_party/rust/ws/src/stream.rs new file mode 100644 index 0000000000..3b8d9c441b --- /dev/null +++ b/third_party/rust/ws/src/stream.rs @@ -0,0 +1,358 @@ +use std::io; +use std::io::ErrorKind::WouldBlock; +#[cfg(any(feature = "ssl", feature = "nativetls"))] +use std::mem::replace; +use std::net::SocketAddr; + +use bytes::{Buf, BufMut}; +use mio::tcp::TcpStream; +#[cfg(feature = "nativetls")] +use native_tls::{ + HandshakeError, MidHandshakeTlsStream as MidHandshakeSslStream, TlsStream as SslStream, +}; +#[cfg(feature = "ssl")] +use openssl::ssl::{ErrorCode as SslErrorCode, HandshakeError, MidHandshakeSslStream, SslStream}; + +use result::{Error, Kind, Result}; + +fn map_non_block<T>(res: io::Result<T>) -> io::Result<Option<T>> { + match res { + Ok(value) => Ok(Some(value)), + Err(err) => { + if let WouldBlock = err.kind() { + Ok(None) + } else { + Err(err) + } + } + } +} + +pub trait TryReadBuf: io::Read { + fn try_read_buf<B: BufMut>(&mut self, buf: &mut B) -> io::Result<Option<usize>> + where + Self: Sized, + { + // Reads the length of the slice supplied by buf.mut_bytes into the buffer + // This is not guaranteed to consume an entire datagram or segment. + // If your protocol is msg based (instead of continuous stream) you should + // ensure that your buffer is large enough to hold an entire segment (1532 bytes if not jumbo + // frames) + let res = map_non_block(self.read(unsafe { buf.bytes_mut() })); + + if let Ok(Some(cnt)) = res { + unsafe { + buf.advance_mut(cnt); + } + } + + res + } +} + +pub trait TryWriteBuf: io::Write { + fn try_write_buf<B: Buf>(&mut self, buf: &mut B) -> io::Result<Option<usize>> + where + Self: Sized, + { + let res = map_non_block(self.write(buf.bytes())); + + if let Ok(Some(cnt)) = res { + buf.advance(cnt); + } + + res + } +} + +impl<T: io::Read> TryReadBuf for T {} +impl<T: io::Write> TryWriteBuf for T {} + +use self::Stream::*; +pub enum Stream { + Tcp(TcpStream), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(TlsStream), +} + +impl Stream { + pub fn tcp(stream: TcpStream) -> Stream { + Tcp(stream) + } + + #[cfg(any(feature = "ssl", feature = "nativetls"))] + pub fn tls(stream: MidHandshakeSslStream<TcpStream>) -> Stream { + Tls(TlsStream::Handshake { + sock: stream, + negotiating: false, + }) + } + + #[cfg(any(feature = "ssl", feature = "nativetls"))] + pub fn tls_live(stream: SslStream<TcpStream>) -> Stream { + Tls(TlsStream::Live(stream)) + } + + #[cfg(any(feature = "ssl", feature = "nativetls"))] + pub fn is_tls(&self) -> bool { + match *self { + Tcp(_) => false, + Tls(_) => true, + } + } + + pub fn evented(&self) -> &TcpStream { + match *self { + Tcp(ref sock) => sock, + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(ref inner) => inner.evented(), + } + } + + pub fn is_negotiating(&self) -> bool { + match *self { + Tcp(_) => false, + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(ref inner) => inner.is_negotiating(), + } + } + + pub fn clear_negotiating(&mut self) -> Result<()> { + match *self { + Tcp(_) => Err(Error::new( + Kind::Internal, + "Attempted to clear negotiating flag on non ssl connection.", + )), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(ref mut inner) => inner.clear_negotiating(), + } + } + + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + match *self { + Tcp(ref sock) => sock.peer_addr(), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(ref inner) => inner.peer_addr(), + } + } + + pub fn local_addr(&self) -> io::Result<SocketAddr> { + match *self { + Tcp(ref sock) => sock.local_addr(), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(ref inner) => inner.local_addr(), + } + } +} + +impl io::Read for Stream { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + match *self { + Tcp(ref mut sock) => sock.read(buf), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(TlsStream::Live(ref mut sock)) => sock.read(buf), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(ref mut tls_stream) => { + trace!("Attempting to read ssl handshake."); + match replace(tls_stream, TlsStream::Upgrading) { + TlsStream::Live(_) | TlsStream::Upgrading => unreachable!(), + TlsStream::Handshake { + sock, + mut negotiating, + } => match sock.handshake() { + Ok(mut sock) => { + trace!("Completed SSL Handshake"); + let res = sock.read(buf); + *tls_stream = TlsStream::Live(sock); + res + } + #[cfg(feature = "ssl")] + Err(HandshakeError::SetupFailure(err)) => { + Err(io::Error::new(io::ErrorKind::Other, err)) + } + #[cfg(feature = "ssl")] + Err(HandshakeError::Failure(mid)) + | Err(HandshakeError::WouldBlock(mid)) => { + if mid.error().code() == SslErrorCode::WANT_READ { + negotiating = true; + } + let err = if let Some(io_error) = mid.error().io_error() { + Err(io::Error::new( + io_error.kind(), + format!("{:?}", io_error.get_ref()), + )) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + format!("{}", mid.error()), + )) + }; + *tls_stream = TlsStream::Handshake { + sock: mid, + negotiating, + }; + err + } + #[cfg(feature = "nativetls")] + Err(HandshakeError::WouldBlock(mid)) => { + negotiating = true; + *tls_stream = TlsStream::Handshake { + sock: mid, + negotiating: negotiating, + }; + Err(io::Error::new(io::ErrorKind::WouldBlock, "SSL would block")) + } + #[cfg(feature = "nativetls")] + Err(HandshakeError::Failure(err)) => { + Err(io::Error::new(io::ErrorKind::Other, format!("{}", err))) + } + }, + } + } + } + } +} + +impl io::Write for Stream { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + match *self { + Tcp(ref mut sock) => sock.write(buf), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(TlsStream::Live(ref mut sock)) => sock.write(buf), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(ref mut tls_stream) => { + trace!("Attempting to write ssl handshake."); + match replace(tls_stream, TlsStream::Upgrading) { + TlsStream::Live(_) | TlsStream::Upgrading => unreachable!(), + TlsStream::Handshake { + sock, + mut negotiating, + } => match sock.handshake() { + Ok(mut sock) => { + trace!("Completed SSL Handshake"); + let res = sock.write(buf); + *tls_stream = TlsStream::Live(sock); + res + } + #[cfg(feature = "ssl")] + Err(HandshakeError::SetupFailure(err)) => { + Err(io::Error::new(io::ErrorKind::Other, err)) + } + #[cfg(feature = "ssl")] + Err(HandshakeError::Failure(mid)) + | Err(HandshakeError::WouldBlock(mid)) => { + if mid.error().code() == SslErrorCode::WANT_READ { + negotiating = true; + } else { + negotiating = false; + } + let err = if let Some(io_error) = mid.error().io_error() { + Err(io::Error::new( + io_error.kind(), + format!("{:?}", io_error.get_ref()), + )) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + format!("{}", mid.error()), + )) + }; + *tls_stream = TlsStream::Handshake { + sock: mid, + negotiating, + }; + err + } + #[cfg(feature = "nativetls")] + Err(HandshakeError::WouldBlock(mid)) => { + negotiating = true; + *tls_stream = TlsStream::Handshake { + sock: mid, + negotiating: negotiating, + }; + Err(io::Error::new(io::ErrorKind::WouldBlock, "SSL would block")) + } + #[cfg(feature = "nativetls")] + Err(HandshakeError::Failure(err)) => { + Err(io::Error::new(io::ErrorKind::Other, format!("{}", err))) + } + }, + } + } + } + } + + fn flush(&mut self) -> io::Result<()> { + match *self { + Tcp(ref mut sock) => sock.flush(), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(TlsStream::Live(ref mut sock)) => sock.flush(), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(TlsStream::Handshake { ref mut sock, .. }) => sock.get_mut().flush(), + #[cfg(any(feature = "ssl", feature = "nativetls"))] + Tls(TlsStream::Upgrading) => panic!("Tried to access actively upgrading TlsStream"), + } + } +} + +#[cfg(any(feature = "ssl", feature = "nativetls"))] +pub enum TlsStream { + Live(SslStream<TcpStream>), + Handshake { + sock: MidHandshakeSslStream<TcpStream>, + negotiating: bool, + }, + Upgrading, +} + +#[cfg(any(feature = "ssl", feature = "nativetls"))] +impl TlsStream { + pub fn evented(&self) -> &TcpStream { + match *self { + TlsStream::Live(ref sock) => sock.get_ref(), + TlsStream::Handshake { ref sock, .. } => sock.get_ref(), + TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), + } + } + + pub fn is_negotiating(&self) -> bool { + match *self { + TlsStream::Live(_) => false, + TlsStream::Handshake { + sock: _, + negotiating, + } => negotiating, + TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), + } + } + + pub fn clear_negotiating(&mut self) -> Result<()> { + match *self { + TlsStream::Live(_) => Err(Error::new( + Kind::Internal, + "Attempted to clear negotiating flag on live ssl connection.", + )), + TlsStream::Handshake { + sock: _, + ref mut negotiating, + } => Ok(*negotiating = false), + TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), + } + } + + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + match *self { + TlsStream::Live(ref sock) => sock.get_ref().peer_addr(), + TlsStream::Handshake { ref sock, .. } => sock.get_ref().peer_addr(), + TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), + } + } + + pub fn local_addr(&self) -> io::Result<SocketAddr> { + match *self { + TlsStream::Live(ref sock) => sock.get_ref().local_addr(), + TlsStream::Handshake { ref sock, .. } => sock.get_ref().local_addr(), + TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"), + } + } +} diff --git a/third_party/rust/ws/src/util.rs b/third_party/rust/ws/src/util.rs new file mode 100644 index 0000000000..fc66394a74 --- /dev/null +++ b/third_party/rust/ws/src/util.rs @@ -0,0 +1,9 @@ +//! The util module rexports some tools from mio in order to facilitate handling timeouts. + +/// Used to identify some timed-out event. +pub use mio::Token; +/// A handle to a specific timeout. +pub use mio_extras::timer::Timeout; +#[cfg(any(feature = "ssl", feature = "nativetls"))] +/// TcpStream underlying the WebSocket +pub use mio::tcp::TcpStream; |