summaryrefslogtreecommitdiffstats
path: root/third_party/python/aiohttp/aiohttp/web_protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/python/aiohttp/aiohttp/web_protocol.py')
-rw-r--r--third_party/python/aiohttp/aiohttp/web_protocol.py667
1 files changed, 667 insertions, 0 deletions
diff --git a/third_party/python/aiohttp/aiohttp/web_protocol.py b/third_party/python/aiohttp/aiohttp/web_protocol.py
new file mode 100644
index 0000000000..8e02bc4aab
--- /dev/null
+++ b/third_party/python/aiohttp/aiohttp/web_protocol.py
@@ -0,0 +1,667 @@
+import asyncio
+import asyncio.streams
+import traceback
+import warnings
+from collections import deque
+from contextlib import suppress
+from html import escape as html_escape
+from http import HTTPStatus
+from logging import Logger
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, cast
+
+import yarl
+
+from .abc import AbstractAccessLogger, AbstractStreamWriter
+from .base_protocol import BaseProtocol
+from .helpers import CeilTimeout, current_task
+from .http import (
+ HttpProcessingError,
+ HttpRequestParser,
+ HttpVersion10,
+ RawRequestMessage,
+ StreamWriter,
+)
+from .log import access_logger, server_logger
+from .streams import EMPTY_PAYLOAD, StreamReader
+from .tcp_helpers import tcp_keepalive
+from .web_exceptions import HTTPException
+from .web_log import AccessLogger
+from .web_request import BaseRequest
+from .web_response import Response, StreamResponse
+
+__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError")
+
+if TYPE_CHECKING: # pragma: no cover
+ from .web_server import Server
+
+
+_RequestFactory = Callable[
+ [
+ RawRequestMessage,
+ StreamReader,
+ "RequestHandler",
+ AbstractStreamWriter,
+ "asyncio.Task[None]",
+ ],
+ BaseRequest,
+]
+
+_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]]
+
+
+ERROR = RawRequestMessage(
+ "UNKNOWN", "/", HttpVersion10, {}, {}, True, False, False, False, yarl.URL("/")
+)
+
+
+class RequestPayloadError(Exception):
+ """Payload parsing error."""
+
+
+class PayloadAccessError(Exception):
+ """Payload was accessed after response was sent."""
+
+
+class RequestHandler(BaseProtocol):
+ """HTTP protocol implementation.
+
+ RequestHandler handles incoming HTTP request. It reads request line,
+ request headers and request payload and calls handle_request() method.
+ By default it always returns with 404 response.
+
+ RequestHandler handles errors in incoming request, like bad
+ status line, bad headers or incomplete payload. If any error occurs,
+ connection gets closed.
+
+ :param keepalive_timeout: number of seconds before closing
+ keep-alive connection
+ :type keepalive_timeout: int or None
+
+ :param bool tcp_keepalive: TCP keep-alive is on, default is on
+
+ :param bool debug: enable debug mode
+
+ :param logger: custom logger object
+ :type logger: aiohttp.log.server_logger
+
+ :param access_log_class: custom class for access_logger
+ :type access_log_class: aiohttp.abc.AbstractAccessLogger
+
+ :param access_log: custom logging object
+ :type access_log: aiohttp.log.server_logger
+
+ :param str access_log_format: access log format string
+
+ :param loop: Optional event loop
+
+ :param int max_line_size: Optional maximum header line size
+
+ :param int max_field_size: Optional maximum header field size
+
+ :param int max_headers: Optional maximum header size
+
+ """
+
+ KEEPALIVE_RESCHEDULE_DELAY = 1
+
+ __slots__ = (
+ "_request_count",
+ "_keepalive",
+ "_manager",
+ "_request_handler",
+ "_request_factory",
+ "_tcp_keepalive",
+ "_keepalive_time",
+ "_keepalive_handle",
+ "_keepalive_timeout",
+ "_lingering_time",
+ "_messages",
+ "_message_tail",
+ "_waiter",
+ "_error_handler",
+ "_task_handler",
+ "_upgrade",
+ "_payload_parser",
+ "_request_parser",
+ "_reading_paused",
+ "logger",
+ "debug",
+ "access_log",
+ "access_logger",
+ "_close",
+ "_force_close",
+ "_current_request",
+ )
+
+ def __init__(
+ self,
+ manager: "Server",
+ *,
+ loop: asyncio.AbstractEventLoop,
+ keepalive_timeout: float = 75.0, # NGINX default is 75 secs
+ tcp_keepalive: bool = True,
+ logger: Logger = server_logger,
+ access_log_class: Type[AbstractAccessLogger] = AccessLogger,
+ access_log: Logger = access_logger,
+ access_log_format: str = AccessLogger.LOG_FORMAT,
+ debug: bool = False,
+ max_line_size: int = 8190,
+ max_headers: int = 32768,
+ max_field_size: int = 8190,
+ lingering_time: float = 10.0,
+ read_bufsize: int = 2 ** 16,
+ ):
+
+ super().__init__(loop)
+
+ self._request_count = 0
+ self._keepalive = False
+ self._current_request = None # type: Optional[BaseRequest]
+ self._manager = manager # type: Optional[Server]
+ self._request_handler = (
+ manager.request_handler
+ ) # type: Optional[_RequestHandler]
+ self._request_factory = (
+ manager.request_factory
+ ) # type: Optional[_RequestFactory]
+
+ self._tcp_keepalive = tcp_keepalive
+ # placeholder to be replaced on keepalive timeout setup
+ self._keepalive_time = 0.0
+ self._keepalive_handle = None # type: Optional[asyncio.Handle]
+ self._keepalive_timeout = keepalive_timeout
+ self._lingering_time = float(lingering_time)
+
+ self._messages = deque() # type: Any # Python 3.5 has no typing.Deque
+ self._message_tail = b""
+
+ self._waiter = None # type: Optional[asyncio.Future[None]]
+ self._error_handler = None # type: Optional[asyncio.Task[None]]
+ self._task_handler = None # type: Optional[asyncio.Task[None]]
+
+ self._upgrade = False
+ self._payload_parser = None # type: Any
+ self._request_parser = HttpRequestParser(
+ self,
+ loop,
+ read_bufsize,
+ max_line_size=max_line_size,
+ max_field_size=max_field_size,
+ max_headers=max_headers,
+ payload_exception=RequestPayloadError,
+ ) # type: Optional[HttpRequestParser]
+
+ self.logger = logger
+ self.debug = debug
+ self.access_log = access_log
+ if access_log:
+ self.access_logger = access_log_class(
+ access_log, access_log_format
+ ) # type: Optional[AbstractAccessLogger]
+ else:
+ self.access_logger = None
+
+ self._close = False
+ self._force_close = False
+
+ def __repr__(self) -> str:
+ return "<{} {}>".format(
+ self.__class__.__name__,
+ "connected" if self.transport is not None else "disconnected",
+ )
+
+ @property
+ def keepalive_timeout(self) -> float:
+ return self._keepalive_timeout
+
+ async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
+ """Worker process is about to exit, we need cleanup everything and
+ stop accepting requests. It is especially important for keep-alive
+ connections."""
+ self._force_close = True
+
+ if self._keepalive_handle is not None:
+ self._keepalive_handle.cancel()
+
+ if self._waiter:
+ self._waiter.cancel()
+
+ # wait for handlers
+ with suppress(asyncio.CancelledError, asyncio.TimeoutError):
+ with CeilTimeout(timeout, loop=self._loop):
+ if self._error_handler is not None and not self._error_handler.done():
+ await self._error_handler
+
+ if self._current_request is not None:
+ self._current_request._cancel(asyncio.CancelledError())
+
+ if self._task_handler is not None and not self._task_handler.done():
+ await self._task_handler
+
+ # force-close non-idle handler
+ if self._task_handler is not None:
+ self._task_handler.cancel()
+
+ if self.transport is not None:
+ self.transport.close()
+ self.transport = None
+
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
+ super().connection_made(transport)
+
+ real_transport = cast(asyncio.Transport, transport)
+ if self._tcp_keepalive:
+ tcp_keepalive(real_transport)
+
+ self._task_handler = self._loop.create_task(self.start())
+ assert self._manager is not None
+ self._manager.connection_made(self, real_transport)
+
+ def connection_lost(self, exc: Optional[BaseException]) -> None:
+ if self._manager is None:
+ return
+ self._manager.connection_lost(self, exc)
+
+ super().connection_lost(exc)
+
+ self._manager = None
+ self._force_close = True
+ self._request_factory = None
+ self._request_handler = None
+ self._request_parser = None
+
+ if self._keepalive_handle is not None:
+ self._keepalive_handle.cancel()
+
+ if self._current_request is not None:
+ if exc is None:
+ exc = ConnectionResetError("Connection lost")
+ self._current_request._cancel(exc)
+
+ if self._error_handler is not None:
+ self._error_handler.cancel()
+ if self._task_handler is not None:
+ self._task_handler.cancel()
+ if self._waiter is not None:
+ self._waiter.cancel()
+
+ self._task_handler = None
+
+ if self._payload_parser is not None:
+ self._payload_parser.feed_eof()
+ self._payload_parser = None
+
+ def set_parser(self, parser: Any) -> None:
+ # Actual type is WebReader
+ assert self._payload_parser is None
+
+ self._payload_parser = parser
+
+ if self._message_tail:
+ self._payload_parser.feed_data(self._message_tail)
+ self._message_tail = b""
+
+ def eof_received(self) -> None:
+ pass
+
+ def data_received(self, data: bytes) -> None:
+ if self._force_close or self._close:
+ return
+ # parse http messages
+ if self._payload_parser is None and not self._upgrade:
+ assert self._request_parser is not None
+ try:
+ messages, upgraded, tail = self._request_parser.feed_data(data)
+ except HttpProcessingError as exc:
+ # something happened during parsing
+ self._error_handler = self._loop.create_task(
+ self.handle_parse_error(
+ StreamWriter(self, self._loop), 400, exc, exc.message
+ )
+ )
+ self.close()
+ except Exception as exc:
+ # 500: internal error
+ self._error_handler = self._loop.create_task(
+ self.handle_parse_error(StreamWriter(self, self._loop), 500, exc)
+ )
+ self.close()
+ else:
+ if messages:
+ # sometimes the parser returns no messages
+ for (msg, payload) in messages:
+ self._request_count += 1
+ self._messages.append((msg, payload))
+
+ waiter = self._waiter
+ if waiter is not None:
+ if not waiter.done():
+ # don't set result twice
+ waiter.set_result(None)
+
+ self._upgrade = upgraded
+ if upgraded and tail:
+ self._message_tail = tail
+
+ # no parser, just store
+ elif self._payload_parser is None and self._upgrade and data:
+ self._message_tail += data
+
+ # feed payload
+ elif data:
+ eof, tail = self._payload_parser.feed_data(data)
+ if eof:
+ self.close()
+
+ def keep_alive(self, val: bool) -> None:
+ """Set keep-alive connection mode.
+
+ :param bool val: new state.
+ """
+ self._keepalive = val
+ if self._keepalive_handle:
+ self._keepalive_handle.cancel()
+ self._keepalive_handle = None
+
+ def close(self) -> None:
+ """Stop accepting new pipelinig messages and close
+ connection when handlers done processing messages"""
+ self._close = True
+ if self._waiter:
+ self._waiter.cancel()
+
+ def force_close(self) -> None:
+ """Force close connection"""
+ self._force_close = True
+ if self._waiter:
+ self._waiter.cancel()
+ if self.transport is not None:
+ self.transport.close()
+ self.transport = None
+
+ def log_access(
+ self, request: BaseRequest, response: StreamResponse, time: float
+ ) -> None:
+ if self.access_logger is not None:
+ self.access_logger.log(request, response, self._loop.time() - time)
+
+ def log_debug(self, *args: Any, **kw: Any) -> None:
+ if self.debug:
+ self.logger.debug(*args, **kw)
+
+ def log_exception(self, *args: Any, **kw: Any) -> None:
+ self.logger.exception(*args, **kw)
+
+ def _process_keepalive(self) -> None:
+ if self._force_close or not self._keepalive:
+ return
+
+ next = self._keepalive_time + self._keepalive_timeout
+
+ # handler in idle state
+ if self._waiter:
+ if self._loop.time() > next:
+ self.force_close()
+ return
+
+ # not all request handlers are done,
+ # reschedule itself to next second
+ self._keepalive_handle = self._loop.call_later(
+ self.KEEPALIVE_RESCHEDULE_DELAY, self._process_keepalive
+ )
+
+ async def _handle_request(
+ self,
+ request: BaseRequest,
+ start_time: float,
+ ) -> Tuple[StreamResponse, bool]:
+ assert self._request_handler is not None
+ try:
+ try:
+ self._current_request = request
+ resp = await self._request_handler(request)
+ finally:
+ self._current_request = None
+ except HTTPException as exc:
+ resp = Response(
+ status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
+ )
+ reset = await self.finish_response(request, resp, start_time)
+ except asyncio.CancelledError:
+ raise
+ except asyncio.TimeoutError as exc:
+ self.log_debug("Request handler timed out.", exc_info=exc)
+ resp = self.handle_error(request, 504)
+ reset = await self.finish_response(request, resp, start_time)
+ except Exception as exc:
+ resp = self.handle_error(request, 500, exc)
+ reset = await self.finish_response(request, resp, start_time)
+ else:
+ reset = await self.finish_response(request, resp, start_time)
+
+ return resp, reset
+
+ async def start(self) -> None:
+ """Process incoming request.
+
+ It reads request line, request headers and request payload, then
+ calls handle_request() method. Subclass has to override
+ handle_request(). start() handles various exceptions in request
+ or response handling. Connection is being closed always unless
+ keep_alive(True) specified.
+ """
+ loop = self._loop
+ handler = self._task_handler
+ assert handler is not None
+ manager = self._manager
+ assert manager is not None
+ keepalive_timeout = self._keepalive_timeout
+ resp = None
+ assert self._request_factory is not None
+ assert self._request_handler is not None
+
+ while not self._force_close:
+ if not self._messages:
+ try:
+ # wait for next request
+ self._waiter = loop.create_future()
+ await self._waiter
+ except asyncio.CancelledError:
+ break
+ finally:
+ self._waiter = None
+
+ message, payload = self._messages.popleft()
+
+ start = loop.time()
+
+ manager.requests_count += 1
+ writer = StreamWriter(self, loop)
+ request = self._request_factory(message, payload, self, writer, handler)
+ try:
+ # a new task is used for copy context vars (#3406)
+ task = self._loop.create_task(self._handle_request(request, start))
+ try:
+ resp, reset = await task
+ except (asyncio.CancelledError, ConnectionError):
+ self.log_debug("Ignored premature client disconnection")
+ break
+ # Deprecation warning (See #2415)
+ if getattr(resp, "__http_exception__", False):
+ warnings.warn(
+ "returning HTTPException object is deprecated "
+ "(#2415) and will be removed, "
+ "please raise the exception instead",
+ DeprecationWarning,
+ )
+
+ # Drop the processed task from asyncio.Task.all_tasks() early
+ del task
+ if reset:
+ self.log_debug("Ignored premature client disconnection 2")
+ break
+
+ # notify server about keep-alive
+ self._keepalive = bool(resp.keep_alive)
+
+ # check payload
+ if not payload.is_eof():
+ lingering_time = self._lingering_time
+ if not self._force_close and lingering_time:
+ self.log_debug(
+ "Start lingering close timer for %s sec.", lingering_time
+ )
+
+ now = loop.time()
+ end_t = now + lingering_time
+
+ with suppress(asyncio.TimeoutError, asyncio.CancelledError):
+ while not payload.is_eof() and now < end_t:
+ with CeilTimeout(end_t - now, loop=loop):
+ # read and ignore
+ await payload.readany()
+ now = loop.time()
+
+ # if payload still uncompleted
+ if not payload.is_eof() and not self._force_close:
+ self.log_debug("Uncompleted request.")
+ self.close()
+
+ payload.set_exception(PayloadAccessError())
+
+ except asyncio.CancelledError:
+ self.log_debug("Ignored premature client disconnection ")
+ break
+ except RuntimeError as exc:
+ if self.debug:
+ self.log_exception("Unhandled runtime exception", exc_info=exc)
+ self.force_close()
+ except Exception as exc:
+ self.log_exception("Unhandled exception", exc_info=exc)
+ self.force_close()
+ finally:
+ if self.transport is None and resp is not None:
+ self.log_debug("Ignored premature client disconnection.")
+ elif not self._force_close:
+ if self._keepalive and not self._close:
+ # start keep-alive timer
+ if keepalive_timeout is not None:
+ now = self._loop.time()
+ self._keepalive_time = now
+ if self._keepalive_handle is None:
+ self._keepalive_handle = loop.call_at(
+ now + keepalive_timeout, self._process_keepalive
+ )
+ else:
+ break
+
+ # remove handler, close transport if no handlers left
+ if not self._force_close:
+ self._task_handler = None
+ if self.transport is not None and self._error_handler is None:
+ self.transport.close()
+
+ async def finish_response(
+ self, request: BaseRequest, resp: StreamResponse, start_time: float
+ ) -> bool:
+ """
+ Prepare the response and write_eof, then log access. This has to
+ be called within the context of any exception so the access logger
+ can get exception information. Returns True if the client disconnects
+ prematurely.
+ """
+ if self._request_parser is not None:
+ self._request_parser.set_upgraded(False)
+ self._upgrade = False
+ if self._message_tail:
+ self._request_parser.feed_data(self._message_tail)
+ self._message_tail = b""
+ try:
+ prepare_meth = resp.prepare
+ except AttributeError:
+ if resp is None:
+ raise RuntimeError("Missing return " "statement on request handler")
+ else:
+ raise RuntimeError(
+ "Web-handler should return "
+ "a response instance, "
+ "got {!r}".format(resp)
+ )
+ try:
+ await prepare_meth(request)
+ await resp.write_eof()
+ except ConnectionError:
+ self.log_access(request, resp, start_time)
+ return True
+ else:
+ self.log_access(request, resp, start_time)
+ return False
+
+ def handle_error(
+ self,
+ request: BaseRequest,
+ status: int = 500,
+ exc: Optional[BaseException] = None,
+ message: Optional[str] = None,
+ ) -> StreamResponse:
+ """Handle errors.
+
+ Returns HTTP response with specific status code. Logs additional
+ information. It always closes current connection."""
+ self.log_exception("Error handling request", exc_info=exc)
+
+ ct = "text/plain"
+ if status == HTTPStatus.INTERNAL_SERVER_ERROR:
+ title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR)
+ msg = HTTPStatus.INTERNAL_SERVER_ERROR.description
+ tb = None
+ if self.debug:
+ with suppress(Exception):
+ tb = traceback.format_exc()
+
+ if "text/html" in request.headers.get("Accept", ""):
+ if tb:
+ tb = html_escape(tb)
+ msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>"
+ message = (
+ "<html><head>"
+ "<title>{title}</title>"
+ "</head><body>\n<h1>{title}</h1>"
+ "\n{msg}\n</body></html>\n"
+ ).format(title=title, msg=msg)
+ ct = "text/html"
+ else:
+ if tb:
+ msg = tb
+ message = title + "\n\n" + msg
+
+ resp = Response(status=status, text=message, content_type=ct)
+ resp.force_close()
+
+ # some data already got sent, connection is broken
+ if request.writer.output_size > 0 or self.transport is None:
+ self.force_close()
+
+ return resp
+
+ async def handle_parse_error(
+ self,
+ writer: AbstractStreamWriter,
+ status: int,
+ exc: Optional[BaseException] = None,
+ message: Optional[str] = None,
+ ) -> None:
+ task = current_task()
+ assert task is not None
+ request = BaseRequest(
+ ERROR, EMPTY_PAYLOAD, self, writer, task, self._loop # type: ignore
+ )
+
+ resp = self.handle_error(request, status, exc, message)
+ await resp.prepare(request)
+ await resp.write_eof()
+
+ if self.transport is not None:
+ self.transport.close()
+
+ self._error_handler = None