summaryrefslogtreecommitdiffstats
path: root/third_party/rust/ws/src
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/ws/src')
-rw-r--r--third_party/rust/ws/src/communication.rs249
-rw-r--r--third_party/rust/ws/src/connection.rs1230
-rw-r--r--third_party/rust/ws/src/deflate/context.rs268
-rw-r--r--third_party/rust/ws/src/deflate/extension.rs565
-rw-r--r--third_party/rust/ws/src/deflate/mod.rs9
-rw-r--r--third_party/rust/ws/src/factory.rs188
-rw-r--r--third_party/rust/ws/src/frame.rs495
-rw-r--r--third_party/rust/ws/src/handler.rs423
-rw-r--r--third_party/rust/ws/src/handshake.rs740
-rw-r--r--third_party/rust/ws/src/io.rs985
-rw-r--r--third_party/rust/ws/src/lib.rs391
-rw-r--r--third_party/rust/ws/src/message.rs173
-rw-r--r--third_party/rust/ws/src/protocol.rs227
-rw-r--r--third_party/rust/ws/src/result.rs204
-rw-r--r--third_party/rust/ws/src/stream.rs358
-rw-r--r--third_party/rust/ws/src/util.rs9
16 files changed, 6514 insertions, 0 deletions
diff --git a/third_party/rust/ws/src/communication.rs b/third_party/rust/ws/src/communication.rs
new file mode 100644
index 0000000000..2b2822f03c
--- /dev/null
+++ b/third_party/rust/ws/src/communication.rs
@@ -0,0 +1,249 @@
+use std::borrow::Cow;
+use std::convert::Into;
+
+use mio;
+use mio::Token;
+use mio_extras::timer::Timeout;
+use url;
+
+use io::ALL;
+use message;
+use protocol::CloseCode;
+use result::{Error, Result};
+use std::cmp::PartialEq;
+use std::hash::{Hash, Hasher};
+use std::fmt;
+
+#[derive(Debug, Clone)]
+pub enum Signal {
+ Message(message::Message),
+ Close(CloseCode, Cow<'static, str>),
+ Ping(Vec<u8>),
+ Pong(Vec<u8>),
+ Connect(url::Url),
+ Shutdown,
+ Timeout { delay: u64, token: Token },
+ Cancel(Timeout),
+}
+
+#[derive(Debug, Clone)]
+pub struct Command {
+ token: Token,
+ signal: Signal,
+ connection_id: u32,
+}
+
+impl Command {
+ pub fn token(&self) -> Token {
+ self.token
+ }
+
+ pub fn into_signal(self) -> Signal {
+ self.signal
+ }
+
+ pub fn connection_id(&self) -> u32 {
+ self.connection_id
+ }
+}
+
+/// A representation of the output of the WebSocket connection. Use this to send messages to the
+/// other endpoint.
+#[derive(Clone)]
+pub struct Sender {
+ token: Token,
+ channel: mio::channel::SyncSender<Command>,
+ connection_id: u32,
+}
+
+impl fmt::Debug for Sender {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f,
+ "Sender {{ token: {:?}, channel: mio::channel::SyncSender<Command>, connection_id: {:?} }}",
+ self.token, self.connection_id)
+ }
+}
+
+impl PartialEq for Sender {
+ fn eq(&self, other: &Sender) -> bool {
+ self.token == other.token && self.connection_id == other.connection_id
+ }
+}
+
+impl Eq for Sender { }
+
+impl Hash for Sender {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ self.connection_id.hash(state);
+ self.token.hash(state);
+ }
+}
+
+
+impl Sender {
+ #[doc(hidden)]
+ #[inline]
+ pub fn new(
+ token: Token,
+ channel: mio::channel::SyncSender<Command>,
+ connection_id: u32,
+ ) -> Sender {
+ Sender {
+ token,
+ channel,
+ connection_id,
+ }
+ }
+
+ /// A Token identifying this sender within the WebSocket.
+ #[inline]
+ pub fn token(&self) -> Token {
+ self.token
+ }
+
+ /// A connection_id identifying this sender within the WebSocket.
+ #[inline]
+ pub fn connection_id(&self) -> u32 {
+ self.connection_id
+ }
+
+ /// Send a message over the connection.
+ #[inline]
+ pub fn send<M>(&self, msg: M) -> Result<()>
+ where
+ M: Into<message::Message>,
+ {
+ self.channel
+ .send(Command {
+ token: self.token,
+ signal: Signal::Message(msg.into()),
+ connection_id: self.connection_id,
+ })
+ .map_err(Error::from)
+ }
+
+ /// Send a message to the endpoints of all connections.
+ ///
+ /// Be careful with this method. It does not discriminate between client and server connections.
+ /// If your WebSocket is only functioning as a server, then usage is simple, this method will
+ /// send a copy of the message to each connected client. However, if you have a WebSocket that
+ /// is listening for connections and is also connected to another WebSocket, this method will
+ /// broadcast a copy of the message to all the clients connected and to that WebSocket server.
+ #[inline]
+ pub fn broadcast<M>(&self, msg: M) -> Result<()>
+ where
+ M: Into<message::Message>,
+ {
+ self.channel
+ .send(Command {
+ token: ALL,
+ signal: Signal::Message(msg.into()),
+ connection_id: self.connection_id,
+ })
+ .map_err(Error::from)
+ }
+
+ /// Send a close code to the other endpoint.
+ #[inline]
+ pub fn close(&self, code: CloseCode) -> Result<()> {
+ self.channel
+ .send(Command {
+ token: self.token,
+ signal: Signal::Close(code, "".into()),
+ connection_id: self.connection_id,
+ })
+ .map_err(Error::from)
+ }
+
+ /// Send a close code and provide a descriptive reason for closing.
+ #[inline]
+ pub fn close_with_reason<S>(&self, code: CloseCode, reason: S) -> Result<()>
+ where
+ S: Into<Cow<'static, str>>,
+ {
+ self.channel
+ .send(Command {
+ token: self.token,
+ signal: Signal::Close(code, reason.into()),
+ connection_id: self.connection_id,
+ })
+ .map_err(Error::from)
+ }
+
+ /// Send a ping to the other endpoint with the given test data.
+ #[inline]
+ pub fn ping(&self, data: Vec<u8>) -> Result<()> {
+ self.channel
+ .send(Command {
+ token: self.token,
+ signal: Signal::Ping(data),
+ connection_id: self.connection_id,
+ })
+ .map_err(Error::from)
+ }
+
+ /// Send a pong to the other endpoint responding with the given test data.
+ #[inline]
+ pub fn pong(&self, data: Vec<u8>) -> Result<()> {
+ self.channel
+ .send(Command {
+ token: self.token,
+ signal: Signal::Pong(data),
+ connection_id: self.connection_id,
+ })
+ .map_err(Error::from)
+ }
+
+ /// Queue a new connection on this WebSocket to the specified URL.
+ #[inline]
+ pub fn connect(&self, url: url::Url) -> Result<()> {
+ self.channel
+ .send(Command {
+ token: self.token,
+ signal: Signal::Connect(url),
+ connection_id: self.connection_id,
+ })
+ .map_err(Error::from)
+ }
+
+ /// Request that all connections terminate and that the WebSocket stop running.
+ #[inline]
+ pub fn shutdown(&self) -> Result<()> {
+ self.channel
+ .send(Command {
+ token: self.token,
+ signal: Signal::Shutdown,
+ connection_id: self.connection_id,
+ })
+ .map_err(Error::from)
+ }
+
+ /// Schedule a `token` to be sent to the WebSocket Handler's `on_timeout` method
+ /// after `ms` milliseconds
+ #[inline]
+ pub fn timeout(&self, ms: u64, token: Token) -> Result<()> {
+ self.channel
+ .send(Command {
+ token: self.token,
+ signal: Signal::Timeout { delay: ms, token },
+ connection_id: self.connection_id,
+ })
+ .map_err(Error::from)
+ }
+
+ /// Queue the cancellation of a previously scheduled timeout.
+ ///
+ /// This method is not guaranteed to prevent the timeout from occurring, because it is
+ /// possible to call this method after a timeout has already occurred. It is still necessary to
+ /// handle spurious timeouts.
+ #[inline]
+ pub fn cancel(&self, timeout: Timeout) -> Result<()> {
+ self.channel
+ .send(Command {
+ token: self.token,
+ signal: Signal::Cancel(timeout),
+ connection_id: self.connection_id,
+ })
+ .map_err(Error::from)
+ }
+}
diff --git a/third_party/rust/ws/src/connection.rs b/third_party/rust/ws/src/connection.rs
new file mode 100644
index 0000000000..b639695e5c
--- /dev/null
+++ b/third_party/rust/ws/src/connection.rs
@@ -0,0 +1,1230 @@
+use std::borrow::Borrow;
+use std::collections::VecDeque;
+use std::io::{Cursor, Read, Seek, SeekFrom, Write};
+use std::mem::replace;
+use std::net::SocketAddr;
+use std::str::from_utf8;
+
+use mio::tcp::TcpStream;
+use mio::{Ready, Token};
+use mio_extras::timer::Timeout;
+use url;
+
+#[cfg(feature = "nativetls")]
+use native_tls::HandshakeError;
+#[cfg(feature = "ssl")]
+use openssl::ssl::HandshakeError;
+
+use frame::Frame;
+use handler::Handler;
+use handshake::{Handshake, Request, Response};
+use message::Message;
+use protocol::{CloseCode, OpCode};
+use result::{Error, Kind, Result};
+use stream::{Stream, TryReadBuf, TryWriteBuf};
+
+use self::Endpoint::*;
+use self::State::*;
+
+use super::Settings;
+
+#[derive(Debug)]
+pub enum State {
+ // Tcp connection accepted, waiting for handshake to complete
+ Connecting(Cursor<Vec<u8>>, Cursor<Vec<u8>>),
+ // Ready to send/receive messages
+ Open,
+ AwaitingClose,
+ RespondingClose,
+ FinishedClose,
+}
+
+/// A little more semantic than a boolean
+#[derive(Debug, Eq, PartialEq, Clone)]
+pub enum Endpoint {
+ /// Will mask outgoing frames
+ Client(url::Url),
+ /// Won't mask outgoing frames
+ Server,
+}
+
+impl State {
+ #[inline]
+ pub fn is_connecting(&self) -> bool {
+ match *self {
+ State::Connecting(..) => true,
+ _ => false,
+ }
+ }
+
+ #[allow(dead_code)]
+ #[inline]
+ pub fn is_open(&self) -> bool {
+ match *self {
+ State::Open => true,
+ _ => false,
+ }
+ }
+
+ #[inline]
+ pub fn is_closing(&self) -> bool {
+ match *self {
+ State::AwaitingClose | State::FinishedClose => true,
+ _ => false,
+ }
+ }
+}
+
+pub struct Connection<H>
+where
+ H: Handler,
+{
+ token: Token,
+ socket: Stream,
+ state: State,
+ endpoint: Endpoint,
+ events: Ready,
+
+ fragments: VecDeque<Frame>,
+
+ in_buffer: Cursor<Vec<u8>>,
+ out_buffer: Cursor<Vec<u8>>,
+
+ handler: H,
+
+ addresses: Vec<SocketAddr>,
+
+ settings: Settings,
+ connection_id: u32,
+}
+
+impl<H> Connection<H>
+where
+ H: Handler,
+{
+ pub fn new(
+ tok: Token,
+ sock: TcpStream,
+ handler: H,
+ settings: Settings,
+ connection_id: u32,
+ ) -> Connection<H> {
+ Connection {
+ token: tok,
+ socket: Stream::tcp(sock),
+ state: Connecting(
+ Cursor::new(Vec::with_capacity(2048)),
+ Cursor::new(Vec::with_capacity(2048)),
+ ),
+ endpoint: Endpoint::Server,
+ events: Ready::empty(),
+ fragments: VecDeque::with_capacity(settings.fragments_capacity),
+ in_buffer: Cursor::new(Vec::with_capacity(settings.in_buffer_capacity)),
+ out_buffer: Cursor::new(Vec::with_capacity(settings.out_buffer_capacity)),
+ handler,
+ addresses: Vec::new(),
+ settings,
+ connection_id,
+ }
+ }
+
+ pub fn as_server(&mut self) -> Result<()> {
+ self.events.insert(Ready::readable());
+ Ok(())
+ }
+
+ pub fn as_client(&mut self, url: url::Url, addrs: Vec<SocketAddr>) -> Result<()> {
+ if let Connecting(ref mut req_buf, _) = self.state {
+ let req = self.handler.build_request(&url)?;
+ self.addresses = addrs;
+ self.events.insert(Ready::writable());
+ self.endpoint = Endpoint::Client(url);
+ req.format(req_buf.get_mut())
+ } else {
+ Err(Error::new(
+ Kind::Internal,
+ "Tried to set connection to client while not connecting.",
+ ))
+ }
+ }
+
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ pub fn encrypt(&mut self) -> Result<()> {
+ let sock = self.socket().try_clone()?;
+ let ssl_stream = match self.endpoint {
+ Server => self.handler.upgrade_ssl_server(sock),
+ Client(ref url) => self.handler.upgrade_ssl_client(sock, url),
+ };
+
+ match ssl_stream {
+ Ok(stream) => {
+ self.socket = Stream::tls_live(stream);
+ Ok(())
+ }
+ #[cfg(feature = "ssl")]
+ Err(Error {
+ kind: Kind::SslHandshake(handshake_err),
+ details,
+ }) => match handshake_err {
+ HandshakeError::SetupFailure(_) => {
+ Err(Error::new(Kind::SslHandshake(handshake_err), details))
+ }
+ HandshakeError::Failure(mid) | HandshakeError::WouldBlock(mid) => {
+ self.socket = Stream::tls(mid);
+ Ok(())
+ }
+ },
+ #[cfg(feature = "nativetls")]
+ Err(Error {
+ kind: Kind::SslHandshake(handshake_err),
+ details,
+ }) => match handshake_err {
+ HandshakeError::Failure(_) => {
+ Err(Error::new(Kind::SslHandshake(handshake_err), details))
+ }
+ HandshakeError::WouldBlock(mid) => {
+ self.socket = Stream::tls(mid);
+ Ok(())
+ }
+ },
+ Err(e) => Err(e),
+ }
+ }
+
+ pub fn token(&self) -> Token {
+ self.token
+ }
+
+ pub fn socket(&self) -> &TcpStream {
+ self.socket.evented()
+ }
+
+ pub fn connection_id(&self) -> u32 {
+ self.connection_id
+ }
+
+ fn peer_addr(&self) -> String {
+ if let Ok(addr) = self.socket.peer_addr() {
+ addr.to_string()
+ } else {
+ "UNKNOWN".into()
+ }
+ }
+
+ // Resetting may be necessary in order to try all possible addresses for a server
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ pub fn reset(&mut self) -> Result<()> {
+ // if self.is_client() {
+ if let Client(ref url) = self.endpoint {
+ if let Connecting(ref mut req, ref mut res) = self.state {
+ req.set_position(0);
+ res.set_position(0);
+ self.events.remove(Ready::readable());
+ self.events.insert(Ready::writable());
+
+ if let Some(ref addr) = self.addresses.pop() {
+ let sock = TcpStream::connect(addr)?;
+ if self.socket.is_tls() {
+ let ssl_stream = self.handler.upgrade_ssl_client(sock, url);
+ match ssl_stream {
+ Ok(stream) => {
+ self.socket = Stream::tls_live(stream);
+ Ok(())
+ }
+ #[cfg(feature = "ssl")]
+ Err(Error {
+ kind: Kind::SslHandshake(handshake_err),
+ details,
+ }) => match handshake_err {
+ HandshakeError::SetupFailure(_) => {
+ Err(Error::new(Kind::SslHandshake(handshake_err), details))
+ }
+ HandshakeError::Failure(mid) | HandshakeError::WouldBlock(mid) => {
+ self.socket = Stream::tls(mid);
+ Ok(())
+ }
+ },
+ #[cfg(feature = "nativetls")]
+ Err(Error {
+ kind: Kind::SslHandshake(handshake_err),
+ details,
+ }) => match handshake_err {
+ HandshakeError::Failure(_) => {
+ Err(Error::new(Kind::SslHandshake(handshake_err), details))
+ }
+ HandshakeError::WouldBlock(mid) => {
+ self.socket = Stream::tls(mid);
+ Ok(())
+ }
+ },
+ Err(e) => Err(e),
+ }
+ } else {
+ self.socket = Stream::tcp(sock);
+ Ok(())
+ }
+ } else {
+ if self.settings.panic_on_new_connection {
+ panic!("Unable to connect to server.");
+ }
+ Err(Error::new(Kind::Internal, "Exhausted possible addresses."))
+ }
+ } else {
+ Err(Error::new(
+ Kind::Internal,
+ "Unable to reset client connection because it is active.",
+ ))
+ }
+ } else {
+ Err(Error::new(
+ Kind::Internal,
+ "Server connections cannot be reset.",
+ ))
+ }
+ }
+
+ #[cfg(not(any(feature = "ssl", feature = "nativetls")))]
+ pub fn reset(&mut self) -> Result<()> {
+ if self.is_client() {
+ if let Connecting(ref mut req, ref mut res) = self.state {
+ req.set_position(0);
+ res.set_position(0);
+ self.events.remove(Ready::readable());
+ self.events.insert(Ready::writable());
+
+ if let Some(ref addr) = self.addresses.pop() {
+ let sock = TcpStream::connect(addr)?;
+ self.socket = Stream::tcp(sock);
+ Ok(())
+ } else {
+ if self.settings.panic_on_new_connection {
+ panic!("Unable to connect to server.");
+ }
+ Err(Error::new(Kind::Internal, "Exhausted possible addresses."))
+ }
+ } else {
+ Err(Error::new(
+ Kind::Internal,
+ "Unable to reset client connection because it is active.",
+ ))
+ }
+ } else {
+ Err(Error::new(
+ Kind::Internal,
+ "Server connections cannot be reset.",
+ ))
+ }
+ }
+
+ pub fn events(&self) -> Ready {
+ self.events
+ }
+
+ pub fn is_client(&self) -> bool {
+ match self.endpoint {
+ Client(_) => true,
+ Server => false,
+ }
+ }
+
+ pub fn is_server(&self) -> bool {
+ match self.endpoint {
+ Client(_) => false,
+ Server => true,
+ }
+ }
+
+ pub fn shutdown(&mut self) {
+ self.handler.on_shutdown();
+ if let Err(err) = self.send_close(CloseCode::Away, "Shutting down.") {
+ self.handler.on_error(err);
+ self.disconnect()
+ }
+ }
+
+ #[inline]
+ pub fn new_timeout(&mut self, event: Token, timeout: Timeout) -> Result<()> {
+ self.handler.on_new_timeout(event, timeout)
+ }
+
+ #[inline]
+ pub fn timeout_triggered(&mut self, event: Token) -> Result<()> {
+ self.handler.on_timeout(event)
+ }
+
+ pub fn error(&mut self, err: Error) {
+ match self.state {
+ Connecting(_, ref mut res) => match err.kind {
+ #[cfg(feature = "ssl")]
+ Kind::Ssl(_) => {
+ self.handler.on_error(err);
+ self.events = Ready::empty();
+ }
+ Kind::Io(_) => {
+ self.handler.on_error(err);
+ self.events = Ready::empty();
+ }
+ Kind::Protocol => {
+ let msg = err.to_string();
+ self.handler.on_error(err);
+ if let Server = self.endpoint {
+ res.get_mut().clear();
+ if let Err(err) =
+ write!(res.get_mut(), "HTTP/1.1 400 Bad Request\r\n\r\n{}", msg)
+ {
+ self.handler.on_error(Error::from(err));
+ self.events = Ready::empty();
+ } else {
+ self.events.remove(Ready::readable());
+ self.events.insert(Ready::writable());
+ }
+ } else {
+ self.events = Ready::empty();
+ }
+ }
+ _ => {
+ let msg = err.to_string();
+ self.handler.on_error(err);
+ if let Server = self.endpoint {
+ res.get_mut().clear();
+ if let Err(err) = write!(
+ res.get_mut(),
+ "HTTP/1.1 500 Internal Server Error\r\n\r\n{}",
+ msg
+ ) {
+ self.handler.on_error(Error::from(err));
+ self.events = Ready::empty();
+ } else {
+ self.events.remove(Ready::readable());
+ self.events.insert(Ready::writable());
+ }
+ } else {
+ self.events = Ready::empty();
+ }
+ }
+ },
+ _ => {
+ match err.kind {
+ Kind::Internal => {
+ if self.settings.panic_on_internal {
+ panic!("Panicking on internal error -- {}", err);
+ }
+ let reason = format!("{}", err);
+
+ self.handler.on_error(err);
+ if let Err(err) = self.send_close(CloseCode::Error, reason) {
+ self.handler.on_error(err);
+ self.disconnect()
+ }
+ }
+ Kind::Capacity => {
+ if self.settings.panic_on_capacity {
+ panic!("Panicking on capacity error -- {}", err);
+ }
+ let reason = format!("{}", err);
+
+ self.handler.on_error(err);
+ if let Err(err) = self.send_close(CloseCode::Size, reason) {
+ self.handler.on_error(err);
+ self.disconnect()
+ }
+ }
+ Kind::Protocol => {
+ if self.settings.panic_on_protocol {
+ panic!("Panicking on protocol error -- {}", err);
+ }
+ let reason = format!("{}", err);
+
+ self.handler.on_error(err);
+ if let Err(err) = self.send_close(CloseCode::Protocol, reason) {
+ self.handler.on_error(err);
+ self.disconnect()
+ }
+ }
+ Kind::Encoding(_) => {
+ if self.settings.panic_on_encoding {
+ panic!("Panicking on encoding error -- {}", err);
+ }
+ let reason = format!("{}", err);
+
+ self.handler.on_error(err);
+ if let Err(err) = self.send_close(CloseCode::Invalid, reason) {
+ self.handler.on_error(err);
+ self.disconnect()
+ }
+ }
+ Kind::Http(_) => {
+ // This may happen if some handler writes a bad response
+ self.handler.on_error(err);
+ error!("Disconnecting WebSocket.");
+ self.disconnect()
+ }
+ Kind::Custom(_) => {
+ self.handler.on_error(err);
+ }
+ Kind::Queue(_) => {
+ if self.settings.panic_on_queue {
+ panic!("Panicking on queue error -- {}", err);
+ }
+ self.handler.on_error(err);
+ }
+ _ => {
+ if self.settings.panic_on_io {
+ panic!("Panicking on io error -- {}", err);
+ }
+ self.handler.on_error(err);
+ self.disconnect()
+ }
+ }
+ }
+ }
+ }
+
+ pub fn disconnect(&mut self) {
+ match self.state {
+ RespondingClose | FinishedClose | Connecting(_, _) => (),
+ _ => {
+ self.handler.on_close(CloseCode::Abnormal, "");
+ }
+ }
+ self.events = Ready::empty()
+ }
+
+ pub fn consume(self) -> H {
+ self.handler
+ }
+
+ fn write_handshake(&mut self) -> Result<()> {
+ if let Connecting(ref mut req, ref mut res) = self.state {
+ match self.endpoint {
+ Server => {
+ let mut done = false;
+ if self.socket.try_write_buf(res)?.is_some() {
+ if res.position() as usize == res.get_ref().len() {
+ done = true
+ }
+ }
+ if !done {
+ return Ok(());
+ }
+ }
+ Client(_) => {
+ if self.socket.try_write_buf(req)?.is_some() {
+ if req.position() as usize == req.get_ref().len() {
+ trace!(
+ "Finished writing handshake request to {}",
+ self.socket
+ .peer_addr()
+ .map(|addr| addr.to_string())
+ .unwrap_or_else(|_| "UNKNOWN".into())
+ );
+ self.events.insert(Ready::readable());
+ self.events.remove(Ready::writable());
+ }
+ }
+ return Ok(());
+ }
+ }
+ }
+
+ if let Connecting(ref req, ref res) = replace(&mut self.state, Open) {
+ trace!(
+ "Finished writing handshake response to {}",
+ self.peer_addr()
+ );
+
+ let request = match Request::parse(req.get_ref()) {
+ Ok(Some(req)) => req,
+ _ => {
+ // An error should already have been sent for the first time it failed to
+ // parse. We don't call disconnect here because `on_open` hasn't been called yet.
+ self.state = FinishedClose;
+ self.events = Ready::empty();
+ return Ok(());
+ }
+ };
+
+ let response = Response::parse(res.get_ref())?.ok_or_else(|| {
+ Error::new(
+ Kind::Internal,
+ "Failed to parse response after handshake is complete.",
+ )
+ })?;
+
+ if response.status() != 101 {
+ self.events = Ready::empty();
+ return Ok(());
+ } else {
+ self.handler.on_open(Handshake {
+ request,
+ response,
+ peer_addr: self.socket.peer_addr().ok(),
+ local_addr: self.socket.local_addr().ok(),
+ })?;
+ debug!("Connection to {} is now open.", self.peer_addr());
+ self.events.insert(Ready::readable());
+ self.check_events();
+ return Ok(());
+ }
+ } else {
+ Err(Error::new(
+ Kind::Internal,
+ "Tried to write WebSocket handshake while not in connecting state!",
+ ))
+ }
+ }
+
+ fn read_handshake(&mut self) -> Result<()> {
+ if let Connecting(ref mut req, ref mut res) = self.state {
+ match self.endpoint {
+ Server => {
+ if let Some(read) = self.socket.try_read_buf(req.get_mut())? {
+ if read == 0 {
+ self.events = Ready::empty();
+ return Ok(());
+ }
+ if let Some(ref request) = Request::parse(req.get_ref())? {
+ trace!("Handshake request received: \n{}", request);
+ let response = self.handler.on_request(request)?;
+ response.format(res.get_mut())?;
+ self.events.remove(Ready::readable());
+ self.events.insert(Ready::writable());
+ }
+ }
+ return Ok(());
+ }
+ Client(_) => {
+ if self.socket.try_read_buf(res.get_mut())?.is_some() {
+ // TODO: see if this can be optimized with drain
+ let end = {
+ let data = res.get_ref();
+ let end = data.iter()
+ .enumerate()
+ .take_while(|&(ind, _)| !data[..ind].ends_with(b"\r\n\r\n"))
+ .count();
+ if !data[..end].ends_with(b"\r\n\r\n") {
+ return Ok(());
+ }
+ self.in_buffer.get_mut().extend(&data[end..]);
+ end
+ };
+ res.get_mut().truncate(end);
+ } else {
+ // NOTE: wait to be polled again; response not ready.
+ return Ok(());
+ }
+ }
+ }
+ }
+
+ if let Connecting(ref req, ref res) = replace(&mut self.state, Open) {
+ trace!(
+ "Finished reading handshake response from {}",
+ self.peer_addr()
+ );
+
+ let request = Request::parse(req.get_ref())?.ok_or_else(|| {
+ Error::new(
+ Kind::Internal,
+ "Failed to parse request after handshake is complete.",
+ )
+ })?;
+
+ let response = Response::parse(res.get_ref())?.ok_or_else(|| {
+ Error::new(
+ Kind::Internal,
+ "Failed to parse response after handshake is complete.",
+ )
+ })?;
+
+ trace!("Handshake response received: \n{}", response);
+
+ if response.status() != 101 {
+ if response.status() != 301 && response.status() != 302 {
+ return Err(Error::new(Kind::Protocol, "Handshake failed."));
+ } else {
+ return Ok(());
+ }
+ }
+
+ if self.settings.key_strict {
+ let req_key = request.hashed_key()?;
+ let res_key = from_utf8(response.key()?)?;
+ if req_key != res_key {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!(
+ "Received incorrect WebSocket Accept key: {} vs {}",
+ req_key, res_key
+ ),
+ ));
+ }
+ }
+
+ self.handler.on_response(&response)?;
+ self.handler.on_open(Handshake {
+ request,
+ response,
+ peer_addr: self.socket.peer_addr().ok(),
+ local_addr: self.socket.local_addr().ok(),
+ })?;
+
+ // check to see if there is anything to read already
+ if !self.in_buffer.get_ref().is_empty() {
+ self.read_frames()?;
+ }
+
+ self.check_events();
+ return Ok(());
+ }
+ Err(Error::new(
+ Kind::Internal,
+ "Tried to read WebSocket handshake while not in connecting state!",
+ ))
+ }
+
+ pub fn read(&mut self) -> Result<()> {
+ if self.socket.is_negotiating() {
+ trace!("Performing TLS negotiation on {}.", self.peer_addr());
+ self.socket.clear_negotiating()?;
+ self.write()
+ } else {
+ let res = if self.state.is_connecting() {
+ trace!("Ready to read handshake from {}.", self.peer_addr());
+ self.read_handshake()
+ } else {
+ trace!("Ready to read messages from {}.", self.peer_addr());
+ while let Some(len) = self.buffer_in()? {
+ self.read_frames()?;
+ if len == 0 {
+ if self.events.is_writable() {
+ self.events.remove(Ready::readable());
+ } else {
+ self.disconnect()
+ }
+ break;
+ }
+ }
+ Ok(())
+ };
+
+ if self.socket.is_negotiating() && res.is_ok() {
+ self.events.remove(Ready::readable());
+ self.events.insert(Ready::writable());
+ }
+ res
+ }
+ }
+
+ fn read_frames(&mut self) -> Result<()> {
+ let max_size = self.settings.max_fragment_size as u64;
+ while let Some(mut frame) = Frame::parse(&mut self.in_buffer, max_size)? {
+ match self.state {
+ // Ignore data received after receiving close frame
+ RespondingClose | FinishedClose => continue,
+ _ => (),
+ }
+
+ if self.settings.masking_strict {
+ if frame.is_masked() {
+ if self.is_client() {
+ return Err(Error::new(
+ Kind::Protocol,
+ "Received masked frame from a server endpoint.",
+ ));
+ }
+ } else {
+ if self.is_server() {
+ return Err(Error::new(
+ Kind::Protocol,
+ "Received unmasked frame from a client endpoint.",
+ ));
+ }
+ }
+ }
+
+ // This is safe whether or not a frame is masked.
+ frame.remove_mask();
+
+ if let Some(frame) = self.handler.on_frame(frame)? {
+ if frame.is_final() {
+ match frame.opcode() {
+ // singleton data frames
+ OpCode::Text => {
+ trace!("Received text frame {:?}", frame);
+ // since we are going to handle this, there can't be an ongoing
+ // message
+ if !self.fragments.is_empty() {
+ return Err(Error::new(Kind::Protocol, "Received unfragmented text frame while processing fragmented message."));
+ }
+ let msg = Message::text(String::from_utf8(frame.into_data())
+ .map_err(|err| err.utf8_error())?);
+ self.handler.on_message(msg)?;
+ }
+ OpCode::Binary => {
+ trace!("Received binary frame {:?}", frame);
+ // since we are going to handle this, there can't be an ongoing
+ // message
+ if !self.fragments.is_empty() {
+ return Err(Error::new(Kind::Protocol, "Received unfragmented binary frame while processing fragmented message."));
+ }
+ let data = frame.into_data();
+ self.handler.on_message(Message::binary(data))?;
+ }
+ // control frames
+ OpCode::Close => {
+ trace!("Received close frame {:?}", frame);
+ // Closing handshake
+ if self.state.is_closing() {
+ if self.is_server() {
+ // Finished handshake, disconnect server side
+ self.events = Ready::empty()
+ } else {
+ // We are a client, so we wait for the server to close the
+ // connection
+ }
+ } else {
+ // Starting handshake, will send the responding close frame
+ self.state = RespondingClose;
+ }
+
+ let mut close_code = [0u8; 2];
+ let mut data = Cursor::new(frame.into_data());
+ if let 2 = data.read(&mut close_code)? {
+ let raw_code: u16 =
+ (u16::from(close_code[0]) << 8) | (u16::from(close_code[1]));
+ trace!(
+ "Connection to {} received raw close code: {:?}, {:?}",
+ self.peer_addr(),
+ raw_code,
+ close_code
+ );
+ let named = CloseCode::from(raw_code);
+ if let CloseCode::Other(code) = named {
+ if code < 1000 ||
+ code >= 5000 ||
+ code == 1004 ||
+ code == 1014 ||
+ code == 1016 || // these below are here to pass the autobahn test suite
+ code == 1100 || // we shouldn't need them later
+ code == 2000
+ || code == 2999
+ {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!(
+ "Received invalid close code from endpoint: {}",
+ code
+ ),
+ ));
+ }
+ }
+ let has_reason = {
+ if let Ok(reason) = from_utf8(&data.get_ref()[2..]) {
+ self.handler.on_close(named, reason); // note reason may be an empty string
+ true
+ } else {
+ self.handler.on_close(named, "");
+ false
+ }
+ };
+
+ if let CloseCode::Abnormal = named {
+ return Err(Error::new(
+ Kind::Protocol,
+ "Received abnormal close code from endpoint.",
+ ));
+ } else if let CloseCode::Status = named {
+ return Err(Error::new(
+ Kind::Protocol,
+ "Received no status close code from endpoint.",
+ ));
+ } else if let CloseCode::Restart = named {
+ return Err(Error::new(
+ Kind::Protocol,
+ "Restart close code is not supported.",
+ ));
+ } else if let CloseCode::Again = named {
+ return Err(Error::new(
+ Kind::Protocol,
+ "Try again later close code is not supported.",
+ ));
+ } else if let CloseCode::Tls = named {
+ return Err(Error::new(
+ Kind::Protocol,
+ "Received TLS close code outside of TLS handshake.",
+ ));
+ } else {
+ if !self.state.is_closing() {
+ if has_reason {
+ self.send_close(named, "")?; // note this drops any extra close data
+ } else {
+ self.send_close(CloseCode::Invalid, "")?;
+ }
+ } else {
+ self.state = FinishedClose;
+ }
+ }
+ } else {
+ // This is not an error. It is allowed behavior in the
+ // protocol, so we don't trigger an error.
+ // "If there is no such data in the Close control frame,
+ // _The WebSocket Connection Close Reason_ is the empty string."
+ self.handler.on_close(CloseCode::Status, "");
+ if !self.state.is_closing() {
+ self.send_close(CloseCode::Empty, "")?;
+ } else {
+ self.state = FinishedClose;
+ }
+ }
+ }
+ OpCode::Ping => {
+ trace!("Received ping frame {:?}", frame);
+ self.send_pong(frame.into_data())?;
+ }
+ OpCode::Pong => {
+ trace!("Received pong frame {:?}", frame);
+ // no ping validation for now
+ }
+ // last fragment
+ OpCode::Continue => {
+ trace!("Received final fragment {:?}", frame);
+ if let Some(first) = self.fragments.pop_front() {
+ let size = self.fragments.iter().fold(
+ first.payload().len() + frame.payload().len(),
+ |len, frame| len + frame.payload().len(),
+ );
+ match first.opcode() {
+ OpCode::Text => {
+ trace!("Constructing text message from fragments: {:?} -> {:?} -> {:?}", first, self.fragments.iter().collect::<Vec<&Frame>>(), frame);
+ let mut data = Vec::with_capacity(size);
+ data.extend(first.into_data());
+ while let Some(frame) = self.fragments.pop_front() {
+ data.extend(frame.into_data());
+ }
+ data.extend(frame.into_data());
+
+ let string = String::from_utf8(data)
+ .map_err(|err| err.utf8_error())?;
+
+ trace!(
+ "Calling handler with constructed message: {:?}",
+ string
+ );
+ self.handler.on_message(Message::text(string))?;
+ }
+ OpCode::Binary => {
+ trace!("Constructing binary message from fragments: {:?} -> {:?} -> {:?}", first, self.fragments.iter().collect::<Vec<&Frame>>(), frame);
+ let mut data = Vec::with_capacity(size);
+ data.extend(first.into_data());
+
+ while let Some(frame) = self.fragments.pop_front() {
+ data.extend(frame.into_data());
+ }
+
+ data.extend(frame.into_data());
+
+ trace!(
+ "Calling handler with constructed message: {:?}",
+ data
+ );
+ self.handler.on_message(Message::binary(data))?;
+ }
+ _ => {
+ return Err(Error::new(
+ Kind::Protocol,
+ "Encounted fragmented control frame.",
+ ))
+ }
+ }
+ } else {
+ return Err(Error::new(
+ Kind::Protocol,
+ "Unable to reconstruct fragmented message. No first frame.",
+ ));
+ }
+ }
+ _ => return Err(Error::new(Kind::Protocol, "Encountered invalid opcode.")),
+ }
+ } else {
+ if frame.is_control() {
+ return Err(Error::new(
+ Kind::Protocol,
+ "Encounted fragmented control frame.",
+ ));
+ } else {
+ trace!("Received non-final fragment frame {:?}", frame);
+ if !self.settings.fragments_grow
+ && self.settings.fragments_capacity == self.fragments.len()
+ {
+ return Err(Error::new(Kind::Capacity, "Exceeded max fragments."));
+ } else {
+ self.fragments.push_back(frame)
+ }
+ }
+ }
+ }
+ }
+ Ok(())
+ }
+
+ pub fn write(&mut self) -> Result<()> {
+ if self.socket.is_negotiating() {
+ trace!("Performing TLS negotiation on {}.", self.peer_addr());
+ self.socket.clear_negotiating()?;
+ self.read()
+ } else {
+ let res = if self.state.is_connecting() {
+ trace!("Ready to write handshake to {}.", self.peer_addr());
+ self.write_handshake()
+ } else {
+ trace!("Ready to write messages to {}.", self.peer_addr());
+
+ // Start out assuming that this write will clear the whole buffer
+ self.events.remove(Ready::writable());
+
+ if let Some(len) = self.socket.try_write_buf(&mut self.out_buffer)? {
+ trace!("Wrote {} bytes to {}", len, self.peer_addr());
+ let finished = len == 0
+ || self.out_buffer.position() == self.out_buffer.get_ref().len() as u64;
+ if finished {
+ match self.state {
+ // we are are a server that is closing and just wrote out our confirming
+ // close frame, let's disconnect
+ FinishedClose if self.is_server() => {
+ self.events = Ready::empty();
+ return Ok(());
+ }
+ _ => (),
+ }
+ }
+ }
+
+ // Check if there is more to write so that the connection will be rescheduled
+ self.check_events();
+ Ok(())
+ };
+
+ if self.socket.is_negotiating() && res.is_ok() {
+ self.events.remove(Ready::writable());
+ self.events.insert(Ready::readable());
+ }
+ res
+ }
+ }
+
+ pub fn send_message(&mut self, msg: Message) -> Result<()> {
+ if self.state.is_closing() {
+ trace!(
+ "Connection is closing. Ignoring request to send message {:?} to {}.",
+ msg,
+ self.peer_addr()
+ );
+ return Ok(());
+ }
+
+ let opcode = msg.opcode();
+ trace!("Message opcode {:?}", opcode);
+ let data = msg.into_data();
+
+ if let Some(frame) = self.handler
+ .on_send_frame(Frame::message(data, opcode, true))?
+ {
+ if frame.payload().len() > self.settings.fragment_size {
+ trace!("Chunking at {:?}.", self.settings.fragment_size);
+ // note this copies the data, so it's actually somewhat expensive to fragment
+ let mut chunks = frame
+ .payload()
+ .chunks(self.settings.fragment_size)
+ .peekable();
+ let chunk = chunks.next().expect("Unable to get initial chunk!");
+
+ let mut first = Frame::message(Vec::from(chunk), opcode, false);
+
+ // Match reserved bits from original to keep extension status intact
+ first.set_rsv1(frame.has_rsv1());
+ first.set_rsv2(frame.has_rsv2());
+ first.set_rsv3(frame.has_rsv3());
+
+ self.buffer_frame(first)?;
+
+ while let Some(chunk) = chunks.next() {
+ if chunks.peek().is_some() {
+ self.buffer_frame(Frame::message(
+ Vec::from(chunk),
+ OpCode::Continue,
+ false,
+ ))?;
+ } else {
+ self.buffer_frame(Frame::message(
+ Vec::from(chunk),
+ OpCode::Continue,
+ true,
+ ))?;
+ }
+ }
+ } else {
+ trace!("Sending unfragmented message frame.");
+ // true means that the message is done
+ self.buffer_frame(frame)?;
+ }
+ }
+ self.check_events();
+ Ok(())
+ }
+
+ #[inline]
+ pub fn send_ping(&mut self, data: Vec<u8>) -> Result<()> {
+ if self.state.is_closing() {
+ trace!(
+ "Connection is closing. Ignoring request to send ping {:?} to {}.",
+ data,
+ self.peer_addr()
+ );
+ return Ok(());
+ }
+ trace!("Sending ping to {}.", self.peer_addr());
+
+ if let Some(frame) = self.handler.on_send_frame(Frame::ping(data))? {
+ self.buffer_frame(frame)?;
+ }
+ self.check_events();
+ Ok(())
+ }
+
+ #[inline]
+ pub fn send_pong(&mut self, data: Vec<u8>) -> Result<()> {
+ if self.state.is_closing() {
+ trace!(
+ "Connection is closing. Ignoring request to send pong {:?} to {}.",
+ data,
+ self.peer_addr()
+ );
+ return Ok(());
+ }
+ trace!("Sending pong to {}.", self.peer_addr());
+
+ if let Some(frame) = self.handler.on_send_frame(Frame::pong(data))? {
+ self.buffer_frame(frame)?;
+ }
+ self.check_events();
+ Ok(())
+ }
+
+ #[inline]
+ pub fn send_close<R>(&mut self, code: CloseCode, reason: R) -> Result<()>
+ where
+ R: Borrow<str>,
+ {
+ match self.state {
+ // We are responding to a close frame the other endpoint, when this frame goes out, we
+ // are done.
+ RespondingClose => self.state = FinishedClose,
+ // Multiple close frames are being sent from our end, ignore the later frames
+ AwaitingClose | FinishedClose => {
+ trace!(
+ "Connection is already closing. Ignoring close {:?} -- {:?} to {}.",
+ code,
+ reason.borrow(),
+ self.peer_addr()
+ );
+ self.check_events();
+ return Ok(());
+ }
+ // We are initiating a closing handshake.
+ Open => self.state = AwaitingClose,
+ Connecting(_, _) => {
+ debug_assert!(false, "Attempted to close connection while not yet open.")
+ }
+ }
+
+ trace!(
+ "Sending close {:?} -- {:?} to {}.",
+ code,
+ reason.borrow(),
+ self.peer_addr()
+ );
+
+ if let Some(frame) = self.handler
+ .on_send_frame(Frame::close(code, reason.borrow()))?
+ {
+ self.buffer_frame(frame)?;
+ }
+
+ trace!("Connection to {} is now closing.", self.peer_addr());
+
+ self.check_events();
+ Ok(())
+ }
+
+ fn check_events(&mut self) {
+ if !self.state.is_connecting() {
+ self.events.insert(Ready::readable());
+ if self.out_buffer.position() < self.out_buffer.get_ref().len() as u64 {
+ self.events.insert(Ready::writable());
+ }
+ }
+ }
+
+ fn buffer_frame(&mut self, mut frame: Frame) -> Result<()> {
+ self.check_buffer_out(&frame)?;
+
+ if self.is_client() {
+ frame.set_mask();
+ }
+
+ trace!("Buffering frame to {}:\n{}", self.peer_addr(), frame);
+
+ let pos = self.out_buffer.position();
+ self.out_buffer.seek(SeekFrom::End(0))?;
+ frame.format(&mut self.out_buffer)?;
+ self.out_buffer.seek(SeekFrom::Start(pos))?;
+ Ok(())
+ }
+
+ fn check_buffer_out(&mut self, frame: &Frame) -> Result<()> {
+ if self.out_buffer.get_ref().capacity() <= self.out_buffer.get_ref().len() + frame.len() {
+ // extend
+ let mut new = Vec::with_capacity(self.out_buffer.get_ref().capacity());
+ new.extend(&self.out_buffer.get_ref()[self.out_buffer.position() as usize..]);
+ if new.len() == new.capacity() {
+ if self.settings.out_buffer_grow {
+ new.reserve(self.settings.out_buffer_capacity)
+ } else {
+ return Err(Error::new(
+ Kind::Capacity,
+ "Maxed out output buffer for connection.",
+ ));
+ }
+ }
+ self.out_buffer = Cursor::new(new);
+ }
+ Ok(())
+ }
+
+ fn buffer_in(&mut self) -> Result<Option<usize>> {
+ trace!("Reading buffer for connection to {}.", self.peer_addr());
+ if let Some(len) = self.socket.try_read_buf(self.in_buffer.get_mut())? {
+ trace!("Buffered {}.", len);
+ if self.in_buffer.get_ref().len() == self.in_buffer.get_ref().capacity() {
+ // extend
+ let mut new = Vec::with_capacity(self.in_buffer.get_ref().capacity());
+ new.extend(&self.in_buffer.get_ref()[self.in_buffer.position() as usize..]);
+ if new.len() == new.capacity() {
+ if self.settings.in_buffer_grow {
+ new.reserve(self.settings.in_buffer_capacity);
+ } else {
+ return Err(Error::new(
+ Kind::Capacity,
+ "Maxed out input buffer for connection.",
+ ));
+ }
+ }
+ self.in_buffer = Cursor::new(new);
+ }
+ Ok(Some(len))
+ } else {
+ Ok(None)
+ }
+ }
+}
diff --git a/third_party/rust/ws/src/deflate/context.rs b/third_party/rust/ws/src/deflate/context.rs
new file mode 100644
index 0000000000..2fa2e23056
--- /dev/null
+++ b/third_party/rust/ws/src/deflate/context.rs
@@ -0,0 +1,268 @@
+use std::mem;
+use std::slice;
+
+use super::ffi;
+use super::libc::{c_char, c_int, c_uint};
+
+use result::{Error, Kind, Result};
+
+const ZLIB_VERSION: &'static str = "1.2.8\0";
+
+trait Context {
+ fn stream(&mut self) -> &mut ffi::z_stream;
+
+ fn stream_apply<F>(&mut self, input: &[u8], output: &mut Vec<u8>, each: F) -> Result<()>
+ where
+ F: Fn(&mut ffi::z_stream) -> Option<Result<()>>,
+ {
+ debug_assert!(output.len() == 0, "Output vector is not empty.");
+
+ let stream = self.stream();
+
+ stream.next_in = input.as_ptr() as *mut _;
+ stream.avail_in = input.len() as c_uint;
+
+ let mut output_size;
+
+ loop {
+ output_size = output.len();
+
+ if output_size == output.capacity() {
+ output.reserve(input.len())
+ }
+
+ let out_slice = unsafe {
+ slice::from_raw_parts_mut(
+ output.as_mut_ptr().offset(output_size as isize),
+ output.capacity() - output_size,
+ )
+ };
+
+ stream.next_out = out_slice.as_mut_ptr();
+ stream.avail_out = out_slice.len() as c_uint;
+
+ let before = stream.total_out;
+ let cont = each(stream);
+
+ unsafe {
+ output.set_len((stream.total_out - before) as usize + output_size);
+ }
+
+ if let Some(result) = cont {
+ return result;
+ }
+ }
+ }
+}
+
+pub struct Compressor {
+ // Box the z_stream to ensure it isn't moved. Moving the z_stream
+ // causes zlib to fail, because it maintains internal pointers.
+ stream: Box<ffi::z_stream>,
+}
+
+impl Compressor {
+ pub fn new(window_bits: i8) -> Compressor {
+ debug_assert!(window_bits >= 9, "Received too small window size.");
+ debug_assert!(window_bits <= 15, "Received too large window size.");
+
+ unsafe {
+ let mut stream: Box<ffi::z_stream> = Box::new(mem::zeroed());
+ let result = ffi::deflateInit2_(
+ stream.as_mut(),
+ 9,
+ ffi::Z_DEFLATED,
+ -window_bits as c_int,
+ 9,
+ ffi::Z_DEFAULT_STRATEGY,
+ ZLIB_VERSION.as_ptr() as *const c_char,
+ mem::size_of::<ffi::z_stream>() as c_int,
+ );
+ assert!(result == ffi::Z_OK, "Failed to initialize compresser.");
+ Compressor { stream: stream }
+ }
+ }
+
+ pub fn compress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<()> {
+ self.stream_apply(input, output, |stream| unsafe {
+ match ffi::deflate(stream, ffi::Z_SYNC_FLUSH) {
+ ffi::Z_OK | ffi::Z_BUF_ERROR => {
+ if stream.avail_in == 0 && stream.avail_out > 0 {
+ Some(Ok(()))
+ } else {
+ None
+ }
+ }
+ code => Some(Err(Error::new(
+ Kind::Protocol,
+ format!("Failed to perform compression: {}", code),
+ ))),
+ }
+ })
+ }
+
+ pub fn reset(&mut self) -> Result<()> {
+ match unsafe { ffi::deflateReset(self.stream.as_mut()) } {
+ ffi::Z_OK => Ok(()),
+ code => Err(Error::new(
+ Kind::Protocol,
+ format!("Failed to reset compression context: {}", code),
+ )),
+ }
+ }
+}
+
+impl Context for Compressor {
+ fn stream(&mut self) -> &mut ffi::z_stream {
+ self.stream.as_mut()
+ }
+}
+
+impl Drop for Compressor {
+ fn drop(&mut self) {
+ match unsafe { ffi::deflateEnd(self.stream.as_mut()) } {
+ ffi::Z_STREAM_ERROR => error!("Compression stream encountered bad state."),
+ // Ignore discarded data error because we are raw
+ ffi::Z_OK | ffi::Z_DATA_ERROR => trace!("Deallocated compression context."),
+ code => error!("Bad zlib status encountered: {}", code),
+ }
+ }
+}
+
+pub struct Decompressor {
+ stream: Box<ffi::z_stream>,
+}
+
+impl Decompressor {
+ pub fn new(window_bits: i8) -> Decompressor {
+ debug_assert!(window_bits >= 8, "Received too small window size.");
+ debug_assert!(window_bits <= 15, "Received too large window size.");
+
+ unsafe {
+ let mut stream: Box<ffi::z_stream> = Box::new(mem::zeroed());
+ let result = ffi::inflateInit2_(
+ stream.as_mut(),
+ -window_bits as c_int,
+ ZLIB_VERSION.as_ptr() as *const c_char,
+ mem::size_of::<ffi::z_stream>() as c_int,
+ );
+ assert!(result == ffi::Z_OK, "Failed to initialize decompresser.");
+ Decompressor { stream: stream }
+ }
+ }
+
+ pub fn decompress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<()> {
+ self.stream_apply(input, output, |stream| unsafe {
+ match ffi::inflate(stream, ffi::Z_SYNC_FLUSH) {
+ ffi::Z_OK | ffi::Z_BUF_ERROR => {
+ if stream.avail_in == 0 && stream.avail_out > 0 {
+ Some(Ok(()))
+ } else {
+ None
+ }
+ }
+ code => Some(Err(Error::new(
+ Kind::Protocol,
+ format!("Failed to perform decompression: {}", code),
+ ))),
+ }
+ })
+ }
+
+ pub fn reset(&mut self) -> Result<()> {
+ match unsafe { ffi::inflateReset(self.stream.as_mut()) } {
+ ffi::Z_OK => Ok(()),
+ code => Err(Error::new(
+ Kind::Protocol,
+ format!("Failed to reset compression context: {}", code),
+ )),
+ }
+ }
+}
+
+impl Context for Decompressor {
+ fn stream(&mut self) -> &mut ffi::z_stream {
+ self.stream.as_mut()
+ }
+}
+
+impl Drop for Decompressor {
+ fn drop(&mut self) {
+ match unsafe { ffi::inflateEnd(self.stream.as_mut()) } {
+ ffi::Z_STREAM_ERROR => error!("Decompression stream encountered bad state."),
+ ffi::Z_OK => trace!("Deallocated decompression context."),
+ code => error!("Bad zlib status encountered: {}", code),
+ }
+ }
+}
+
+mod test {
+ #![allow(unused_imports, unused_variables, dead_code)]
+ use super::*;
+
+ fn as_hex(s: &[u8]) {
+ for byte in s {
+ print!("0x{:x} ", byte);
+ }
+ print!("\n");
+ }
+
+ #[test]
+ fn round_trip() {
+ for i in 9..16 {
+ let data = "HI THERE THIS IS some data. これはデータだよ。".as_bytes();
+ let mut compressed = Vec::with_capacity(data.len());
+ let mut decompressed = Vec::with_capacity(data.len());
+
+ let com = Compressor::new(i);
+ let mut moved_com = com;
+
+ moved_com
+ .compress(&data, &mut compressed)
+ .expect("Failed to compress data.");
+
+ let dec = Decompressor::new(i);
+ let mut moved_dec = dec;
+
+ moved_dec
+ .decompress(&compressed, &mut decompressed)
+ .expect("Failed to decompress data.");
+
+ assert_eq!(data, &decompressed[..]);
+ }
+ }
+
+ #[test]
+ fn reset() {
+ let data1 = "HI THERE 直子さん".as_bytes();
+ let data2 = "HI THERE 人太郎".as_bytes();
+ let mut compressed1 = Vec::with_capacity(data1.len());
+ let mut compressed2 = Vec::with_capacity(data2.len());
+ let mut compressed2_ind = Vec::with_capacity(data2.len());
+
+ let mut decompressed1 = Vec::with_capacity(data1.len());
+ let mut decompressed2 = Vec::with_capacity(data2.len());
+ let mut decompressed2_ind = Vec::with_capacity(data2.len());
+
+ let mut com = Compressor::new(9);
+
+ com.compress(&data1, &mut compressed1).unwrap();
+ com.compress(&data2, &mut compressed2).unwrap();
+ com.reset().unwrap();
+ com.compress(&data2, &mut compressed2_ind).unwrap();
+
+ let mut dec = Decompressor::new(9);
+
+ dec.decompress(&compressed1, &mut decompressed1).unwrap();
+ dec.decompress(&compressed2, &mut decompressed2).unwrap();
+ dec.reset().unwrap();
+ dec.decompress(&compressed2_ind, &mut decompressed2_ind)
+ .unwrap();
+
+ assert_eq!(data1, &decompressed1[..]);
+ assert_eq!(data2, &decompressed2[..]);
+ assert_eq!(data2, &decompressed2_ind[..]);
+ assert!(compressed2 != compressed2_ind);
+ assert!(compressed2.len() < compressed2_ind.len());
+ }
+}
diff --git a/third_party/rust/ws/src/deflate/extension.rs b/third_party/rust/ws/src/deflate/extension.rs
new file mode 100644
index 0000000000..712e11fb8e
--- /dev/null
+++ b/third_party/rust/ws/src/deflate/extension.rs
@@ -0,0 +1,565 @@
+use std::mem::replace;
+
+#[cfg(feature = "ssl")]
+use openssl::ssl::SslStream;
+#[cfg(feature = "nativetls")]
+use native_tls::TlsStream as SslStream;
+use url;
+
+use frame::Frame;
+use handler::Handler;
+use handshake::{Handshake, Request, Response};
+use message::Message;
+use protocol::{CloseCode, OpCode};
+use result::{Error, Kind, Result};
+#[cfg(any(feature = "ssl", feature = "nativetls"))]
+use util::TcpStream;
+use util::{Timeout, Token};
+
+use super::context::{Compressor, Decompressor};
+
+/// Deflate Extension Handler Settings
+#[derive(Debug, Clone, Copy)]
+pub struct DeflateSettings {
+ /// The max size of the sliding window. If the other endpoint selects a smaller size, that size
+ /// will be used instead. This must be an integer between 9 and 15 inclusive.
+ /// Default: 15
+ pub max_window_bits: u8,
+ /// Indicates whether to ask the other endpoint to reset the sliding window for each message.
+ /// Default: false
+ pub request_no_context_takeover: bool,
+ /// Indicates whether this endpoint will agree to reset the sliding window for each message it
+ /// compresses. If this endpoint won't agree to reset the sliding window, then the handshake
+ /// will fail if this endpoint is a client and the server requests no context takeover.
+ /// Default: true
+ pub accept_no_context_takeover: bool,
+ /// The number of WebSocket frames to store when defragmenting an incoming fragmented
+ /// compressed message.
+ /// This setting may be different from the `fragments_capacity` setting of the WebSocket in order to
+ /// allow for differences between compressed and uncompressed messages.
+ /// Default: 10
+ pub fragments_capacity: usize,
+ /// Indicates whether the extension handler will reallocate if the `fragments_capacity` is
+ /// exceeded. If this is not true, a capacity error will be triggered instead.
+ /// Default: true
+ pub fragments_grow: bool,
+}
+
+impl Default for DeflateSettings {
+ fn default() -> DeflateSettings {
+ DeflateSettings {
+ max_window_bits: 15,
+ request_no_context_takeover: false,
+ accept_no_context_takeover: true,
+ fragments_capacity: 10,
+ fragments_grow: true,
+ }
+ }
+}
+
+/// Utility for applying the permessage-deflate extension to a handler with particular deflate
+/// settings.
+#[derive(Debug, Clone, Copy)]
+pub struct DeflateBuilder {
+ settings: DeflateSettings,
+}
+
+impl DeflateBuilder {
+ /// Create a new DeflateBuilder with the default settings.
+ pub fn new() -> DeflateBuilder {
+ DeflateBuilder {
+ settings: DeflateSettings::default(),
+ }
+ }
+
+ /// Configure the DeflateBuilder with the given deflate settings.
+ pub fn with_settings(&mut self, settings: DeflateSettings) -> &mut DeflateBuilder {
+ self.settings = settings;
+ self
+ }
+
+ /// Wrap another handler in with a deflate handler as configured.
+ pub fn build<H: Handler>(&self, handler: H) -> DeflateHandler<H> {
+ DeflateHandler {
+ com: Compressor::new(self.settings.max_window_bits as i8),
+ dec: Decompressor::new(self.settings.max_window_bits as i8),
+ fragments: Vec::with_capacity(self.settings.fragments_capacity),
+ compress_reset: false,
+ decompress_reset: false,
+ pass: false,
+ settings: self.settings,
+ inner: handler,
+ }
+ }
+}
+
+/// A WebSocket handler that implements the permessage-deflate extension.
+///
+/// This handler wraps a child handler and proxies all handler methods to it. The handler will
+/// decompress incoming WebSocket message frames in their reserved bits match the
+/// permessage-deflate specification and pass them to the child handler. Message frames sent from
+/// the child handler will be compressed and sent to the other endpoint using deflate compression.
+pub struct DeflateHandler<H: Handler> {
+ com: Compressor,
+ dec: Decompressor,
+ fragments: Vec<Frame>,
+ compress_reset: bool,
+ decompress_reset: bool,
+ pass: bool,
+ settings: DeflateSettings,
+ inner: H,
+}
+
+impl<H: Handler> DeflateHandler<H> {
+ /// Wrap a child handler to provide the permessage-deflate extension.
+ pub fn new(handler: H) -> DeflateHandler<H> {
+ trace!("Using permessage-deflate handler.");
+ let settings = DeflateSettings::default();
+ DeflateHandler {
+ com: Compressor::new(settings.max_window_bits as i8),
+ dec: Decompressor::new(settings.max_window_bits as i8),
+ fragments: Vec::with_capacity(settings.fragments_capacity),
+ compress_reset: false,
+ decompress_reset: false,
+ pass: false,
+ settings: settings,
+ inner: handler,
+ }
+ }
+
+ #[doc(hidden)]
+ #[inline]
+ fn decline(&mut self, mut res: Response) -> Result<Response> {
+ trace!("Declined permessage-deflate offer");
+ self.pass = true;
+ res.remove_extension("permessage-deflate");
+ Ok(res)
+ }
+}
+
+impl<H: Handler> Handler for DeflateHandler<H> {
+ fn build_request(&mut self, url: &url::Url) -> Result<Request> {
+ let mut req = self.inner.build_request(url)?;
+ let mut req_ext = String::with_capacity(100);
+ req_ext.push_str("permessage-deflate");
+ if self.settings.max_window_bits < 15 {
+ req_ext.push_str(&format!(
+ "; client_max_window_bits={}; server_max_window_bits={}",
+ self.settings.max_window_bits, self.settings.max_window_bits
+ ))
+ } else {
+ req_ext.push_str("; client_max_window_bits")
+ }
+ if self.settings.request_no_context_takeover {
+ req_ext.push_str("; server_no_context_takeover")
+ }
+ req.add_extension(&req_ext);
+ Ok(req)
+ }
+
+ fn on_request(&mut self, req: &Request) -> Result<Response> {
+ let mut res = self.inner.on_request(req)?;
+
+ 'ext: for req_ext in req.extensions()?
+ .iter()
+ .filter(|&&ext| ext.contains("permessage-deflate"))
+ {
+ let mut res_ext = String::with_capacity(req_ext.len());
+ let mut s_takeover = false;
+ let mut c_takeover = false;
+ let mut s_max = false;
+ let mut c_max = false;
+
+ for param in req_ext.split(';') {
+ match param.trim() {
+ "permessage-deflate" => res_ext.push_str("permessage-deflate"),
+ "server_no_context_takeover" => {
+ if s_takeover {
+ return self.decline(res);
+ } else {
+ s_takeover = true;
+ if self.settings.accept_no_context_takeover {
+ self.compress_reset = true;
+ res_ext.push_str("; server_no_context_takeover");
+ } else {
+ continue 'ext;
+ }
+ }
+ }
+ "client_no_context_takeover" => {
+ if c_takeover {
+ return self.decline(res);
+ } else {
+ c_takeover = true;
+ self.decompress_reset = true;
+ res_ext.push_str("; client_no_context_takeover");
+ }
+ }
+ param if param.starts_with("server_max_window_bits") => {
+ if s_max {
+ return self.decline(res);
+ } else {
+ s_max = true;
+ let mut param_iter = param.split('=');
+ param_iter.next(); // we already know the name
+ if let Some(window_bits_str) = param_iter.next() {
+ if let Ok(window_bits) = window_bits_str.trim().parse() {
+ if window_bits >= 9 && window_bits <= 15 {
+ if window_bits < self.settings.max_window_bits as i8 {
+ self.com = Compressor::new(window_bits);
+ res_ext.push_str("; ");
+ res_ext.push_str(param)
+ }
+ } else {
+ return self.decline(res);
+ }
+ } else {
+ return self.decline(res);
+ }
+ }
+ }
+ }
+ param if param.starts_with("client_max_window_bits") => {
+ if c_max {
+ return self.decline(res);
+ } else {
+ c_max = true;
+ let mut param_iter = param.split('=');
+ param_iter.next(); // we already know the name
+ if let Some(window_bits_str) = param_iter.next() {
+ if let Ok(window_bits) = window_bits_str.trim().parse() {
+ if window_bits >= 9 && window_bits <= 15 {
+ if window_bits < self.settings.max_window_bits as i8 {
+ self.dec = Decompressor::new(window_bits);
+ res_ext.push_str("; ");
+ res_ext.push_str(param);
+ continue;
+ }
+ } else {
+ return self.decline(res);
+ }
+ } else {
+ return self.decline(res);
+ }
+ }
+ res_ext.push_str("; ");
+ res_ext.push_str(&format!(
+ "client_max_window_bits={}",
+ self.settings.max_window_bits
+ ))
+ }
+ }
+ _ => {
+ // decline all extension offers because we got a bad parameter
+ return self.decline(res);
+ }
+ }
+ }
+
+ if !res_ext.contains("client_no_context_takeover")
+ && self.settings.request_no_context_takeover
+ {
+ self.decompress_reset = true;
+ res_ext.push_str("; client_no_context_takeover");
+ }
+
+ if !res_ext.contains("server_max_window_bits") {
+ res_ext.push_str("; ");
+ res_ext.push_str(&format!(
+ "server_max_window_bits={}",
+ self.settings.max_window_bits
+ ))
+ }
+
+ if !res_ext.contains("client_max_window_bits") && self.settings.max_window_bits < 15 {
+ continue;
+ }
+
+ res.add_extension(&res_ext);
+ return Ok(res);
+ }
+ self.decline(res)
+ }
+
+ fn on_response(&mut self, res: &Response) -> Result<()> {
+ if let Some(res_ext) = res.extensions()?
+ .iter()
+ .find(|&&ext| ext.contains("permessage-deflate"))
+ {
+ let mut name = false;
+ let mut s_takeover = false;
+ let mut c_takeover = false;
+ let mut s_max = false;
+ let mut c_max = false;
+
+ for param in res_ext.split(';') {
+ match param.trim() {
+ "permessage-deflate" => {
+ if name {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!("Duplicate extension name permessage-deflate"),
+ ));
+ } else {
+ name = true;
+ }
+ }
+ "server_no_context_takeover" => {
+ if s_takeover {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!("Duplicate extension parameter server_no_context_takeover"),
+ ));
+ } else {
+ s_takeover = true;
+ self.decompress_reset = true;
+ }
+ }
+ "client_no_context_takeover" => {
+ if c_takeover {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!("Duplicate extension parameter client_no_context_takeover"),
+ ));
+ } else {
+ c_takeover = true;
+ if self.settings.accept_no_context_takeover {
+ self.compress_reset = true;
+ } else {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!("The client requires context takeover."),
+ ));
+ }
+ }
+ }
+ param if param.starts_with("server_max_window_bits") => {
+ if s_max {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!("Duplicate extension parameter server_max_window_bits"),
+ ));
+ } else {
+ s_max = true;
+ let mut param_iter = param.split('=');
+ param_iter.next(); // we already know the name
+ if let Some(window_bits_str) = param_iter.next() {
+ if let Ok(window_bits) = window_bits_str.trim().parse() {
+ if window_bits >= 9 && window_bits <= 15 {
+ if window_bits as u8 != self.settings.max_window_bits {
+ self.dec = Decompressor::new(window_bits);
+ }
+ } else {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!(
+ "Invalid server_max_window_bits parameter: {}",
+ window_bits
+ ),
+ ));
+ }
+ } else {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!(
+ "Invalid server_max_window_bits parameter: {}",
+ window_bits_str
+ ),
+ ));
+ }
+ }
+ }
+ }
+ param if param.starts_with("client_max_window_bits") => {
+ if c_max {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!("Duplicate extension parameter client_max_window_bits"),
+ ));
+ } else {
+ c_max = true;
+ let mut param_iter = param.split('=');
+ param_iter.next(); // we already know the name
+ if let Some(window_bits_str) = param_iter.next() {
+ if let Ok(window_bits) = window_bits_str.trim().parse() {
+ if window_bits >= 9 && window_bits <= 15 {
+ if window_bits as u8 != self.settings.max_window_bits {
+ self.com = Compressor::new(window_bits);
+ }
+ } else {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!(
+ "Invalid client_max_window_bits parameter: {}",
+ window_bits
+ ),
+ ));
+ }
+ } else {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!(
+ "Invalid client_max_window_bits parameter: {}",
+ window_bits_str
+ ),
+ ));
+ }
+ }
+ }
+ }
+ param => {
+ // fail the connection because we got a bad parameter
+ return Err(Error::new(
+ Kind::Protocol,
+ format!("Bad extension parameter: {}", param),
+ ));
+ }
+ }
+ }
+ } else {
+ self.pass = true
+ }
+
+ Ok(())
+ }
+
+ fn on_frame(&mut self, mut frame: Frame) -> Result<Option<Frame>> {
+ if !self.pass && !frame.is_control() {
+ if !self.fragments.is_empty() || frame.has_rsv1() {
+ frame.set_rsv1(false);
+
+ if !frame.is_final() {
+ self.fragments.push(frame);
+ return Ok(None);
+ } else {
+ if frame.opcode() == OpCode::Continue {
+ if self.fragments.is_empty() {
+ return Err(Error::new(
+ Kind::Protocol,
+ "Unable to reconstruct fragmented message. No first frame.",
+ ));
+ } else {
+ if !self.settings.fragments_grow
+ && self.settings.fragments_capacity == self.fragments.len()
+ {
+ return Err(Error::new(Kind::Capacity, "Exceeded max fragments."));
+ } else {
+ self.fragments.push(frame);
+ }
+
+ // it's safe to unwrap because of the above check for empty
+ let opcode = self.fragments.first().unwrap().opcode();
+ let size = self.fragments
+ .iter()
+ .fold(0, |len, frame| len + frame.payload().len());
+ let mut compressed = Vec::with_capacity(size);
+ let mut decompressed = Vec::with_capacity(size * 2);
+ for frag in replace(
+ &mut self.fragments,
+ Vec::with_capacity(self.settings.fragments_capacity),
+ ) {
+ compressed.extend(frag.into_data())
+ }
+
+ compressed.extend(&[0, 0, 255, 255]);
+ self.dec.decompress(&compressed, &mut decompressed)?;
+ frame = Frame::message(decompressed, opcode, true);
+ }
+ } else {
+ let mut decompressed = Vec::with_capacity(frame.payload().len() * 2);
+ frame.payload_mut().extend(&[0, 0, 255, 255]);
+
+ self.dec.decompress(frame.payload(), &mut decompressed)?;
+
+ *frame.payload_mut() = decompressed;
+ }
+
+ if self.decompress_reset {
+ self.dec.reset()?
+ }
+ }
+ }
+ }
+ self.inner.on_frame(frame)
+ }
+
+ fn on_send_frame(&mut self, frame: Frame) -> Result<Option<Frame>> {
+ if let Some(mut frame) = self.inner.on_send_frame(frame)? {
+ if !self.pass && !frame.is_control() {
+ debug_assert!(
+ frame.is_final(),
+ "Received non-final frame from upstream handler!"
+ );
+ debug_assert!(
+ frame.opcode() != OpCode::Continue,
+ "Received continue frame from upstream handler!"
+ );
+
+ frame.set_rsv1(true);
+ let mut compressed = Vec::with_capacity(frame.payload().len());
+ self.com.compress(frame.payload(), &mut compressed)?;
+ let len = compressed.len();
+ compressed.truncate(len - 4);
+ *frame.payload_mut() = compressed;
+
+ if self.compress_reset {
+ self.com.reset()?
+ }
+ }
+ Ok(Some(frame))
+ } else {
+ Ok(None)
+ }
+ }
+
+ #[inline]
+ fn on_shutdown(&mut self) {
+ self.inner.on_shutdown()
+ }
+
+ #[inline]
+ fn on_open(&mut self, shake: Handshake) -> Result<()> {
+ self.inner.on_open(shake)
+ }
+
+ #[inline]
+ fn on_message(&mut self, msg: Message) -> Result<()> {
+ self.inner.on_message(msg)
+ }
+
+ #[inline]
+ fn on_close(&mut self, code: CloseCode, reason: &str) {
+ self.inner.on_close(code, reason)
+ }
+
+ #[inline]
+ fn on_error(&mut self, err: Error) {
+ self.inner.on_error(err)
+ }
+
+ #[inline]
+ fn on_timeout(&mut self, event: Token) -> Result<()> {
+ self.inner.on_timeout(event)
+ }
+
+ #[inline]
+ fn on_new_timeout(&mut self, tok: Token, timeout: Timeout) -> Result<()> {
+ self.inner.on_new_timeout(tok, timeout)
+ }
+
+ #[inline]
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ fn upgrade_ssl_client(
+ &mut self,
+ stream: TcpStream,
+ url: &url::Url,
+ ) -> Result<SslStream<TcpStream>> {
+ self.inner.upgrade_ssl_client(stream, url)
+ }
+
+ #[inline]
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ fn upgrade_ssl_server(&mut self, stream: TcpStream) -> Result<SslStream<TcpStream>> {
+ self.inner.upgrade_ssl_server(stream)
+ }
+}
diff --git a/third_party/rust/ws/src/deflate/mod.rs b/third_party/rust/ws/src/deflate/mod.rs
new file mode 100644
index 0000000000..8d79012e73
--- /dev/null
+++ b/third_party/rust/ws/src/deflate/mod.rs
@@ -0,0 +1,9 @@
+//! The deflate module provides tools for applying the permessage-deflate extension.
+
+extern crate libc;
+extern crate libz_sys as ffi;
+
+mod context;
+mod extension;
+
+pub use self::extension::{DeflateBuilder, DeflateHandler, DeflateSettings};
diff --git a/third_party/rust/ws/src/factory.rs b/third_party/rust/ws/src/factory.rs
new file mode 100644
index 0000000000..048ac78110
--- /dev/null
+++ b/third_party/rust/ws/src/factory.rs
@@ -0,0 +1,188 @@
+use communication::Sender;
+use handler::Handler;
+
+/// A trait for creating new WebSocket handlers.
+pub trait Factory {
+ type Handler: Handler;
+
+ /// Called when a TCP connection is made.
+ fn connection_made(&mut self, _: Sender) -> Self::Handler;
+
+ /// Called when the WebSocket is shutting down.
+ #[inline]
+ fn on_shutdown(&mut self) {
+ debug!("Factory received WebSocket shutdown request.");
+ }
+
+ /// Called when a new connection is established for a client endpoint.
+ /// This method can be used to differentiate a client aspect for a handler.
+ ///
+ /// ```
+ /// use ws::{Sender, Factory, Handler};
+ ///
+ /// struct MyHandler {
+ /// ws: Sender,
+ /// is_client: bool,
+ /// }
+ ///
+ /// impl Handler for MyHandler {}
+ ///
+ /// struct MyFactory;
+ ///
+ /// impl Factory for MyFactory {
+ /// type Handler = MyHandler;
+ ///
+ /// fn connection_made(&mut self, ws: Sender) -> MyHandler {
+ /// MyHandler {
+ /// ws: ws,
+ /// // default to server
+ /// is_client: false,
+ /// }
+ /// }
+ ///
+ /// fn client_connected(&mut self, ws: Sender) -> MyHandler {
+ /// MyHandler {
+ /// ws: ws,
+ /// is_client: true,
+ /// }
+ /// }
+ /// }
+ /// ```
+ #[inline]
+ fn client_connected(&mut self, ws: Sender) -> Self::Handler {
+ self.connection_made(ws)
+ }
+
+ /// Called when a new connection is established for a server endpoint.
+ /// This method can be used to differentiate a server aspect for a handler.
+ ///
+ /// ```
+ /// use ws::{Sender, Factory, Handler};
+ ///
+ /// struct MyHandler {
+ /// ws: Sender,
+ /// is_server: bool,
+ /// }
+ ///
+ /// impl Handler for MyHandler {}
+ ///
+ /// struct MyFactory;
+ ///
+ /// impl Factory for MyFactory {
+ /// type Handler = MyHandler;
+ ///
+ /// fn connection_made(&mut self, ws: Sender) -> MyHandler {
+ /// MyHandler {
+ /// ws: ws,
+ /// // default to client
+ /// is_server: false,
+ /// }
+ /// }
+ ///
+ /// fn server_connected(&mut self, ws: Sender) -> MyHandler {
+ /// MyHandler {
+ /// ws: ws,
+ /// is_server: true,
+ /// }
+ /// }
+ /// }
+ #[inline]
+ fn server_connected(&mut self, ws: Sender) -> Self::Handler {
+ self.connection_made(ws)
+ }
+
+ /// Called when a TCP connection is lost with the handler that was
+ /// setup for that connection.
+ ///
+ /// The default implementation is a noop that simply drops the handler.
+ /// You can use this to track connections being destroyed or to finalize
+ /// state that was not internally tracked by the handler.
+ #[inline]
+ fn connection_lost(&mut self, _: Self::Handler) {}
+}
+
+impl<F, H> Factory for F
+where
+ H: Handler,
+ F: FnMut(Sender) -> H,
+{
+ type Handler = H;
+
+ fn connection_made(&mut self, out: Sender) -> H {
+ self(out)
+ }
+}
+
+mod test {
+ #![allow(unused_imports, unused_variables, dead_code)]
+ use super::*;
+ use communication::{Command, Sender};
+ use frame;
+ use handler::Handler;
+ use handshake::{Handshake, Request, Response};
+ use message;
+ use mio;
+ use protocol::CloseCode;
+ use result::Result;
+
+ #[derive(Debug, Eq, PartialEq)]
+ struct M;
+ impl Handler for M {
+ fn on_message(&mut self, _: message::Message) -> Result<()> {
+ println!("test");
+ Ok(())
+ }
+
+ fn on_frame(&mut self, f: frame::Frame) -> Result<Option<frame::Frame>> {
+ Ok(None)
+ }
+ }
+
+ #[test]
+ fn impl_factory() {
+ struct X;
+
+ impl Factory for X {
+ type Handler = M;
+ fn connection_made(&mut self, _: Sender) -> M {
+ M
+ }
+ }
+
+ let (chn, _) = mio::channel::sync_channel(42);
+
+ let mut x = X;
+ let m = x.connection_made(Sender::new(mio::Token(0), chn, 0));
+ assert_eq!(m, M);
+ }
+
+ #[test]
+ fn closure_factory() {
+ let (chn, _) = mio::channel::sync_channel(42);
+
+ let mut factory = |_| |_| Ok(());
+
+ factory.connection_made(Sender::new(mio::Token(0), chn, 0));
+ }
+
+ #[test]
+ fn connection_lost() {
+ struct X;
+
+ impl Factory for X {
+ type Handler = M;
+ fn connection_made(&mut self, _: Sender) -> M {
+ M
+ }
+ fn connection_lost(&mut self, handler: M) {
+ assert_eq!(handler, M);
+ }
+ }
+
+ let (chn, _) = mio::channel::sync_channel(42);
+
+ let mut x = X;
+ let m = x.connection_made(Sender::new(mio::Token(0), chn, 0));
+ x.connection_lost(m);
+ }
+}
diff --git a/third_party/rust/ws/src/frame.rs b/third_party/rust/ws/src/frame.rs
new file mode 100644
index 0000000000..154816c7ad
--- /dev/null
+++ b/third_party/rust/ws/src/frame.rs
@@ -0,0 +1,495 @@
+use std::default::Default;
+use std::fmt;
+use std::io::{Cursor, ErrorKind, Read, Write};
+
+use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
+use rand;
+
+use protocol::{CloseCode, OpCode};
+use result::{Error, Kind, Result};
+use stream::TryReadBuf;
+
+fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
+ let iter = buf.iter_mut().zip(mask.iter().cycle());
+ for (byte, &key) in iter {
+ *byte ^= key
+ }
+}
+
+/// A struct representing a WebSocket frame.
+#[derive(Debug, Clone)]
+pub struct Frame {
+ finished: bool,
+ rsv1: bool,
+ rsv2: bool,
+ rsv3: bool,
+ opcode: OpCode,
+
+ mask: Option<[u8; 4]>,
+
+ payload: Vec<u8>,
+}
+
+impl Frame {
+ /// Get the length of the frame.
+ /// This is the length of the header + the length of the payload.
+ #[inline]
+ pub fn len(&self) -> usize {
+ let mut header_length = 2;
+ let payload_len = self.payload().len();
+ if payload_len > 125 {
+ if payload_len <= u16::max_value() as usize {
+ header_length += 2;
+ } else {
+ header_length += 8;
+ }
+ }
+
+ if self.is_masked() {
+ header_length += 4;
+ }
+
+ header_length + payload_len
+ }
+
+ /// Return `false`: a frame is never empty since it has a header.
+ #[inline]
+ pub fn is_empty(&self) -> bool {
+ false
+ }
+
+ /// Test whether the frame is a final frame.
+ #[inline]
+ pub fn is_final(&self) -> bool {
+ self.finished
+ }
+
+ /// Test whether the first reserved bit is set.
+ #[inline]
+ pub fn has_rsv1(&self) -> bool {
+ self.rsv1
+ }
+
+ /// Test whether the second reserved bit is set.
+ #[inline]
+ pub fn has_rsv2(&self) -> bool {
+ self.rsv2
+ }
+
+ /// Test whether the third reserved bit is set.
+ #[inline]
+ pub fn has_rsv3(&self) -> bool {
+ self.rsv3
+ }
+
+ /// Get the OpCode of the frame.
+ #[inline]
+ pub fn opcode(&self) -> OpCode {
+ self.opcode
+ }
+
+ /// Test whether this is a control frame.
+ #[inline]
+ pub fn is_control(&self) -> bool {
+ self.opcode.is_control()
+ }
+
+ /// Get a reference to the frame's payload.
+ #[inline]
+ pub fn payload(&self) -> &Vec<u8> {
+ &self.payload
+ }
+
+ // Test whether the frame is masked.
+ #[doc(hidden)]
+ #[inline]
+ pub fn is_masked(&self) -> bool {
+ self.mask.is_some()
+ }
+
+ // Get an optional reference to the frame's mask.
+ #[doc(hidden)]
+ #[allow(dead_code)]
+ #[inline]
+ pub fn mask(&self) -> Option<&[u8; 4]> {
+ self.mask.as_ref()
+ }
+
+ /// Make this frame a final frame.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn set_final(&mut self, is_final: bool) -> &mut Frame {
+ self.finished = is_final;
+ self
+ }
+
+ /// Set the first reserved bit.
+ #[inline]
+ pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame {
+ self.rsv1 = has_rsv1;
+ self
+ }
+
+ /// Set the second reserved bit.
+ #[inline]
+ pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame {
+ self.rsv2 = has_rsv2;
+ self
+ }
+
+ /// Set the third reserved bit.
+ #[inline]
+ pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame {
+ self.rsv3 = has_rsv3;
+ self
+ }
+
+ /// Set the OpCode.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame {
+ self.opcode = opcode;
+ self
+ }
+
+ /// Edit the frame's payload.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn payload_mut(&mut self) -> &mut Vec<u8> {
+ &mut self.payload
+ }
+
+ // Generate a new mask for this frame.
+ //
+ // This method simply generates and stores the mask. It does not change the payload data.
+ // Instead, the payload data will be masked with the generated mask when the frame is sent
+ // to the other endpoint.
+ #[doc(hidden)]
+ #[inline]
+ pub fn set_mask(&mut self) -> &mut Frame {
+ self.mask = Some(rand::random());
+ self
+ }
+
+ // This method unmasks the payload and should only be called on frames that are actually
+ // masked. In other words, those frames that have just been received from a client endpoint.
+ #[doc(hidden)]
+ #[inline]
+ pub fn remove_mask(&mut self) -> &mut Frame {
+ self.mask
+ .take()
+ .map(|mask| apply_mask(&mut self.payload, &mask));
+ self
+ }
+
+ /// Consume the frame into its payload.
+ pub fn into_data(self) -> Vec<u8> {
+ self.payload
+ }
+
+ /// Create a new data frame.
+ #[inline]
+ pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame {
+ debug_assert!(
+ match code {
+ OpCode::Text | OpCode::Binary | OpCode::Continue => true,
+ _ => false,
+ },
+ "Invalid opcode for data frame."
+ );
+
+ Frame {
+ finished,
+ opcode: code,
+ payload: data,
+ ..Frame::default()
+ }
+ }
+
+ /// Create a new Pong control frame.
+ #[inline]
+ pub fn pong(data: Vec<u8>) -> Frame {
+ Frame {
+ opcode: OpCode::Pong,
+ payload: data,
+ ..Frame::default()
+ }
+ }
+
+ /// Create a new Ping control frame.
+ #[inline]
+ pub fn ping(data: Vec<u8>) -> Frame {
+ Frame {
+ opcode: OpCode::Ping,
+ payload: data,
+ ..Frame::default()
+ }
+ }
+
+ /// Create a new Close control frame.
+ #[inline]
+ pub fn close(code: CloseCode, reason: &str) -> Frame {
+ let payload = if let CloseCode::Empty = code {
+ Vec::new()
+ } else {
+ let u: u16 = code.into();
+ let raw = [(u >> 8) as u8, u as u8];
+ [&raw, reason.as_bytes()].concat()
+ };
+
+ Frame {
+ payload,
+ ..Frame::default()
+ }
+ }
+
+ /// Parse the input stream into a frame.
+ pub fn parse(cursor: &mut Cursor<Vec<u8>>, max_payload_length: u64) -> Result<Option<Frame>> {
+ let size = cursor.get_ref().len() as u64 - cursor.position();
+ let initial = cursor.position();
+ trace!("Position in buffer {}", initial);
+
+ let mut head = [0u8; 2];
+ if cursor.read(&mut head)? != 2 {
+ cursor.set_position(initial);
+ return Ok(None);
+ }
+
+ trace!("Parsed headers {:?}", head);
+
+ let first = head[0];
+ let second = head[1];
+ trace!("First: {:b}", first);
+ trace!("Second: {:b}", second);
+
+ let finished = first & 0x80 != 0;
+
+ let rsv1 = first & 0x40 != 0;
+ let rsv2 = first & 0x20 != 0;
+ let rsv3 = first & 0x10 != 0;
+
+ let opcode = OpCode::from(first & 0x0F);
+ trace!("Opcode: {:?}", opcode);
+
+ let masked = second & 0x80 != 0;
+ trace!("Masked: {:?}", masked);
+
+ let mut header_length = 2;
+
+ let mut length = u64::from(second & 0x7F);
+
+ if let Some(length_nbytes) = match length {
+ 126 => Some(2),
+ 127 => Some(8),
+ _ => None,
+ } {
+ match cursor.read_uint::<BigEndian>(length_nbytes) {
+ Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
+ cursor.set_position(initial);
+ return Ok(None);
+ }
+ Err(err) => {
+ return Err(Error::from(err));
+ }
+ Ok(read) => {
+ length = read;
+ }
+ };
+ header_length += length_nbytes as u64;
+ }
+ trace!("Payload length: {}", length);
+
+ if length > max_payload_length {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!(
+ "Rejected frame with payload length exceeding defined max: {}.",
+ max_payload_length
+ ),
+ ));
+ }
+
+ let mask = if masked {
+ let mut mask_bytes = [0u8; 4];
+ if cursor.read(&mut mask_bytes)? != 4 {
+ cursor.set_position(initial);
+ return Ok(None);
+ } else {
+ header_length += 4;
+ Some(mask_bytes)
+ }
+ } else {
+ None
+ };
+
+ match length.checked_add(header_length) {
+ Some(l) if size < l => {
+ cursor.set_position(initial);
+ return Ok(None);
+ }
+ Some(_) => (),
+ None => return Ok(None),
+ };
+
+ let mut data = Vec::with_capacity(length as usize);
+ if length > 0 {
+ if let Some(read) = cursor.try_read_buf(&mut data)? {
+ debug_assert!(read == length as usize, "Read incorrect payload length!");
+ }
+ }
+
+ // Disallow bad opcode
+ if let OpCode::Bad = opcode {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!("Encountered invalid opcode: {}", first & 0x0F),
+ ));
+ }
+
+ // control frames must have length <= 125
+ match opcode {
+ OpCode::Ping | OpCode::Pong if length > 125 => {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!(
+ "Rejected WebSocket handshake.Received control frame with length: {}.",
+ length
+ ),
+ ))
+ }
+ OpCode::Close if length > 125 => {
+ debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
+ return Ok(Some(Frame::close(
+ CloseCode::Protocol,
+ "Received close frame with payload length exceeding 125.",
+ )));
+ }
+ _ => (),
+ }
+
+ let frame = Frame {
+ finished,
+ rsv1,
+ rsv2,
+ rsv3,
+ opcode,
+ mask,
+ payload: data,
+ };
+
+ Ok(Some(frame))
+ }
+
+ /// Write a frame out to a buffer
+ pub fn format<W>(&mut self, w: &mut W) -> Result<()>
+ where
+ W: Write,
+ {
+ let mut one = 0u8;
+ let code: u8 = self.opcode.into();
+ if self.is_final() {
+ one |= 0x80;
+ }
+ if self.has_rsv1() {
+ one |= 0x40;
+ }
+ if self.has_rsv2() {
+ one |= 0x20;
+ }
+ if self.has_rsv3() {
+ one |= 0x10;
+ }
+ one |= code;
+
+ let mut two = 0u8;
+ if self.is_masked() {
+ two |= 0x80;
+ }
+
+ match self.payload.len() {
+ len if len < 126 => {
+ two |= len as u8;
+ }
+ len if len <= 65535 => {
+ two |= 126;
+ }
+ _ => {
+ two |= 127;
+ }
+ }
+ w.write_all(&[one, two])?;
+
+ if let Some(length_bytes) = match self.payload.len() {
+ len if len < 126 => None,
+ len if len <= 65535 => Some(2),
+ _ => Some(8),
+ } {
+ w.write_uint::<BigEndian>(self.payload.len() as u64, length_bytes)?;
+ }
+
+ if self.is_masked() {
+ let mask = self.mask.take().unwrap();
+ apply_mask(&mut self.payload, &mask);
+ w.write_all(&mask)?;
+ }
+
+ w.write_all(&self.payload)?;
+ Ok(())
+ }
+}
+
+impl Default for Frame {
+ fn default() -> Frame {
+ Frame {
+ finished: true,
+ rsv1: false,
+ rsv2: false,
+ rsv3: false,
+ opcode: OpCode::Close,
+ mask: None,
+ payload: Vec::new(),
+ }
+ }
+}
+
+impl fmt::Display for Frame {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "
+<FRAME>
+final: {}
+reserved: {} {} {}
+opcode: {}
+length: {}
+payload length: {}
+payload: 0x{}
+ ",
+ self.finished,
+ self.rsv1,
+ self.rsv2,
+ self.rsv3,
+ self.opcode,
+ // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
+ self.len(),
+ self.payload.len(),
+ self.payload
+ .iter()
+ .map(|byte| format!("{:x}", byte))
+ .collect::<String>()
+ )
+ }
+}
+
+mod test {
+ #![allow(unused_imports, unused_variables, dead_code)]
+ use super::*;
+ use protocol::OpCode;
+
+ #[test]
+ fn display_frame() {
+ let f = Frame::message("hi there".into(), OpCode::Text, true);
+ let view = format!("{}", f);
+ view.contains("payload:");
+ }
+}
diff --git a/third_party/rust/ws/src/handler.rs b/third_party/rust/ws/src/handler.rs
new file mode 100644
index 0000000000..7c1ef21b4b
--- /dev/null
+++ b/third_party/rust/ws/src/handler.rs
@@ -0,0 +1,423 @@
+use log::Level::Error as ErrorLevel;
+#[cfg(feature = "nativetls")]
+use native_tls::{TlsConnector, TlsStream as SslStream};
+#[cfg(feature = "ssl")]
+use openssl::ssl::{SslConnector, SslMethod, SslStream};
+use url;
+
+use frame::Frame;
+use handshake::{Handshake, Request, Response};
+use message::Message;
+use protocol::CloseCode;
+use result::{Error, Kind, Result};
+use util::{Timeout, Token};
+
+#[cfg(any(feature = "ssl", feature = "nativetls"))]
+use util::TcpStream;
+
+/// The core trait of this library.
+/// Implementing this trait provides the business logic of the WebSocket application.
+pub trait Handler {
+ // general
+
+ /// Called when a request to shutdown all connections has been received.
+ #[inline]
+ fn on_shutdown(&mut self) {
+ debug!("Handler received WebSocket shutdown request.");
+ }
+
+ // WebSocket events
+
+ /// Called when the WebSocket handshake is successful and the connection is open for sending
+ /// and receiving messages.
+ fn on_open(&mut self, shake: Handshake) -> Result<()> {
+ if let Some(addr) = shake.remote_addr()? {
+ debug!("Connection with {} now open", addr);
+ }
+ Ok(())
+ }
+
+ /// Called on incoming messages.
+ fn on_message(&mut self, msg: Message) -> Result<()> {
+ debug!("Received message {:?}", msg);
+ Ok(())
+ }
+
+ /// Called any time this endpoint receives a close control frame.
+ /// This may be because the other endpoint is initiating a closing handshake,
+ /// or it may be the other endpoint confirming the handshake initiated by this endpoint.
+ fn on_close(&mut self, code: CloseCode, reason: &str) {
+ debug!("Connection closing due to ({:?}) {}", code, reason);
+ }
+
+ /// Called when an error occurs on the WebSocket.
+ fn on_error(&mut self, err: Error) {
+ // Ignore connection reset errors by default, but allow library clients to see them by
+ // overriding this method if they want
+ if let Kind::Io(ref err) = err.kind {
+ if let Some(104) = err.raw_os_error() {
+ return;
+ }
+ }
+
+ error!("{:?}", err);
+ if !log_enabled!(ErrorLevel) {
+ println!(
+ "Encountered an error: {}\nEnable a logger to see more information.",
+ err
+ );
+ }
+ }
+
+ // handshake events
+
+ /// A method for handling the low-level workings of the request portion of the WebSocket
+ /// handshake.
+ ///
+ /// Implementors should select a WebSocket protocol and extensions where they are supported.
+ ///
+ /// Implementors can inspect the Request and must return a Response or an error
+ /// indicating that the handshake failed. The default implementation provides conformance with
+ /// the WebSocket protocol, and implementors should use the `Response::from_request` method and
+ /// then modify the resulting response as necessary in order to maintain conformance.
+ ///
+ /// This method will not be called when the handler represents a client endpoint. Use
+ /// `build_request` to provide an initial handshake request.
+ ///
+ /// # Examples
+ ///
+ /// ```ignore
+ /// let mut res = try!(Response::from_request(req));
+ /// if try!(req.extensions()).iter().find(|&&ext| ext.contains("myextension-name")).is_some() {
+ /// res.add_extension("myextension-name")
+ /// }
+ /// Ok(res)
+ /// ```
+ #[inline]
+ fn on_request(&mut self, req: &Request) -> Result<Response> {
+ debug!("Handler received request:\n{}", req);
+ Response::from_request(req)
+ }
+
+ /// A method for handling the low-level workings of the response portion of the WebSocket
+ /// handshake.
+ ///
+ /// Implementors can inspect the Response and choose to fail the connection by
+ /// returning an error. This method will not be called when the handler represents a server
+ /// endpoint. The response should indicate which WebSocket protocol and extensions the server
+ /// has agreed to if any.
+ #[inline]
+ fn on_response(&mut self, res: &Response) -> Result<()> {
+ debug!("Handler received response:\n{}", res);
+ Ok(())
+ }
+
+ // timeout events
+
+ /// Called when a timeout is triggered.
+ ///
+ /// This method will be called when the eventloop encounters a timeout on the specified
+ /// token. To schedule a timeout with your specific token use the `Sender::timeout` method.
+ ///
+ /// # Examples
+ ///
+ /// ```ignore
+ /// const GRATI: Token = Token(1);
+ ///
+ /// ... Handler
+ ///
+ /// fn on_open(&mut self, _: Handshake) -> Result<()> {
+ /// // schedule a timeout to send a gratuitous pong every 5 seconds
+ /// self.ws.timeout(5_000, GRATI)
+ /// }
+ ///
+ /// fn on_timeout(&mut self, event: Token) -> Result<()> {
+ /// if event == GRATI {
+ /// // send gratuitous pong
+ /// try!(self.ws.pong(vec![]))
+ /// // reschedule the timeout
+ /// self.ws.timeout(5_000, GRATI)
+ /// } else {
+ /// Err(Error::new(ErrorKind::Internal, "Invalid timeout token encountered!"))
+ /// }
+ /// }
+ /// ```
+ #[inline]
+ fn on_timeout(&mut self, event: Token) -> Result<()> {
+ debug!("Handler received timeout token: {:?}", event);
+ Ok(())
+ }
+
+ /// Called when a timeout has been scheduled on the eventloop.
+ ///
+ /// This method is the hook for obtaining a Timeout object that may be used to cancel a
+ /// timeout. This is a noop by default.
+ ///
+ /// # Examples
+ ///
+ /// ```ignore
+ /// const PING: Token = Token(1);
+ /// const EXPIRE: Token = Token(2);
+ ///
+ /// ... Handler
+ ///
+ /// fn on_open(&mut self, _: Handshake) -> Result<()> {
+ /// // schedule a timeout to send a ping every 5 seconds
+ /// try!(self.ws.timeout(5_000, PING));
+ /// // schedule a timeout to close the connection if there is no activity for 30 seconds
+ /// self.ws.timeout(30_000, EXPIRE)
+ /// }
+ ///
+ /// fn on_timeout(&mut self, event: Token) -> Result<()> {
+ /// match event {
+ /// PING => {
+ /// self.ws.ping(vec![]);
+ /// self.ws.timeout(5_000, PING)
+ /// }
+ /// EXPIRE => self.ws.close(CloseCode::Away),
+ /// _ => Err(Error::new(ErrorKind::Internal, "Invalid timeout token encountered!")),
+ /// }
+ /// }
+ ///
+ /// fn on_new_timeout(&mut self, event: Token, timeout: Timeout) -> Result<()> {
+ /// if event == EXPIRE {
+ /// if let Some(t) = self.timeout.take() {
+ /// try!(self.ws.cancel(t))
+ /// }
+ /// self.timeout = Some(timeout)
+ /// }
+ /// Ok(())
+ /// }
+ ///
+ /// fn on_frame(&mut self, frame: Frame) -> Result<Option<Frame>> {
+ /// // some activity has occurred, let's reset the expiration
+ /// try!(self.ws.timeout(30_000, EXPIRE));
+ /// Ok(Some(frame))
+ /// }
+ /// ```
+ #[inline]
+ fn on_new_timeout(&mut self, _: Token, _: Timeout) -> Result<()> {
+ // default implementation discards the timeout handle
+ Ok(())
+ }
+
+ // frame events
+
+ /// A method for handling incoming frames.
+ ///
+ /// This method provides very low-level access to the details of the WebSocket protocol. It may
+ /// be necessary to implement this method in order to provide a particular extension, but
+ /// incorrect implementation may cause the other endpoint to fail the connection.
+ ///
+ /// Returning `Ok(None)` will cause the connection to forget about a particular frame. This is
+ /// useful if you want ot filter out a frame or if you don't want any of the default handler
+ /// methods to run.
+ ///
+ /// By default this method simply ensures that no reserved bits are set.
+ #[inline]
+ fn on_frame(&mut self, frame: Frame) -> Result<Option<Frame>> {
+ debug!("Handler received: {}", frame);
+ // default implementation doesn't allow for reserved bits to be set
+ if frame.has_rsv1() || frame.has_rsv2() || frame.has_rsv3() {
+ Err(Error::new(
+ Kind::Protocol,
+ "Encountered frame with reserved bits set.",
+ ))
+ } else {
+ Ok(Some(frame))
+ }
+ }
+
+ /// A method for handling outgoing frames.
+ ///
+ /// This method provides very low-level access to the details of the WebSocket protocol. It may
+ /// be necessary to implement this method in order to provide a particular extension, but
+ /// incorrect implementation may cause the other endpoint to fail the connection.
+ ///
+ /// Returning `Ok(None)` will cause the connection to forget about a particular frame, meaning
+ /// that it will not be sent. You can use this approach to merge multiple frames into a single
+ /// frame before sending the message.
+ ///
+ /// For messages, this method will be called with a single complete, final frame before any
+ /// fragmentation is performed. Automatic fragmentation will be performed on the returned
+ /// frame, if any, based on the `fragment_size` setting.
+ ///
+ /// By default this method simply ensures that no reserved bits are set.
+ #[inline]
+ fn on_send_frame(&mut self, frame: Frame) -> Result<Option<Frame>> {
+ trace!("Handler will send: {}", frame);
+ // default implementation doesn't allow for reserved bits to be set
+ if frame.has_rsv1() || frame.has_rsv2() || frame.has_rsv3() {
+ Err(Error::new(
+ Kind::Protocol,
+ "Encountered frame with reserved bits set.",
+ ))
+ } else {
+ Ok(Some(frame))
+ }
+ }
+
+ // constructors
+
+ /// A method for creating the initial handshake request for WebSocket clients.
+ ///
+ /// The default implementation provides conformance with the WebSocket protocol, but this
+ /// method may be overridden. In order to facilitate conformance,
+ /// implementors should use the `Request::from_url` method and then modify the resulting
+ /// request as necessary.
+ ///
+ /// Implementors should indicate any available WebSocket extensions here.
+ ///
+ /// # Examples
+ /// ```ignore
+ /// let mut req = try!(Request::from_url(url));
+ /// req.add_extension("permessage-deflate; client_max_window_bits");
+ /// Ok(req)
+ /// ```
+ #[inline]
+ fn build_request(&mut self, url: &url::Url) -> Result<Request> {
+ trace!("Handler is building request to {}.", url);
+ Request::from_url(url)
+ }
+
+ /// A method for wrapping a client TcpStream with Ssl Authentication machinery
+ ///
+ /// Override this method to customize how the connection is encrypted. By default
+ /// this will use the Server Name Indication extension in conformance with RFC6455.
+ #[inline]
+ #[cfg(feature = "ssl")]
+ fn upgrade_ssl_client(
+ &mut self,
+ stream: TcpStream,
+ url: &url::Url,
+ ) -> Result<SslStream<TcpStream>> {
+ let domain = url.domain().ok_or(Error::new(
+ Kind::Protocol,
+ format!("Unable to parse domain from {}. Needed for SSL.", url),
+ ))?;
+ let connector = SslConnector::builder(SslMethod::tls())
+ .map_err(|e| {
+ Error::new(
+ Kind::Internal,
+ format!("Failed to upgrade client to SSL: {}", e),
+ )
+ })?
+ .build();
+ connector.connect(domain, stream).map_err(Error::from)
+ }
+
+ #[inline]
+ #[cfg(feature = "nativetls")]
+ fn upgrade_ssl_client(
+ &mut self,
+ stream: TcpStream,
+ url: &url::Url,
+ ) -> Result<SslStream<TcpStream>> {
+ let domain = url.domain().ok_or(Error::new(
+ Kind::Protocol,
+ format!("Unable to parse domain from {}. Needed for SSL.", url),
+ ))?;
+
+ let connector = TlsConnector::new().map_err(|e| {
+ Error::new(
+ Kind::Internal,
+ format!("Failed to upgrade client to SSL: {}", e),
+ )
+ })?;
+
+ connector.connect(domain, stream).map_err(Error::from)
+ }
+ /// A method for wrapping a server TcpStream with Ssl Authentication machinery
+ ///
+ /// Override this method to customize how the connection is encrypted. By default
+ /// this method is not implemented.
+ #[inline]
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ fn upgrade_ssl_server(&mut self, _: TcpStream) -> Result<SslStream<TcpStream>> {
+ unimplemented!()
+ }
+}
+
+impl<F> Handler for F
+where
+ F: Fn(Message) -> Result<()>,
+{
+ fn on_message(&mut self, msg: Message) -> Result<()> {
+ self(msg)
+ }
+}
+
+mod test {
+ #![allow(unused_imports, unused_variables, dead_code)]
+ use super::*;
+ use frame;
+ use handshake::{Handshake, Request, Response};
+ use message;
+ use mio;
+ use protocol::CloseCode;
+ use result::Result;
+ use url;
+
+ #[derive(Debug, Eq, PartialEq)]
+ struct M;
+ impl Handler for M {
+ fn on_message(&mut self, _: message::Message) -> Result<()> {
+ println!("test");
+ Ok(())
+ }
+
+ fn on_frame(&mut self, f: frame::Frame) -> Result<Option<frame::Frame>> {
+ Ok(None)
+ }
+ }
+
+ #[test]
+ fn handler() {
+ struct H;
+
+ impl Handler for H {
+ fn on_open(&mut self, shake: Handshake) -> Result<()> {
+ assert!(shake.request.key().is_ok());
+ assert!(shake.response.key().is_ok());
+ Ok(())
+ }
+
+ fn on_message(&mut self, msg: message::Message) -> Result<()> {
+ Ok(assert_eq!(
+ msg,
+ message::Message::Text(String::from("testme"))
+ ))
+ }
+
+ fn on_close(&mut self, code: CloseCode, _: &str) {
+ assert_eq!(code, CloseCode::Normal)
+ }
+ }
+
+ let mut h = H;
+ let url = url::Url::parse("wss://127.0.0.1:3012").unwrap();
+ let req = Request::from_url(&url).unwrap();
+ let res = Response::from_request(&req).unwrap();
+ h.on_open(Handshake {
+ request: req,
+ response: res,
+ peer_addr: None,
+ local_addr: None,
+ }).unwrap();
+ h.on_message(message::Message::Text("testme".to_owned()))
+ .unwrap();
+ h.on_close(CloseCode::Normal, "");
+ }
+
+ #[test]
+ fn closure_handler() {
+ let mut close = |msg| {
+ assert_eq!(msg, message::Message::Binary(vec![1, 2, 3]));
+ Ok(())
+ };
+
+ close
+ .on_message(message::Message::Binary(vec![1, 2, 3]))
+ .unwrap();
+ }
+}
diff --git a/third_party/rust/ws/src/handshake.rs b/third_party/rust/ws/src/handshake.rs
new file mode 100644
index 0000000000..b7520bde56
--- /dev/null
+++ b/third_party/rust/ws/src/handshake.rs
@@ -0,0 +1,740 @@
+use std::fmt;
+use std::io::Write;
+use std::net::SocketAddr;
+use std::str::from_utf8;
+
+use httparse;
+use rand;
+use sha1::{self, Digest};
+use url;
+
+use result::{Error, Kind, Result};
+
+static WS_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+static BASE64: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+const MAX_HEADERS: usize = 124;
+
+fn generate_key() -> String {
+ let key: [u8; 16] = rand::random();
+ encode_base64(&key)
+}
+
+pub fn hash_key(key: &[u8]) -> String {
+ let mut hasher = sha1::Sha1::new();
+
+ hasher.input(key);
+ hasher.input(WS_GUID.as_bytes());
+
+ encode_base64(&hasher.result())
+}
+
+// This code is based on rustc_serialize base64 STANDARD
+fn encode_base64(data: &[u8]) -> String {
+ let len = data.len();
+ let mod_len = len % 3;
+
+ let mut encoded = vec![b'='; (len + 2) / 3 * 4];
+ {
+ let mut in_iter = data[..len - mod_len].iter().map(|&c| u32::from(c));
+ let mut out_iter = encoded.iter_mut();
+
+ let enc = |val| BASE64[val as usize];
+ let mut write = |val| *out_iter.next().unwrap() = val;
+
+ while let (Some(one), Some(two), Some(three)) =
+ (in_iter.next(), in_iter.next(), in_iter.next())
+ {
+ let g24 = one << 16 | two << 8 | three;
+ write(enc((g24 >> 18) & 63));
+ write(enc((g24 >> 12) & 63));
+ write(enc((g24 >> 6) & 63));
+ write(enc(g24 & 63));
+ }
+
+ match mod_len {
+ 1 => {
+ let pad = (u32::from(data[len - 1])) << 16;
+ write(enc((pad >> 18) & 63));
+ write(enc((pad >> 12) & 63));
+ }
+ 2 => {
+ let pad = (u32::from(data[len - 2])) << 16 | (u32::from(data[len - 1])) << 8;
+ write(enc((pad >> 18) & 63));
+ write(enc((pad >> 12) & 63));
+ write(enc((pad >> 6) & 63));
+ }
+ _ => (),
+ }
+ }
+
+ String::from_utf8(encoded).unwrap()
+}
+
+/// A struct representing the two halves of the WebSocket handshake.
+#[derive(Debug)]
+pub struct Handshake {
+ /// The HTTP request sent to begin the handshake.
+ pub request: Request,
+ /// The HTTP response from the server confirming the handshake.
+ pub response: Response,
+ /// The socket address of the other endpoint. This address may
+ /// be an intermediary such as a proxy server.
+ pub peer_addr: Option<SocketAddr>,
+ /// The socket address of this endpoint.
+ pub local_addr: Option<SocketAddr>,
+}
+
+impl Handshake {
+ /// Get the IP address of the remote connection.
+ ///
+ /// This is the preferred method of obtaining the client's IP address.
+ /// It will attempt to retrieve the most likely IP address based on request
+ /// headers, falling back to the address of the peer.
+ ///
+ /// # Note
+ /// This assumes that the peer is a client. If you are implementing a
+ /// WebSocket client and want to obtain the address of the server, use
+ /// `Handshake::peer_addr` instead.
+ ///
+ /// This method does not ensure that the address is a valid IP address.
+ #[allow(dead_code)]
+ pub fn remote_addr(&self) -> Result<Option<String>> {
+ Ok(self.request.client_addr()?.map(String::from).or_else(|| {
+ if let Some(addr) = self.peer_addr {
+ Some(addr.ip().to_string())
+ } else {
+ None
+ }
+ }))
+ }
+}
+
+/// The handshake request.
+#[derive(Debug)]
+pub struct Request {
+ path: String,
+ method: String,
+ headers: Vec<(String, Vec<u8>)>,
+}
+
+impl Request {
+ /// Get the value of the first instance of an HTTP header.
+ pub fn header(&self, header: &str) -> Option<&Vec<u8>> {
+ self.headers
+ .iter()
+ .find(|&&(ref key, _)| key.to_lowercase() == header.to_lowercase())
+ .map(|&(_, ref val)| val)
+ }
+
+ /// Edit the value of the first instance of an HTTP header.
+ pub fn header_mut(&mut self, header: &str) -> Option<&mut Vec<u8>> {
+ self.headers
+ .iter_mut()
+ .find(|&&mut (ref key, _)| key.to_lowercase() == header.to_lowercase())
+ .map(|&mut (_, ref mut val)| val)
+ }
+
+ /// Access the request headers.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn headers(&self) -> &Vec<(String, Vec<u8>)> {
+ &self.headers
+ }
+
+ /// Edit the request headers.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn headers_mut(&mut self) -> &mut Vec<(String, Vec<u8>)> {
+ &mut self.headers
+ }
+
+ /// Get the origin of the request if it comes from a browser.
+ #[allow(dead_code)]
+ pub fn origin(&self) -> Result<Option<&str>> {
+ if let Some(origin) = self.header("origin") {
+ Ok(Some(from_utf8(origin)?))
+ } else {
+ Ok(None)
+ }
+ }
+
+ /// Get the unhashed WebSocket key sent in the request.
+ pub fn key(&self) -> Result<&Vec<u8>> {
+ self.header("sec-websocket-key")
+ .ok_or_else(|| Error::new(Kind::Protocol, "Unable to parse WebSocket key."))
+ }
+
+ /// Get the hashed WebSocket key from this request.
+ pub fn hashed_key(&self) -> Result<String> {
+ Ok(hash_key(self.key()?))
+ }
+
+ /// Get the WebSocket protocol version from the request (should be 13).
+ #[allow(dead_code)]
+ pub fn version(&self) -> Result<&str> {
+ if let Some(version) = self.header("sec-websocket-version") {
+ from_utf8(version).map_err(Error::from)
+ } else {
+ Err(Error::new(
+ Kind::Protocol,
+ "The Sec-WebSocket-Version header is missing.",
+ ))
+ }
+ }
+
+ /// Get the request method.
+ #[inline]
+ pub fn method(&self) -> &str {
+ &self.method
+ }
+
+ /// Get the path of the request.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn resource(&self) -> &str {
+ &self.path
+ }
+
+ /// Get the possible protocols for the WebSocket connection.
+ #[allow(dead_code)]
+ pub fn protocols(&self) -> Result<Vec<&str>> {
+ if let Some(protos) = self.header("sec-websocket-protocol") {
+ Ok(from_utf8(protos)?
+ .split(',')
+ .map(|proto| proto.trim())
+ .collect())
+ } else {
+ Ok(Vec::new())
+ }
+ }
+
+ /// Add a possible protocol to this request.
+ /// This may result in duplicate protocols listed.
+ #[allow(dead_code)]
+ pub fn add_protocol(&mut self, protocol: &str) {
+ if let Some(protos) = self.header_mut("sec-websocket-protocol") {
+ protos.push(b","[0]);
+ protos.extend(protocol.as_bytes());
+ return;
+ }
+ self.headers_mut()
+ .push(("Sec-WebSocket-Protocol".into(), protocol.into()))
+ }
+
+ /// Remove a possible protocol from this request.
+ #[allow(dead_code)]
+ pub fn remove_protocol(&mut self, protocol: &str) {
+ if let Some(protos) = self.header_mut("sec-websocket-protocol") {
+ let mut new_protos = Vec::with_capacity(protos.len());
+
+ if let Ok(protos_str) = from_utf8(protos) {
+ new_protos = protos_str
+ .split(',')
+ .filter(|proto| proto.trim() == protocol)
+ .collect::<Vec<&str>>()
+ .join(",")
+ .into();
+ }
+ if new_protos.len() < protos.len() {
+ *protos = new_protos
+ }
+ }
+ }
+
+ /// Get the possible extensions for the WebSocket connection.
+ #[allow(dead_code)]
+ pub fn extensions(&self) -> Result<Vec<&str>> {
+ if let Some(exts) = self.header("sec-websocket-extensions") {
+ Ok(from_utf8(exts)?.split(',').map(|ext| ext.trim()).collect())
+ } else {
+ Ok(Vec::new())
+ }
+ }
+
+ /// Add a possible extension to this request.
+ /// This may result in duplicate extensions listed. Also, the order of extensions
+ /// indicates preference, so if the preference matters, consider using the
+ /// `Sec-WebSocket-Protocol` header directly.
+ #[allow(dead_code)]
+ pub fn add_extension(&mut self, ext: &str) {
+ if let Some(exts) = self.header_mut("sec-websocket-extensions") {
+ exts.push(b","[0]);
+ exts.extend(ext.as_bytes());
+ return;
+ }
+ self.headers_mut()
+ .push(("Sec-WebSocket-Extensions".into(), ext.into()))
+ }
+
+ /// Remove a possible extension from this request.
+ /// This will remove all configurations of the extension.
+ #[allow(dead_code)]
+ pub fn remove_extension(&mut self, ext: &str) {
+ if let Some(exts) = self.header_mut("sec-websocket-extensions") {
+ let mut new_exts = Vec::with_capacity(exts.len());
+
+ if let Ok(exts_str) = from_utf8(exts) {
+ new_exts = exts_str
+ .split(',')
+ .filter(|e| e.trim().starts_with(ext))
+ .collect::<Vec<&str>>()
+ .join(",")
+ .into();
+ }
+ if new_exts.len() < exts.len() {
+ *exts = new_exts
+ }
+ }
+ }
+
+ /// Get the IP address of the client.
+ ///
+ /// This method will attempt to retrieve the most likely IP address of the requester
+ /// in the following manner:
+ ///
+ /// If the `X-Forwarded-For` header exists, this method will return the left most
+ /// address in the list.
+ ///
+ /// If the [Forwarded HTTP Header Field](https://tools.ietf.org/html/rfc7239) exits,
+ /// this method will return the left most address indicated by the `for` parameter,
+ /// if it exists.
+ ///
+ /// # Note
+ /// This method does not ensure that the address is a valid IP address.
+ #[allow(dead_code)]
+ pub fn client_addr(&self) -> Result<Option<&str>> {
+ if let Some(x_forward) = self.header("x-forwarded-for") {
+ return Ok(from_utf8(x_forward)?.split(',').next());
+ }
+
+ // We only care about the first forwarded header, so header is ok
+ if let Some(forward) = self.header("forwarded") {
+ if let Some(_for) = from_utf8(forward)?
+ .split(';')
+ .find(|f| f.trim().starts_with("for"))
+ {
+ if let Some(_for_eq) = _for.trim().split(',').next() {
+ let mut it = _for_eq.split('=');
+ it.next();
+ return Ok(it.next());
+ }
+ }
+ }
+ Ok(None)
+ }
+
+ /// Attempt to parse an HTTP request from a buffer. If the buffer does not contain a complete
+ /// request, this will return `Ok(None)`.
+ pub fn parse(buf: &[u8]) -> Result<Option<Request>> {
+ let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
+ let mut req = httparse::Request::new(&mut headers);
+ let parsed = req.parse(buf)?;
+ if !parsed.is_partial() {
+ Ok(Some(Request {
+ path: req.path.unwrap().into(),
+ method: req.method.unwrap().into(),
+ headers: req.headers
+ .iter()
+ .map(|h| (h.name.into(), h.value.into()))
+ .collect(),
+ }))
+ } else {
+ Ok(None)
+ }
+ }
+
+ /// Construct a new WebSocket handshake HTTP request from a url.
+ pub fn from_url(url: &url::Url) -> Result<Request> {
+ let query = if let Some(q) = url.query() {
+ format!("?{}", q)
+ } else {
+ "".into()
+ };
+
+ let mut headers = vec![
+ ("Connection".into(), "Upgrade".into()),
+ (
+ "Host".into(),
+ format!(
+ "{}:{}",
+ url.host_str().ok_or_else(|| Error::new(
+ Kind::Internal,
+ "No host passed for WebSocket connection.",
+ ))?,
+ url.port_or_known_default().unwrap_or(80)
+ ).into(),
+ ),
+ ("Sec-WebSocket-Version".into(), "13".into()),
+ ("Sec-WebSocket-Key".into(), generate_key().into()),
+ ("Upgrade".into(), "websocket".into()),
+ ];
+
+ if url.password().is_some() || url.username() != "" {
+ let basic = encode_base64(format!("{}:{}", url.username(), url.password().unwrap_or("")).as_bytes());
+ headers.push(("Authorization".into(), format!("Basic {}", basic).into()))
+ }
+
+ let req = Request {
+ path: format!("{}{}", url.path(), query),
+ method: "GET".to_owned(),
+ headers: headers,
+ };
+
+ debug!("Built request from URL:\n{}", req);
+
+ Ok(req)
+ }
+
+ /// Write a request out to a buffer
+ pub fn format<W>(&self, w: &mut W) -> Result<()>
+ where
+ W: Write,
+ {
+ write!(w, "{} {} HTTP/1.1\r\n", self.method, self.path)?;
+ for &(ref key, ref val) in &self.headers {
+ write!(w, "{}: ", key)?;
+ w.write_all(val)?;
+ write!(w, "\r\n")?;
+ }
+ write!(w, "\r\n")?;
+ Ok(())
+ }
+}
+
+impl fmt::Display for Request {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let mut s = Vec::with_capacity(2048);
+ self.format(&mut s).map_err(|err| {
+ error!("{:?}", err);
+ fmt::Error
+ })?;
+ write!(
+ f,
+ "{}",
+ from_utf8(&s).map_err(|err| {
+ error!("Unable to format request as utf8: {:?}", err);
+ fmt::Error
+ })?
+ )
+ }
+}
+
+/// The handshake response.
+#[derive(Debug)]
+pub struct Response {
+ status: u16,
+ reason: String,
+ headers: Vec<(String, Vec<u8>)>,
+ body: Vec<u8>,
+}
+
+impl Response {
+ // TODO: resolve the overlap with Request
+
+ /// Construct a generic HTTP response with a body.
+ pub fn new<R>(status: u16, reason: R, body: Vec<u8>) -> Response
+ where
+ R: Into<String>,
+ {
+ Response {
+ status,
+ reason: reason.into(),
+ headers: vec![("Content-Length".into(), body.len().to_string().into())],
+ body,
+ }
+ }
+
+ /// Get the response body.
+ #[inline]
+ pub fn body(&self) -> &[u8] {
+ &self.body
+ }
+
+ /// Get the value of the first instance of an HTTP header.
+ fn header(&self, header: &str) -> Option<&Vec<u8>> {
+ self.headers
+ .iter()
+ .find(|&&(ref key, _)| key.to_lowercase() == header.to_lowercase())
+ .map(|&(_, ref val)| val)
+ }
+ /// Edit the value of the first instance of an HTTP header.
+ pub fn header_mut(&mut self, header: &str) -> Option<&mut Vec<u8>> {
+ self.headers
+ .iter_mut()
+ .find(|&&mut (ref key, _)| key.to_lowercase() == header.to_lowercase())
+ .map(|&mut (_, ref mut val)| val)
+ }
+
+ /// Access the request headers.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn headers(&self) -> &Vec<(String, Vec<u8>)> {
+ &self.headers
+ }
+
+ /// Edit the request headers.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn headers_mut(&mut self) -> &mut Vec<(String, Vec<u8>)> {
+ &mut self.headers
+ }
+
+ /// Get the HTTP status code.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn status(&self) -> u16 {
+ self.status
+ }
+
+ /// Set the HTTP status code.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn set_status(&mut self, status: u16) {
+ self.status = status
+ }
+
+ /// Get the HTTP status reason.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn reason(&self) -> &str {
+ &self.reason
+ }
+
+ /// Set the HTTP status reason.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn set_reason<R>(&mut self, reason: R)
+ where
+ R: Into<String>,
+ {
+ self.reason = reason.into()
+ }
+
+ /// Get the hashed WebSocket key.
+ pub fn key(&self) -> Result<&Vec<u8>> {
+ self.header("sec-websocket-accept")
+ .ok_or_else(|| Error::new(Kind::Protocol, "Unable to parse WebSocket key."))
+ }
+
+ /// Get the protocol that the server has decided to use.
+ #[allow(dead_code)]
+ pub fn protocol(&self) -> Result<Option<&str>> {
+ if let Some(proto) = self.header("sec-websocket-protocol") {
+ Ok(Some(from_utf8(proto)?))
+ } else {
+ Ok(None)
+ }
+ }
+
+ /// Set the protocol that the server has decided to use.
+ #[allow(dead_code)]
+ pub fn set_protocol(&mut self, protocol: &str) {
+ if let Some(proto) = self.header_mut("sec-websocket-protocol") {
+ *proto = protocol.into();
+ return;
+ }
+ self.headers_mut()
+ .push(("Sec-WebSocket-Protocol".into(), protocol.into()))
+ }
+
+ /// Get the extensions that the server has decided to use. If these are unacceptable, it is
+ /// appropriate to send an Extension close code.
+ #[allow(dead_code)]
+ pub fn extensions(&self) -> Result<Vec<&str>> {
+ if let Some(exts) = self.header("sec-websocket-extensions") {
+ Ok(from_utf8(exts)?
+ .split(',')
+ .map(|proto| proto.trim())
+ .collect())
+ } else {
+ Ok(Vec::new())
+ }
+ }
+
+ /// Add an accepted extension to this response.
+ /// This may result in duplicate extensions listed.
+ #[allow(dead_code)]
+ pub fn add_extension(&mut self, ext: &str) {
+ if let Some(exts) = self.header_mut("sec-websocket-extensions") {
+ exts.push(b","[0]);
+ exts.extend(ext.as_bytes());
+ return;
+ }
+ self.headers_mut()
+ .push(("Sec-WebSocket-Extensions".into(), ext.into()))
+ }
+
+ /// Remove an accepted extension from this response.
+ /// This will remove all configurations of the extension.
+ #[allow(dead_code)]
+ pub fn remove_extension(&mut self, ext: &str) {
+ if let Some(exts) = self.header_mut("sec-websocket-extensions") {
+ let mut new_exts = Vec::with_capacity(exts.len());
+
+ if let Ok(exts_str) = from_utf8(exts) {
+ new_exts = exts_str
+ .split(',')
+ .filter(|e| e.trim().starts_with(ext))
+ .collect::<Vec<&str>>()
+ .join(",")
+ .into();
+ }
+ if new_exts.len() < exts.len() {
+ *exts = new_exts
+ }
+ }
+ }
+
+ /// Attempt to parse an HTTP response from a buffer. If the buffer does not contain a complete
+ /// response, thiw will return `Ok(None)`.
+ pub fn parse(buf: &[u8]) -> Result<Option<Response>> {
+ let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
+ let mut res = httparse::Response::new(&mut headers);
+
+ let parsed = res.parse(buf)?;
+ if !parsed.is_partial() {
+ Ok(Some(Response {
+ status: res.code.unwrap(),
+ reason: res.reason.unwrap().into(),
+ headers: res.headers
+ .iter()
+ .map(|h| (h.name.into(), h.value.into()))
+ .collect(),
+ body: Vec::new(),
+ }))
+ } else {
+ Ok(None)
+ }
+ }
+
+ /// Construct a new WebSocket handshake HTTP response from a request.
+ /// This will create a response that ignores protocols and extensions. Edit this response to
+ /// accept a protocol and extensions as necessary.
+ pub fn from_request(req: &Request) -> Result<Response> {
+ let res = Response {
+ status: 101,
+ reason: "Switching Protocols".into(),
+ headers: vec![
+ ("Connection".into(), "Upgrade".into()),
+ ("Sec-WebSocket-Accept".into(), req.hashed_key()?.into()),
+ ("Upgrade".into(), "websocket".into()),
+ ],
+ body: Vec::new(),
+ };
+
+ debug!("Built response from request:\n{}", res);
+ Ok(res)
+ }
+
+ /// Write a response out to a buffer
+ pub fn format<W>(&self, w: &mut W) -> Result<()>
+ where
+ W: Write,
+ {
+ write!(w, "HTTP/1.1 {} {}\r\n", self.status, self.reason)?;
+ for &(ref key, ref val) in &self.headers {
+ write!(w, "{}: ", key)?;
+ w.write_all(val)?;
+ write!(w, "\r\n")?;
+ }
+ write!(w, "\r\n")?;
+ w.write_all(&self.body)?;
+ Ok(())
+ }
+}
+
+impl fmt::Display for Response {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let mut s = Vec::with_capacity(2048);
+ self.format(&mut s).map_err(|err| {
+ error!("{:?}", err);
+ fmt::Error
+ })?;
+ write!(
+ f,
+ "{}",
+ from_utf8(&s).map_err(|err| {
+ error!("Unable to format response as utf8: {:?}", err);
+ fmt::Error
+ })?
+ )
+ }
+}
+
+mod test {
+ #![allow(unused_imports, unused_variables, dead_code)]
+ use super::*;
+ use std::io::Write;
+ use std::net::SocketAddr;
+ use std::str::FromStr;
+
+ #[test]
+ fn remote_addr() {
+ let mut buf = Vec::with_capacity(2048);
+ write!(
+ &mut buf,
+ "GET / HTTP/1.1\r\n\
+ Connection: Upgrade\r\n\
+ Upgrade: websocket\r\n\
+ Sec-WebSocket-Version: 13\r\n\
+ Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n"
+ ).unwrap();
+
+ let req = Request::parse(&buf).unwrap().unwrap();
+ let res = Response::from_request(&req).unwrap();
+ let shake = Handshake {
+ request: req,
+ response: res,
+ peer_addr: Some(SocketAddr::from_str("127.0.0.1:8888").unwrap()),
+ local_addr: None,
+ };
+ assert_eq!(shake.remote_addr().unwrap().unwrap(), "127.0.0.1");
+ }
+
+ #[test]
+ fn remote_addr_x_forwarded_for() {
+ let mut buf = Vec::with_capacity(2048);
+ write!(
+ &mut buf,
+ "GET / HTTP/1.1\r\n\
+ Connection: Upgrade\r\n\
+ Upgrade: websocket\r\n\
+ X-Forwarded-For: 192.168.1.1, 192.168.1.2, 192.168.1.3\r\n\
+ Sec-WebSocket-Version: 13\r\n\
+ Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n"
+ ).unwrap();
+
+ let req = Request::parse(&buf).unwrap().unwrap();
+ let res = Response::from_request(&req).unwrap();
+ let shake = Handshake {
+ request: req,
+ response: res,
+ peer_addr: None,
+ local_addr: None,
+ };
+ assert_eq!(shake.remote_addr().unwrap().unwrap(), "192.168.1.1");
+ }
+
+ #[test]
+ fn remote_addr_forwarded() {
+ let mut buf = Vec::with_capacity(2048);
+ write!(
+ &mut buf,
+ "GET / HTTP/1.1\r\n\
+ Connection: Upgrade\r\n\
+ Upgrade: websocket\r\n\
+ Forwarded: by=192.168.1.1; for=192.0.2.43, for=\"[2001:db8:cafe::17]\", for=unknown\r\n\
+ Sec-WebSocket-Version: 13\r\n\
+ Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n")
+ .unwrap();
+ let req = Request::parse(&buf).unwrap().unwrap();
+ let res = Response::from_request(&req).unwrap();
+ let shake = Handshake {
+ request: req,
+ response: res,
+ peer_addr: None,
+ local_addr: None,
+ };
+ assert_eq!(shake.remote_addr().unwrap().unwrap(), "192.0.2.43");
+ }
+}
diff --git a/third_party/rust/ws/src/io.rs b/third_party/rust/ws/src/io.rs
new file mode 100644
index 0000000000..7739cdc472
--- /dev/null
+++ b/third_party/rust/ws/src/io.rs
@@ -0,0 +1,985 @@
+use std::borrow::Borrow;
+use std::io::{Error as IoError, ErrorKind};
+use std::net::{SocketAddr, ToSocketAddrs};
+use std::time::Duration;
+use std::usize;
+
+use mio;
+use mio::tcp::{TcpListener, TcpStream};
+use mio::{Poll, PollOpt, Ready, Token};
+use mio_extras;
+
+use url::Url;
+
+#[cfg(feature = "native_tls")]
+use native_tls::Error as SslError;
+
+use super::Settings;
+use communication::{Command, Sender, Signal};
+use connection::Connection;
+use factory::Factory;
+use slab::Slab;
+use result::{Error, Kind, Result};
+
+
+const QUEUE: Token = Token(usize::MAX - 3);
+const TIMER: Token = Token(usize::MAX - 4);
+pub const ALL: Token = Token(usize::MAX - 5);
+const SYSTEM: Token = Token(usize::MAX - 6);
+
+type Conn<F> = Connection<<F as Factory>::Handler>;
+
+const MAX_EVENTS: usize = 1024;
+const MESSAGES_PER_TICK: usize = 256;
+const TIMER_TICK_MILLIS: u64 = 100;
+const TIMER_WHEEL_SIZE: usize = 1024;
+const TIMER_CAPACITY: usize = 65_536;
+
+#[cfg(not(windows))]
+const CONNECTION_REFUSED: i32 = 111;
+#[cfg(windows)]
+const CONNECTION_REFUSED: i32 = 61;
+
+fn url_to_addrs(url: &Url) -> Result<Vec<SocketAddr>> {
+ let host = url.host_str();
+ if host.is_none() || (url.scheme() != "ws" && url.scheme() != "wss") {
+ return Err(Error::new(
+ Kind::Internal,
+ format!("Not a valid websocket url: {}", url),
+ ));
+ }
+ let host = host.unwrap();
+
+ let port = url.port_or_known_default().unwrap_or(80);
+ let mut addrs = (&host[..], port)
+ .to_socket_addrs()?
+ .collect::<Vec<SocketAddr>>();
+ addrs.dedup();
+ Ok(addrs)
+}
+
+enum State {
+ Active,
+ Inactive,
+}
+
+impl State {
+ fn is_active(&self) -> bool {
+ match *self {
+ State::Active => true,
+ State::Inactive => false,
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct Timeout {
+ connection: Token,
+ event: Token,
+}
+
+pub struct Handler<F>
+where
+ F: Factory,
+{
+ listener: Option<TcpListener>,
+ connections: Slab<Conn<F>>,
+ factory: F,
+ settings: Settings,
+ state: State,
+ queue_tx: mio::channel::SyncSender<Command>,
+ queue_rx: mio::channel::Receiver<Command>,
+ timer: mio_extras::timer::Timer<Timeout>,
+ next_connection_id: u32,
+}
+
+impl<F> Handler<F>
+where
+ F: Factory,
+{
+ pub fn new(factory: F, settings: Settings) -> Handler<F> {
+ let (tx, rx) = mio::channel::sync_channel(settings.max_connections * settings.queue_size);
+ let timer = mio_extras::timer::Builder::default()
+ .tick_duration(Duration::from_millis(TIMER_TICK_MILLIS))
+ .num_slots(TIMER_WHEEL_SIZE)
+ .capacity(TIMER_CAPACITY)
+ .build();
+ Handler {
+ listener: None,
+ connections: Slab::with_capacity(settings.max_connections),
+ factory,
+ settings,
+ state: State::Inactive,
+ queue_tx: tx,
+ queue_rx: rx,
+ timer,
+ next_connection_id: 0,
+ }
+ }
+
+ pub fn sender(&self) -> Sender {
+ Sender::new(ALL, self.queue_tx.clone(), 0)
+ }
+
+ pub fn listen(&mut self, poll: &mut Poll, addr: &SocketAddr) -> Result<&mut Handler<F>> {
+ debug_assert!(
+ self.listener.is_none(),
+ "Attempted to listen for connections from two addresses on the same websocket."
+ );
+
+ let tcp = TcpListener::bind(addr)?;
+ // TODO: consider net2 in order to set reuse_addr
+ poll.register(&tcp, ALL, Ready::readable(), PollOpt::level())?;
+ self.listener = Some(tcp);
+ Ok(self)
+ }
+
+ pub fn local_addr(&self) -> ::std::io::Result<SocketAddr> {
+ if let Some(ref listener) = self.listener {
+ listener.local_addr()
+ } else {
+ Err(IoError::new(ErrorKind::NotFound, "Not a listening socket"))
+ }
+ }
+
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ pub fn connect(&mut self, poll: &mut Poll, url: Url) -> Result<()> {
+ let settings = self.settings;
+
+ let (tok, addresses) = {
+ let (tok, entry, connection_id, handler) =
+ if self.connections.len() < settings.max_connections {
+ let entry = self.connections.vacant_entry();
+ let tok = Token(entry.key());
+ let connection_id = self.next_connection_id;
+ self.next_connection_id = self.next_connection_id.wrapping_add(1);
+ (
+ tok,
+ entry,
+ connection_id,
+ self.factory.client_connected(Sender::new(
+ tok,
+ self.queue_tx.clone(),
+ connection_id,
+ )),
+ )
+ } else {
+ return Err(Error::new(
+ Kind::Capacity,
+ "Unable to add another connection to the event loop.",
+ ));
+ };
+
+ let mut addresses = match url_to_addrs(&url) {
+ Ok(addresses) => addresses,
+ Err(err) => {
+ self.factory.connection_lost(handler);
+ return Err(err);
+ }
+ };
+
+ loop {
+ if let Some(addr) = addresses.pop() {
+ if let Ok(sock) = TcpStream::connect(&addr) {
+ if settings.tcp_nodelay {
+ sock.set_nodelay(true)?
+ }
+ addresses.push(addr); // Replace the first addr in case ssl fails and we fallback
+ entry.insert(Connection::new(tok, sock, handler, settings, connection_id));
+ break;
+ }
+ } else {
+ self.factory.connection_lost(handler);
+ return Err(Error::new(
+ Kind::Internal,
+ format!("Unable to obtain any socket address for {}", url),
+ ));
+ }
+ }
+
+ (tok, addresses)
+ };
+
+ let will_encrypt = url.scheme() == "wss";
+
+ if let Err(error) = self.connections[tok.into()].as_client(url, addresses) {
+ let handler = self.connections.remove(tok.into()).consume();
+ self.factory.connection_lost(handler);
+ return Err(error);
+ }
+
+ if will_encrypt {
+ while let Err(ssl_error) = self.connections[tok.into()].encrypt() {
+ match ssl_error.kind {
+ #[cfg(feature = "ssl")]
+ Kind::Ssl(ref inner_ssl_error) => {
+ if let Some(io_error) = inner_ssl_error.io_error() {
+ if let Some(errno) = io_error.raw_os_error() {
+ if errno == CONNECTION_REFUSED {
+ if let Err(reset_error) = self.connections[tok.into()].reset() {
+ trace!(
+ "Encountered error while trying to reset connection: {:?}",
+ reset_error
+ );
+ } else {
+ continue;
+ }
+ }
+ }
+ }
+ }
+ #[cfg(feature = "nativetls")]
+ Kind::Ssl(_) => {
+ if let Err(reset_error) = self.connections[tok.into()].reset() {
+ trace!(
+ "Encountered error while trying to reset connection: {:?}",
+ reset_error
+ );
+ } else {
+ continue;
+ }
+ }
+ _ => (),
+ }
+ self.connections[tok.into()].error(ssl_error);
+ // Allow socket to be registered anyway to await hangup
+ break;
+ }
+ }
+
+ poll.register(
+ self.connections[tok.into()].socket(),
+ self.connections[tok.into()].token(),
+ self.connections[tok.into()].events(),
+ PollOpt::edge() | PollOpt::oneshot(),
+ ).map_err(Error::from)
+ .or_else(|err| {
+ error!(
+ "Encountered error while trying to build WebSocket connection: {}",
+ err
+ );
+ let handler = self.connections.remove(tok.into()).consume();
+ self.factory.connection_lost(handler);
+ Err(err)
+ })
+ }
+
+ #[cfg(not(any(feature = "ssl", feature = "nativetls")))]
+ pub fn connect(&mut self, poll: &mut Poll, url: Url) -> Result<()> {
+ let settings = self.settings;
+
+ let (tok, addresses) = {
+ let (tok, entry, connection_id, handler) =
+ if self.connections.len() < settings.max_connections {
+ let entry = self.connections.vacant_entry();
+ let tok = Token(entry.key());
+ let connection_id = self.next_connection_id;
+ self.next_connection_id = self.next_connection_id.wrapping_add(1);
+ (
+ tok,
+ entry,
+ connection_id,
+ self.factory.client_connected(Sender::new(
+ tok,
+ self.queue_tx.clone(),
+ connection_id,
+ )),
+ )
+ } else {
+ return Err(Error::new(
+ Kind::Capacity,
+ "Unable to add another connection to the event loop.",
+ ));
+ };
+
+ let mut addresses = match url_to_addrs(&url) {
+ Ok(addresses) => addresses,
+ Err(err) => {
+ self.factory.connection_lost(handler);
+ return Err(err);
+ }
+ };
+
+ loop {
+ if let Some(addr) = addresses.pop() {
+ if let Ok(sock) = TcpStream::connect(&addr) {
+ if settings.tcp_nodelay {
+ sock.set_nodelay(true)?
+ }
+ entry.insert(Connection::new(tok, sock, handler, settings, connection_id));
+ break;
+ }
+ } else {
+ self.factory.connection_lost(handler);
+ return Err(Error::new(
+ Kind::Internal,
+ format!("Unable to obtain any socket address for {}", url),
+ ));
+ }
+ }
+
+ (tok, addresses)
+ };
+
+ if url.scheme() == "wss" {
+ let error = Error::new(
+ Kind::Protocol,
+ "The ssl feature is not enabled. Please enable it to use wss urls.",
+ );
+ let handler = self.connections.remove(tok.into()).consume();
+ self.factory.connection_lost(handler);
+ return Err(error);
+ }
+
+ if let Err(error) = self.connections[tok.into()].as_client(url, addresses) {
+ let handler = self.connections.remove(tok.into()).consume();
+ self.factory.connection_lost(handler);
+ return Err(error);
+ }
+
+ poll.register(
+ self.connections[tok.into()].socket(),
+ self.connections[tok.into()].token(),
+ self.connections[tok.into()].events(),
+ PollOpt::edge() | PollOpt::oneshot(),
+ ).map_err(Error::from)
+ .or_else(|err| {
+ error!(
+ "Encountered error while trying to build WebSocket connection: {}",
+ err
+ );
+ let handler = self.connections.remove(tok.into()).consume();
+ self.factory.connection_lost(handler);
+ Err(err)
+ })
+ }
+
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ pub fn accept(&mut self, poll: &mut Poll, sock: TcpStream) -> Result<()> {
+ let factory = &mut self.factory;
+ let settings = self.settings;
+
+ if settings.tcp_nodelay {
+ sock.set_nodelay(true)?
+ }
+
+ let tok = {
+ if self.connections.len() < settings.max_connections {
+ let entry = self.connections.vacant_entry();
+ let tok = Token(entry.key());
+ let connection_id = self.next_connection_id;
+ self.next_connection_id = self.next_connection_id.wrapping_add(1);
+ let handler = factory.server_connected(Sender::new(
+ tok,
+ self.queue_tx.clone(),
+ connection_id,
+ ));
+ entry.insert(Connection::new(tok, sock, handler, settings, connection_id));
+ tok
+ } else {
+ return Err(Error::new(
+ Kind::Capacity,
+ "Unable to add another connection to the event loop.",
+ ));
+ }
+ };
+
+ let conn = &mut self.connections[tok.into()];
+
+ conn.as_server()?;
+ if settings.encrypt_server {
+ conn.encrypt()?
+ }
+
+ poll.register(
+ conn.socket(),
+ conn.token(),
+ conn.events(),
+ PollOpt::edge() | PollOpt::oneshot(),
+ ).map_err(Error::from)
+ .or_else(|err| {
+ error!(
+ "Encountered error while trying to build WebSocket connection: {}",
+ err
+ );
+ conn.error(err);
+ if settings.panic_on_new_connection {
+ panic!("Encountered error while trying to build WebSocket connection.");
+ }
+ Ok(())
+ })
+ }
+
+ #[cfg(not(any(feature = "ssl", feature = "nativetls")))]
+ pub fn accept(&mut self, poll: &mut Poll, sock: TcpStream) -> Result<()> {
+ let factory = &mut self.factory;
+ let settings = self.settings;
+
+ if settings.tcp_nodelay {
+ sock.set_nodelay(true)?
+ }
+
+ let tok = {
+ if self.connections.len() < settings.max_connections {
+ let entry = self.connections.vacant_entry();
+ let tok = Token(entry.key());
+ let connection_id = self.next_connection_id;
+ self.next_connection_id = self.next_connection_id.wrapping_add(1);
+ let handler = factory.server_connected(Sender::new(
+ tok,
+ self.queue_tx.clone(),
+ connection_id,
+ ));
+ entry.insert(Connection::new(tok, sock, handler, settings, connection_id));
+ tok
+ } else {
+ return Err(Error::new(
+ Kind::Capacity,
+ "Unable to add another connection to the event loop.",
+ ));
+ }
+ };
+
+ let conn = &mut self.connections[tok.into()];
+
+ conn.as_server()?;
+ if settings.encrypt_server {
+ return Err(Error::new(
+ Kind::Protocol,
+ "The ssl feature is not enabled. Please enable it to use wss urls.",
+ ));
+ }
+
+ poll.register(
+ conn.socket(),
+ conn.token(),
+ conn.events(),
+ PollOpt::edge() | PollOpt::oneshot(),
+ ).map_err(Error::from)
+ .or_else(|err| {
+ error!(
+ "Encountered error while trying to build WebSocket connection: {}",
+ err
+ );
+ conn.error(err);
+ if settings.panic_on_new_connection {
+ panic!("Encountered error while trying to build WebSocket connection.");
+ }
+ Ok(())
+ })
+ }
+
+ pub fn run(&mut self, poll: &mut Poll) -> Result<()> {
+ trace!("Running event loop");
+ poll.register(
+ &self.queue_rx,
+ QUEUE,
+ Ready::readable(),
+ PollOpt::edge() | PollOpt::oneshot(),
+ )?;
+ poll.register(&self.timer, TIMER, Ready::readable(), PollOpt::edge())?;
+
+ self.state = State::Active;
+ let result = self.event_loop(poll);
+ self.state = State::Inactive;
+
+ result
+ .and(poll.deregister(&self.timer).map_err(Error::from))
+ .and(poll.deregister(&self.queue_rx).map_err(Error::from))
+ }
+
+ #[inline]
+ fn event_loop(&mut self, poll: &mut Poll) -> Result<()> {
+ let mut events = mio::Events::with_capacity(MAX_EVENTS);
+ while self.state.is_active() {
+ trace!("Waiting for event");
+ let nevents = match poll.poll(&mut events, None) {
+ Ok(nevents) => nevents,
+ Err(err) => {
+ if err.kind() == ErrorKind::Interrupted {
+ if self.settings.shutdown_on_interrupt {
+ error!("Websocket shutting down for interrupt.");
+ self.state = State::Inactive;
+ } else {
+ error!("Websocket received interrupt.");
+ }
+ 0
+ } else {
+ return Err(Error::from(err));
+ }
+ }
+ };
+ trace!("Processing {} events", nevents);
+
+ for i in 0..nevents {
+ let evt = events.get(i).unwrap();
+ self.handle_event(poll, evt.token(), evt.kind());
+ }
+
+ self.check_count();
+ }
+ Ok(())
+ }
+
+ #[inline]
+ fn schedule(&self, poll: &mut Poll, conn: &Conn<F>) -> Result<()> {
+ trace!(
+ "Scheduling connection to {} as {:?}",
+ conn.socket()
+ .peer_addr()
+ .map(|addr| addr.to_string())
+ .unwrap_or_else(|_| "UNKNOWN".into()),
+ conn.events()
+ );
+ poll.reregister(
+ conn.socket(),
+ conn.token(),
+ conn.events(),
+ PollOpt::edge() | PollOpt::oneshot(),
+ )?;
+ Ok(())
+ }
+
+ fn shutdown(&mut self) {
+ debug!("Received shutdown signal. WebSocket is attempting to shut down.");
+ for (_, conn) in self.connections.iter_mut() {
+ conn.shutdown();
+ }
+ self.factory.on_shutdown();
+ self.state = State::Inactive;
+ if self.settings.panic_on_shutdown {
+ panic!("Panicking on shutdown as per setting.")
+ }
+ }
+
+ #[inline]
+ fn check_active(&mut self, poll: &mut Poll, active: bool, token: Token) {
+ // NOTE: Closing state only applies after a ws connection was successfully
+ // established. It's possible that we may go inactive while in a connecting
+ // state if the handshake fails.
+ if !active {
+ if let Ok(addr) = self.connections[token.into()].socket().peer_addr() {
+ debug!("WebSocket connection to {} disconnected.", addr);
+ } else {
+ trace!("WebSocket connection to token={:?} disconnected.", token);
+ }
+ let handler = self.connections.remove(token.into()).consume();
+ self.factory.connection_lost(handler);
+ } else {
+ self.schedule(poll, &self.connections[token.into()])
+ .or_else(|err| {
+ // This will be an io error, so disconnect will already be called
+ self.connections[token.into()].error(err);
+ let handler = self.connections.remove(token.into()).consume();
+ self.factory.connection_lost(handler);
+ Ok::<(), Error>(())
+ })
+ .unwrap()
+ }
+ }
+
+ #[inline]
+ fn is_client(&self) -> bool {
+ self.listener.is_none()
+ }
+
+ #[inline]
+ fn check_count(&mut self) {
+ trace!("Active connections {:?}", self.connections.len());
+ if self.connections.is_empty() {
+ if !self.state.is_active() {
+ debug!("Shutting down websocket server.");
+ } else if self.is_client() {
+ debug!("Shutting down websocket client.");
+ self.factory.on_shutdown();
+ self.state = State::Inactive;
+ }
+ }
+ }
+
+ fn handle_event(&mut self, poll: &mut Poll, token: Token, events: Ready) {
+ match token {
+ SYSTEM => {
+ debug_assert!(false, "System token used for io event. This is a bug!");
+ error!("System token used for io event. This is a bug!");
+ }
+ ALL => {
+ if events.is_readable() {
+ match self.listener
+ .as_ref()
+ .expect("No listener provided for server websocket connections")
+ .accept()
+ {
+ Ok((sock, addr)) => {
+ info!("Accepted a new tcp connection from {}.", addr);
+ if let Err(err) = self.accept(poll, sock) {
+ error!("Unable to build WebSocket connection {:?}", err);
+ if self.settings.panic_on_new_connection {
+ panic!("Unable to build WebSocket connection {:?}", err);
+ }
+ }
+ }
+ Err(err) => error!(
+ "Encountered an error {:?} while accepting tcp connection.",
+ err
+ ),
+ }
+ }
+ }
+ TIMER => while let Some(t) = self.timer.poll() {
+ self.handle_timeout(poll, t);
+ },
+ QUEUE => {
+ for _ in 0..MESSAGES_PER_TICK {
+ match self.queue_rx.try_recv() {
+ Ok(cmd) => self.handle_queue(poll, cmd),
+ _ => break,
+ }
+ }
+ let _ = poll.reregister(
+ &self.queue_rx,
+ QUEUE,
+ Ready::readable(),
+ PollOpt::edge() | PollOpt::oneshot(),
+ );
+ }
+ _ => {
+ let active = {
+ let conn_events = self.connections[token.into()].events();
+
+ if (events & conn_events).is_readable() {
+ if let Err(err) = self.connections[token.into()].read() {
+ trace!("Encountered error while reading: {}", err);
+ if let Kind::Io(ref err) = err.kind {
+ if let Some(errno) = err.raw_os_error() {
+ if errno == CONNECTION_REFUSED {
+ match self.connections[token.into()].reset() {
+ Ok(_) => {
+ poll.register(
+ self.connections[token.into()].socket(),
+ self.connections[token.into()].token(),
+ self.connections[token.into()].events(),
+ PollOpt::edge() | PollOpt::oneshot(),
+ ).or_else(|err| {
+ self.connections[token.into()]
+ .error(Error::from(err));
+ let handler = self.connections
+ .remove(token.into())
+ .consume();
+ self.factory.connection_lost(handler);
+ Ok::<(), Error>(())
+ })
+ .unwrap();
+ return;
+ }
+ Err(err) => {
+ trace!("Encountered error while trying to reset connection: {:?}", err);
+ }
+ }
+ }
+ }
+ }
+ // This will trigger disconnect if the connection is open
+ self.connections[token.into()].error(err)
+ }
+ }
+
+ let conn_events = self.connections[token.into()].events();
+
+ if (events & conn_events).is_writable() {
+ if let Err(err) = self.connections[token.into()].write() {
+ trace!("Encountered error while writing: {}", err);
+ if let Kind::Io(ref err) = err.kind {
+ if let Some(errno) = err.raw_os_error() {
+ if errno == CONNECTION_REFUSED {
+ match self.connections[token.into()].reset() {
+ Ok(_) => {
+ poll.register(
+ self.connections[token.into()].socket(),
+ self.connections[token.into()].token(),
+ self.connections[token.into()].events(),
+ PollOpt::edge() | PollOpt::oneshot(),
+ ).or_else(|err| {
+ self.connections[token.into()]
+ .error(Error::from(err));
+ let handler = self.connections
+ .remove(token.into())
+ .consume();
+ self.factory.connection_lost(handler);
+ Ok::<(), Error>(())
+ })
+ .unwrap();
+ return;
+ }
+ Err(err) => {
+ trace!("Encountered error while trying to reset connection: {:?}", err);
+ }
+ }
+ }
+ }
+ }
+ // This will trigger disconnect if the connection is open
+ self.connections[token.into()].error(err)
+ }
+ }
+
+ // connection events may have changed
+ self.connections[token.into()].events().is_readable()
+ || self.connections[token.into()].events().is_writable()
+ };
+
+ self.check_active(poll, active, token)
+ }
+ }
+ }
+
+ fn handle_queue(&mut self, poll: &mut Poll, cmd: Command) {
+ match cmd.token() {
+ SYSTEM => {
+ // Scaffolding for system events such as internal timeouts
+ }
+ ALL => {
+ let mut dead = Vec::with_capacity(self.connections.len());
+
+ match cmd.into_signal() {
+ Signal::Message(msg) => {
+ trace!("Broadcasting message: {:?}", msg);
+ for (_, conn) in self.connections.iter_mut() {
+ if let Err(err) = conn.send_message(msg.clone()) {
+ dead.push((conn.token(), err))
+ }
+ }
+ }
+ Signal::Close(code, reason) => {
+ trace!("Broadcasting close: {:?} - {}", code, reason);
+ for (_, conn) in self.connections.iter_mut() {
+ if let Err(err) = conn.send_close(code, reason.borrow()) {
+ dead.push((conn.token(), err))
+ }
+ }
+ }
+ Signal::Ping(data) => {
+ trace!("Broadcasting ping");
+ for (_, conn) in self.connections.iter_mut() {
+ if let Err(err) = conn.send_ping(data.clone()) {
+ dead.push((conn.token(), err))
+ }
+ }
+ }
+ Signal::Pong(data) => {
+ trace!("Broadcasting pong");
+ for (_, conn) in self.connections.iter_mut() {
+ if let Err(err) = conn.send_pong(data.clone()) {
+ dead.push((conn.token(), err))
+ }
+ }
+ }
+ Signal::Connect(url) => {
+ if let Err(err) = self.connect(poll, url.clone()) {
+ if self.settings.panic_on_new_connection {
+ panic!("Unable to establish connection to {}: {:?}", url, err);
+ }
+ error!("Unable to establish connection to {}: {:?}", url, err);
+ }
+ return;
+ }
+ Signal::Shutdown => self.shutdown(),
+ Signal::Timeout {
+ delay,
+ token: event,
+ } => {
+ let timeout = self.timer.set_timeout(
+ Duration::from_millis(delay),
+ Timeout {
+ connection: ALL,
+ event,
+ },
+ );
+ for (_, conn) in self.connections.iter_mut() {
+ if let Err(err) = conn.new_timeout(event, timeout.clone()) {
+ conn.error(err);
+ }
+ }
+ return;
+ }
+ Signal::Cancel(timeout) => {
+ self.timer.cancel_timeout(&timeout);
+ return;
+ }
+ }
+
+ for (_, conn) in self.connections.iter() {
+ if let Err(err) = self.schedule(poll, conn) {
+ dead.push((conn.token(), err))
+ }
+ }
+ for (token, err) in dead {
+ // note the same connection may be called twice
+ self.connections[token.into()].error(err)
+ }
+ }
+ token => {
+ let connection_id = cmd.connection_id();
+ match cmd.into_signal() {
+ Signal::Message(msg) => {
+ if let Some(conn) = self.connections.get_mut(token.into()) {
+ if conn.connection_id() == connection_id {
+ if let Err(err) = conn.send_message(msg) {
+ conn.error(err)
+ }
+ } else {
+ trace!("Connection disconnected while a message was waiting in the queue.")
+ }
+ } else {
+ trace!(
+ "Connection disconnected while a message was waiting in the queue."
+ )
+ }
+ }
+ Signal::Close(code, reason) => {
+ if let Some(conn) = self.connections.get_mut(token.into()) {
+ if conn.connection_id() == connection_id {
+ if let Err(err) = conn.send_close(code, reason) {
+ conn.error(err)
+ }
+ } else {
+ trace!("Connection disconnected while close signal was waiting in the queue.")
+ }
+ } else {
+ trace!("Connection disconnected while close signal was waiting in the queue.")
+ }
+ }
+ Signal::Ping(data) => {
+ if let Some(conn) = self.connections.get_mut(token.into()) {
+ if conn.connection_id() == connection_id {
+ if let Err(err) = conn.send_ping(data) {
+ conn.error(err)
+ }
+ } else {
+ trace!("Connection disconnected while ping signal was waiting in the queue.")
+ }
+ } else {
+ trace!("Connection disconnected while ping signal was waiting in the queue.")
+ }
+ }
+ Signal::Pong(data) => {
+ if let Some(conn) = self.connections.get_mut(token.into()) {
+ if conn.connection_id() == connection_id {
+ if let Err(err) = conn.send_pong(data) {
+ conn.error(err)
+ }
+ } else {
+ trace!("Connection disconnected while pong signal was waiting in the queue.")
+ }
+ } else {
+ trace!("Connection disconnected while pong signal was waiting in the queue.")
+ }
+ }
+ Signal::Connect(url) => {
+ if let Err(err) = self.connect(poll, url.clone()) {
+ if let Some(conn) = self.connections.get_mut(token.into()) {
+ conn.error(err)
+ } else {
+ if self.settings.panic_on_new_connection {
+ panic!("Unable to establish connection to {}: {:?}", url, err);
+ }
+ error!("Unable to establish connection to {}: {:?}", url, err);
+ }
+ }
+ return;
+ }
+ Signal::Shutdown => self.shutdown(),
+ Signal::Timeout {
+ delay,
+ token: event,
+ } => {
+ let timeout = self.timer.set_timeout(
+ Duration::from_millis(delay),
+ Timeout {
+ connection: token,
+ event,
+ },
+ );
+ if let Some(conn) = self.connections.get_mut(token.into()) {
+ if let Err(err) = conn.new_timeout(event, timeout) {
+ conn.error(err)
+ }
+ } else {
+ trace!("Connection disconnected while pong signal was waiting in the queue.")
+ }
+ return;
+ }
+ Signal::Cancel(timeout) => {
+ self.timer.cancel_timeout(&timeout);
+ return;
+ }
+ }
+
+ if self.connections.get(token.into()).is_some() {
+ if let Err(err) = self.schedule(poll, &self.connections[token.into()]) {
+ self.connections[token.into()].error(err)
+ }
+ }
+ }
+ }
+ }
+
+ fn handle_timeout(&mut self, poll: &mut Poll, Timeout { connection, event }: Timeout) {
+ let active = {
+ if let Some(conn) = self.connections.get_mut(connection.into()) {
+ if let Err(err) = conn.timeout_triggered(event) {
+ conn.error(err)
+ }
+
+ conn.events().is_readable() || conn.events().is_writable()
+ } else {
+ trace!("Connection disconnected while timeout was waiting.");
+ return;
+ }
+ };
+ self.check_active(poll, active, connection);
+ }
+}
+
+mod test {
+ #![allow(unused_imports, unused_variables, dead_code)]
+ use std::str::FromStr;
+
+ use url::Url;
+
+ use super::url_to_addrs;
+ use super::*;
+ use result::{Error, Kind};
+
+ #[test]
+ fn test_url_to_addrs() {
+ let ws_url = Url::from_str("ws://example.com?query=me").unwrap();
+ let wss_url = Url::from_str("wss://example.com/suburl#fragment").unwrap();
+ let bad_url = Url::from_str("http://howdy.bad.com").unwrap();
+ let no_resolve = Url::from_str("ws://bad.elucitrans.com").unwrap();
+
+ assert!(url_to_addrs(&ws_url).is_ok());
+ assert!(url_to_addrs(&ws_url).unwrap().len() > 0);
+ assert!(url_to_addrs(&wss_url).is_ok());
+ assert!(url_to_addrs(&wss_url).unwrap().len() > 0);
+
+ match url_to_addrs(&bad_url) {
+ Ok(_) => panic!("url_to_addrs accepts http urls."),
+ Err(Error {
+ kind: Kind::Internal,
+ details: _,
+ }) => (), // pass
+ err => panic!("{:?}", err),
+ }
+
+ match url_to_addrs(&no_resolve) {
+ Ok(_) => panic!("url_to_addrs creates addresses for non-existent domains."),
+ Err(Error {
+ kind: Kind::Io(_),
+ details: _,
+ }) => (), // pass
+ err => panic!("{:?}", err),
+ }
+ }
+
+}
diff --git a/third_party/rust/ws/src/lib.rs b/third_party/rust/ws/src/lib.rs
new file mode 100644
index 0000000000..ea9f1a54e4
--- /dev/null
+++ b/third_party/rust/ws/src/lib.rs
@@ -0,0 +1,391 @@
+//! Lightweight, event-driven WebSockets for Rust.
+#![allow(deprecated)]
+#![deny(missing_copy_implementations, trivial_casts, trivial_numeric_casts, unstable_features,
+ unused_import_braces)]
+
+extern crate byteorder;
+extern crate bytes;
+extern crate httparse;
+extern crate mio;
+extern crate mio_extras;
+#[cfg(feature = "ssl")]
+extern crate openssl;
+#[cfg(feature = "nativetls")]
+extern crate native_tls;
+extern crate rand;
+extern crate sha1;
+extern crate slab;
+extern crate url;
+#[macro_use]
+extern crate log;
+
+mod communication;
+mod connection;
+mod factory;
+mod frame;
+mod handler;
+mod handshake;
+mod io;
+mod message;
+mod protocol;
+mod result;
+mod stream;
+
+#[cfg(feature = "permessage-deflate")]
+pub mod deflate;
+
+pub mod util;
+
+pub use factory::Factory;
+pub use handler::Handler;
+
+pub use communication::Sender;
+pub use frame::Frame;
+pub use handshake::{Handshake, Request, Response};
+pub use message::Message;
+pub use protocol::{CloseCode, OpCode};
+pub use result::Kind as ErrorKind;
+pub use result::{Error, Result};
+
+use std::borrow::Borrow;
+use std::default::Default;
+use std::fmt;
+use std::net::{SocketAddr, ToSocketAddrs};
+
+use mio::Poll;
+
+/// A utility function for setting up a WebSocket server.
+///
+/// # Safety
+///
+/// This function blocks until the event loop finishes running. Avoid calling this method within
+/// another WebSocket handler.
+///
+/// # Examples
+///
+/// ```no_run
+/// use ws::listen;
+///
+/// listen("127.0.0.1:3012", |out| {
+/// move |msg| {
+/// out.send(msg)
+/// }
+/// }).unwrap()
+/// ```
+///
+pub fn listen<A, F, H>(addr: A, factory: F) -> Result<()>
+where
+ A: ToSocketAddrs + fmt::Debug,
+ F: FnMut(Sender) -> H,
+ H: Handler,
+{
+ let ws = WebSocket::new(factory)?;
+ ws.listen(addr)?;
+ Ok(())
+}
+
+/// A utility function for setting up a WebSocket client.
+///
+/// # Safety
+///
+/// This function blocks until the event loop finishes running. Avoid calling this method within
+/// another WebSocket handler. If you need to establish a connection from inside of a handler,
+/// use the `connect` method on the Sender.
+///
+/// # Examples
+///
+/// ```no_run
+/// use ws::{connect, CloseCode};
+///
+/// connect("ws://127.0.0.1:3012", |out| {
+/// out.send("Hello WebSocket").unwrap();
+///
+/// move |msg| {
+/// println!("Got message: {}", msg);
+/// out.close(CloseCode::Normal)
+/// }
+/// }).unwrap()
+/// ```
+///
+pub fn connect<U, F, H>(url: U, factory: F) -> Result<()>
+where
+ U: Borrow<str>,
+ F: FnMut(Sender) -> H,
+ H: Handler,
+{
+ let mut ws = WebSocket::new(factory)?;
+ let parsed = url::Url::parse(url.borrow()).map_err(|err| {
+ Error::new(
+ ErrorKind::Internal,
+ format!("Unable to parse {} as url due to {:?}", url.borrow(), err),
+ )
+ })?;
+ ws.connect(parsed)?;
+ ws.run()?;
+ Ok(())
+}
+
+/// WebSocket settings
+#[derive(Debug, Clone, Copy)]
+pub struct Settings {
+ /// The maximum number of connections that this WebSocket will support.
+ /// The default setting is low and should be increased when expecting more
+ /// connections because this is a hard limit and no new connections beyond
+ /// this limit can be made until an old connection is dropped.
+ /// Default: 100
+ pub max_connections: usize,
+ /// The number of events anticipated per connection. The event loop queue size will
+ /// be `queue_size` * `max_connections`. In order to avoid an overflow error,
+ /// `queue_size` * `max_connections` must be less than or equal to `usize::max_value()`.
+ /// The queue is shared between connections, which means that a connection may schedule
+ /// more events than `queue_size` provided that another connection is using less than
+ /// `queue_size`. However, if the queue is maxed out a Queue error will occur.
+ /// Default: 5
+ pub queue_size: usize,
+ /// Whether to panic when unable to establish a new TCP connection.
+ /// Default: false
+ pub panic_on_new_connection: bool,
+ /// Whether to panic when a shutdown of the WebSocket is requested.
+ /// Default: false
+ pub panic_on_shutdown: bool,
+ /// The maximum number of fragments the connection can handle without reallocating.
+ /// Default: 10
+ pub fragments_capacity: usize,
+ /// Whether to reallocate when `fragments_capacity` is reached. If this is false,
+ /// a Capacity error will be triggered instead.
+ /// Default: true
+ pub fragments_grow: bool,
+ /// The maximum length of outgoing frames. Messages longer than this will be fragmented.
+ /// Default: 65,535
+ pub fragment_size: usize,
+ /// The maximum length of acceptable incoming frames. Messages longer than this will be rejected.
+ /// Default: unlimited
+ pub max_fragment_size: usize,
+ /// The size of the incoming buffer. A larger buffer uses more memory but will allow for fewer
+ /// reallocations.
+ /// Default: 2048
+ pub in_buffer_capacity: usize,
+ /// Whether to reallocate the incoming buffer when `in_buffer_capacity` is reached. If this is
+ /// false, a Capacity error will be triggered instead.
+ /// Default: true
+ pub in_buffer_grow: bool,
+ /// The size of the outgoing buffer. A larger buffer uses more memory but will allow for fewer
+ /// reallocations.
+ /// Default: 2048
+ pub out_buffer_capacity: usize,
+ /// Whether to reallocate the incoming buffer when `out_buffer_capacity` is reached. If this is
+ /// false, a Capacity error will be triggered instead.
+ /// Default: true
+ pub out_buffer_grow: bool,
+ /// Whether to panic when an Internal error is encountered. Internal errors should generally
+ /// not occur, so this setting defaults to true as a debug measure, whereas production
+ /// applications should consider setting it to false.
+ /// Default: true
+ pub panic_on_internal: bool,
+ /// Whether to panic when a Capacity error is encountered.
+ /// Default: false
+ pub panic_on_capacity: bool,
+ /// Whether to panic when a Protocol error is encountered.
+ /// Default: false
+ pub panic_on_protocol: bool,
+ /// Whether to panic when an Encoding error is encountered.
+ /// Default: false
+ pub panic_on_encoding: bool,
+ /// Whether to panic when a Queue error is encountered.
+ /// Default: false
+ pub panic_on_queue: bool,
+ /// Whether to panic when an Io error is encountered.
+ /// Default: false
+ pub panic_on_io: bool,
+ /// Whether to panic when a Timer error is encountered.
+ /// Default: false
+ pub panic_on_timeout: bool,
+ /// Whether to shutdown the eventloop when an interrupt is received.
+ /// Default: true
+ pub shutdown_on_interrupt: bool,
+ /// The WebSocket protocol requires frames sent from client endpoints to be masked as a
+ /// security and sanity precaution. Enforcing this requirement, which may be removed at some
+ /// point may cause incompatibilities. If you need the extra security, set this to true.
+ /// Default: false
+ pub masking_strict: bool,
+ /// The WebSocket protocol requires clients to verify the key returned by a server to ensure
+ /// that the server and all intermediaries can perform the protocol. Verifying the key will
+ /// consume processing time and other resources with the benefit that we can fail the
+ /// connection early. The default in WS-RS is to accept any key from the server and instead
+ /// fail late if a protocol error occurs. Change this setting to enable key verification.
+ /// Default: false
+ pub key_strict: bool,
+ /// The WebSocket protocol requires clients to perform an opening handshake using the HTTP
+ /// GET method for the request. However, since only WebSockets are supported on the connection,
+ /// verifying the method of handshake requests is not always necessary. To enforce the
+ /// requirement that handshakes begin with a GET method, set this to true.
+ /// Default: false
+ pub method_strict: bool,
+ /// Indicate whether server connections should use ssl encryption when accepting connections.
+ /// Setting this to true means that clients should use the `wss` scheme to connect to this
+ /// server. Note that using this flag will in general necessitate overriding the
+ /// `Handler::upgrade_ssl_server` method in order to provide the details of the ssl context. It may be
+ /// simpler for most users to use a reverse proxy such as nginx to provide server side
+ /// encryption.
+ ///
+ /// Default: false
+ pub encrypt_server: bool,
+ /// Disables Nagle's algorithm.
+ /// Usually tcp socket tries to accumulate packets to send them all together (every 200ms).
+ /// When enabled socket will try to send packet as fast as possible.
+ ///
+ /// Default: false
+ pub tcp_nodelay: bool,
+}
+
+impl Default for Settings {
+ fn default() -> Settings {
+ Settings {
+ max_connections: 100,
+ queue_size: 5,
+ panic_on_new_connection: false,
+ panic_on_shutdown: false,
+ fragments_capacity: 10,
+ fragments_grow: true,
+ fragment_size: u16::max_value() as usize,
+ max_fragment_size: usize::max_value(),
+ in_buffer_capacity: 2048,
+ in_buffer_grow: true,
+ out_buffer_capacity: 2048,
+ out_buffer_grow: true,
+ panic_on_internal: true,
+ panic_on_capacity: false,
+ panic_on_protocol: false,
+ panic_on_encoding: false,
+ panic_on_queue: false,
+ panic_on_io: false,
+ panic_on_timeout: false,
+ shutdown_on_interrupt: true,
+ masking_strict: false,
+ key_strict: false,
+ method_strict: false,
+ encrypt_server: false,
+ tcp_nodelay: false,
+ }
+ }
+}
+
+/// The WebSocket struct. A WebSocket can support multiple incoming and outgoing connections.
+pub struct WebSocket<F>
+where
+ F: Factory,
+{
+ poll: Poll,
+ handler: io::Handler<F>,
+}
+
+impl<F> WebSocket<F>
+where
+ F: Factory,
+{
+ /// Create a new WebSocket using the given Factory to create handlers.
+ pub fn new(factory: F) -> Result<WebSocket<F>> {
+ Builder::new().build(factory)
+ }
+
+ /// Consume the WebSocket and bind to the specified address.
+ /// If the `addr_spec` yields multiple addresses this will return after the
+ /// first successful bind. `local_addr` can be called to determine which
+ /// address it ended up binding to.
+ /// After the server is successfully bound you should start it using `run`.
+ pub fn bind<A>(mut self, addr_spec: A) -> Result<WebSocket<F>>
+ where
+ A: ToSocketAddrs,
+ {
+ let mut last_error = Error::new(ErrorKind::Internal, "No address given");
+
+ for addr in addr_spec.to_socket_addrs()? {
+ if let Err(e) = self.handler.listen(&mut self.poll, &addr) {
+ error!("Unable to listen on {}", addr);
+ last_error = e;
+ } else {
+ let actual_addr = self.handler.local_addr().unwrap_or(addr);
+ info!("Listening for new connections on {}.", actual_addr);
+ return Ok(self);
+ }
+ }
+
+ Err(last_error)
+ }
+
+ /// Consume the WebSocket and listen for new connections on the specified address.
+ ///
+ /// # Safety
+ ///
+ /// This method will block until the event loop finishes running.
+ pub fn listen<A>(self, addr_spec: A) -> Result<WebSocket<F>>
+ where
+ A: ToSocketAddrs,
+ {
+ self.bind(addr_spec).and_then(|server| server.run())
+ }
+
+ /// Queue an outgoing connection on this WebSocket. This method may be called multiple times,
+ /// but the actual connections will not be established until `run` is called.
+ pub fn connect(&mut self, url: url::Url) -> Result<&mut WebSocket<F>> {
+ let sender = self.handler.sender();
+ info!("Queuing connection to {}", url);
+ sender.connect(url)?;
+ Ok(self)
+ }
+
+ /// Run the WebSocket. This will run the encapsulated event loop blocking the calling thread until
+ /// the WebSocket is shutdown.
+ pub fn run(mut self) -> Result<WebSocket<F>> {
+ self.handler.run(&mut self.poll)?;
+ Ok(self)
+ }
+
+ /// Get a Sender that can be used to send messages on all connections.
+ /// Calling `send` on this Sender is equivalent to calling `broadcast`.
+ /// Calling `shutdown` on this Sender will shutdown the WebSocket even if no connections have
+ /// been established.
+ #[inline]
+ pub fn broadcaster(&self) -> Sender {
+ self.handler.sender()
+ }
+
+ /// Get the local socket address this socket is bound to. Will return an error
+ /// if the backend returns an error. Will return a `NotFound` error if
+ /// this WebSocket is not a listening socket.
+ pub fn local_addr(&self) -> ::std::io::Result<SocketAddr> {
+ self.handler.local_addr()
+ }
+}
+
+/// Utility for constructing a WebSocket from various settings.
+#[derive(Debug, Default, Clone, Copy)]
+pub struct Builder {
+ settings: Settings,
+}
+
+// TODO: add convenience methods for each setting
+impl Builder {
+ /// Create a new Builder with default settings.
+ pub fn new() -> Builder {
+ Builder::default()
+ }
+
+ /// Build a WebSocket using this builder and a factory.
+ /// It is possible to use the same builder to create multiple WebSockets.
+ pub fn build<F>(&self, factory: F) -> Result<WebSocket<F>>
+ where
+ F: Factory,
+ {
+ Ok(WebSocket {
+ poll: Poll::new()?,
+ handler: io::Handler::new(factory, self.settings),
+ })
+ }
+
+ /// Set the WebSocket settings to use.
+ pub fn with_settings(&mut self, settings: Settings) -> &mut Builder {
+ self.settings = settings;
+ self
+ }
+}
diff --git a/third_party/rust/ws/src/message.rs b/third_party/rust/ws/src/message.rs
new file mode 100644
index 0000000000..08509a5e1d
--- /dev/null
+++ b/third_party/rust/ws/src/message.rs
@@ -0,0 +1,173 @@
+use std::convert::{From, Into};
+use std::fmt;
+use std::result::Result as StdResult;
+use std::str::from_utf8;
+
+use protocol::OpCode;
+use result::Result;
+
+use self::Message::*;
+
+/// An enum representing the various forms of a WebSocket message.
+#[derive(Debug, Eq, PartialEq, Clone)]
+pub enum Message {
+ /// A text WebSocket message
+ Text(String),
+ /// A binary WebSocket message
+ Binary(Vec<u8>),
+}
+
+impl Message {
+ /// Create a new text WebSocket message from a stringable.
+ pub fn text<S>(string: S) -> Message
+ where
+ S: Into<String>,
+ {
+ Message::Text(string.into())
+ }
+
+ /// Create a new binary WebSocket message by converting to Vec<u8>.
+ pub fn binary<B>(bin: B) -> Message
+ where
+ B: Into<Vec<u8>>,
+ {
+ Message::Binary(bin.into())
+ }
+
+ /// Indicates whether a message is a text message.
+ pub fn is_text(&self) -> bool {
+ match *self {
+ Text(_) => true,
+ Binary(_) => false,
+ }
+ }
+
+ /// Indicates whether a message is a binary message.
+ pub fn is_binary(&self) -> bool {
+ match *self {
+ Text(_) => false,
+ Binary(_) => true,
+ }
+ }
+
+ /// Get the length of the WebSocket message.
+ pub fn len(&self) -> usize {
+ match *self {
+ Text(ref string) => string.len(),
+ Binary(ref data) => data.len(),
+ }
+ }
+
+ /// Returns true if the WebSocket message has no content.
+ /// For example, if the other side of the connection sent an empty string.
+ pub fn is_empty(&self) -> bool {
+ match *self {
+ Text(ref string) => string.is_empty(),
+ Binary(ref data) => data.is_empty(),
+ }
+ }
+
+ #[doc(hidden)]
+ pub fn opcode(&self) -> OpCode {
+ match *self {
+ Text(_) => OpCode::Text,
+ Binary(_) => OpCode::Binary,
+ }
+ }
+
+ /// Consume the WebSocket and return it as binary data.
+ pub fn into_data(self) -> Vec<u8> {
+ match self {
+ Text(string) => string.into_bytes(),
+ Binary(data) => data,
+ }
+ }
+
+ /// Attempt to consume the WebSocket message and convert it to a String.
+ pub fn into_text(self) -> Result<String> {
+ match self {
+ Text(string) => Ok(string),
+ Binary(data) => Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?),
+ }
+ }
+
+ /// Attempt to get a &str from the WebSocket message,
+ /// this will try to convert binary data to utf8.
+ pub fn as_text(&self) -> Result<&str> {
+ match *self {
+ Text(ref string) => Ok(string),
+ Binary(ref data) => Ok(from_utf8(data)?),
+ }
+ }
+}
+
+impl From<String> for Message {
+ fn from(string: String) -> Message {
+ Message::text(string)
+ }
+}
+
+impl<'s> From<&'s str> for Message {
+ fn from(string: &'s str) -> Message {
+ Message::text(string)
+ }
+}
+
+impl<'b> From<&'b [u8]> for Message {
+ fn from(data: &'b [u8]) -> Message {
+ Message::binary(data)
+ }
+}
+
+impl From<Vec<u8>> for Message {
+ fn from(data: Vec<u8>) -> Message {
+ Message::binary(data)
+ }
+}
+
+impl fmt::Display for Message {
+ fn fmt(&self, f: &mut fmt::Formatter) -> StdResult<(), fmt::Error> {
+ if let Ok(string) = self.as_text() {
+ write!(f, "{}", string)
+ } else {
+ write!(f, "Binary Data<length={}>", self.len())
+ }
+ }
+}
+
+mod test {
+ #![allow(unused_imports, unused_variables, dead_code)]
+ use super::*;
+
+ #[test]
+ fn display() {
+ let t = Message::text(format!("test"));
+ assert_eq!(t.to_string(), "test".to_owned());
+
+ let bin = Message::binary(vec![0, 1, 3, 4, 241]);
+ assert_eq!(bin.to_string(), "Binary Data<length=5>".to_owned());
+ }
+
+ #[test]
+ fn binary_convert() {
+ let bin = [6u8, 7, 8, 9, 10, 241];
+ let msg = Message::from(&bin[..]);
+ assert!(msg.is_binary());
+ assert!(msg.into_text().is_err());
+ }
+
+ #[test]
+ fn binary_convert_vec() {
+ let bin = vec![6u8, 7, 8, 9, 10, 241];
+ let msg = Message::from(bin);
+ assert!(msg.is_binary());
+ assert!(msg.into_text().is_err());
+ }
+
+ #[test]
+ fn text_convert() {
+ let s = "kiwotsukete";
+ let msg = Message::from(s);
+ assert!(msg.is_text());
+ }
+}
diff --git a/third_party/rust/ws/src/protocol.rs b/third_party/rust/ws/src/protocol.rs
new file mode 100644
index 0000000000..d4c1f2e326
--- /dev/null
+++ b/third_party/rust/ws/src/protocol.rs
@@ -0,0 +1,227 @@
+use std::convert::{From, Into};
+use std::fmt;
+
+use self::OpCode::*;
+/// Operation codes as part of rfc6455.
+#[derive(Debug, Eq, PartialEq, Clone, Copy)]
+pub enum OpCode {
+ /// Indicates a continuation frame of a fragmented message.
+ Continue,
+ /// Indicates a text data frame.
+ Text,
+ /// Indicates a binary data frame.
+ Binary,
+ /// Indicates a close control frame.
+ Close,
+ /// Indicates a ping control frame.
+ Ping,
+ /// Indicates a pong control frame.
+ Pong,
+ /// Indicates an invalid opcode was received.
+ Bad,
+}
+
+impl OpCode {
+ /// Test whether the opcode indicates a control frame.
+ pub fn is_control(&self) -> bool {
+ match *self {
+ Text | Binary | Continue => false,
+ _ => true,
+ }
+ }
+}
+
+impl fmt::Display for OpCode {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match *self {
+ Continue => write!(f, "CONTINUE"),
+ Text => write!(f, "TEXT"),
+ Binary => write!(f, "BINARY"),
+ Close => write!(f, "CLOSE"),
+ Ping => write!(f, "PING"),
+ Pong => write!(f, "PONG"),
+ Bad => write!(f, "BAD"),
+ }
+ }
+}
+
+impl Into<u8> for OpCode {
+ fn into(self) -> u8 {
+ match self {
+ Continue => 0,
+ Text => 1,
+ Binary => 2,
+ Close => 8,
+ Ping => 9,
+ Pong => 10,
+ Bad => {
+ debug_assert!(
+ false,
+ "Attempted to convert invalid opcode to u8. This is a bug."
+ );
+ 8 // if this somehow happens, a close frame will help us tear down quickly
+ }
+ }
+ }
+}
+
+impl From<u8> for OpCode {
+ fn from(byte: u8) -> OpCode {
+ match byte {
+ 0 => Continue,
+ 1 => Text,
+ 2 => Binary,
+ 8 => Close,
+ 9 => Ping,
+ 10 => Pong,
+ _ => Bad,
+ }
+ }
+}
+
+use self::CloseCode::*;
+/// Status code used to indicate why an endpoint is closing the WebSocket connection.
+#[derive(Debug, Eq, PartialEq, Clone, Copy)]
+pub enum CloseCode {
+ /// Indicates a normal closure, meaning that the purpose for
+ /// which the connection was established has been fulfilled.
+ Normal,
+ /// Indicates that an endpoint is "going away", such as a server
+ /// going down or a browser having navigated away from a page.
+ Away,
+ /// Indicates that an endpoint is terminating the connection due
+ /// to a protocol error.
+ Protocol,
+ /// Indicates that an endpoint is terminating the connection
+ /// because it has received a type of data it cannot accept (e.g., an
+ /// endpoint that understands only text data MAY send this if it
+ /// receives a binary message).
+ Unsupported,
+ /// Indicates that no status code was included in a closing frame. This
+ /// close code makes it possible to use a single method, `on_close` to
+ /// handle even cases where no close code was provided.
+ Status,
+ /// Indicates an abnormal closure. If the abnormal closure was due to an
+ /// error, this close code will not be used. Instead, the `on_error` method
+ /// of the handler will be called with the error. However, if the connection
+ /// is simply dropped, without an error, this close code will be sent to the
+ /// handler.
+ Abnormal,
+ /// Indicates that an endpoint is terminating the connection
+ /// because it has received data within a message that was not
+ /// consistent with the type of the message (e.g., non-UTF-8 [RFC3629]
+ /// data within a text message).
+ Invalid,
+ /// Indicates that an endpoint is terminating the connection
+ /// because it has received a message that violates its policy. This
+ /// is a generic status code that can be returned when there is no
+ /// other more suitable status code (e.g., Unsupported or Size) or if there
+ /// is a need to hide specific details about the policy.
+ Policy,
+ /// Indicates that an endpoint is terminating the connection
+ /// because it has received a message that is too big for it to
+ /// process.
+ Size,
+ /// Indicates that an endpoint (client) is terminating the
+ /// connection because it has expected the server to negotiate one or
+ /// more extension, but the server didn't return them in the response
+ /// message of the WebSocket handshake. The list of extensions that
+ /// are needed should be given as the reason for closing.
+ /// Note that this status code is not used by the server, because it
+ /// can fail the WebSocket handshake instead.
+ Extension,
+ /// Indicates that a server is terminating the connection because
+ /// it encountered an unexpected condition that prevented it from
+ /// fulfilling the request.
+ Error,
+ /// Indicates that the server is restarting. A client may choose to reconnect,
+ /// and if it does, it should use a randomized delay of 5-30 seconds between attempts.
+ Restart,
+ /// Indicates that the server is overloaded and the client should either connect
+ /// to a different IP (when multiple targets exist), or reconnect to the same IP
+ /// when a user has performed an action.
+ Again,
+ #[doc(hidden)]
+ Tls,
+ #[doc(hidden)]
+ Empty,
+ #[doc(hidden)]
+ Other(u16),
+}
+
+impl Into<u16> for CloseCode {
+ fn into(self) -> u16 {
+ match self {
+ Normal => 1000,
+ Away => 1001,
+ Protocol => 1002,
+ Unsupported => 1003,
+ Status => 1005,
+ Abnormal => 1006,
+ Invalid => 1007,
+ Policy => 1008,
+ Size => 1009,
+ Extension => 1010,
+ Error => 1011,
+ Restart => 1012,
+ Again => 1013,
+ Tls => 1015,
+ Empty => 0,
+ Other(code) => code,
+ }
+ }
+}
+
+impl From<u16> for CloseCode {
+ fn from(code: u16) -> CloseCode {
+ match code {
+ 1000 => Normal,
+ 1001 => Away,
+ 1002 => Protocol,
+ 1003 => Unsupported,
+ 1005 => Status,
+ 1006 => Abnormal,
+ 1007 => Invalid,
+ 1008 => Policy,
+ 1009 => Size,
+ 1010 => Extension,
+ 1011 => Error,
+ 1012 => Restart,
+ 1013 => Again,
+ 1015 => Tls,
+ 0 => Empty,
+ _ => Other(code),
+ }
+ }
+}
+
+mod test {
+ #![allow(unused_imports, unused_variables, dead_code)]
+ use super::*;
+
+ #[test]
+ fn opcode_from_u8() {
+ let byte = 2u8;
+ assert_eq!(OpCode::from(byte), OpCode::Binary);
+ }
+
+ #[test]
+ fn opcode_into_u8() {
+ let text = OpCode::Text;
+ let byte: u8 = text.into();
+ assert_eq!(byte, 1u8);
+ }
+
+ #[test]
+ fn closecode_from_u16() {
+ let byte = 1008u16;
+ assert_eq!(CloseCode::from(byte), CloseCode::Policy);
+ }
+
+ #[test]
+ fn closecode_into_u16() {
+ let text = CloseCode::Away;
+ let byte: u16 = text.into();
+ assert_eq!(byte, 1001u16);
+ }
+}
diff --git a/third_party/rust/ws/src/result.rs b/third_party/rust/ws/src/result.rs
new file mode 100644
index 0000000000..eb3c151813
--- /dev/null
+++ b/third_party/rust/ws/src/result.rs
@@ -0,0 +1,204 @@
+use std::borrow::Cow;
+use std::convert::{From, Into};
+use std::error::Error as StdError;
+use std::fmt;
+use std::io;
+use std::result::Result as StdResult;
+use std::str::Utf8Error;
+
+use httparse;
+use mio;
+#[cfg(feature = "ssl")]
+use openssl::ssl::{Error as SslError, HandshakeError as SslHandshakeError};
+#[cfg(feature = "nativetls")]
+use native_tls::{Error as SslError, HandshakeError as SslHandshakeError};
+#[cfg(any(feature = "ssl", feature = "nativetls"))]
+type HandshakeError = SslHandshakeError<mio::tcp::TcpStream>;
+
+use communication::Command;
+
+pub type Result<T> = StdResult<T, Error>;
+
+/// The type of an error, which may indicate other kinds of errors as the underlying cause.
+#[derive(Debug)]
+pub enum Kind {
+ /// Indicates an internal application error.
+ /// If panic_on_internal is true, which is the default, then the application will panic.
+ /// Otherwise the WebSocket will automatically attempt to send an Error (1011) close code.
+ Internal,
+ /// Indicates a state where some size limit has been exceeded, such as an inability to accept
+ /// any more new connections.
+ /// If a Connection is active, the WebSocket will automatically attempt to send
+ /// a Size (1009) close code.
+ Capacity,
+ /// Indicates a violation of the WebSocket protocol.
+ /// The WebSocket will automatically attempt to send a Protocol (1002) close code, or if
+ /// this error occurs during a handshake, an HTTP 400 response will be generated.
+ Protocol,
+ /// Indicates that the WebSocket received data that should be utf8 encoded but was not.
+ /// The WebSocket will automatically attempt to send a Invalid Frame Payload Data (1007) close
+ /// code.
+ Encoding(Utf8Error),
+ /// Indicates an underlying IO Error.
+ /// This kind of error will result in a WebSocket Connection disconnecting.
+ Io(io::Error),
+ /// Indicates a failure to parse an HTTP message.
+ /// This kind of error should only occur during a WebSocket Handshake, and a HTTP 500 response
+ /// will be generated.
+ Http(httparse::Error),
+ /// Indicates a failure to send a signal on the internal EventLoop channel. This means that
+ /// the WebSocket is overloaded. In order to avoid this error, it is important to set
+ /// `Settings::max_connections` and `Settings:queue_size` high enough to handle the load.
+ /// If encountered, retuning from a handler method and waiting for the EventLoop to consume
+ /// the queue may relieve the situation.
+ Queue(mio::channel::SendError<Command>),
+ /// Indicates a failure to perform SSL encryption.
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Ssl(SslError),
+ /// Indicates a failure to perform SSL encryption.
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ SslHandshake(HandshakeError),
+ /// A custom error kind for use by applications. This error kind involves extra overhead
+ /// because it will allocate the memory on the heap. The WebSocket ignores such errors by
+ /// default, simply passing them to the Connection Handler.
+ Custom(Box<dyn StdError + Send + Sync>),
+}
+
+/// A struct indicating the kind of error that has occurred and any precise details of that error.
+pub struct Error {
+ pub kind: Kind,
+ pub details: Cow<'static, str>,
+}
+
+impl Error {
+ pub fn new<I>(kind: Kind, details: I) -> Error
+ where
+ I: Into<Cow<'static, str>>,
+ {
+ Error {
+ kind,
+ details: details.into(),
+ }
+ }
+
+ pub fn into_box(self) -> Box<dyn StdError> {
+ match self.kind {
+ Kind::Custom(err) => err,
+ _ => Box::new(self),
+ }
+ }
+}
+
+impl fmt::Debug for Error {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ if self.details.len() > 0 {
+ write!(f, "WS Error <{:?}>: {}", self.kind, self.details)
+ } else {
+ write!(f, "WS Error <{:?}>", self.kind)
+ }
+ }
+}
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ if self.details.len() > 0 {
+ write!(f, "{}: {}", self.description(), self.details)
+ } else {
+ write!(f, "{}", self.description())
+ }
+ }
+}
+
+impl StdError for Error {
+ fn description(&self) -> &str {
+ match self.kind {
+ Kind::Internal => "Internal Application Error",
+ Kind::Capacity => "WebSocket at Capacity",
+ Kind::Protocol => "WebSocket Protocol Error",
+ Kind::Encoding(ref err) => err.description(),
+ Kind::Io(ref err) => err.description(),
+ Kind::Http(_) => "Unable to parse HTTP",
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Kind::Ssl(ref err) => err.description(),
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Kind::SslHandshake(ref err) => err.description(),
+ Kind::Queue(_) => "Unable to send signal on event loop",
+ Kind::Custom(ref err) => err.description(),
+ }
+ }
+
+ fn cause(&self) -> Option<&dyn StdError> {
+ match self.kind {
+ Kind::Encoding(ref err) => Some(err),
+ Kind::Io(ref err) => Some(err),
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Kind::Ssl(ref err) => Some(err),
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Kind::SslHandshake(ref err) => err.cause(),
+ Kind::Custom(ref err) => Some(err.as_ref()),
+ _ => None,
+ }
+ }
+}
+
+impl From<io::Error> for Error {
+ fn from(err: io::Error) -> Error {
+ Error::new(Kind::Io(err), "")
+ }
+}
+
+impl From<httparse::Error> for Error {
+ fn from(err: httparse::Error) -> Error {
+ let details = match err {
+ httparse::Error::HeaderName => "Invalid byte in header name.",
+ httparse::Error::HeaderValue => "Invalid byte in header value.",
+ httparse::Error::NewLine => "Invalid byte in new line.",
+ httparse::Error::Status => "Invalid byte in Response status.",
+ httparse::Error::Token => "Invalid byte where token is required.",
+ httparse::Error::TooManyHeaders => {
+ "Parsed more headers than provided buffer can contain."
+ }
+ httparse::Error::Version => "Invalid byte in HTTP version.",
+ };
+
+ Error::new(Kind::Http(err), details)
+ }
+}
+
+impl From<mio::channel::SendError<Command>> for Error {
+ fn from(err: mio::channel::SendError<Command>) -> Error {
+ match err {
+ mio::channel::SendError::Io(err) => Error::from(err),
+ _ => Error::new(Kind::Queue(err), ""),
+ }
+ }
+}
+
+impl From<Utf8Error> for Error {
+ fn from(err: Utf8Error) -> Error {
+ Error::new(Kind::Encoding(err), "")
+ }
+}
+
+#[cfg(any(feature = "ssl", feature = "nativetls"))]
+impl From<SslError> for Error {
+ fn from(err: SslError) -> Error {
+ Error::new(Kind::Ssl(err), "")
+ }
+}
+
+#[cfg(any(feature = "ssl", feature = "nativetls"))]
+impl From<HandshakeError> for Error {
+ fn from(err: HandshakeError) -> Error {
+ Error::new(Kind::SslHandshake(err), "")
+ }
+}
+
+impl<B> From<Box<B>> for Error
+where
+ B: StdError + Send + Sync + 'static,
+{
+ fn from(err: Box<B>) -> Error {
+ Error::new(Kind::Custom(err), "")
+ }
+}
diff --git a/third_party/rust/ws/src/stream.rs b/third_party/rust/ws/src/stream.rs
new file mode 100644
index 0000000000..3b8d9c441b
--- /dev/null
+++ b/third_party/rust/ws/src/stream.rs
@@ -0,0 +1,358 @@
+use std::io;
+use std::io::ErrorKind::WouldBlock;
+#[cfg(any(feature = "ssl", feature = "nativetls"))]
+use std::mem::replace;
+use std::net::SocketAddr;
+
+use bytes::{Buf, BufMut};
+use mio::tcp::TcpStream;
+#[cfg(feature = "nativetls")]
+use native_tls::{
+ HandshakeError, MidHandshakeTlsStream as MidHandshakeSslStream, TlsStream as SslStream,
+};
+#[cfg(feature = "ssl")]
+use openssl::ssl::{ErrorCode as SslErrorCode, HandshakeError, MidHandshakeSslStream, SslStream};
+
+use result::{Error, Kind, Result};
+
+fn map_non_block<T>(res: io::Result<T>) -> io::Result<Option<T>> {
+ match res {
+ Ok(value) => Ok(Some(value)),
+ Err(err) => {
+ if let WouldBlock = err.kind() {
+ Ok(None)
+ } else {
+ Err(err)
+ }
+ }
+ }
+}
+
+pub trait TryReadBuf: io::Read {
+ fn try_read_buf<B: BufMut>(&mut self, buf: &mut B) -> io::Result<Option<usize>>
+ where
+ Self: Sized,
+ {
+ // Reads the length of the slice supplied by buf.mut_bytes into the buffer
+ // This is not guaranteed to consume an entire datagram or segment.
+ // If your protocol is msg based (instead of continuous stream) you should
+ // ensure that your buffer is large enough to hold an entire segment (1532 bytes if not jumbo
+ // frames)
+ let res = map_non_block(self.read(unsafe { buf.bytes_mut() }));
+
+ if let Ok(Some(cnt)) = res {
+ unsafe {
+ buf.advance_mut(cnt);
+ }
+ }
+
+ res
+ }
+}
+
+pub trait TryWriteBuf: io::Write {
+ fn try_write_buf<B: Buf>(&mut self, buf: &mut B) -> io::Result<Option<usize>>
+ where
+ Self: Sized,
+ {
+ let res = map_non_block(self.write(buf.bytes()));
+
+ if let Ok(Some(cnt)) = res {
+ buf.advance(cnt);
+ }
+
+ res
+ }
+}
+
+impl<T: io::Read> TryReadBuf for T {}
+impl<T: io::Write> TryWriteBuf for T {}
+
+use self::Stream::*;
+pub enum Stream {
+ Tcp(TcpStream),
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Tls(TlsStream),
+}
+
+impl Stream {
+ pub fn tcp(stream: TcpStream) -> Stream {
+ Tcp(stream)
+ }
+
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ pub fn tls(stream: MidHandshakeSslStream<TcpStream>) -> Stream {
+ Tls(TlsStream::Handshake {
+ sock: stream,
+ negotiating: false,
+ })
+ }
+
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ pub fn tls_live(stream: SslStream<TcpStream>) -> Stream {
+ Tls(TlsStream::Live(stream))
+ }
+
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ pub fn is_tls(&self) -> bool {
+ match *self {
+ Tcp(_) => false,
+ Tls(_) => true,
+ }
+ }
+
+ pub fn evented(&self) -> &TcpStream {
+ match *self {
+ Tcp(ref sock) => sock,
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Tls(ref inner) => inner.evented(),
+ }
+ }
+
+ pub fn is_negotiating(&self) -> bool {
+ match *self {
+ Tcp(_) => false,
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Tls(ref inner) => inner.is_negotiating(),
+ }
+ }
+
+ pub fn clear_negotiating(&mut self) -> Result<()> {
+ match *self {
+ Tcp(_) => Err(Error::new(
+ Kind::Internal,
+ "Attempted to clear negotiating flag on non ssl connection.",
+ )),
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Tls(ref mut inner) => inner.clear_negotiating(),
+ }
+ }
+
+ pub fn peer_addr(&self) -> io::Result<SocketAddr> {
+ match *self {
+ Tcp(ref sock) => sock.peer_addr(),
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Tls(ref inner) => inner.peer_addr(),
+ }
+ }
+
+ pub fn local_addr(&self) -> io::Result<SocketAddr> {
+ match *self {
+ Tcp(ref sock) => sock.local_addr(),
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Tls(ref inner) => inner.local_addr(),
+ }
+ }
+}
+
+impl io::Read for Stream {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ match *self {
+ Tcp(ref mut sock) => sock.read(buf),
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Tls(TlsStream::Live(ref mut sock)) => sock.read(buf),
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Tls(ref mut tls_stream) => {
+ trace!("Attempting to read ssl handshake.");
+ match replace(tls_stream, TlsStream::Upgrading) {
+ TlsStream::Live(_) | TlsStream::Upgrading => unreachable!(),
+ TlsStream::Handshake {
+ sock,
+ mut negotiating,
+ } => match sock.handshake() {
+ Ok(mut sock) => {
+ trace!("Completed SSL Handshake");
+ let res = sock.read(buf);
+ *tls_stream = TlsStream::Live(sock);
+ res
+ }
+ #[cfg(feature = "ssl")]
+ Err(HandshakeError::SetupFailure(err)) => {
+ Err(io::Error::new(io::ErrorKind::Other, err))
+ }
+ #[cfg(feature = "ssl")]
+ Err(HandshakeError::Failure(mid))
+ | Err(HandshakeError::WouldBlock(mid)) => {
+ if mid.error().code() == SslErrorCode::WANT_READ {
+ negotiating = true;
+ }
+ let err = if let Some(io_error) = mid.error().io_error() {
+ Err(io::Error::new(
+ io_error.kind(),
+ format!("{:?}", io_error.get_ref()),
+ ))
+ } else {
+ Err(io::Error::new(
+ io::ErrorKind::Other,
+ format!("{}", mid.error()),
+ ))
+ };
+ *tls_stream = TlsStream::Handshake {
+ sock: mid,
+ negotiating,
+ };
+ err
+ }
+ #[cfg(feature = "nativetls")]
+ Err(HandshakeError::WouldBlock(mid)) => {
+ negotiating = true;
+ *tls_stream = TlsStream::Handshake {
+ sock: mid,
+ negotiating: negotiating,
+ };
+ Err(io::Error::new(io::ErrorKind::WouldBlock, "SSL would block"))
+ }
+ #[cfg(feature = "nativetls")]
+ Err(HandshakeError::Failure(err)) => {
+ Err(io::Error::new(io::ErrorKind::Other, format!("{}", err)))
+ }
+ },
+ }
+ }
+ }
+ }
+}
+
+impl io::Write for Stream {
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ match *self {
+ Tcp(ref mut sock) => sock.write(buf),
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Tls(TlsStream::Live(ref mut sock)) => sock.write(buf),
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Tls(ref mut tls_stream) => {
+ trace!("Attempting to write ssl handshake.");
+ match replace(tls_stream, TlsStream::Upgrading) {
+ TlsStream::Live(_) | TlsStream::Upgrading => unreachable!(),
+ TlsStream::Handshake {
+ sock,
+ mut negotiating,
+ } => match sock.handshake() {
+ Ok(mut sock) => {
+ trace!("Completed SSL Handshake");
+ let res = sock.write(buf);
+ *tls_stream = TlsStream::Live(sock);
+ res
+ }
+ #[cfg(feature = "ssl")]
+ Err(HandshakeError::SetupFailure(err)) => {
+ Err(io::Error::new(io::ErrorKind::Other, err))
+ }
+ #[cfg(feature = "ssl")]
+ Err(HandshakeError::Failure(mid))
+ | Err(HandshakeError::WouldBlock(mid)) => {
+ if mid.error().code() == SslErrorCode::WANT_READ {
+ negotiating = true;
+ } else {
+ negotiating = false;
+ }
+ let err = if let Some(io_error) = mid.error().io_error() {
+ Err(io::Error::new(
+ io_error.kind(),
+ format!("{:?}", io_error.get_ref()),
+ ))
+ } else {
+ Err(io::Error::new(
+ io::ErrorKind::Other,
+ format!("{}", mid.error()),
+ ))
+ };
+ *tls_stream = TlsStream::Handshake {
+ sock: mid,
+ negotiating,
+ };
+ err
+ }
+ #[cfg(feature = "nativetls")]
+ Err(HandshakeError::WouldBlock(mid)) => {
+ negotiating = true;
+ *tls_stream = TlsStream::Handshake {
+ sock: mid,
+ negotiating: negotiating,
+ };
+ Err(io::Error::new(io::ErrorKind::WouldBlock, "SSL would block"))
+ }
+ #[cfg(feature = "nativetls")]
+ Err(HandshakeError::Failure(err)) => {
+ Err(io::Error::new(io::ErrorKind::Other, format!("{}", err)))
+ }
+ },
+ }
+ }
+ }
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ match *self {
+ Tcp(ref mut sock) => sock.flush(),
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Tls(TlsStream::Live(ref mut sock)) => sock.flush(),
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Tls(TlsStream::Handshake { ref mut sock, .. }) => sock.get_mut().flush(),
+ #[cfg(any(feature = "ssl", feature = "nativetls"))]
+ Tls(TlsStream::Upgrading) => panic!("Tried to access actively upgrading TlsStream"),
+ }
+ }
+}
+
+#[cfg(any(feature = "ssl", feature = "nativetls"))]
+pub enum TlsStream {
+ Live(SslStream<TcpStream>),
+ Handshake {
+ sock: MidHandshakeSslStream<TcpStream>,
+ negotiating: bool,
+ },
+ Upgrading,
+}
+
+#[cfg(any(feature = "ssl", feature = "nativetls"))]
+impl TlsStream {
+ pub fn evented(&self) -> &TcpStream {
+ match *self {
+ TlsStream::Live(ref sock) => sock.get_ref(),
+ TlsStream::Handshake { ref sock, .. } => sock.get_ref(),
+ TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
+ }
+ }
+
+ pub fn is_negotiating(&self) -> bool {
+ match *self {
+ TlsStream::Live(_) => false,
+ TlsStream::Handshake {
+ sock: _,
+ negotiating,
+ } => negotiating,
+ TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
+ }
+ }
+
+ pub fn clear_negotiating(&mut self) -> Result<()> {
+ match *self {
+ TlsStream::Live(_) => Err(Error::new(
+ Kind::Internal,
+ "Attempted to clear negotiating flag on live ssl connection.",
+ )),
+ TlsStream::Handshake {
+ sock: _,
+ ref mut negotiating,
+ } => Ok(*negotiating = false),
+ TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
+ }
+ }
+
+ pub fn peer_addr(&self) -> io::Result<SocketAddr> {
+ match *self {
+ TlsStream::Live(ref sock) => sock.get_ref().peer_addr(),
+ TlsStream::Handshake { ref sock, .. } => sock.get_ref().peer_addr(),
+ TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
+ }
+ }
+
+ pub fn local_addr(&self) -> io::Result<SocketAddr> {
+ match *self {
+ TlsStream::Live(ref sock) => sock.get_ref().local_addr(),
+ TlsStream::Handshake { ref sock, .. } => sock.get_ref().local_addr(),
+ TlsStream::Upgrading => panic!("Tried to access actively upgrading TlsStream"),
+ }
+ }
+}
diff --git a/third_party/rust/ws/src/util.rs b/third_party/rust/ws/src/util.rs
new file mode 100644
index 0000000000..fc66394a74
--- /dev/null
+++ b/third_party/rust/ws/src/util.rs
@@ -0,0 +1,9 @@
+//! The util module rexports some tools from mio in order to facilitate handling timeouts.
+
+/// Used to identify some timed-out event.
+pub use mio::Token;
+/// A handle to a specific timeout.
+pub use mio_extras::timer::Timeout;
+#[cfg(any(feature = "ssl", feature = "nativetls"))]
+/// TcpStream underlying the WebSocket
+pub use mio::tcp::TcpStream;