diff options
Diffstat (limited to 'third_party/python/aiohttp/aiohttp/http_parser.py')
-rw-r--r-- | third_party/python/aiohttp/aiohttp/http_parser.py | 901 |
1 files changed, 901 insertions, 0 deletions
diff --git a/third_party/python/aiohttp/aiohttp/http_parser.py b/third_party/python/aiohttp/aiohttp/http_parser.py new file mode 100644 index 0000000000..71ba815ae6 --- /dev/null +++ b/third_party/python/aiohttp/aiohttp/http_parser.py @@ -0,0 +1,901 @@ +import abc +import asyncio +import collections +import re +import string +import zlib +from enum import IntEnum +from typing import Any, List, Optional, Tuple, Type, Union + +from multidict import CIMultiDict, CIMultiDictProxy, istr +from yarl import URL + +from . import hdrs +from .base_protocol import BaseProtocol +from .helpers import NO_EXTENSIONS, BaseTimerContext +from .http_exceptions import ( + BadStatusLine, + ContentEncodingError, + ContentLengthError, + InvalidHeader, + LineTooLong, + TransferEncodingError, +) +from .http_writer import HttpVersion, HttpVersion10 +from .log import internal_logger +from .streams import EMPTY_PAYLOAD, StreamReader +from .typedefs import RawHeaders + +try: + import brotli + + HAS_BROTLI = True +except ImportError: # pragma: no cover + HAS_BROTLI = False + + +__all__ = ( + "HeadersParser", + "HttpParser", + "HttpRequestParser", + "HttpResponseParser", + "RawRequestMessage", + "RawResponseMessage", +) + +ASCIISET = set(string.printable) + +# See https://tools.ietf.org/html/rfc7230#section-3.1.1 +# and https://tools.ietf.org/html/rfc7230#appendix-B +# +# method = token +# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / +# "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA +# token = 1*tchar +METHRE = re.compile(r"[!#$%&'*+\-.^_`|~0-9A-Za-z]+") +VERSRE = re.compile(r"HTTP/(\d+).(\d+)") +HDRRE = re.compile(rb"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]") + +RawRequestMessage = collections.namedtuple( + "RawRequestMessage", + [ + "method", + "path", + "version", + "headers", + "raw_headers", + "should_close", + "compression", + "upgrade", + "chunked", + "url", + ], +) + +RawResponseMessage = collections.namedtuple( + "RawResponseMessage", + [ + "version", + "code", + "reason", + "headers", + "raw_headers", + "should_close", + "compression", + "upgrade", + "chunked", + ], +) + + +class ParseState(IntEnum): + + PARSE_NONE = 0 + PARSE_LENGTH = 1 + PARSE_CHUNKED = 2 + PARSE_UNTIL_EOF = 3 + + +class ChunkState(IntEnum): + PARSE_CHUNKED_SIZE = 0 + PARSE_CHUNKED_CHUNK = 1 + PARSE_CHUNKED_CHUNK_EOF = 2 + PARSE_MAYBE_TRAILERS = 3 + PARSE_TRAILERS = 4 + + +class HeadersParser: + def __init__( + self, + max_line_size: int = 8190, + max_headers: int = 32768, + max_field_size: int = 8190, + ) -> None: + self.max_line_size = max_line_size + self.max_headers = max_headers + self.max_field_size = max_field_size + + def parse_headers( + self, lines: List[bytes] + ) -> Tuple["CIMultiDictProxy[str]", RawHeaders]: + headers = CIMultiDict() # type: CIMultiDict[str] + raw_headers = [] + + lines_idx = 1 + line = lines[1] + line_count = len(lines) + + while line: + # Parse initial header name : value pair. + try: + bname, bvalue = line.split(b":", 1) + except ValueError: + raise InvalidHeader(line) from None + + bname = bname.strip(b" \t") + bvalue = bvalue.lstrip() + if HDRRE.search(bname): + raise InvalidHeader(bname) + if len(bname) > self.max_field_size: + raise LineTooLong( + "request header name {}".format( + bname.decode("utf8", "xmlcharrefreplace") + ), + str(self.max_field_size), + str(len(bname)), + ) + + header_length = len(bvalue) + + # next line + lines_idx += 1 + line = lines[lines_idx] + + # consume continuation lines + continuation = line and line[0] in (32, 9) # (' ', '\t') + + if continuation: + bvalue_lst = [bvalue] + while continuation: + header_length += len(line) + if header_length > self.max_field_size: + raise LineTooLong( + "request header field {}".format( + bname.decode("utf8", "xmlcharrefreplace") + ), + str(self.max_field_size), + str(header_length), + ) + bvalue_lst.append(line) + + # next line + lines_idx += 1 + if lines_idx < line_count: + line = lines[lines_idx] + if line: + continuation = line[0] in (32, 9) # (' ', '\t') + else: + line = b"" + break + bvalue = b"".join(bvalue_lst) + else: + if header_length > self.max_field_size: + raise LineTooLong( + "request header field {}".format( + bname.decode("utf8", "xmlcharrefreplace") + ), + str(self.max_field_size), + str(header_length), + ) + + bvalue = bvalue.strip() + name = bname.decode("utf-8", "surrogateescape") + value = bvalue.decode("utf-8", "surrogateescape") + + headers.add(name, value) + raw_headers.append((bname, bvalue)) + + return (CIMultiDictProxy(headers), tuple(raw_headers)) + + +class HttpParser(abc.ABC): + def __init__( + self, + protocol: Optional[BaseProtocol] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + limit: int = 2 ** 16, + max_line_size: int = 8190, + max_headers: int = 32768, + max_field_size: int = 8190, + timer: Optional[BaseTimerContext] = None, + code: Optional[int] = None, + method: Optional[str] = None, + readall: bool = False, + payload_exception: Optional[Type[BaseException]] = None, + response_with_body: bool = True, + read_until_eof: bool = False, + auto_decompress: bool = True, + ) -> None: + self.protocol = protocol + self.loop = loop + self.max_line_size = max_line_size + self.max_headers = max_headers + self.max_field_size = max_field_size + self.timer = timer + self.code = code + self.method = method + self.readall = readall + self.payload_exception = payload_exception + self.response_with_body = response_with_body + self.read_until_eof = read_until_eof + + self._lines = [] # type: List[bytes] + self._tail = b"" + self._upgraded = False + self._payload = None + self._payload_parser = None # type: Optional[HttpPayloadParser] + self._auto_decompress = auto_decompress + self._limit = limit + self._headers_parser = HeadersParser(max_line_size, max_headers, max_field_size) + + @abc.abstractmethod + def parse_message(self, lines: List[bytes]) -> Any: + pass + + def feed_eof(self) -> Any: + if self._payload_parser is not None: + self._payload_parser.feed_eof() + self._payload_parser = None + else: + # try to extract partial message + if self._tail: + self._lines.append(self._tail) + + if self._lines: + if self._lines[-1] != "\r\n": + self._lines.append(b"") + try: + return self.parse_message(self._lines) + except Exception: + return None + + def feed_data( + self, + data: bytes, + SEP: bytes = b"\r\n", + EMPTY: bytes = b"", + CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH, + METH_CONNECT: str = hdrs.METH_CONNECT, + SEC_WEBSOCKET_KEY1: istr = hdrs.SEC_WEBSOCKET_KEY1, + ) -> Tuple[List[Any], bool, bytes]: + + messages = [] + + if self._tail: + data, self._tail = self._tail + data, b"" + + data_len = len(data) + start_pos = 0 + loop = self.loop + + while start_pos < data_len: + + # read HTTP message (request/response line + headers), \r\n\r\n + # and split by lines + if self._payload_parser is None and not self._upgraded: + pos = data.find(SEP, start_pos) + # consume \r\n + if pos == start_pos and not self._lines: + start_pos = pos + 2 + continue + + if pos >= start_pos: + # line found + self._lines.append(data[start_pos:pos]) + start_pos = pos + 2 + + # \r\n\r\n found + if self._lines[-1] == EMPTY: + try: + msg = self.parse_message(self._lines) + finally: + self._lines.clear() + + # payload length + length = msg.headers.get(CONTENT_LENGTH) + if length is not None: + try: + length = int(length) + except ValueError: + raise InvalidHeader(CONTENT_LENGTH) + if length < 0: + raise InvalidHeader(CONTENT_LENGTH) + + # do not support old websocket spec + if SEC_WEBSOCKET_KEY1 in msg.headers: + raise InvalidHeader(SEC_WEBSOCKET_KEY1) + + self._upgraded = msg.upgrade + + method = getattr(msg, "method", self.method) + + assert self.protocol is not None + # calculate payload + if ( + (length is not None and length > 0) + or msg.chunked + and not msg.upgrade + ): + payload = StreamReader( + self.protocol, + timer=self.timer, + loop=loop, + limit=self._limit, + ) + payload_parser = HttpPayloadParser( + payload, + length=length, + chunked=msg.chunked, + method=method, + compression=msg.compression, + code=self.code, + readall=self.readall, + response_with_body=self.response_with_body, + auto_decompress=self._auto_decompress, + ) + if not payload_parser.done: + self._payload_parser = payload_parser + elif method == METH_CONNECT: + payload = StreamReader( + self.protocol, + timer=self.timer, + loop=loop, + limit=self._limit, + ) + self._upgraded = True + self._payload_parser = HttpPayloadParser( + payload, + method=msg.method, + compression=msg.compression, + readall=True, + auto_decompress=self._auto_decompress, + ) + else: + if ( + getattr(msg, "code", 100) >= 199 + and length is None + and self.read_until_eof + ): + payload = StreamReader( + self.protocol, + timer=self.timer, + loop=loop, + limit=self._limit, + ) + payload_parser = HttpPayloadParser( + payload, + length=length, + chunked=msg.chunked, + method=method, + compression=msg.compression, + code=self.code, + readall=True, + response_with_body=self.response_with_body, + auto_decompress=self._auto_decompress, + ) + if not payload_parser.done: + self._payload_parser = payload_parser + else: + payload = EMPTY_PAYLOAD # type: ignore + + messages.append((msg, payload)) + else: + self._tail = data[start_pos:] + data = EMPTY + break + + # no parser, just store + elif self._payload_parser is None and self._upgraded: + assert not self._lines + break + + # feed payload + elif data and start_pos < data_len: + assert not self._lines + assert self._payload_parser is not None + try: + eof, data = self._payload_parser.feed_data(data[start_pos:]) + except BaseException as exc: + if self.payload_exception is not None: + self._payload_parser.payload.set_exception( + self.payload_exception(str(exc)) + ) + else: + self._payload_parser.payload.set_exception(exc) + + eof = True + data = b"" + + if eof: + start_pos = 0 + data_len = len(data) + self._payload_parser = None + continue + else: + break + + if data and start_pos < data_len: + data = data[start_pos:] + else: + data = EMPTY + + return messages, self._upgraded, data + + def parse_headers( + self, lines: List[bytes] + ) -> Tuple[ + "CIMultiDictProxy[str]", RawHeaders, Optional[bool], Optional[str], bool, bool + ]: + """Parses RFC 5322 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + headers, raw_headers = self._headers_parser.parse_headers(lines) + close_conn = None + encoding = None + upgrade = False + chunked = False + + # keep-alive + conn = headers.get(hdrs.CONNECTION) + if conn: + v = conn.lower() + if v == "close": + close_conn = True + elif v == "keep-alive": + close_conn = False + elif v == "upgrade": + upgrade = True + + # encoding + enc = headers.get(hdrs.CONTENT_ENCODING) + if enc: + enc = enc.lower() + if enc in ("gzip", "deflate", "br"): + encoding = enc + + # chunking + te = headers.get(hdrs.TRANSFER_ENCODING) + if te and "chunked" in te.lower(): + chunked = True + + return (headers, raw_headers, close_conn, encoding, upgrade, chunked) + + def set_upgraded(self, val: bool) -> None: + """Set connection upgraded (to websocket) mode. + :param bool val: new state. + """ + self._upgraded = val + + +class HttpRequestParser(HttpParser): + """Read request status line. Exception .http_exceptions.BadStatusLine + could be raised in case of any errors in status line. + Returns RawRequestMessage. + """ + + def parse_message(self, lines: List[bytes]) -> Any: + # request line + line = lines[0].decode("utf-8", "surrogateescape") + try: + method, path, version = line.split(None, 2) + except ValueError: + raise BadStatusLine(line) from None + + if len(path) > self.max_line_size: + raise LineTooLong( + "Status line is too long", str(self.max_line_size), str(len(path)) + ) + + path_part, _hash_separator, url_fragment = path.partition("#") + path_part, _question_mark_separator, qs_part = path_part.partition("?") + + # method + if not METHRE.match(method): + raise BadStatusLine(method) + + # version + try: + if version.startswith("HTTP/"): + n1, n2 = version[5:].split(".", 1) + version_o = HttpVersion(int(n1), int(n2)) + else: + raise BadStatusLine(version) + except Exception: + raise BadStatusLine(version) + + # read headers + ( + headers, + raw_headers, + close, + compression, + upgrade, + chunked, + ) = self.parse_headers(lines) + + if close is None: # then the headers weren't set in the request + if version_o <= HttpVersion10: # HTTP 1.0 must asks to not close + close = True + else: # HTTP 1.1 must ask to close. + close = False + + return RawRequestMessage( + method, + path, + version_o, + headers, + raw_headers, + close, + compression, + upgrade, + chunked, + # NOTE: `yarl.URL.build()` is used to mimic what the Cython-based + # NOTE: parser does, otherwise it results into the same + # NOTE: HTTP Request-Line input producing different + # NOTE: `yarl.URL()` objects + URL.build( + path=path_part, + query_string=qs_part, + fragment=url_fragment, + encoded=True, + ), + ) + + +class HttpResponseParser(HttpParser): + """Read response status line and headers. + + BadStatusLine could be raised in case of any errors in status line. + Returns RawResponseMessage""" + + def parse_message(self, lines: List[bytes]) -> Any: + line = lines[0].decode("utf-8", "surrogateescape") + try: + version, status = line.split(None, 1) + except ValueError: + raise BadStatusLine(line) from None + + try: + status, reason = status.split(None, 1) + except ValueError: + reason = "" + + if len(reason) > self.max_line_size: + raise LineTooLong( + "Status line is too long", str(self.max_line_size), str(len(reason)) + ) + + # version + match = VERSRE.match(version) + if match is None: + raise BadStatusLine(line) + version_o = HttpVersion(int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit number + try: + status_i = int(status) + except ValueError: + raise BadStatusLine(line) from None + + if status_i > 999: + raise BadStatusLine(line) + + # read headers + ( + headers, + raw_headers, + close, + compression, + upgrade, + chunked, + ) = self.parse_headers(lines) + + if close is None: + close = version_o <= HttpVersion10 + + return RawResponseMessage( + version_o, + status_i, + reason.strip(), + headers, + raw_headers, + close, + compression, + upgrade, + chunked, + ) + + +class HttpPayloadParser: + def __init__( + self, + payload: StreamReader, + length: Optional[int] = None, + chunked: bool = False, + compression: Optional[str] = None, + code: Optional[int] = None, + method: Optional[str] = None, + readall: bool = False, + response_with_body: bool = True, + auto_decompress: bool = True, + ) -> None: + self._length = 0 + self._type = ParseState.PARSE_NONE + self._chunk = ChunkState.PARSE_CHUNKED_SIZE + self._chunk_size = 0 + self._chunk_tail = b"" + self._auto_decompress = auto_decompress + self.done = False + + # payload decompression wrapper + if response_with_body and compression and self._auto_decompress: + real_payload = DeflateBuffer( + payload, compression + ) # type: Union[StreamReader, DeflateBuffer] + else: + real_payload = payload + + # payload parser + if not response_with_body: + # don't parse payload if it's not expected to be received + self._type = ParseState.PARSE_NONE + real_payload.feed_eof() + self.done = True + + elif chunked: + self._type = ParseState.PARSE_CHUNKED + elif length is not None: + self._type = ParseState.PARSE_LENGTH + self._length = length + if self._length == 0: + real_payload.feed_eof() + self.done = True + else: + if readall and code != 204: + self._type = ParseState.PARSE_UNTIL_EOF + elif method in ("PUT", "POST"): + internal_logger.warning( # pragma: no cover + "Content-Length or Transfer-Encoding header is required" + ) + self._type = ParseState.PARSE_NONE + real_payload.feed_eof() + self.done = True + + self.payload = real_payload + + def feed_eof(self) -> None: + if self._type == ParseState.PARSE_UNTIL_EOF: + self.payload.feed_eof() + elif self._type == ParseState.PARSE_LENGTH: + raise ContentLengthError( + "Not enough data for satisfy content length header." + ) + elif self._type == ParseState.PARSE_CHUNKED: + raise TransferEncodingError( + "Not enough data for satisfy transfer length header." + ) + + def feed_data( + self, chunk: bytes, SEP: bytes = b"\r\n", CHUNK_EXT: bytes = b";" + ) -> Tuple[bool, bytes]: + # Read specified amount of bytes + if self._type == ParseState.PARSE_LENGTH: + required = self._length + chunk_len = len(chunk) + + if required >= chunk_len: + self._length = required - chunk_len + self.payload.feed_data(chunk, chunk_len) + if self._length == 0: + self.payload.feed_eof() + return True, b"" + else: + self._length = 0 + self.payload.feed_data(chunk[:required], required) + self.payload.feed_eof() + return True, chunk[required:] + + # Chunked transfer encoding parser + elif self._type == ParseState.PARSE_CHUNKED: + if self._chunk_tail: + chunk = self._chunk_tail + chunk + self._chunk_tail = b"" + + while chunk: + + # read next chunk size + if self._chunk == ChunkState.PARSE_CHUNKED_SIZE: + pos = chunk.find(SEP) + if pos >= 0: + i = chunk.find(CHUNK_EXT, 0, pos) + if i >= 0: + size_b = chunk[:i] # strip chunk-extensions + else: + size_b = chunk[:pos] + + try: + size = int(bytes(size_b), 16) + except ValueError: + exc = TransferEncodingError( + chunk[:pos].decode("ascii", "surrogateescape") + ) + self.payload.set_exception(exc) + raise exc from None + + chunk = chunk[pos + 2 :] + if size == 0: # eof marker + self._chunk = ChunkState.PARSE_MAYBE_TRAILERS + else: + self._chunk = ChunkState.PARSE_CHUNKED_CHUNK + self._chunk_size = size + self.payload.begin_http_chunk_receiving() + else: + self._chunk_tail = chunk + return False, b"" + + # read chunk and feed buffer + if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK: + required = self._chunk_size + chunk_len = len(chunk) + + if required > chunk_len: + self._chunk_size = required - chunk_len + self.payload.feed_data(chunk, chunk_len) + return False, b"" + else: + self._chunk_size = 0 + self.payload.feed_data(chunk[:required], required) + chunk = chunk[required:] + self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF + self.payload.end_http_chunk_receiving() + + # toss the CRLF at the end of the chunk + if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK_EOF: + if chunk[:2] == SEP: + chunk = chunk[2:] + self._chunk = ChunkState.PARSE_CHUNKED_SIZE + else: + self._chunk_tail = chunk + return False, b"" + + # if stream does not contain trailer, after 0\r\n + # we should get another \r\n otherwise + # trailers needs to be skiped until \r\n\r\n + if self._chunk == ChunkState.PARSE_MAYBE_TRAILERS: + head = chunk[:2] + if head == SEP: + # end of stream + self.payload.feed_eof() + return True, chunk[2:] + # Both CR and LF, or only LF may not be received yet. It is + # expected that CRLF or LF will be shown at the very first + # byte next time, otherwise trailers should come. The last + # CRLF which marks the end of response might not be + # contained in the same TCP segment which delivered the + # size indicator. + if not head: + return False, b"" + if head == SEP[:1]: + self._chunk_tail = head + return False, b"" + self._chunk = ChunkState.PARSE_TRAILERS + + # read and discard trailer up to the CRLF terminator + if self._chunk == ChunkState.PARSE_TRAILERS: + pos = chunk.find(SEP) + if pos >= 0: + chunk = chunk[pos + 2 :] + self._chunk = ChunkState.PARSE_MAYBE_TRAILERS + else: + self._chunk_tail = chunk + return False, b"" + + # Read all bytes until eof + elif self._type == ParseState.PARSE_UNTIL_EOF: + self.payload.feed_data(chunk, len(chunk)) + + return False, b"" + + +class DeflateBuffer: + """DeflateStream decompress stream and feed data into specified stream.""" + + def __init__(self, out: StreamReader, encoding: Optional[str]) -> None: + self.out = out + self.size = 0 + self.encoding = encoding + self._started_decoding = False + + if encoding == "br": + if not HAS_BROTLI: # pragma: no cover + raise ContentEncodingError( + "Can not decode content-encoding: brotli (br). " + "Please install `brotlipy`" + ) + self.decompressor = brotli.Decompressor() + else: + zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS + self.decompressor = zlib.decompressobj(wbits=zlib_mode) + + def set_exception(self, exc: BaseException) -> None: + self.out.set_exception(exc) + + def feed_data(self, chunk: bytes, size: int) -> None: + if not size: + return + + self.size += size + + # RFC1950 + # bits 0..3 = CM = 0b1000 = 8 = "deflate" + # bits 4..7 = CINFO = 1..7 = windows size. + if ( + not self._started_decoding + and self.encoding == "deflate" + and chunk[0] & 0xF != 8 + ): + # Change the decoder to decompress incorrectly compressed data + # Actually we should issue a warning about non-RFC-compliant data. + self.decompressor = zlib.decompressobj(wbits=-zlib.MAX_WBITS) + + try: + chunk = self.decompressor.decompress(chunk) + except Exception: + raise ContentEncodingError( + "Can not decode content-encoding: %s" % self.encoding + ) + + self._started_decoding = True + + if chunk: + self.out.feed_data(chunk, len(chunk)) + + def feed_eof(self) -> None: + chunk = self.decompressor.flush() + + if chunk or self.size > 0: + self.out.feed_data(chunk, len(chunk)) + if self.encoding == "deflate" and not self.decompressor.eof: + raise ContentEncodingError("deflate") + + self.out.feed_eof() + + def begin_http_chunk_receiving(self) -> None: + self.out.begin_http_chunk_receiving() + + def end_http_chunk_receiving(self) -> None: + self.out.end_http_chunk_receiving() + + +HttpRequestParserPy = HttpRequestParser +HttpResponseParserPy = HttpResponseParser +RawRequestMessagePy = RawRequestMessage +RawResponseMessagePy = RawResponseMessage + +try: + if not NO_EXTENSIONS: + from ._http_parser import ( # type: ignore + HttpRequestParser, + HttpResponseParser, + RawRequestMessage, + RawResponseMessage, + ) + + HttpRequestParserC = HttpRequestParser + HttpResponseParserC = HttpResponseParser + RawRequestMessageC = RawRequestMessage + RawResponseMessageC = RawResponseMessage +except ImportError: # pragma: no cover + pass |