diff options
Diffstat (limited to 'third_party/rust/ws/src/deflate')
-rw-r--r-- | third_party/rust/ws/src/deflate/context.rs | 268 | ||||
-rw-r--r-- | third_party/rust/ws/src/deflate/extension.rs | 565 | ||||
-rw-r--r-- | third_party/rust/ws/src/deflate/mod.rs | 9 |
3 files changed, 842 insertions, 0 deletions
diff --git a/third_party/rust/ws/src/deflate/context.rs b/third_party/rust/ws/src/deflate/context.rs new file mode 100644 index 0000000000..2fa2e23056 --- /dev/null +++ b/third_party/rust/ws/src/deflate/context.rs @@ -0,0 +1,268 @@ +use std::mem; +use std::slice; + +use super::ffi; +use super::libc::{c_char, c_int, c_uint}; + +use result::{Error, Kind, Result}; + +const ZLIB_VERSION: &'static str = "1.2.8\0"; + +trait Context { + fn stream(&mut self) -> &mut ffi::z_stream; + + fn stream_apply<F>(&mut self, input: &[u8], output: &mut Vec<u8>, each: F) -> Result<()> + where + F: Fn(&mut ffi::z_stream) -> Option<Result<()>>, + { + debug_assert!(output.len() == 0, "Output vector is not empty."); + + let stream = self.stream(); + + stream.next_in = input.as_ptr() as *mut _; + stream.avail_in = input.len() as c_uint; + + let mut output_size; + + loop { + output_size = output.len(); + + if output_size == output.capacity() { + output.reserve(input.len()) + } + + let out_slice = unsafe { + slice::from_raw_parts_mut( + output.as_mut_ptr().offset(output_size as isize), + output.capacity() - output_size, + ) + }; + + stream.next_out = out_slice.as_mut_ptr(); + stream.avail_out = out_slice.len() as c_uint; + + let before = stream.total_out; + let cont = each(stream); + + unsafe { + output.set_len((stream.total_out - before) as usize + output_size); + } + + if let Some(result) = cont { + return result; + } + } + } +} + +pub struct Compressor { + // Box the z_stream to ensure it isn't moved. Moving the z_stream + // causes zlib to fail, because it maintains internal pointers. + stream: Box<ffi::z_stream>, +} + +impl Compressor { + pub fn new(window_bits: i8) -> Compressor { + debug_assert!(window_bits >= 9, "Received too small window size."); + debug_assert!(window_bits <= 15, "Received too large window size."); + + unsafe { + let mut stream: Box<ffi::z_stream> = Box::new(mem::zeroed()); + let result = ffi::deflateInit2_( + stream.as_mut(), + 9, + ffi::Z_DEFLATED, + -window_bits as c_int, + 9, + ffi::Z_DEFAULT_STRATEGY, + ZLIB_VERSION.as_ptr() as *const c_char, + mem::size_of::<ffi::z_stream>() as c_int, + ); + assert!(result == ffi::Z_OK, "Failed to initialize compresser."); + Compressor { stream: stream } + } + } + + pub fn compress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<()> { + self.stream_apply(input, output, |stream| unsafe { + match ffi::deflate(stream, ffi::Z_SYNC_FLUSH) { + ffi::Z_OK | ffi::Z_BUF_ERROR => { + if stream.avail_in == 0 && stream.avail_out > 0 { + Some(Ok(())) + } else { + None + } + } + code => Some(Err(Error::new( + Kind::Protocol, + format!("Failed to perform compression: {}", code), + ))), + } + }) + } + + pub fn reset(&mut self) -> Result<()> { + match unsafe { ffi::deflateReset(self.stream.as_mut()) } { + ffi::Z_OK => Ok(()), + code => Err(Error::new( + Kind::Protocol, + format!("Failed to reset compression context: {}", code), + )), + } + } +} + +impl Context for Compressor { + fn stream(&mut self) -> &mut ffi::z_stream { + self.stream.as_mut() + } +} + +impl Drop for Compressor { + fn drop(&mut self) { + match unsafe { ffi::deflateEnd(self.stream.as_mut()) } { + ffi::Z_STREAM_ERROR => error!("Compression stream encountered bad state."), + // Ignore discarded data error because we are raw + ffi::Z_OK | ffi::Z_DATA_ERROR => trace!("Deallocated compression context."), + code => error!("Bad zlib status encountered: {}", code), + } + } +} + +pub struct Decompressor { + stream: Box<ffi::z_stream>, +} + +impl Decompressor { + pub fn new(window_bits: i8) -> Decompressor { + debug_assert!(window_bits >= 8, "Received too small window size."); + debug_assert!(window_bits <= 15, "Received too large window size."); + + unsafe { + let mut stream: Box<ffi::z_stream> = Box::new(mem::zeroed()); + let result = ffi::inflateInit2_( + stream.as_mut(), + -window_bits as c_int, + ZLIB_VERSION.as_ptr() as *const c_char, + mem::size_of::<ffi::z_stream>() as c_int, + ); + assert!(result == ffi::Z_OK, "Failed to initialize decompresser."); + Decompressor { stream: stream } + } + } + + pub fn decompress(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<()> { + self.stream_apply(input, output, |stream| unsafe { + match ffi::inflate(stream, ffi::Z_SYNC_FLUSH) { + ffi::Z_OK | ffi::Z_BUF_ERROR => { + if stream.avail_in == 0 && stream.avail_out > 0 { + Some(Ok(())) + } else { + None + } + } + code => Some(Err(Error::new( + Kind::Protocol, + format!("Failed to perform decompression: {}", code), + ))), + } + }) + } + + pub fn reset(&mut self) -> Result<()> { + match unsafe { ffi::inflateReset(self.stream.as_mut()) } { + ffi::Z_OK => Ok(()), + code => Err(Error::new( + Kind::Protocol, + format!("Failed to reset compression context: {}", code), + )), + } + } +} + +impl Context for Decompressor { + fn stream(&mut self) -> &mut ffi::z_stream { + self.stream.as_mut() + } +} + +impl Drop for Decompressor { + fn drop(&mut self) { + match unsafe { ffi::inflateEnd(self.stream.as_mut()) } { + ffi::Z_STREAM_ERROR => error!("Decompression stream encountered bad state."), + ffi::Z_OK => trace!("Deallocated decompression context."), + code => error!("Bad zlib status encountered: {}", code), + } + } +} + +mod test { + #![allow(unused_imports, unused_variables, dead_code)] + use super::*; + + fn as_hex(s: &[u8]) { + for byte in s { + print!("0x{:x} ", byte); + } + print!("\n"); + } + + #[test] + fn round_trip() { + for i in 9..16 { + let data = "HI THERE THIS IS some data. これはデータだよ。".as_bytes(); + let mut compressed = Vec::with_capacity(data.len()); + let mut decompressed = Vec::with_capacity(data.len()); + + let com = Compressor::new(i); + let mut moved_com = com; + + moved_com + .compress(&data, &mut compressed) + .expect("Failed to compress data."); + + let dec = Decompressor::new(i); + let mut moved_dec = dec; + + moved_dec + .decompress(&compressed, &mut decompressed) + .expect("Failed to decompress data."); + + assert_eq!(data, &decompressed[..]); + } + } + + #[test] + fn reset() { + let data1 = "HI THERE 直子さん".as_bytes(); + let data2 = "HI THERE 人太郎".as_bytes(); + let mut compressed1 = Vec::with_capacity(data1.len()); + let mut compressed2 = Vec::with_capacity(data2.len()); + let mut compressed2_ind = Vec::with_capacity(data2.len()); + + let mut decompressed1 = Vec::with_capacity(data1.len()); + let mut decompressed2 = Vec::with_capacity(data2.len()); + let mut decompressed2_ind = Vec::with_capacity(data2.len()); + + let mut com = Compressor::new(9); + + com.compress(&data1, &mut compressed1).unwrap(); + com.compress(&data2, &mut compressed2).unwrap(); + com.reset().unwrap(); + com.compress(&data2, &mut compressed2_ind).unwrap(); + + let mut dec = Decompressor::new(9); + + dec.decompress(&compressed1, &mut decompressed1).unwrap(); + dec.decompress(&compressed2, &mut decompressed2).unwrap(); + dec.reset().unwrap(); + dec.decompress(&compressed2_ind, &mut decompressed2_ind) + .unwrap(); + + assert_eq!(data1, &decompressed1[..]); + assert_eq!(data2, &decompressed2[..]); + assert_eq!(data2, &decompressed2_ind[..]); + assert!(compressed2 != compressed2_ind); + assert!(compressed2.len() < compressed2_ind.len()); + } +} diff --git a/third_party/rust/ws/src/deflate/extension.rs b/third_party/rust/ws/src/deflate/extension.rs new file mode 100644 index 0000000000..712e11fb8e --- /dev/null +++ b/third_party/rust/ws/src/deflate/extension.rs @@ -0,0 +1,565 @@ +use std::mem::replace; + +#[cfg(feature = "ssl")] +use openssl::ssl::SslStream; +#[cfg(feature = "nativetls")] +use native_tls::TlsStream as SslStream; +use url; + +use frame::Frame; +use handler::Handler; +use handshake::{Handshake, Request, Response}; +use message::Message; +use protocol::{CloseCode, OpCode}; +use result::{Error, Kind, Result}; +#[cfg(any(feature = "ssl", feature = "nativetls"))] +use util::TcpStream; +use util::{Timeout, Token}; + +use super::context::{Compressor, Decompressor}; + +/// Deflate Extension Handler Settings +#[derive(Debug, Clone, Copy)] +pub struct DeflateSettings { + /// The max size of the sliding window. If the other endpoint selects a smaller size, that size + /// will be used instead. This must be an integer between 9 and 15 inclusive. + /// Default: 15 + pub max_window_bits: u8, + /// Indicates whether to ask the other endpoint to reset the sliding window for each message. + /// Default: false + pub request_no_context_takeover: bool, + /// Indicates whether this endpoint will agree to reset the sliding window for each message it + /// compresses. If this endpoint won't agree to reset the sliding window, then the handshake + /// will fail if this endpoint is a client and the server requests no context takeover. + /// Default: true + pub accept_no_context_takeover: bool, + /// The number of WebSocket frames to store when defragmenting an incoming fragmented + /// compressed message. + /// This setting may be different from the `fragments_capacity` setting of the WebSocket in order to + /// allow for differences between compressed and uncompressed messages. + /// Default: 10 + pub fragments_capacity: usize, + /// Indicates whether the extension handler will reallocate if the `fragments_capacity` is + /// exceeded. If this is not true, a capacity error will be triggered instead. + /// Default: true + pub fragments_grow: bool, +} + +impl Default for DeflateSettings { + fn default() -> DeflateSettings { + DeflateSettings { + max_window_bits: 15, + request_no_context_takeover: false, + accept_no_context_takeover: true, + fragments_capacity: 10, + fragments_grow: true, + } + } +} + +/// Utility for applying the permessage-deflate extension to a handler with particular deflate +/// settings. +#[derive(Debug, Clone, Copy)] +pub struct DeflateBuilder { + settings: DeflateSettings, +} + +impl DeflateBuilder { + /// Create a new DeflateBuilder with the default settings. + pub fn new() -> DeflateBuilder { + DeflateBuilder { + settings: DeflateSettings::default(), + } + } + + /// Configure the DeflateBuilder with the given deflate settings. + pub fn with_settings(&mut self, settings: DeflateSettings) -> &mut DeflateBuilder { + self.settings = settings; + self + } + + /// Wrap another handler in with a deflate handler as configured. + pub fn build<H: Handler>(&self, handler: H) -> DeflateHandler<H> { + DeflateHandler { + com: Compressor::new(self.settings.max_window_bits as i8), + dec: Decompressor::new(self.settings.max_window_bits as i8), + fragments: Vec::with_capacity(self.settings.fragments_capacity), + compress_reset: false, + decompress_reset: false, + pass: false, + settings: self.settings, + inner: handler, + } + } +} + +/// A WebSocket handler that implements the permessage-deflate extension. +/// +/// This handler wraps a child handler and proxies all handler methods to it. The handler will +/// decompress incoming WebSocket message frames in their reserved bits match the +/// permessage-deflate specification and pass them to the child handler. Message frames sent from +/// the child handler will be compressed and sent to the other endpoint using deflate compression. +pub struct DeflateHandler<H: Handler> { + com: Compressor, + dec: Decompressor, + fragments: Vec<Frame>, + compress_reset: bool, + decompress_reset: bool, + pass: bool, + settings: DeflateSettings, + inner: H, +} + +impl<H: Handler> DeflateHandler<H> { + /// Wrap a child handler to provide the permessage-deflate extension. + pub fn new(handler: H) -> DeflateHandler<H> { + trace!("Using permessage-deflate handler."); + let settings = DeflateSettings::default(); + DeflateHandler { + com: Compressor::new(settings.max_window_bits as i8), + dec: Decompressor::new(settings.max_window_bits as i8), + fragments: Vec::with_capacity(settings.fragments_capacity), + compress_reset: false, + decompress_reset: false, + pass: false, + settings: settings, + inner: handler, + } + } + + #[doc(hidden)] + #[inline] + fn decline(&mut self, mut res: Response) -> Result<Response> { + trace!("Declined permessage-deflate offer"); + self.pass = true; + res.remove_extension("permessage-deflate"); + Ok(res) + } +} + +impl<H: Handler> Handler for DeflateHandler<H> { + fn build_request(&mut self, url: &url::Url) -> Result<Request> { + let mut req = self.inner.build_request(url)?; + let mut req_ext = String::with_capacity(100); + req_ext.push_str("permessage-deflate"); + if self.settings.max_window_bits < 15 { + req_ext.push_str(&format!( + "; client_max_window_bits={}; server_max_window_bits={}", + self.settings.max_window_bits, self.settings.max_window_bits + )) + } else { + req_ext.push_str("; client_max_window_bits") + } + if self.settings.request_no_context_takeover { + req_ext.push_str("; server_no_context_takeover") + } + req.add_extension(&req_ext); + Ok(req) + } + + fn on_request(&mut self, req: &Request) -> Result<Response> { + let mut res = self.inner.on_request(req)?; + + 'ext: for req_ext in req.extensions()? + .iter() + .filter(|&&ext| ext.contains("permessage-deflate")) + { + let mut res_ext = String::with_capacity(req_ext.len()); + let mut s_takeover = false; + let mut c_takeover = false; + let mut s_max = false; + let mut c_max = false; + + for param in req_ext.split(';') { + match param.trim() { + "permessage-deflate" => res_ext.push_str("permessage-deflate"), + "server_no_context_takeover" => { + if s_takeover { + return self.decline(res); + } else { + s_takeover = true; + if self.settings.accept_no_context_takeover { + self.compress_reset = true; + res_ext.push_str("; server_no_context_takeover"); + } else { + continue 'ext; + } + } + } + "client_no_context_takeover" => { + if c_takeover { + return self.decline(res); + } else { + c_takeover = true; + self.decompress_reset = true; + res_ext.push_str("; client_no_context_takeover"); + } + } + param if param.starts_with("server_max_window_bits") => { + if s_max { + return self.decline(res); + } else { + s_max = true; + let mut param_iter = param.split('='); + param_iter.next(); // we already know the name + if let Some(window_bits_str) = param_iter.next() { + if let Ok(window_bits) = window_bits_str.trim().parse() { + if window_bits >= 9 && window_bits <= 15 { + if window_bits < self.settings.max_window_bits as i8 { + self.com = Compressor::new(window_bits); + res_ext.push_str("; "); + res_ext.push_str(param) + } + } else { + return self.decline(res); + } + } else { + return self.decline(res); + } + } + } + } + param if param.starts_with("client_max_window_bits") => { + if c_max { + return self.decline(res); + } else { + c_max = true; + let mut param_iter = param.split('='); + param_iter.next(); // we already know the name + if let Some(window_bits_str) = param_iter.next() { + if let Ok(window_bits) = window_bits_str.trim().parse() { + if window_bits >= 9 && window_bits <= 15 { + if window_bits < self.settings.max_window_bits as i8 { + self.dec = Decompressor::new(window_bits); + res_ext.push_str("; "); + res_ext.push_str(param); + continue; + } + } else { + return self.decline(res); + } + } else { + return self.decline(res); + } + } + res_ext.push_str("; "); + res_ext.push_str(&format!( + "client_max_window_bits={}", + self.settings.max_window_bits + )) + } + } + _ => { + // decline all extension offers because we got a bad parameter + return self.decline(res); + } + } + } + + if !res_ext.contains("client_no_context_takeover") + && self.settings.request_no_context_takeover + { + self.decompress_reset = true; + res_ext.push_str("; client_no_context_takeover"); + } + + if !res_ext.contains("server_max_window_bits") { + res_ext.push_str("; "); + res_ext.push_str(&format!( + "server_max_window_bits={}", + self.settings.max_window_bits + )) + } + + if !res_ext.contains("client_max_window_bits") && self.settings.max_window_bits < 15 { + continue; + } + + res.add_extension(&res_ext); + return Ok(res); + } + self.decline(res) + } + + fn on_response(&mut self, res: &Response) -> Result<()> { + if let Some(res_ext) = res.extensions()? + .iter() + .find(|&&ext| ext.contains("permessage-deflate")) + { + let mut name = false; + let mut s_takeover = false; + let mut c_takeover = false; + let mut s_max = false; + let mut c_max = false; + + for param in res_ext.split(';') { + match param.trim() { + "permessage-deflate" => { + if name { + return Err(Error::new( + Kind::Protocol, + format!("Duplicate extension name permessage-deflate"), + )); + } else { + name = true; + } + } + "server_no_context_takeover" => { + if s_takeover { + return Err(Error::new( + Kind::Protocol, + format!("Duplicate extension parameter server_no_context_takeover"), + )); + } else { + s_takeover = true; + self.decompress_reset = true; + } + } + "client_no_context_takeover" => { + if c_takeover { + return Err(Error::new( + Kind::Protocol, + format!("Duplicate extension parameter client_no_context_takeover"), + )); + } else { + c_takeover = true; + if self.settings.accept_no_context_takeover { + self.compress_reset = true; + } else { + return Err(Error::new( + Kind::Protocol, + format!("The client requires context takeover."), + )); + } + } + } + param if param.starts_with("server_max_window_bits") => { + if s_max { + return Err(Error::new( + Kind::Protocol, + format!("Duplicate extension parameter server_max_window_bits"), + )); + } else { + s_max = true; + let mut param_iter = param.split('='); + param_iter.next(); // we already know the name + if let Some(window_bits_str) = param_iter.next() { + if let Ok(window_bits) = window_bits_str.trim().parse() { + if window_bits >= 9 && window_bits <= 15 { + if window_bits as u8 != self.settings.max_window_bits { + self.dec = Decompressor::new(window_bits); + } + } else { + return Err(Error::new( + Kind::Protocol, + format!( + "Invalid server_max_window_bits parameter: {}", + window_bits + ), + )); + } + } else { + return Err(Error::new( + Kind::Protocol, + format!( + "Invalid server_max_window_bits parameter: {}", + window_bits_str + ), + )); + } + } + } + } + param if param.starts_with("client_max_window_bits") => { + if c_max { + return Err(Error::new( + Kind::Protocol, + format!("Duplicate extension parameter client_max_window_bits"), + )); + } else { + c_max = true; + let mut param_iter = param.split('='); + param_iter.next(); // we already know the name + if let Some(window_bits_str) = param_iter.next() { + if let Ok(window_bits) = window_bits_str.trim().parse() { + if window_bits >= 9 && window_bits <= 15 { + if window_bits as u8 != self.settings.max_window_bits { + self.com = Compressor::new(window_bits); + } + } else { + return Err(Error::new( + Kind::Protocol, + format!( + "Invalid client_max_window_bits parameter: {}", + window_bits + ), + )); + } + } else { + return Err(Error::new( + Kind::Protocol, + format!( + "Invalid client_max_window_bits parameter: {}", + window_bits_str + ), + )); + } + } + } + } + param => { + // fail the connection because we got a bad parameter + return Err(Error::new( + Kind::Protocol, + format!("Bad extension parameter: {}", param), + )); + } + } + } + } else { + self.pass = true + } + + Ok(()) + } + + fn on_frame(&mut self, mut frame: Frame) -> Result<Option<Frame>> { + if !self.pass && !frame.is_control() { + if !self.fragments.is_empty() || frame.has_rsv1() { + frame.set_rsv1(false); + + if !frame.is_final() { + self.fragments.push(frame); + return Ok(None); + } else { + if frame.opcode() == OpCode::Continue { + if self.fragments.is_empty() { + return Err(Error::new( + Kind::Protocol, + "Unable to reconstruct fragmented message. No first frame.", + )); + } else { + if !self.settings.fragments_grow + && self.settings.fragments_capacity == self.fragments.len() + { + return Err(Error::new(Kind::Capacity, "Exceeded max fragments.")); + } else { + self.fragments.push(frame); + } + + // it's safe to unwrap because of the above check for empty + let opcode = self.fragments.first().unwrap().opcode(); + let size = self.fragments + .iter() + .fold(0, |len, frame| len + frame.payload().len()); + let mut compressed = Vec::with_capacity(size); + let mut decompressed = Vec::with_capacity(size * 2); + for frag in replace( + &mut self.fragments, + Vec::with_capacity(self.settings.fragments_capacity), + ) { + compressed.extend(frag.into_data()) + } + + compressed.extend(&[0, 0, 255, 255]); + self.dec.decompress(&compressed, &mut decompressed)?; + frame = Frame::message(decompressed, opcode, true); + } + } else { + let mut decompressed = Vec::with_capacity(frame.payload().len() * 2); + frame.payload_mut().extend(&[0, 0, 255, 255]); + + self.dec.decompress(frame.payload(), &mut decompressed)?; + + *frame.payload_mut() = decompressed; + } + + if self.decompress_reset { + self.dec.reset()? + } + } + } + } + self.inner.on_frame(frame) + } + + fn on_send_frame(&mut self, frame: Frame) -> Result<Option<Frame>> { + if let Some(mut frame) = self.inner.on_send_frame(frame)? { + if !self.pass && !frame.is_control() { + debug_assert!( + frame.is_final(), + "Received non-final frame from upstream handler!" + ); + debug_assert!( + frame.opcode() != OpCode::Continue, + "Received continue frame from upstream handler!" + ); + + frame.set_rsv1(true); + let mut compressed = Vec::with_capacity(frame.payload().len()); + self.com.compress(frame.payload(), &mut compressed)?; + let len = compressed.len(); + compressed.truncate(len - 4); + *frame.payload_mut() = compressed; + + if self.compress_reset { + self.com.reset()? + } + } + Ok(Some(frame)) + } else { + Ok(None) + } + } + + #[inline] + fn on_shutdown(&mut self) { + self.inner.on_shutdown() + } + + #[inline] + fn on_open(&mut self, shake: Handshake) -> Result<()> { + self.inner.on_open(shake) + } + + #[inline] + fn on_message(&mut self, msg: Message) -> Result<()> { + self.inner.on_message(msg) + } + + #[inline] + fn on_close(&mut self, code: CloseCode, reason: &str) { + self.inner.on_close(code, reason) + } + + #[inline] + fn on_error(&mut self, err: Error) { + self.inner.on_error(err) + } + + #[inline] + fn on_timeout(&mut self, event: Token) -> Result<()> { + self.inner.on_timeout(event) + } + + #[inline] + fn on_new_timeout(&mut self, tok: Token, timeout: Timeout) -> Result<()> { + self.inner.on_new_timeout(tok, timeout) + } + + #[inline] + #[cfg(any(feature = "ssl", feature = "nativetls"))] + fn upgrade_ssl_client( + &mut self, + stream: TcpStream, + url: &url::Url, + ) -> Result<SslStream<TcpStream>> { + self.inner.upgrade_ssl_client(stream, url) + } + + #[inline] + #[cfg(any(feature = "ssl", feature = "nativetls"))] + fn upgrade_ssl_server(&mut self, stream: TcpStream) -> Result<SslStream<TcpStream>> { + self.inner.upgrade_ssl_server(stream) + } +} diff --git a/third_party/rust/ws/src/deflate/mod.rs b/third_party/rust/ws/src/deflate/mod.rs new file mode 100644 index 0000000000..8d79012e73 --- /dev/null +++ b/third_party/rust/ws/src/deflate/mod.rs @@ -0,0 +1,9 @@ +//! The deflate module provides tools for applying the permessage-deflate extension. + +extern crate libc; +extern crate libz_sys as ffi; + +mod context; +mod extension; + +pub use self::extension::{DeflateBuilder, DeflateHandler, DeflateSettings}; |