summaryrefslogtreecommitdiffstats
path: root/testing/web-platform/tests/tools/third_party/websockets/src/websockets/server.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--testing/web-platform/tests/tools/third_party/websockets/src/websockets/server.py175
1 files changed, 117 insertions, 58 deletions
diff --git a/testing/web-platform/tests/tools/third_party/websockets/src/websockets/server.py b/testing/web-platform/tests/tools/third_party/websockets/src/websockets/server.py
index 5dad50b6a1..191660553f 100644
--- a/testing/web-platform/tests/tools/third_party/websockets/src/websockets/server.py
+++ b/testing/web-platform/tests/tools/third_party/websockets/src/websockets/server.py
@@ -4,9 +4,9 @@ import base64
import binascii
import email.utils
import http
-from typing import Generator, List, Optional, Sequence, Tuple, cast
+import warnings
+from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple, cast
-from .connection import CONNECTING, OPEN, SERVER, Connection, State
from .datastructures import Headers, MultipleValuesError
from .exceptions import (
InvalidHandshake,
@@ -25,13 +25,14 @@ from .headers import (
parse_subprotocol,
parse_upgrade,
)
-from .http import USER_AGENT
from .http11 import Request, Response
+from .protocol import CONNECTING, OPEN, SERVER, Protocol, State
from .typing import (
ConnectionOption,
ExtensionHeader,
LoggerLike,
Origin,
+ StatusLike,
Subprotocol,
UpgradeProtocol,
)
@@ -39,13 +40,15 @@ from .utils import accept_key
# See #940 for why lazy_import isn't used here for backwards compatibility.
-from .legacy.server import * # isort:skip # noqa
+# See #1400 for why listing compatibility imports in __all__ helps PyCharm.
+from .legacy.server import * # isort:skip # noqa: I001
+from .legacy.server import __all__ as legacy__all__
-__all__ = ["ServerConnection"]
+__all__ = ["ServerProtocol"] + legacy__all__
-class ServerConnection(Connection):
+class ServerProtocol(Protocol):
"""
Sans-I/O implementation of a WebSocket server connection.
@@ -58,20 +61,31 @@ class ServerConnection(Connection):
should be tried.
subprotocols: list of supported subprotocols, in order of decreasing
preference.
+ select_subprotocol: Callback for selecting a subprotocol among
+ those supported by the client and the server. It has the same
+ signature as the :meth:`select_subprotocol` method, including a
+ :class:`ServerProtocol` instance as first argument.
state: initial state of the WebSocket connection.
max_size: maximum size of incoming messages in bytes;
- :obj:`None` to disable the limit.
+ :obj:`None` disables the limit.
logger: logger for this connection;
defaults to ``logging.getLogger("websockets.client")``;
- see the :doc:`logging guide <../topics/logging>` for details.
+ see the :doc:`logging guide <../../topics/logging>` for details.
"""
def __init__(
self,
+ *,
origins: Optional[Sequence[Optional[Origin]]] = None,
extensions: Optional[Sequence[ServerExtensionFactory]] = None,
subprotocols: Optional[Sequence[Subprotocol]] = None,
+ select_subprotocol: Optional[
+ Callable[
+ [ServerProtocol, Sequence[Subprotocol]],
+ Optional[Subprotocol],
+ ]
+ ] = None,
state: State = CONNECTING,
max_size: Optional[int] = 2**20,
logger: Optional[LoggerLike] = None,
@@ -85,6 +99,14 @@ class ServerConnection(Connection):
self.origins = origins
self.available_extensions = extensions
self.available_subprotocols = subprotocols
+ if select_subprotocol is not None:
+ # Bind select_subprotocol then shadow self.select_subprotocol.
+ # Use setattr to work around https://github.com/python/mypy/issues/2427.
+ setattr(
+ self,
+ "select_subprotocol",
+ select_subprotocol.__get__(self, self.__class__),
+ )
def accept(self, request: Request) -> Response:
"""
@@ -95,13 +117,13 @@ class ServerConnection(Connection):
You must send the handshake response with :meth:`send_response`.
- You can modify it before sending it, for example to add HTTP headers.
+ You may modify it before sending it, for example to add HTTP headers.
Args:
request: WebSocket handshake request event received from the client.
Returns:
- Response: WebSocket handshake response event to send to the client.
+ WebSocket handshake response event to send to the client.
"""
try:
@@ -145,6 +167,8 @@ class ServerConnection(Connection):
f"Failed to open a WebSocket connection: {exc}.\n",
)
except Exception as exc:
+ # Handle exceptions raised by user-provided select_subprotocol and
+ # unexpected errors.
request._exception = exc
self.handshake_exc = exc
self.logger.error("opening handshake failed", exc_info=True)
@@ -170,13 +194,12 @@ class ServerConnection(Connection):
if protocol_header is not None:
headers["Sec-WebSocket-Protocol"] = protocol_header
- headers["Server"] = USER_AGENT
-
self.logger.info("connection open")
return Response(101, "Switching Protocols", headers)
def process_request(
- self, request: Request
+ self,
+ request: Request,
) -> Tuple[str, Optional[str], Optional[str]]:
"""
Check a handshake request and negotiate extensions and subprotocol.
@@ -274,6 +297,7 @@ class ServerConnection(Connection):
Optional[Origin]: origin, if it is acceptable.
Raises:
+ InvalidHandshake: if the Origin header is invalid.
InvalidOrigin: if the origin isn't acceptable.
"""
@@ -298,8 +322,8 @@ class ServerConnection(Connection):
Accept or reject each extension proposed in the client request.
Negotiate parameters for accepted extensions.
- :rfc:`6455` leaves the rules up to the specification of each
- :extension.
+ Per :rfc:`6455`, negotiation rules are defined by the specification of
+ each extension.
To provide this level of flexibility, for each extension proposed by
the client, we check for a match with each extension available in the
@@ -324,7 +348,7 @@ class ServerConnection(Connection):
HTTP response header and list of accepted extensions.
Raises:
- InvalidHandshake: to abort the handshake with an HTTP 400 error.
+ InvalidHandshake: if the Sec-WebSocket-Extensions header is invalid.
"""
response_header_value: Optional[str] = None
@@ -335,15 +359,12 @@ class ServerConnection(Connection):
header_values = headers.get_all("Sec-WebSocket-Extensions")
if header_values and self.available_extensions:
-
parsed_header_values: List[ExtensionHeader] = sum(
[parse_extension(header_value) for header_value in header_values], []
)
for name, request_params in parsed_header_values:
-
for ext_factory in self.available_extensions:
-
# Skip non-matching extensions based on their name.
if ext_factory.name != name:
continue
@@ -384,64 +405,83 @@ class ServerConnection(Connection):
also the value of the ``Sec-WebSocket-Protocol`` response header.
Raises:
- InvalidHandshake: to abort the handshake with an HTTP 400 error.
+ InvalidHandshake: if the Sec-WebSocket-Subprotocol header is invalid.
"""
- subprotocol: Optional[Subprotocol] = None
-
- header_values = headers.get_all("Sec-WebSocket-Protocol")
-
- if header_values and self.available_subprotocols:
-
- parsed_header_values: List[Subprotocol] = sum(
- [parse_subprotocol(header_value) for header_value in header_values], []
- )
-
- subprotocol = self.select_subprotocol(
- parsed_header_values, self.available_subprotocols
- )
+ subprotocols: Sequence[Subprotocol] = sum(
+ [
+ parse_subprotocol(header_value)
+ for header_value in headers.get_all("Sec-WebSocket-Protocol")
+ ],
+ [],
+ )
- return subprotocol
+ return self.select_subprotocol(subprotocols)
def select_subprotocol(
self,
- client_subprotocols: Sequence[Subprotocol],
- server_subprotocols: Sequence[Subprotocol],
+ subprotocols: Sequence[Subprotocol],
) -> Optional[Subprotocol]:
"""
Pick a subprotocol among those offered by the client.
- If several subprotocols are supported by the client and the server,
- the default implementation selects the preferred subprotocols by
- giving equal value to the priorities of the client and the server.
+ If several subprotocols are supported by both the client and the server,
+ pick the first one in the list declared the server.
+
+ If the server doesn't support any subprotocols, continue without a
+ subprotocol, regardless of what the client offers.
+
+ If the server supports at least one subprotocol and the client doesn't
+ offer any, abort the handshake with an HTTP 400 error.
- If no common subprotocol is supported by the client and the server, it
- proceeds without a subprotocol.
+ You provide a ``select_subprotocol`` argument to :class:`ServerProtocol`
+ to override this logic. For example, you could accept the connection
+ even if client doesn't offer a subprotocol, rather than reject it.
- This is unlikely to be the most useful implementation in practice, as
- many servers providing a subprotocol will require that the client uses
- that subprotocol.
+ Here's how to negotiate the ``chat`` subprotocol if the client supports
+ it and continue without a subprotocol otherwise::
+
+ def select_subprotocol(protocol, subprotocols):
+ if "chat" in subprotocols:
+ return "chat"
Args:
- client_subprotocols: list of subprotocols offered by the client.
- server_subprotocols: list of subprotocols available on the server.
+ subprotocols: list of subprotocols offered by the client.
Returns:
- Optional[Subprotocol]: Subprotocol, if a common subprotocol was
- found.
+ Optional[Subprotocol]: Selected subprotocol, if a common subprotocol
+ was found.
+
+ :obj:`None` to continue without a subprotocol.
+
+ Raises:
+ NegotiationError: custom implementations may raise this exception
+ to abort the handshake with an HTTP 400 error.
"""
- subprotocols = set(client_subprotocols) & set(server_subprotocols)
- if not subprotocols:
+ # Server doesn't offer any subprotocols.
+ if not self.available_subprotocols: # None or empty list
return None
- priority = lambda p: (
- client_subprotocols.index(p) + server_subprotocols.index(p)
+
+ # Server offers at least one subprotocol but client doesn't offer any.
+ if not subprotocols:
+ raise NegotiationError("missing subprotocol")
+
+ # Server and client both offer subprotocols. Look for a shared one.
+ proposed_subprotocols = set(subprotocols)
+ for subprotocol in self.available_subprotocols:
+ if subprotocol in proposed_subprotocols:
+ return subprotocol
+
+ # No common subprotocol was found.
+ raise NegotiationError(
+ "invalid subprotocol; expected one of "
+ + ", ".join(self.available_subprotocols)
)
- return sorted(subprotocols, key=priority)[0]
def reject(
self,
- status: http.HTTPStatus,
+ status: StatusLike,
text: str,
) -> Response:
"""
@@ -462,6 +502,8 @@ class ServerConnection(Connection):
Response: WebSocket handshake response event to send to the client.
"""
+ # If a user passes an int instead of a HTTPStatus, fix it automatically.
+ status = http.HTTPStatus(status)
body = text.encode()
headers = Headers(
[
@@ -469,16 +511,15 @@ class ServerConnection(Connection):
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", "text/plain; charset=utf-8"),
- ("Server", USER_AGENT),
]
)
response = Response(status.value, status.phrase, headers, body)
# When reject() is called from accept(), handshake_exc is already set.
# If a user calls reject(), set handshake_exc to guarantee invariant:
- # "handshake_exc is None if and only if opening handshake succeded."
+ # "handshake_exc is None if and only if opening handshake succeeded."
if self.handshake_exc is None:
self.handshake_exc = InvalidStatus(response)
- self.logger.info("connection failed (%d %s)", status.value, status.phrase)
+ self.logger.info("connection rejected (%d %s)", status.value, status.phrase)
return response
def send_response(self, response: Response) -> None:
@@ -509,7 +550,16 @@ class ServerConnection(Connection):
def parse(self) -> Generator[None, None, None]:
if self.state is CONNECTING:
- request = yield from Request.parse(self.reader.read_line)
+ try:
+ request = yield from Request.parse(
+ self.reader.read_line,
+ )
+ except Exception as exc:
+ self.handshake_exc = exc
+ self.send_eof()
+ self.parser = self.discard()
+ next(self.parser) # start coroutine
+ yield
if self.debug:
self.logger.debug("< GET %s HTTP/1.1", request.path)
@@ -519,3 +569,12 @@ class ServerConnection(Connection):
self.events.append(request)
yield from super().parse()
+
+
+class ServerConnection(ServerProtocol):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ warnings.warn(
+ "ServerConnection was renamed to ServerProtocol",
+ DeprecationWarning,
+ )
+ super().__init__(*args, **kwargs)