summaryrefslogtreecommitdiffstats
path: root/third_party/rust/ws/src/frame.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/ws/src/frame.rs')
-rw-r--r--third_party/rust/ws/src/frame.rs495
1 files changed, 495 insertions, 0 deletions
diff --git a/third_party/rust/ws/src/frame.rs b/third_party/rust/ws/src/frame.rs
new file mode 100644
index 0000000000..154816c7ad
--- /dev/null
+++ b/third_party/rust/ws/src/frame.rs
@@ -0,0 +1,495 @@
+use std::default::Default;
+use std::fmt;
+use std::io::{Cursor, ErrorKind, Read, Write};
+
+use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
+use rand;
+
+use protocol::{CloseCode, OpCode};
+use result::{Error, Kind, Result};
+use stream::TryReadBuf;
+
+fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
+ let iter = buf.iter_mut().zip(mask.iter().cycle());
+ for (byte, &key) in iter {
+ *byte ^= key
+ }
+}
+
+/// A struct representing a WebSocket frame.
+#[derive(Debug, Clone)]
+pub struct Frame {
+ finished: bool,
+ rsv1: bool,
+ rsv2: bool,
+ rsv3: bool,
+ opcode: OpCode,
+
+ mask: Option<[u8; 4]>,
+
+ payload: Vec<u8>,
+}
+
+impl Frame {
+ /// Get the length of the frame.
+ /// This is the length of the header + the length of the payload.
+ #[inline]
+ pub fn len(&self) -> usize {
+ let mut header_length = 2;
+ let payload_len = self.payload().len();
+ if payload_len > 125 {
+ if payload_len <= u16::max_value() as usize {
+ header_length += 2;
+ } else {
+ header_length += 8;
+ }
+ }
+
+ if self.is_masked() {
+ header_length += 4;
+ }
+
+ header_length + payload_len
+ }
+
+ /// Return `false`: a frame is never empty since it has a header.
+ #[inline]
+ pub fn is_empty(&self) -> bool {
+ false
+ }
+
+ /// Test whether the frame is a final frame.
+ #[inline]
+ pub fn is_final(&self) -> bool {
+ self.finished
+ }
+
+ /// Test whether the first reserved bit is set.
+ #[inline]
+ pub fn has_rsv1(&self) -> bool {
+ self.rsv1
+ }
+
+ /// Test whether the second reserved bit is set.
+ #[inline]
+ pub fn has_rsv2(&self) -> bool {
+ self.rsv2
+ }
+
+ /// Test whether the third reserved bit is set.
+ #[inline]
+ pub fn has_rsv3(&self) -> bool {
+ self.rsv3
+ }
+
+ /// Get the OpCode of the frame.
+ #[inline]
+ pub fn opcode(&self) -> OpCode {
+ self.opcode
+ }
+
+ /// Test whether this is a control frame.
+ #[inline]
+ pub fn is_control(&self) -> bool {
+ self.opcode.is_control()
+ }
+
+ /// Get a reference to the frame's payload.
+ #[inline]
+ pub fn payload(&self) -> &Vec<u8> {
+ &self.payload
+ }
+
+ // Test whether the frame is masked.
+ #[doc(hidden)]
+ #[inline]
+ pub fn is_masked(&self) -> bool {
+ self.mask.is_some()
+ }
+
+ // Get an optional reference to the frame's mask.
+ #[doc(hidden)]
+ #[allow(dead_code)]
+ #[inline]
+ pub fn mask(&self) -> Option<&[u8; 4]> {
+ self.mask.as_ref()
+ }
+
+ /// Make this frame a final frame.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn set_final(&mut self, is_final: bool) -> &mut Frame {
+ self.finished = is_final;
+ self
+ }
+
+ /// Set the first reserved bit.
+ #[inline]
+ pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame {
+ self.rsv1 = has_rsv1;
+ self
+ }
+
+ /// Set the second reserved bit.
+ #[inline]
+ pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame {
+ self.rsv2 = has_rsv2;
+ self
+ }
+
+ /// Set the third reserved bit.
+ #[inline]
+ pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame {
+ self.rsv3 = has_rsv3;
+ self
+ }
+
+ /// Set the OpCode.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame {
+ self.opcode = opcode;
+ self
+ }
+
+ /// Edit the frame's payload.
+ #[allow(dead_code)]
+ #[inline]
+ pub fn payload_mut(&mut self) -> &mut Vec<u8> {
+ &mut self.payload
+ }
+
+ // Generate a new mask for this frame.
+ //
+ // This method simply generates and stores the mask. It does not change the payload data.
+ // Instead, the payload data will be masked with the generated mask when the frame is sent
+ // to the other endpoint.
+ #[doc(hidden)]
+ #[inline]
+ pub fn set_mask(&mut self) -> &mut Frame {
+ self.mask = Some(rand::random());
+ self
+ }
+
+ // This method unmasks the payload and should only be called on frames that are actually
+ // masked. In other words, those frames that have just been received from a client endpoint.
+ #[doc(hidden)]
+ #[inline]
+ pub fn remove_mask(&mut self) -> &mut Frame {
+ self.mask
+ .take()
+ .map(|mask| apply_mask(&mut self.payload, &mask));
+ self
+ }
+
+ /// Consume the frame into its payload.
+ pub fn into_data(self) -> Vec<u8> {
+ self.payload
+ }
+
+ /// Create a new data frame.
+ #[inline]
+ pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame {
+ debug_assert!(
+ match code {
+ OpCode::Text | OpCode::Binary | OpCode::Continue => true,
+ _ => false,
+ },
+ "Invalid opcode for data frame."
+ );
+
+ Frame {
+ finished,
+ opcode: code,
+ payload: data,
+ ..Frame::default()
+ }
+ }
+
+ /// Create a new Pong control frame.
+ #[inline]
+ pub fn pong(data: Vec<u8>) -> Frame {
+ Frame {
+ opcode: OpCode::Pong,
+ payload: data,
+ ..Frame::default()
+ }
+ }
+
+ /// Create a new Ping control frame.
+ #[inline]
+ pub fn ping(data: Vec<u8>) -> Frame {
+ Frame {
+ opcode: OpCode::Ping,
+ payload: data,
+ ..Frame::default()
+ }
+ }
+
+ /// Create a new Close control frame.
+ #[inline]
+ pub fn close(code: CloseCode, reason: &str) -> Frame {
+ let payload = if let CloseCode::Empty = code {
+ Vec::new()
+ } else {
+ let u: u16 = code.into();
+ let raw = [(u >> 8) as u8, u as u8];
+ [&raw, reason.as_bytes()].concat()
+ };
+
+ Frame {
+ payload,
+ ..Frame::default()
+ }
+ }
+
+ /// Parse the input stream into a frame.
+ pub fn parse(cursor: &mut Cursor<Vec<u8>>, max_payload_length: u64) -> Result<Option<Frame>> {
+ let size = cursor.get_ref().len() as u64 - cursor.position();
+ let initial = cursor.position();
+ trace!("Position in buffer {}", initial);
+
+ let mut head = [0u8; 2];
+ if cursor.read(&mut head)? != 2 {
+ cursor.set_position(initial);
+ return Ok(None);
+ }
+
+ trace!("Parsed headers {:?}", head);
+
+ let first = head[0];
+ let second = head[1];
+ trace!("First: {:b}", first);
+ trace!("Second: {:b}", second);
+
+ let finished = first & 0x80 != 0;
+
+ let rsv1 = first & 0x40 != 0;
+ let rsv2 = first & 0x20 != 0;
+ let rsv3 = first & 0x10 != 0;
+
+ let opcode = OpCode::from(first & 0x0F);
+ trace!("Opcode: {:?}", opcode);
+
+ let masked = second & 0x80 != 0;
+ trace!("Masked: {:?}", masked);
+
+ let mut header_length = 2;
+
+ let mut length = u64::from(second & 0x7F);
+
+ if let Some(length_nbytes) = match length {
+ 126 => Some(2),
+ 127 => Some(8),
+ _ => None,
+ } {
+ match cursor.read_uint::<BigEndian>(length_nbytes) {
+ Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
+ cursor.set_position(initial);
+ return Ok(None);
+ }
+ Err(err) => {
+ return Err(Error::from(err));
+ }
+ Ok(read) => {
+ length = read;
+ }
+ };
+ header_length += length_nbytes as u64;
+ }
+ trace!("Payload length: {}", length);
+
+ if length > max_payload_length {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!(
+ "Rejected frame with payload length exceeding defined max: {}.",
+ max_payload_length
+ ),
+ ));
+ }
+
+ let mask = if masked {
+ let mut mask_bytes = [0u8; 4];
+ if cursor.read(&mut mask_bytes)? != 4 {
+ cursor.set_position(initial);
+ return Ok(None);
+ } else {
+ header_length += 4;
+ Some(mask_bytes)
+ }
+ } else {
+ None
+ };
+
+ match length.checked_add(header_length) {
+ Some(l) if size < l => {
+ cursor.set_position(initial);
+ return Ok(None);
+ }
+ Some(_) => (),
+ None => return Ok(None),
+ };
+
+ let mut data = Vec::with_capacity(length as usize);
+ if length > 0 {
+ if let Some(read) = cursor.try_read_buf(&mut data)? {
+ debug_assert!(read == length as usize, "Read incorrect payload length!");
+ }
+ }
+
+ // Disallow bad opcode
+ if let OpCode::Bad = opcode {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!("Encountered invalid opcode: {}", first & 0x0F),
+ ));
+ }
+
+ // control frames must have length <= 125
+ match opcode {
+ OpCode::Ping | OpCode::Pong if length > 125 => {
+ return Err(Error::new(
+ Kind::Protocol,
+ format!(
+ "Rejected WebSocket handshake.Received control frame with length: {}.",
+ length
+ ),
+ ))
+ }
+ OpCode::Close if length > 125 => {
+ debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
+ return Ok(Some(Frame::close(
+ CloseCode::Protocol,
+ "Received close frame with payload length exceeding 125.",
+ )));
+ }
+ _ => (),
+ }
+
+ let frame = Frame {
+ finished,
+ rsv1,
+ rsv2,
+ rsv3,
+ opcode,
+ mask,
+ payload: data,
+ };
+
+ Ok(Some(frame))
+ }
+
+ /// Write a frame out to a buffer
+ pub fn format<W>(&mut self, w: &mut W) -> Result<()>
+ where
+ W: Write,
+ {
+ let mut one = 0u8;
+ let code: u8 = self.opcode.into();
+ if self.is_final() {
+ one |= 0x80;
+ }
+ if self.has_rsv1() {
+ one |= 0x40;
+ }
+ if self.has_rsv2() {
+ one |= 0x20;
+ }
+ if self.has_rsv3() {
+ one |= 0x10;
+ }
+ one |= code;
+
+ let mut two = 0u8;
+ if self.is_masked() {
+ two |= 0x80;
+ }
+
+ match self.payload.len() {
+ len if len < 126 => {
+ two |= len as u8;
+ }
+ len if len <= 65535 => {
+ two |= 126;
+ }
+ _ => {
+ two |= 127;
+ }
+ }
+ w.write_all(&[one, two])?;
+
+ if let Some(length_bytes) = match self.payload.len() {
+ len if len < 126 => None,
+ len if len <= 65535 => Some(2),
+ _ => Some(8),
+ } {
+ w.write_uint::<BigEndian>(self.payload.len() as u64, length_bytes)?;
+ }
+
+ if self.is_masked() {
+ let mask = self.mask.take().unwrap();
+ apply_mask(&mut self.payload, &mask);
+ w.write_all(&mask)?;
+ }
+
+ w.write_all(&self.payload)?;
+ Ok(())
+ }
+}
+
+impl Default for Frame {
+ fn default() -> Frame {
+ Frame {
+ finished: true,
+ rsv1: false,
+ rsv2: false,
+ rsv3: false,
+ opcode: OpCode::Close,
+ mask: None,
+ payload: Vec::new(),
+ }
+ }
+}
+
+impl fmt::Display for Frame {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "
+<FRAME>
+final: {}
+reserved: {} {} {}
+opcode: {}
+length: {}
+payload length: {}
+payload: 0x{}
+ ",
+ self.finished,
+ self.rsv1,
+ self.rsv2,
+ self.rsv3,
+ self.opcode,
+ // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
+ self.len(),
+ self.payload.len(),
+ self.payload
+ .iter()
+ .map(|byte| format!("{:x}", byte))
+ .collect::<String>()
+ )
+ }
+}
+
+mod test {
+ #![allow(unused_imports, unused_variables, dead_code)]
+ use super::*;
+ use protocol::OpCode;
+
+ #[test]
+ fn display_frame() {
+ let f = Frame::message("hi there".into(), OpCode::Text, true);
+ let view = format!("{}", f);
+ view.contains("payload:");
+ }
+}