summaryrefslogtreecommitdiffstats
path: root/test/pyhttpd/ws_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/pyhttpd/ws_util.py')
-rw-r--r--test/pyhttpd/ws_util.py137
1 files changed, 137 insertions, 0 deletions
diff --git a/test/pyhttpd/ws_util.py b/test/pyhttpd/ws_util.py
new file mode 100644
index 0000000..38a3cf7
--- /dev/null
+++ b/test/pyhttpd/ws_util.py
@@ -0,0 +1,137 @@
+import logging
+import struct
+
+
+log = logging.getLogger(__name__)
+
+
+class WsFrame:
+
+ CONT = 0
+ TEXT = 1
+ BINARY = 2
+ RSVD3 = 3
+ RSVD4 = 4
+ RSVD5 = 5
+ RSVD6 = 6
+ RSVD7 = 7
+ CLOSE = 8
+ PING = 9
+ PONG = 10
+ RSVD11 = 11
+ RSVD12 = 12
+ RSVD13 = 13
+ RSVD14 = 14
+ RSVD15 = 15
+
+ OP_NAMES = [
+ "CONT",
+ "TEXT",
+ "BINARY",
+ "RSVD3",
+ "RSVD4",
+ "RSVD5",
+ "RSVD6",
+ "RSVD7",
+ "CLOSE",
+ "PING",
+ "PONG",
+ "RSVD11",
+ "RSVD12",
+ "RSVD13",
+ "RSVD14",
+ "RSVD15",
+ ]
+
+ def __init__(self, opcode: int, fin: bool, mask: bytes, data: bytes):
+ self.opcode = opcode
+ self.fin = fin
+ self.mask = mask
+ self.data = data
+ self.length = len(data)
+
+ def __repr__(self):
+ return f'WsFrame[{self.OP_NAMES[self.opcode]} fin={self.fin}, mask={self.mask}, len={len(self.data)}]'
+
+ @property
+ def data_len(self) -> int:
+ return len(self.data) if self.data else 0
+
+ def to_network(self) -> bytes:
+ nd = bytearray()
+ h1 = self.opcode
+ if self.fin:
+ h1 |= 0x80
+ nd.extend(struct.pack("!B", h1))
+ mask_bit = 0x80 if self.mask is not None else 0x0
+ h2 = self.data_len
+ if h2 > 65535:
+ nd.extend(struct.pack("!BQ", 127|mask_bit, h2))
+ elif h2 > 126:
+ nd.extend(struct.pack("!BH", 126|mask_bit, h2))
+ else:
+ nd.extend(struct.pack("!B", h2|mask_bit))
+ if self.mask is not None:
+ nd.extend(self.mask)
+ if self.data is not None:
+ nd.extend(self.data)
+ return nd
+
+ @classmethod
+ def client_ping(cls, data: bytes, mask: bytes = None) -> 'WsFrame':
+ if mask is None:
+ mask = bytes.fromhex('00 00 00 00')
+ return WsFrame(opcode=WsFrame.PING, fin=True, mask=mask, data=data)
+
+ @classmethod
+ def client_close(cls, code: int, reason: str = None,
+ mask: bytes = None) -> 'WsFrame':
+ data = bytearray(struct.pack("!H", code))
+ if reason is not None:
+ data.extend(reason.encode())
+ if mask is None:
+ mask = bytes.fromhex('00 00 00 00')
+ return WsFrame(opcode=WsFrame.CLOSE, fin=True, mask=mask, data=data)
+
+
+class WsFrameReader:
+
+ def __init__(self, data: bytes):
+ self.data = data
+
+ def _read(self, n: int):
+ if len(self.data) < n:
+ raise EOFError(f'have {len(self.data)} bytes left, but {n} requested')
+ elif n == 0:
+ return b''
+ chunk = self.data[:n]
+ del self.data[:n]
+ return chunk
+
+ def next_frame(self):
+ data = self._read(2)
+ h1, h2 = struct.unpack("!BB", data)
+ log.debug(f'parsed h1={h1} h2={h2} from {data}')
+ fin = True if h1 & 0x80 else False
+ opcode = h1 & 0xf
+ has_mask = True if h2 & 0x80 else False
+ mask = None
+ dlen = h2 & 0x7f
+ if dlen == 126:
+ (dlen,) = struct.unpack("!H", self._read(2))
+ elif dlen == 127:
+ (dlen,) = struct.unpack("!Q", self._read(8))
+ if has_mask:
+ mask = self._read(4)
+ return WsFrame(opcode=opcode, fin=fin, mask=mask, data=self._read(dlen))
+
+ def eof(self):
+ return len(self.data) == 0
+
+ @classmethod
+ def parse(cls, data: bytes):
+ frames = []
+ reader = WsFrameReader(data=data)
+ while not reader.eof():
+ frames.append(reader.next_frame())
+ return frames