diff options
Diffstat (limited to 'third_party/python/aiohttp/aiohttp/web_protocol.py')
-rw-r--r-- | third_party/python/aiohttp/aiohttp/web_protocol.py | 667 |
1 files changed, 667 insertions, 0 deletions
diff --git a/third_party/python/aiohttp/aiohttp/web_protocol.py b/third_party/python/aiohttp/aiohttp/web_protocol.py new file mode 100644 index 0000000000..8e02bc4aab --- /dev/null +++ b/third_party/python/aiohttp/aiohttp/web_protocol.py @@ -0,0 +1,667 @@ +import asyncio +import asyncio.streams +import traceback +import warnings +from collections import deque +from contextlib import suppress +from html import escape as html_escape +from http import HTTPStatus +from logging import Logger +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, cast + +import yarl + +from .abc import AbstractAccessLogger, AbstractStreamWriter +from .base_protocol import BaseProtocol +from .helpers import CeilTimeout, current_task +from .http import ( + HttpProcessingError, + HttpRequestParser, + HttpVersion10, + RawRequestMessage, + StreamWriter, +) +from .log import access_logger, server_logger +from .streams import EMPTY_PAYLOAD, StreamReader +from .tcp_helpers import tcp_keepalive +from .web_exceptions import HTTPException +from .web_log import AccessLogger +from .web_request import BaseRequest +from .web_response import Response, StreamResponse + +__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError") + +if TYPE_CHECKING: # pragma: no cover + from .web_server import Server + + +_RequestFactory = Callable[ + [ + RawRequestMessage, + StreamReader, + "RequestHandler", + AbstractStreamWriter, + "asyncio.Task[None]", + ], + BaseRequest, +] + +_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]] + + +ERROR = RawRequestMessage( + "UNKNOWN", "/", HttpVersion10, {}, {}, True, False, False, False, yarl.URL("/") +) + + +class RequestPayloadError(Exception): + """Payload parsing error.""" + + +class PayloadAccessError(Exception): + """Payload was accessed after response was sent.""" + + +class RequestHandler(BaseProtocol): + """HTTP protocol implementation. + + RequestHandler handles incoming HTTP request. It reads request line, + request headers and request payload and calls handle_request() method. + By default it always returns with 404 response. + + RequestHandler handles errors in incoming request, like bad + status line, bad headers or incomplete payload. If any error occurs, + connection gets closed. + + :param keepalive_timeout: number of seconds before closing + keep-alive connection + :type keepalive_timeout: int or None + + :param bool tcp_keepalive: TCP keep-alive is on, default is on + + :param bool debug: enable debug mode + + :param logger: custom logger object + :type logger: aiohttp.log.server_logger + + :param access_log_class: custom class for access_logger + :type access_log_class: aiohttp.abc.AbstractAccessLogger + + :param access_log: custom logging object + :type access_log: aiohttp.log.server_logger + + :param str access_log_format: access log format string + + :param loop: Optional event loop + + :param int max_line_size: Optional maximum header line size + + :param int max_field_size: Optional maximum header field size + + :param int max_headers: Optional maximum header size + + """ + + KEEPALIVE_RESCHEDULE_DELAY = 1 + + __slots__ = ( + "_request_count", + "_keepalive", + "_manager", + "_request_handler", + "_request_factory", + "_tcp_keepalive", + "_keepalive_time", + "_keepalive_handle", + "_keepalive_timeout", + "_lingering_time", + "_messages", + "_message_tail", + "_waiter", + "_error_handler", + "_task_handler", + "_upgrade", + "_payload_parser", + "_request_parser", + "_reading_paused", + "logger", + "debug", + "access_log", + "access_logger", + "_close", + "_force_close", + "_current_request", + ) + + def __init__( + self, + manager: "Server", + *, + loop: asyncio.AbstractEventLoop, + keepalive_timeout: float = 75.0, # NGINX default is 75 secs + tcp_keepalive: bool = True, + logger: Logger = server_logger, + access_log_class: Type[AbstractAccessLogger] = AccessLogger, + access_log: Logger = access_logger, + access_log_format: str = AccessLogger.LOG_FORMAT, + debug: bool = False, + max_line_size: int = 8190, + max_headers: int = 32768, + max_field_size: int = 8190, + lingering_time: float = 10.0, + read_bufsize: int = 2 ** 16, + ): + + super().__init__(loop) + + self._request_count = 0 + self._keepalive = False + self._current_request = None # type: Optional[BaseRequest] + self._manager = manager # type: Optional[Server] + self._request_handler = ( + manager.request_handler + ) # type: Optional[_RequestHandler] + self._request_factory = ( + manager.request_factory + ) # type: Optional[_RequestFactory] + + self._tcp_keepalive = tcp_keepalive + # placeholder to be replaced on keepalive timeout setup + self._keepalive_time = 0.0 + self._keepalive_handle = None # type: Optional[asyncio.Handle] + self._keepalive_timeout = keepalive_timeout + self._lingering_time = float(lingering_time) + + self._messages = deque() # type: Any # Python 3.5 has no typing.Deque + self._message_tail = b"" + + self._waiter = None # type: Optional[asyncio.Future[None]] + self._error_handler = None # type: Optional[asyncio.Task[None]] + self._task_handler = None # type: Optional[asyncio.Task[None]] + + self._upgrade = False + self._payload_parser = None # type: Any + self._request_parser = HttpRequestParser( + self, + loop, + read_bufsize, + max_line_size=max_line_size, + max_field_size=max_field_size, + max_headers=max_headers, + payload_exception=RequestPayloadError, + ) # type: Optional[HttpRequestParser] + + self.logger = logger + self.debug = debug + self.access_log = access_log + if access_log: + self.access_logger = access_log_class( + access_log, access_log_format + ) # type: Optional[AbstractAccessLogger] + else: + self.access_logger = None + + self._close = False + self._force_close = False + + def __repr__(self) -> str: + return "<{} {}>".format( + self.__class__.__name__, + "connected" if self.transport is not None else "disconnected", + ) + + @property + def keepalive_timeout(self) -> float: + return self._keepalive_timeout + + async def shutdown(self, timeout: Optional[float] = 15.0) -> None: + """Worker process is about to exit, we need cleanup everything and + stop accepting requests. It is especially important for keep-alive + connections.""" + self._force_close = True + + if self._keepalive_handle is not None: + self._keepalive_handle.cancel() + + if self._waiter: + self._waiter.cancel() + + # wait for handlers + with suppress(asyncio.CancelledError, asyncio.TimeoutError): + with CeilTimeout(timeout, loop=self._loop): + if self._error_handler is not None and not self._error_handler.done(): + await self._error_handler + + if self._current_request is not None: + self._current_request._cancel(asyncio.CancelledError()) + + if self._task_handler is not None and not self._task_handler.done(): + await self._task_handler + + # force-close non-idle handler + if self._task_handler is not None: + self._task_handler.cancel() + + if self.transport is not None: + self.transport.close() + self.transport = None + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + super().connection_made(transport) + + real_transport = cast(asyncio.Transport, transport) + if self._tcp_keepalive: + tcp_keepalive(real_transport) + + self._task_handler = self._loop.create_task(self.start()) + assert self._manager is not None + self._manager.connection_made(self, real_transport) + + def connection_lost(self, exc: Optional[BaseException]) -> None: + if self._manager is None: + return + self._manager.connection_lost(self, exc) + + super().connection_lost(exc) + + self._manager = None + self._force_close = True + self._request_factory = None + self._request_handler = None + self._request_parser = None + + if self._keepalive_handle is not None: + self._keepalive_handle.cancel() + + if self._current_request is not None: + if exc is None: + exc = ConnectionResetError("Connection lost") + self._current_request._cancel(exc) + + if self._error_handler is not None: + self._error_handler.cancel() + if self._task_handler is not None: + self._task_handler.cancel() + if self._waiter is not None: + self._waiter.cancel() + + self._task_handler = None + + if self._payload_parser is not None: + self._payload_parser.feed_eof() + self._payload_parser = None + + def set_parser(self, parser: Any) -> None: + # Actual type is WebReader + assert self._payload_parser is None + + self._payload_parser = parser + + if self._message_tail: + self._payload_parser.feed_data(self._message_tail) + self._message_tail = b"" + + def eof_received(self) -> None: + pass + + def data_received(self, data: bytes) -> None: + if self._force_close or self._close: + return + # parse http messages + if self._payload_parser is None and not self._upgrade: + assert self._request_parser is not None + try: + messages, upgraded, tail = self._request_parser.feed_data(data) + except HttpProcessingError as exc: + # something happened during parsing + self._error_handler = self._loop.create_task( + self.handle_parse_error( + StreamWriter(self, self._loop), 400, exc, exc.message + ) + ) + self.close() + except Exception as exc: + # 500: internal error + self._error_handler = self._loop.create_task( + self.handle_parse_error(StreamWriter(self, self._loop), 500, exc) + ) + self.close() + else: + if messages: + # sometimes the parser returns no messages + for (msg, payload) in messages: + self._request_count += 1 + self._messages.append((msg, payload)) + + waiter = self._waiter + if waiter is not None: + if not waiter.done(): + # don't set result twice + waiter.set_result(None) + + self._upgrade = upgraded + if upgraded and tail: + self._message_tail = tail + + # no parser, just store + elif self._payload_parser is None and self._upgrade and data: + self._message_tail += data + + # feed payload + elif data: + eof, tail = self._payload_parser.feed_data(data) + if eof: + self.close() + + def keep_alive(self, val: bool) -> None: + """Set keep-alive connection mode. + + :param bool val: new state. + """ + self._keepalive = val + if self._keepalive_handle: + self._keepalive_handle.cancel() + self._keepalive_handle = None + + def close(self) -> None: + """Stop accepting new pipelinig messages and close + connection when handlers done processing messages""" + self._close = True + if self._waiter: + self._waiter.cancel() + + def force_close(self) -> None: + """Force close connection""" + self._force_close = True + if self._waiter: + self._waiter.cancel() + if self.transport is not None: + self.transport.close() + self.transport = None + + def log_access( + self, request: BaseRequest, response: StreamResponse, time: float + ) -> None: + if self.access_logger is not None: + self.access_logger.log(request, response, self._loop.time() - time) + + def log_debug(self, *args: Any, **kw: Any) -> None: + if self.debug: + self.logger.debug(*args, **kw) + + def log_exception(self, *args: Any, **kw: Any) -> None: + self.logger.exception(*args, **kw) + + def _process_keepalive(self) -> None: + if self._force_close or not self._keepalive: + return + + next = self._keepalive_time + self._keepalive_timeout + + # handler in idle state + if self._waiter: + if self._loop.time() > next: + self.force_close() + return + + # not all request handlers are done, + # reschedule itself to next second + self._keepalive_handle = self._loop.call_later( + self.KEEPALIVE_RESCHEDULE_DELAY, self._process_keepalive + ) + + async def _handle_request( + self, + request: BaseRequest, + start_time: float, + ) -> Tuple[StreamResponse, bool]: + assert self._request_handler is not None + try: + try: + self._current_request = request + resp = await self._request_handler(request) + finally: + self._current_request = None + except HTTPException as exc: + resp = Response( + status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers + ) + reset = await self.finish_response(request, resp, start_time) + except asyncio.CancelledError: + raise + except asyncio.TimeoutError as exc: + self.log_debug("Request handler timed out.", exc_info=exc) + resp = self.handle_error(request, 504) + reset = await self.finish_response(request, resp, start_time) + except Exception as exc: + resp = self.handle_error(request, 500, exc) + reset = await self.finish_response(request, resp, start_time) + else: + reset = await self.finish_response(request, resp, start_time) + + return resp, reset + + async def start(self) -> None: + """Process incoming request. + + It reads request line, request headers and request payload, then + calls handle_request() method. Subclass has to override + handle_request(). start() handles various exceptions in request + or response handling. Connection is being closed always unless + keep_alive(True) specified. + """ + loop = self._loop + handler = self._task_handler + assert handler is not None + manager = self._manager + assert manager is not None + keepalive_timeout = self._keepalive_timeout + resp = None + assert self._request_factory is not None + assert self._request_handler is not None + + while not self._force_close: + if not self._messages: + try: + # wait for next request + self._waiter = loop.create_future() + await self._waiter + except asyncio.CancelledError: + break + finally: + self._waiter = None + + message, payload = self._messages.popleft() + + start = loop.time() + + manager.requests_count += 1 + writer = StreamWriter(self, loop) + request = self._request_factory(message, payload, self, writer, handler) + try: + # a new task is used for copy context vars (#3406) + task = self._loop.create_task(self._handle_request(request, start)) + try: + resp, reset = await task + except (asyncio.CancelledError, ConnectionError): + self.log_debug("Ignored premature client disconnection") + break + # Deprecation warning (See #2415) + if getattr(resp, "__http_exception__", False): + warnings.warn( + "returning HTTPException object is deprecated " + "(#2415) and will be removed, " + "please raise the exception instead", + DeprecationWarning, + ) + + # Drop the processed task from asyncio.Task.all_tasks() early + del task + if reset: + self.log_debug("Ignored premature client disconnection 2") + break + + # notify server about keep-alive + self._keepalive = bool(resp.keep_alive) + + # check payload + if not payload.is_eof(): + lingering_time = self._lingering_time + if not self._force_close and lingering_time: + self.log_debug( + "Start lingering close timer for %s sec.", lingering_time + ) + + now = loop.time() + end_t = now + lingering_time + + with suppress(asyncio.TimeoutError, asyncio.CancelledError): + while not payload.is_eof() and now < end_t: + with CeilTimeout(end_t - now, loop=loop): + # read and ignore + await payload.readany() + now = loop.time() + + # if payload still uncompleted + if not payload.is_eof() and not self._force_close: + self.log_debug("Uncompleted request.") + self.close() + + payload.set_exception(PayloadAccessError()) + + except asyncio.CancelledError: + self.log_debug("Ignored premature client disconnection ") + break + except RuntimeError as exc: + if self.debug: + self.log_exception("Unhandled runtime exception", exc_info=exc) + self.force_close() + except Exception as exc: + self.log_exception("Unhandled exception", exc_info=exc) + self.force_close() + finally: + if self.transport is None and resp is not None: + self.log_debug("Ignored premature client disconnection.") + elif not self._force_close: + if self._keepalive and not self._close: + # start keep-alive timer + if keepalive_timeout is not None: + now = self._loop.time() + self._keepalive_time = now + if self._keepalive_handle is None: + self._keepalive_handle = loop.call_at( + now + keepalive_timeout, self._process_keepalive + ) + else: + break + + # remove handler, close transport if no handlers left + if not self._force_close: + self._task_handler = None + if self.transport is not None and self._error_handler is None: + self.transport.close() + + async def finish_response( + self, request: BaseRequest, resp: StreamResponse, start_time: float + ) -> bool: + """ + Prepare the response and write_eof, then log access. This has to + be called within the context of any exception so the access logger + can get exception information. Returns True if the client disconnects + prematurely. + """ + if self._request_parser is not None: + self._request_parser.set_upgraded(False) + self._upgrade = False + if self._message_tail: + self._request_parser.feed_data(self._message_tail) + self._message_tail = b"" + try: + prepare_meth = resp.prepare + except AttributeError: + if resp is None: + raise RuntimeError("Missing return " "statement on request handler") + else: + raise RuntimeError( + "Web-handler should return " + "a response instance, " + "got {!r}".format(resp) + ) + try: + await prepare_meth(request) + await resp.write_eof() + except ConnectionError: + self.log_access(request, resp, start_time) + return True + else: + self.log_access(request, resp, start_time) + return False + + def handle_error( + self, + request: BaseRequest, + status: int = 500, + exc: Optional[BaseException] = None, + message: Optional[str] = None, + ) -> StreamResponse: + """Handle errors. + + Returns HTTP response with specific status code. Logs additional + information. It always closes current connection.""" + self.log_exception("Error handling request", exc_info=exc) + + ct = "text/plain" + if status == HTTPStatus.INTERNAL_SERVER_ERROR: + title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR) + msg = HTTPStatus.INTERNAL_SERVER_ERROR.description + tb = None + if self.debug: + with suppress(Exception): + tb = traceback.format_exc() + + if "text/html" in request.headers.get("Accept", ""): + if tb: + tb = html_escape(tb) + msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>" + message = ( + "<html><head>" + "<title>{title}</title>" + "</head><body>\n<h1>{title}</h1>" + "\n{msg}\n</body></html>\n" + ).format(title=title, msg=msg) + ct = "text/html" + else: + if tb: + msg = tb + message = title + "\n\n" + msg + + resp = Response(status=status, text=message, content_type=ct) + resp.force_close() + + # some data already got sent, connection is broken + if request.writer.output_size > 0 or self.transport is None: + self.force_close() + + return resp + + async def handle_parse_error( + self, + writer: AbstractStreamWriter, + status: int, + exc: Optional[BaseException] = None, + message: Optional[str] = None, + ) -> None: + task = current_task() + assert task is not None + request = BaseRequest( + ERROR, EMPTY_PAYLOAD, self, writer, task, self._loop # type: ignore + ) + + resp = self.handle_error(request, status, exc, message) + await resp.prepare(request) + await resp.write_eof() + + if self.transport is not None: + self.transport.close() + + self._error_handler = None |