summaryrefslogtreecommitdiffstats
path: root/testing/web-platform/tests/tools/webdriver
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--testing/web-platform/tests/tools/webdriver/.gitignore2
-rw-r--r--testing/web-platform/tests/tools/webdriver/README.md73
-rw-r--r--testing/web-platform/tests/tools/webdriver/setup.py14
-rw-r--r--testing/web-platform/tests/tools/webdriver/webdriver/__init__.py39
-rw-r--r--testing/web-platform/tests/tools/webdriver/webdriver/bidi/__init__.py3
-rw-r--r--testing/web-platform/tests/tools/webdriver/webdriver/bidi/client.py226
-rw-r--r--testing/web-platform/tests/tools/webdriver/webdriver/bidi/error.py70
-rw-r--r--testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/__init__.py5
-rw-r--r--testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/_module.py99
-rw-r--r--testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/browsing_context.py82
-rw-r--r--testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/script.py136
-rw-r--r--testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/session.py31
-rw-r--r--testing/web-platform/tests/tools/webdriver/webdriver/bidi/transport.py76
-rw-r--r--testing/web-platform/tests/tools/webdriver/webdriver/client.py900
-rw-r--r--testing/web-platform/tests/tools/webdriver/webdriver/error.py232
-rw-r--r--testing/web-platform/tests/tools/webdriver/webdriver/protocol.py49
-rw-r--r--testing/web-platform/tests/tools/webdriver/webdriver/transport.py267
17 files changed, 2304 insertions, 0 deletions
diff --git a/testing/web-platform/tests/tools/webdriver/.gitignore b/testing/web-platform/tests/tools/webdriver/.gitignore
new file mode 100644
index 0000000000..e8413e0ee4
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/.gitignore
@@ -0,0 +1,2 @@
+webdriver.egg-info/
+*.pyc
diff --git a/testing/web-platform/tests/tools/webdriver/README.md b/testing/web-platform/tests/tools/webdriver/README.md
new file mode 100644
index 0000000000..9433aaa926
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/README.md
@@ -0,0 +1,73 @@
+# WebDriver client for Python
+
+This package provides Python bindings
+that conform to the [W3C WebDriver standard](https://w3c.github.io/webdriver/),
+which specifies a remote control protocol for web browsers.
+
+These bindings are written with determining
+implementation compliance to the specification in mind,
+so that different remote end drivers
+can determine whether they meet the recognised standard.
+The client is used for the WebDriver specification tests
+in [web-platform-tests](https://github.com/web-platform-tests/wpt).
+
+## Installation
+
+To install the package individually
+in your virtualenv or system-wide:
+
+ % python setup.py install
+
+Since this package does not have any external dependencies,
+you can also use the client directly from the checkout directory,
+which is useful if you want to contribute patches back:
+
+ % cd /path/to/wdclient
+ % python
+ Python 2.7.12+ (default, Aug 4 2016, 20:04:34)
+ [GCC 6.1.1 20160724] on linux2
+ Type "help", "copyright", "credits" or "license" for more information.
+ >>> import webdriver
+ >>>
+
+If you are writing WebDriver specification tests for
+[WPT](https://github.com/web-platform-tests/wpt),
+there is no need to install the client manually
+as it is included in the `tools/webdriver` directory.
+
+## Usage
+
+You can use the built-in
+[context manager](https://docs.python.org/2/reference/compound_stmts.html#the-with-statement)
+to manage the lifetime of the session.
+The session is started implicitly
+at the first call to a command if it has not already been started,
+and will implicitly be ended when exiting the context:
+
+```py
+import webdriver
+
+with webdriver.Session("127.0.0.1", 4444) as session:
+ session.url = "https://mozilla.org"
+ print "The current URL is %s" % session.url
+```
+
+The following is functionally equivalent to the above,
+but giving you manual control of the session:
+
+```py
+import webdriver
+
+session = webdriver.Session("127.0.0.1", 4444)
+session.start()
+
+session.url = "https://mozilla.org"
+print "The current URL is %s" % session.url
+
+session.end()
+```
+
+## Dependencies
+
+This client has the benefit of only using standard library dependencies.
+No external PyPI dependencies are needed.
diff --git a/testing/web-platform/tests/tools/webdriver/setup.py b/testing/web-platform/tests/tools/webdriver/setup.py
new file mode 100644
index 0000000000..c473961cb6
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/setup.py
@@ -0,0 +1,14 @@
+from setuptools import setup, find_packages
+
+setup(name="webdriver",
+ version="1.0",
+ description="WebDriver client compatible with "
+ "the W3C browser automation specification.",
+ author="Mozilla Engineering Productivity",
+ author_email="tools@lists.mozilla.org",
+ license="BSD",
+ packages=find_packages(),
+ classifiers=["Development Status :: 4 - Beta",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)",
+ "Operating System :: OS Independent"])
diff --git a/testing/web-platform/tests/tools/webdriver/webdriver/__init__.py b/testing/web-platform/tests/tools/webdriver/webdriver/__init__.py
new file mode 100644
index 0000000000..a81751407e
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/__init__.py
@@ -0,0 +1,39 @@
+# flake8: noqa
+
+from .client import (
+ Cookies,
+ Element,
+ Find,
+ Frame,
+ Session,
+ ShadowRoot,
+ Timeouts,
+ Window)
+from .error import (
+ ElementNotSelectableException,
+ ElementNotVisibleException,
+ InvalidArgumentException,
+ InvalidCookieDomainException,
+ InvalidElementCoordinatesException,
+ InvalidElementStateException,
+ InvalidSelectorException,
+ InvalidSessionIdException,
+ JavascriptErrorException,
+ MoveTargetOutOfBoundsException,
+ NoSuchAlertException,
+ NoSuchElementException,
+ NoSuchFrameException,
+ NoSuchWindowException,
+ ScriptTimeoutException,
+ SessionNotCreatedException,
+ StaleElementReferenceException,
+ TimeoutException,
+ UnableToSetCookieException,
+ UnexpectedAlertOpenException,
+ UnknownCommandException,
+ UnknownErrorException,
+ UnknownMethodException,
+ UnsupportedOperationException,
+ WebDriverException)
+from .bidi import (
+ BidiSession)
diff --git a/testing/web-platform/tests/tools/webdriver/webdriver/bidi/__init__.py b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/__init__.py
new file mode 100644
index 0000000000..e7c56332f9
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/__init__.py
@@ -0,0 +1,3 @@
+# flake8: noqa
+
+from .client import BidiSession
diff --git a/testing/web-platform/tests/tools/webdriver/webdriver/bidi/client.py b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/client.py
new file mode 100644
index 0000000000..9dc80d8121
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/client.py
@@ -0,0 +1,226 @@
+# mypy: allow-untyped-defs
+
+import asyncio
+from collections import defaultdict
+from typing import Any, Awaitable, Callable, List, Optional, Mapping, MutableMapping
+from urllib.parse import urljoin, urlparse
+
+from . import modules
+from .error import from_error_details
+from .transport import get_running_loop, Transport
+
+
+class BidiSession:
+ """A WebDriver BiDi session.
+
+ This is the main representation of a BiDi session and provides the
+ interface for running commands in the session, and for attaching
+ event handlers to the session. For example:
+
+ async def on_log(method, data):
+ print(data)
+
+ session = BidiSession("ws://localhost:4445", capabilities)
+ remove_listener = session.add_event_listener("log.entryAdded", on_log)
+ await session.start()
+ await session.subscribe("log.entryAdded")
+
+ # Do some stuff with the session
+
+ remove_listener()
+ session.end()
+
+ If the session id is provided it's assumed that the underlying
+ WebDriver session was already created, and the WebSocket URL was
+ taken from the new session response. If no session id is provided, it's
+ assumed that a BiDi-only session should be created when start() is called.
+
+ It can also be used as a context manager, with the WebSocket transport
+ implictly being created when the context is entered, and closed when
+ the context is exited.
+
+ :param websocket_url: WebSockets URL on which to connect to the session.
+ This excludes any path component.
+ :param session_id: String id of existing HTTP session
+ :param capabilities: Capabilities response of existing session
+ :param requested_capabilities: Dictionary representing the capabilities request.
+
+ """
+
+ def __init__(self,
+ websocket_url: str,
+ session_id: Optional[str] = None,
+ capabilities: Optional[Mapping[str, Any]] = None,
+ requested_capabilities: Optional[Mapping[str, Any]] = None):
+ self.transport: Optional[Transport] = None
+
+ # The full URL for a websocket looks like
+ # ws://<host>:<port>/session when we're creating a session and
+ # ws://<host>:<port>/session/<sessionid> when we're connecting to an existing session.
+ # To be user friendly, handle the case where the class was created with either a
+ # full URL including the path, and also the case where just a server url is passed in.
+ parsed_url = urlparse(websocket_url)
+ if parsed_url.path == "" or parsed_url.path == "/":
+ if session_id is None:
+ websocket_url = urljoin(websocket_url, "session")
+ else:
+ websocket_url = urljoin(websocket_url, f"session/{session_id}")
+ else:
+ if session_id is not None:
+ if parsed_url.path != f"/session/{session_id}":
+ raise ValueError(f"WebSocket URL {session_id} doesn't match session id")
+ else:
+ if parsed_url.path != "/session":
+ raise ValueError(f"WebSocket URL {session_id} doesn't match session url")
+
+ if session_id is None and capabilities is not None:
+ raise ValueError("Tried to create BiDi-only session with existing capabilities")
+
+ self.websocket_url = websocket_url
+ self.requested_capabilities = requested_capabilities
+ self.capabilities = capabilities
+ self.session_id = session_id
+
+ self.command_id = 0
+ self.pending_commands: MutableMapping[int, "asyncio.Future[Any]"] = {}
+ self.event_listeners: MutableMapping[
+ Optional[str],
+ List[Callable[[str, Mapping[str, Any]], Any]]
+ ] = defaultdict(list)
+
+ # Modules.
+ # For each module, have a property representing that module
+ self.session = modules.Session(self)
+ self.browsing_context = modules.BrowsingContext(self)
+ self.script = modules.Script(self)
+
+ @property
+ def event_loop(self):
+ if self.transport:
+ return self.transport.loop
+
+ return None
+
+ @classmethod
+ def from_http(cls,
+ session_id: str,
+ capabilities: Mapping[str, Any]) -> "BidiSession":
+ """Create a BiDi session from an existing HTTP session
+
+ :param session_id: String id of the session
+ :param capabilities: Capabilities returned in the New Session HTTP response."""
+ websocket_url = capabilities.get("webSocketUrl")
+ if websocket_url is None:
+ raise ValueError("No webSocketUrl found in capabilities")
+ if not isinstance(websocket_url, str):
+ raise ValueError("webSocketUrl is not a string")
+ return cls(websocket_url, session_id=session_id, capabilities=capabilities)
+
+ @classmethod
+ def bidi_only(cls,
+ websocket_url: str,
+ requested_capabilities: Optional[Mapping[str, Any]] = None) -> "BidiSession":
+ """Create a BiDi session where there is no existing HTTP session
+
+ :param webdocket_url: URL to the WebSocket server listening for BiDi connections
+ :param requested_capabilities: Capabilities request for establishing the session."""
+ return cls(websocket_url, requested_capabilities=requested_capabilities)
+
+ async def __aenter__(self) -> "BidiSession":
+ await self.start()
+ return self
+
+ async def __aexit__(self, *args: Any) -> None:
+ await self.end()
+
+ async def start(self,
+ loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
+ """Connect to the WebDriver BiDi remote via WebSockets"""
+
+ if loop is None:
+ loop = get_running_loop()
+
+ self.transport = Transport(self.websocket_url, self.on_message, loop=loop)
+ await self.transport.start()
+
+ if self.session_id is None:
+ self.session_id, self.capabilities = await self.session.new(
+ capabilities=self.requested_capabilities)
+
+ async def send_command(
+ self,
+ method: str,
+ params: Mapping[str, Any]
+ ) -> Awaitable[Mapping[str, Any]]:
+ """Send a command to the remote server"""
+ # this isn't threadsafe
+ self.command_id += 1
+ command_id = self.command_id
+
+ body = {
+ "id": command_id,
+ "method": method,
+ "params": params
+ }
+ assert command_id not in self.pending_commands
+ assert self.transport is not None
+ self.pending_commands[command_id] = self.transport.loop.create_future()
+ await self.transport.send(body)
+
+ return self.pending_commands[command_id]
+
+ async def on_message(self, data: Mapping[str, Any]) -> None:
+ """Handle a message from the remote server"""
+ if "id" in data:
+ # This is a command response or error
+ future = self.pending_commands.get(data["id"])
+ if future is None:
+ raise ValueError(f"No pending command with id {data['id']}")
+ if "result" in data:
+ future.set_result(data["result"])
+ elif "error" in data and "message" in data:
+ assert isinstance(data["error"], str)
+ assert isinstance(data["message"], str)
+ exception = from_error_details(data["error"],
+ data["message"],
+ data.get("stacktrace"))
+ future.set_exception(exception)
+ else:
+ raise ValueError(f"Unexpected message: {data!r}")
+ elif "method" in data and "params" in data:
+ # This is an event
+ method = data["method"]
+ params = data["params"]
+
+ listeners = self.event_listeners.get(method, [])
+ if not listeners:
+ listeners = self.event_listeners.get(None, [])
+ for listener in listeners:
+ await listener(method, params)
+ else:
+ raise ValueError(f"Unexpected message: {data!r}")
+
+ async def end(self) -> None:
+ """Close websocket connection."""
+ assert self.transport is not None
+ await self.transport.end()
+ self.transport = None
+
+ def add_event_listener(
+ self,
+ name: Optional[str],
+ fn: Callable[[str, Mapping[str, Any]], Awaitable[Any]]
+ ) -> Callable[[], None]:
+ """Add a listener for the event with a given name.
+
+ If name is None, the listener is called for all messages that are not otherwise
+ handled.
+
+ :param name: Name of event to listen for or None to register a default handler
+ :param fn: Async callback function that receives event data
+
+ :return: Function to remove the added listener
+ """
+ self.event_listeners[name].append(fn)
+
+ return lambda: self.event_listeners[name].remove(fn)
diff --git a/testing/web-platform/tests/tools/webdriver/webdriver/bidi/error.py b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/error.py
new file mode 100644
index 0000000000..9e8737e54c
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/error.py
@@ -0,0 +1,70 @@
+# mypy: allow-untyped-defs
+
+import collections
+
+from typing import ClassVar, DefaultDict, Optional, Type
+
+
+class BidiException(Exception):
+ # The error_code class variable is used to map the JSON Error Code (see
+ # https://w3c.github.io/webdriver/#errors) to a BidiException subclass.
+ # TODO: Match on error and let it be a class variables only.
+ error_code = None # type: ClassVar[str]
+
+ def __init__(self, message: str, stacktrace: Optional[str] = None):
+ super()
+
+ self.message = message
+ self.stacktrace = stacktrace
+
+ def __repr__(self):
+ """Return the object representation in string format."""
+ return f"{self.__class__.__name__}({self.error}, {self.message}, {self.stacktrace})"
+
+ def __str__(self):
+ """Return the string representation of the object."""
+ message = f"{self.error_code} ({self.message})"
+
+ if self.stacktrace is not None:
+ message += f"\n\nRemote-end stacktrace:\n\n{self.stacktrace}"
+
+ return message
+
+
+class InvalidArgumentException(BidiException):
+ error_code = "invalid argument"
+
+
+class NoSuchFrameException(BidiException):
+ error_code = "no such frame"
+
+
+class UnknownCommandException(BidiException):
+ error_code = "unknown command"
+
+
+class UnknownErrorException(BidiException):
+ error_code = "unknown error"
+
+
+def from_error_details(error: str, message: str, stacktrace: Optional[str]) -> BidiException:
+ """Create specific WebDriver BiDi exception class from error details.
+
+ Defaults to ``UnknownErrorException`` if `error` is unknown.
+ """
+ cls = get(error)
+ return cls(message, stacktrace)
+
+
+def get(error_code: str) -> Type[BidiException]:
+ """Get exception from `error_code`.
+
+ It's falling back to ``UnknownErrorException`` if it is not found.
+ """
+ return _errors.get(error_code, UnknownErrorException)
+
+
+_errors: DefaultDict[str, Type[BidiException]] = collections.defaultdict()
+for item in list(locals().values()):
+ if type(item) == type and issubclass(item, BidiException):
+ _errors[item.error_code] = item
diff --git a/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/__init__.py b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/__init__.py
new file mode 100644
index 0000000000..487b1270ab
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/__init__.py
@@ -0,0 +1,5 @@
+# flake8: noqa
+
+from .session import Session
+from .browsing_context import BrowsingContext
+from .script import Script
diff --git a/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/_module.py b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/_module.py
new file mode 100644
index 0000000000..c2034033c7
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/_module.py
@@ -0,0 +1,99 @@
+import functools
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Optional,
+ Mapping,
+ MutableMapping,
+ TYPE_CHECKING,
+)
+
+if TYPE_CHECKING:
+ from ..client import BidiSession
+
+
+class command:
+ """Decorator for implementing bidi commands.
+
+ Implementing a command involves specifying an async function that
+ builds the parameters to the command. The decorator arranges those
+ parameters to be turned into a send_command call, using the class
+ and method names to determine the method in the call.
+
+ Commands decorated in this way don't return a future, but await
+ the actual response. In some cases it can be useful to
+ post-process this response before returning it to the client. This
+ can be done by specifying a second decorated method like
+ @command_name.result. That method will then be called once the
+ result of the original command is known, and the return value of
+ the method used as the response of the command. If this method
+ is specified, the `raw_result` parameter of the command can be set
+ to `True` to get the result without post-processing.
+
+ So for an example, if we had a command test.testMethod, which
+ returned a result which we want to convert to a TestResult type,
+ the implementation might look like:
+
+ class Test(BidiModule):
+ @command
+ def test_method(self, test_data=None):
+ return {"testData": test_data}
+
+ @test_method.result
+ def convert_test_method_result(self, result):
+ return TestData(**result)
+ """
+
+ def __init__(self, fn: Callable[..., Mapping[str, Any]]):
+ self.params_fn = fn
+ self.result_fn: Optional[Callable[..., Any]] = None
+
+ def result(self, fn: Callable[[Any, MutableMapping[str, Any]], Any]) -> None:
+ self.result_fn = fn
+
+ def __set_name__(self, owner: Any, name: str) -> None:
+ # This is called when the class is created
+ # see https://docs.python.org/3/reference/datamodel.html#object.__set_name__
+ params_fn = self.params_fn
+ result_fn = self.result_fn
+
+ @functools.wraps(params_fn)
+ async def inner(self: Any, **kwargs: Any) -> Any:
+ raw_result = kwargs.pop("raw_result", False)
+ params = params_fn(self, **kwargs)
+
+ # Convert the classname and the method name to a bidi command name
+ mod_name = owner.__name__[0].lower() + owner.__name__[1:]
+ if hasattr(owner, "prefix"):
+ mod_name = f"{owner.prefix}:{mod_name}"
+ cmd_name = f"{mod_name}.{to_camelcase(name)}"
+
+ future = await self.session.send_command(cmd_name, params)
+ result = await future
+
+ if result_fn is not None and not raw_result:
+ # Convert the result if we have a conversion function defined
+ result = result_fn(self, result)
+ return result
+
+ # Overwrite the method on the owner class with the wrapper
+ setattr(owner, name, inner)
+
+ def __call__(*args: Any, **kwargs: Any) -> Awaitable[Any]:
+ # This isn't really used, but mypy doesn't understand __set_name__
+ pass
+
+
+class BidiModule:
+ def __init__(self, session: "BidiSession"):
+ self.session = session
+
+
+def to_camelcase(name: str) -> str:
+ """Convert a python style method name foo_bar to a BiDi command name fooBar"""
+ parts = name.split("_")
+ parts[0] = parts[0].lower()
+ for i in range(1, len(parts)):
+ parts[i] = parts[i].title()
+ return "".join(parts)
diff --git a/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/browsing_context.py b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/browsing_context.py
new file mode 100644
index 0000000000..70c834c384
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/browsing_context.py
@@ -0,0 +1,82 @@
+import base64
+from typing import Any, Mapping, MutableMapping, Optional
+
+from ._module import BidiModule, command
+
+
+class BrowsingContext(BidiModule):
+ @command
+ def capture_screenshot(self, context: str) -> Mapping[str, Any]:
+ params: MutableMapping[str, Any] = {
+ "context": context
+ }
+
+ return params
+
+ @capture_screenshot.result
+ def _capture_screenshot(self, result: Mapping[str, Any]) -> bytes:
+ assert result["data"] is not None
+ return base64.b64decode(result["data"])
+
+ @command
+ def close(self, context: Optional[str] = None) -> Mapping[str, Any]:
+ params: MutableMapping[str, Any] = {}
+
+ if context is not None:
+ params["context"] = context
+
+ return params
+
+ @command
+ def create(self, type_hint: str, reference_context: Optional[str] = None) -> Mapping[str, Any]:
+ params: MutableMapping[str, Any] = {"type": type_hint}
+
+ if reference_context is not None:
+ params["referenceContext"] = reference_context
+
+ return params
+
+ @create.result
+ def _create(self, result: Mapping[str, Any]) -> Any:
+ assert result["context"] is not None
+
+ return result
+
+ @command
+ def get_tree(self,
+ max_depth: Optional[int] = None,
+ root: Optional[str] = None) -> Mapping[str, Any]:
+ params: MutableMapping[str, Any] = {}
+
+ if max_depth is not None:
+ params["maxDepth"] = max_depth
+ if root is not None:
+ params["root"] = root
+
+ return params
+
+ @get_tree.result
+ def _get_tree(self, result: Mapping[str, Any]) -> Any:
+ assert result["contexts"] is not None
+ assert isinstance(result["contexts"], list)
+
+ return result["contexts"]
+
+ @command
+ def navigate(
+ self, context: str, url: str, wait: Optional[str] = None
+ ) -> Mapping[str, Any]:
+ params: MutableMapping[str, Any] = {"context": context, "url": url}
+ if wait is not None:
+ params["wait"] = wait
+ return params
+
+ @navigate.result
+ def _navigate(self, result: Mapping[str, Any]) -> Any:
+ if result["navigation"] is not None:
+ assert isinstance(result["navigation"], str)
+
+ assert result["url"] is not None
+ assert isinstance(result["url"], str)
+
+ return result
diff --git a/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/script.py b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/script.py
new file mode 100644
index 0000000000..d9af11a8e2
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/script.py
@@ -0,0 +1,136 @@
+from enum import Enum
+from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Union
+
+from ..error import UnknownErrorException
+from ._module import BidiModule, command
+
+
+class ScriptEvaluateResultException(Exception):
+ def __init__(self, result: Mapping[str, Any]):
+ self.result = result
+ super().__init__("Script execution failed.")
+
+
+class OwnershipModel(Enum):
+ NONE = "none"
+ ROOT = "root"
+
+
+class RealmTypes(Enum):
+ AUDIO_WORKLET = "audio-worklet"
+ DEDICATED_WORKER = "dedicated-worker"
+ PAINT_WORKLET = "paint-worklet"
+ SERVICE_WORKER = "service-worker"
+ SHARED_WORKER = "shared-worker"
+ WINDOW = "window"
+ WORKER = "worker"
+ WORKLET = "worklet"
+
+
+class RealmTarget(Dict[str, Any]):
+ def __init__(self, realm: str):
+ dict.__init__(self, realm=realm)
+
+
+class ContextTarget(Dict[str, Any]):
+ def __init__(self, context: str, sandbox: Optional[str] = None):
+ if sandbox is None:
+ dict.__init__(self, context=context)
+ else:
+ dict.__init__(self, context=context, sandbox=sandbox)
+
+
+Target = Union[RealmTarget, ContextTarget]
+
+
+class Script(BidiModule):
+ @command
+ def call_function(
+ self,
+ function_declaration: str,
+ await_promise: bool,
+ target: Target,
+ arguments: Optional[List[Mapping[str, Any]]] = None,
+ this: Optional[Mapping[str, Any]] = None,
+ result_ownership: Optional[OwnershipModel] = None,
+ ) -> Mapping[str, Any]:
+ params: MutableMapping[str, Any] = {
+ "functionDeclaration": function_declaration,
+ "target": target,
+ "awaitPromise": await_promise,
+ }
+
+ if arguments is not None:
+ params["arguments"] = arguments
+ if this is not None:
+ params["this"] = this
+ if result_ownership is not None:
+ params["resultOwnership"] = result_ownership
+ return params
+
+ @call_function.result
+ def _call_function(self, result: Mapping[str, Any]) -> Any:
+ assert "type" in result
+
+ if result["type"] == "success":
+ return result["result"]
+ elif result["type"] == "exception":
+ raise ScriptEvaluateResultException(result)
+ else:
+ raise UnknownErrorException(f"""Invalid type '{result["type"]}' in response""")
+
+ @command
+ def disown(self, handles: List[str], target: Target) -> Mapping[str, Any]:
+ params: MutableMapping[str, Any] = {"handles": handles, "target": target}
+ return params
+
+ @command
+ def evaluate(
+ self,
+ expression: str,
+ target: Target,
+ await_promise: bool,
+ result_ownership: Optional[OwnershipModel] = None,
+ ) -> Mapping[str, Any]:
+ params: MutableMapping[str, Any] = {
+ "expression": expression,
+ "target": target,
+ "awaitPromise": await_promise,
+ }
+
+ if result_ownership is not None:
+ params["resultOwnership"] = result_ownership
+ return params
+
+ @evaluate.result
+ def _evaluate(self, result: Mapping[str, Any]) -> Any:
+ assert "type" in result
+
+ if result["type"] == "success":
+ return result["result"]
+ elif result["type"] == "exception":
+ raise ScriptEvaluateResultException(result)
+ else:
+ raise UnknownErrorException(f"""Invalid type '{result["type"]}' in response""")
+
+ @command
+ def get_realms(
+ self,
+ context: Optional[str] = None,
+ type: Optional[RealmTypes] = None,
+ ) -> Mapping[str, Any]:
+ params: MutableMapping[str, Any] = {}
+
+ if context is not None:
+ params["context"] = context
+ if type is not None:
+ params["type"] = type
+
+ return params
+
+ @get_realms.result
+ def _get_realms(self, result: Mapping[str, Any]) -> Any:
+ assert result["realms"] is not None
+ assert isinstance(result["realms"], list)
+
+ return result["realms"]
diff --git a/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/session.py b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/session.py
new file mode 100644
index 0000000000..7c1fef30ae
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/modules/session.py
@@ -0,0 +1,31 @@
+from typing import Any, List, Optional, Mapping, MutableMapping
+
+from ._module import BidiModule, command
+
+
+class Session(BidiModule):
+ @command
+ def new(self, capabilities: Mapping[str, Any]) -> Mapping[str, Mapping[str, Any]]:
+ return {"capabilities": capabilities}
+
+ @new.result
+ def _new(self, result: Mapping[str, Any]) -> Any:
+ return result.get("session_id"), result.get("capabilities", {})
+
+ @command
+ def subscribe(self,
+ events: List[str],
+ contexts: Optional[List[str]] = None) -> Mapping[str, Any]:
+ params: MutableMapping[str, Any] = {"events": events}
+ if contexts is not None:
+ params["contexts"] = contexts
+ return params
+
+ @command
+ def unsubscribe(self,
+ events: Optional[List[str]] = None,
+ contexts: Optional[List[str]] = None) -> Mapping[str, Any]:
+ params: MutableMapping[str, Any] = {"events": events if events is not None else []}
+ if contexts is not None:
+ params["contexts"] = contexts
+ return params
diff --git a/testing/web-platform/tests/tools/webdriver/webdriver/bidi/transport.py b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/transport.py
new file mode 100644
index 0000000000..afe054528e
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/bidi/transport.py
@@ -0,0 +1,76 @@
+import asyncio
+import json
+import logging
+import sys
+from typing import Any, Callable, Coroutine, List, Optional, Mapping
+
+import websockets
+
+logger = logging.getLogger("webdriver.bidi")
+
+
+def get_running_loop() -> asyncio.AbstractEventLoop:
+ if sys.version_info >= (3, 7):
+ return asyncio.get_running_loop()
+ else:
+ # Unlike the above, this will actually create an event loop
+ # if there isn't one; hopefully running tests in Python >= 3.7
+ # will allow us to catch any behaviour difference
+ # (Needs to be in else for mypy to believe this is reachable)
+ return asyncio.get_event_loop()
+
+
+class Transport:
+ """Low level message handler for the WebSockets connection"""
+ def __init__(self, url: str,
+ msg_handler: Callable[[Mapping[str, Any]], Coroutine[Any, Any, None]],
+ loop: Optional[asyncio.AbstractEventLoop] = None):
+ self.url = url
+ self.connection: Optional[websockets.WebSocketClientProtocol] = None
+ self.msg_handler = msg_handler
+ self.send_buf: List[Mapping[str, Any]] = []
+
+ if loop is None:
+ loop = get_running_loop()
+ self.loop = loop
+
+ self.read_message_task: Optional[asyncio.Task[Any]] = None
+
+ async def start(self) -> None:
+ self.connection = await websockets.client.connect(self.url)
+ self.read_message_task = self.loop.create_task(self.read_messages())
+
+ for msg in self.send_buf:
+ await self._send(self.connection, msg)
+
+ async def send(self, data: Mapping[str, Any]) -> None:
+ if self.connection is not None:
+ await self._send(self.connection, data)
+ else:
+ self.send_buf.append(data)
+
+ @staticmethod
+ async def _send(
+ connection: websockets.WebSocketClientProtocol,
+ data: Mapping[str, Any]
+ ) -> None:
+ msg = json.dumps(data)
+ logger.debug("→ %s", msg)
+ await connection.send(msg)
+
+ async def handle(self, msg: str) -> None:
+ logger.debug("← %s", msg)
+ data = json.loads(msg)
+ await self.msg_handler(data)
+
+ async def end(self) -> None:
+ if self.connection:
+ await self.connection.close()
+ self.connection = None
+
+ async def read_messages(self) -> None:
+ assert self.connection is not None
+ async for msg in self.connection:
+ if not isinstance(msg, str):
+ raise ValueError("Got a binary message")
+ await self.handle(msg)
diff --git a/testing/web-platform/tests/tools/webdriver/webdriver/client.py b/testing/web-platform/tests/tools/webdriver/webdriver/client.py
new file mode 100644
index 0000000000..030e3fc56b
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/client.py
@@ -0,0 +1,900 @@
+# mypy: allow-untyped-defs
+
+from typing import Dict
+from urllib import parse as urlparse
+
+from . import error
+from . import protocol
+from . import transport
+from .bidi.client import BidiSession
+
+
+def command(func):
+ def inner(self, *args, **kwargs):
+ if hasattr(self, "session"):
+ session = self.session
+ else:
+ session = self
+
+ if session.session_id is None:
+ session.start()
+
+ return func(self, *args, **kwargs)
+
+ inner.__name__ = func.__name__
+ inner.__doc__ = func.__doc__
+
+ return inner
+
+
+class Timeouts:
+
+ def __init__(self, session):
+ self.session = session
+
+ def _get(self, key=None):
+ timeouts = self.session.send_session_command("GET", "timeouts")
+ if key is not None:
+ return timeouts[key]
+ return timeouts
+
+ def _set(self, key, secs):
+ body = {key: secs * 1000}
+ self.session.send_session_command("POST", "timeouts", body)
+ return None
+
+ @property
+ def script(self):
+ return self._get("script")
+
+ @script.setter
+ def script(self, secs):
+ return self._set("script", secs)
+
+ @property
+ def page_load(self):
+ return self._get("pageLoad")
+
+ @page_load.setter
+ def page_load(self, secs):
+ return self._set("pageLoad", secs)
+
+ @property
+ def implicit(self):
+ return self._get("implicit")
+
+ @implicit.setter
+ def implicit(self, secs):
+ return self._set("implicit", secs)
+
+ def __str__(self):
+ name = "%s.%s" % (self.__module__, self.__class__.__name__)
+ return "<%s script=%d, load=%d, implicit=%d>" % \
+ (name, self.script, self.page_load, self.implicit)
+
+
+class ActionSequence:
+ """API for creating and performing action sequences.
+
+ Each action method adds one or more actions to a queue. When perform()
+ is called, the queued actions fire in order.
+
+ May be chained together as in::
+
+ ActionSequence(session, "key", id) \
+ .key_down("a") \
+ .key_up("a") \
+ .perform()
+ """
+ def __init__(self, session, action_type, input_id, pointer_params=None):
+ """Represents a sequence of actions of one type for one input source.
+
+ :param session: WebDriver session.
+ :param action_type: Action type; may be "none", "key", or "pointer".
+ :param input_id: ID of input source.
+ :param pointer_params: Optional dictionary of pointer parameters.
+ """
+ self.session = session
+ self._id = input_id
+ self._type = action_type
+ self._actions = []
+ self._pointer_params = pointer_params
+
+ @property
+ def dict(self):
+ d = {
+ "type": self._type,
+ "id": self._id,
+ "actions": self._actions,
+ }
+ if self._pointer_params is not None:
+ d["parameters"] = self._pointer_params
+ return d
+
+ @command
+ def perform(self):
+ """Perform all queued actions."""
+ self.session.actions.perform([self.dict])
+
+ def _key_action(self, subtype, value):
+ self._actions.append({"type": subtype, "value": value})
+
+ def _pointer_action(self, subtype, button=None, x=None, y=None, duration=None, origin=None, width=None,
+ height=None, pressure=None, tangential_pressure=None, tilt_x=None,
+ tilt_y=None, twist=None, altitude_angle=None, azimuth_angle=None):
+ action = {
+ "type": subtype
+ }
+ if button is not None:
+ action["button"] = button
+ if x is not None:
+ action["x"] = x
+ if y is not None:
+ action["y"] = y
+ if duration is not None:
+ action["duration"] = duration
+ if origin is not None:
+ action["origin"] = origin
+ if width is not None:
+ action["width"] = width
+ if height is not None:
+ action["height"] = height
+ if pressure is not None:
+ action["pressure"] = pressure
+ if tangential_pressure is not None:
+ action["tangentialPressure"] = tangential_pressure
+ if tilt_x is not None:
+ action["tiltX"] = tilt_x
+ if tilt_y is not None:
+ action["tiltY"] = tilt_y
+ if twist is not None:
+ action["twist"] = twist
+ if altitude_angle is not None:
+ action["altitudeAngle"] = altitude_angle
+ if azimuth_angle is not None:
+ action["azimuthAngle"] = azimuth_angle
+ self._actions.append(action)
+
+ def pause(self, duration):
+ self._actions.append({"type": "pause", "duration": duration})
+ return self
+
+ def pointer_move(self, x, y, duration=None, origin=None, width=None, height=None,
+ pressure=None, tangential_pressure=None, tilt_x=None, tilt_y=None,
+ twist=None, altitude_angle=None, azimuth_angle=None):
+ """Queue a pointerMove action.
+
+ :param x: Destination x-axis coordinate of pointer in CSS pixels.
+ :param y: Destination y-axis coordinate of pointer in CSS pixels.
+ :param duration: Number of milliseconds over which to distribute the
+ move. If None, remote end defaults to 0.
+ :param origin: Origin of coordinates, either "viewport", "pointer" or
+ an Element. If None, remote end defaults to "viewport".
+ """
+ self._pointer_action("pointerMove", x=x, y=y, duration=duration, origin=origin,
+ width=width, height=height, pressure=pressure,
+ tangential_pressure=tangential_pressure, tilt_x=tilt_x, tilt_y=tilt_y,
+ twist=twist, altitude_angle=altitude_angle, azimuth_angle=azimuth_angle)
+ return self
+
+ def pointer_up(self, button=0):
+ """Queue a pointerUp action for `button`.
+
+ :param button: Pointer button to perform action with.
+ Default: 0, which represents main device button.
+ """
+ self._pointer_action("pointerUp", button=button)
+ return self
+
+ def pointer_down(self, button=0, width=None, height=None, pressure=None,
+ tangential_pressure=None, tilt_x=None, tilt_y=None,
+ twist=None, altitude_angle=None, azimuth_angle=None):
+ """Queue a pointerDown action for `button`.
+
+ :param button: Pointer button to perform action with.
+ Default: 0, which represents main device button.
+ """
+ self._pointer_action("pointerDown", button=button, width=width, height=height,
+ pressure=pressure, tangential_pressure=tangential_pressure,
+ tilt_x=tilt_x, tilt_y=tilt_y, twist=twist, altitude_angle=altitude_angle,
+ azimuth_angle=azimuth_angle)
+ return self
+
+ def click(self, element=None, button=0):
+ """Queue a click with the specified button.
+
+ If an element is given, move the pointer to that element first,
+ otherwise click current pointer coordinates.
+
+ :param element: Optional element to click.
+ :param button: Integer representing pointer button to perform action
+ with. Default: 0, which represents main device button.
+ """
+ if element:
+ self.pointer_move(0, 0, origin=element)
+ return self.pointer_down(button).pointer_up(button)
+
+ def key_up(self, value):
+ """Queue a keyUp action for `value`.
+
+ :param value: Character to perform key action with.
+ """
+ self._key_action("keyUp", value)
+ return self
+
+ def key_down(self, value):
+ """Queue a keyDown action for `value`.
+
+ :param value: Character to perform key action with.
+ """
+ self._key_action("keyDown", value)
+ return self
+
+ def send_keys(self, keys):
+ """Queue a keyDown and keyUp action for each character in `keys`.
+
+ :param keys: String of keys to perform key actions with.
+ """
+ for c in keys:
+ self.key_down(c)
+ self.key_up(c)
+ return self
+
+ def scroll(self, x, y, delta_x, delta_y, duration=None, origin=None):
+ """Queue a scroll action.
+
+ :param x: Destination x-axis coordinate of pointer in CSS pixels.
+ :param y: Destination y-axis coordinate of pointer in CSS pixels.
+ :param delta_x: scroll delta on x-axis in CSS pixels.
+ :param delta_y: scroll delta on y-axis in CSS pixels.
+ :param duration: Number of milliseconds over which to distribute the
+ scroll. If None, remote end defaults to 0.
+ :param origin: Origin of coordinates, either "viewport" or an Element.
+ If None, remote end defaults to "viewport".
+ """
+ action = {
+ "type": "scroll",
+ "x": x,
+ "y": y,
+ "deltaX": delta_x,
+ "deltaY": delta_y
+ }
+ if duration is not None:
+ action["duration"] = duration
+ if origin is not None:
+ action["origin"] = origin
+ self._actions.append(action)
+ return self
+
+
+class Actions:
+ def __init__(self, session):
+ self.session = session
+
+ @command
+ def perform(self, actions=None):
+ """Performs actions by tick from each action sequence in `actions`.
+
+ :param actions: List of input source action sequences. A single action
+ sequence may be created with the help of
+ ``ActionSequence.dict``.
+ """
+ body = {"actions": [] if actions is None else actions}
+ actions = self.session.send_session_command("POST", "actions", body)
+ return actions
+
+ @command
+ def release(self):
+ return self.session.send_session_command("DELETE", "actions")
+
+ def sequence(self, *args, **kwargs):
+ """Return an empty ActionSequence of the designated type.
+
+ See ActionSequence for parameter list.
+ """
+ return ActionSequence(self.session, *args, **kwargs)
+
+
+class Window:
+ identifier = "window-fcc6-11e5-b4f8-330a88ab9d7f"
+
+ def __init__(self, session):
+ self.session = session
+
+ @command
+ def close(self):
+ handles = self.session.send_session_command("DELETE", "window")
+ if handles is not None and len(handles) == 0:
+ # With no more open top-level browsing contexts, the session is closed.
+ self.session.session_id = None
+
+ return handles
+
+ # The many "type: ignore" comments here and below are to silence mypy's
+ # "Decorated property not supported" error, which is due to a limitation
+ # in mypy, see https://github.com/python/mypy/issues/1362.
+ @property # type: ignore
+ @command
+ def rect(self):
+ return self.session.send_session_command("GET", "window/rect")
+
+ @rect.setter # type: ignore
+ @command
+ def rect(self, new_rect):
+ self.session.send_session_command("POST", "window/rect", new_rect)
+
+ @property # type: ignore
+ @command
+ def size(self):
+ """Gets the window size as a tuple of `(width, height)`."""
+ rect = self.rect
+ return (rect["width"], rect["height"])
+
+ @size.setter # type: ignore
+ @command
+ def size(self, new_size):
+ """Set window size by passing a tuple of `(width, height)`."""
+ try:
+ width, height = new_size
+ body = {"width": width, "height": height}
+ self.session.send_session_command("POST", "window/rect", body)
+ except (error.UnknownErrorException, error.InvalidArgumentException):
+ # silently ignore this error as the command is not implemented
+ # for Android. Revert this once it is implemented.
+ pass
+
+ @property # type: ignore
+ @command
+ def position(self):
+ """Gets the window position as a tuple of `(x, y)`."""
+ rect = self.rect
+ return (rect["x"], rect["y"])
+
+ @position.setter # type: ignore
+ @command
+ def position(self, new_position):
+ """Set window position by passing a tuple of `(x, y)`."""
+ try:
+ x, y = new_position
+ body = {"x": x, "y": y}
+ self.session.send_session_command("POST", "window/rect", body)
+ except error.UnknownErrorException:
+ # silently ignore this error as the command is not implemented
+ # for Android. Revert this once it is implemented.
+ pass
+
+ @command
+ def maximize(self):
+ return self.session.send_session_command("POST", "window/maximize")
+
+ @command
+ def minimize(self):
+ return self.session.send_session_command("POST", "window/minimize")
+
+ @command
+ def fullscreen(self):
+ return self.session.send_session_command("POST", "window/fullscreen")
+
+ @classmethod
+ def from_json(cls, json, session):
+ uuid = json[Window.identifier]
+ return cls(uuid, session)
+
+
+class Frame:
+ identifier = "frame-075b-4da1-b6ba-e579c2d3230a"
+
+ def __init__(self, session):
+ self.session = session
+
+ @classmethod
+ def from_json(cls, json, session):
+ uuid = json[Frame.identifier]
+ return cls(uuid, session)
+
+
+class ShadowRoot:
+ identifier = "shadow-6066-11e4-a52e-4f735466cecf"
+
+ def __init__(self, session, id):
+ """
+ Construct a new shadow root representation.
+
+ :param id: Shadow root UUID which must be unique across
+ all browsing contexts.
+ :param session: Current ``webdriver.Session``.
+ """
+ self.id = id
+ self.session = session
+
+ @classmethod
+ def from_json(cls, json, session):
+ uuid = json[ShadowRoot.identifier]
+ return cls(session, uuid)
+
+ def send_shadow_command(self, method, uri, body=None):
+ url = f"shadow/{self.id}/{uri}"
+ return self.session.send_session_command(method, url, body)
+
+ @command
+ def find_element(self, strategy, selector):
+ body = {"using": strategy,
+ "value": selector}
+ return self.send_shadow_command("POST", "element", body)
+
+ @command
+ def find_elements(self, strategy, selector):
+ body = {"using": strategy,
+ "value": selector}
+ return self.send_shadow_command("POST", "elements", body)
+
+
+class Find:
+ def __init__(self, session):
+ self.session = session
+
+ @command
+ def css(self, element_selector, all=True):
+ elements = self._find_element("css selector", element_selector, all)
+ return elements
+
+ def _find_element(self, strategy, selector, all):
+ route = "elements" if all else "element"
+ body = {"using": strategy,
+ "value": selector}
+ return self.session.send_session_command("POST", route, body)
+
+
+class Cookies:
+ def __init__(self, session):
+ self.session = session
+
+ def __getitem__(self, name):
+ self.session.send_session_command("GET", "cookie/%s" % name, {})
+
+ def __setitem__(self, name, value):
+ cookie = {"name": name,
+ "value": None}
+
+ if isinstance(name, str):
+ cookie["value"] = value
+ elif hasattr(value, "value"):
+ cookie["value"] = value.value
+ self.session.send_session_command("POST", "cookie/%s" % name, {})
+
+
+class UserPrompt:
+ def __init__(self, session):
+ self.session = session
+
+ @command
+ def dismiss(self):
+ self.session.send_session_command("POST", "alert/dismiss")
+
+ @command
+ def accept(self):
+ self.session.send_session_command("POST", "alert/accept")
+
+ @property # type: ignore
+ @command
+ def text(self):
+ return self.session.send_session_command("GET", "alert/text")
+
+ @text.setter # type: ignore
+ @command
+ def text(self, value):
+ body = {"text": value}
+ self.session.send_session_command("POST", "alert/text", body=body)
+
+
+class Session:
+ def __init__(self,
+ host,
+ port,
+ url_prefix="/",
+ enable_bidi=False,
+ capabilities=None,
+ extension=None):
+
+ if enable_bidi:
+ if capabilities is not None:
+ capabilities.setdefault("alwaysMatch", {}).update({"webSocketUrl": True})
+ else:
+ capabilities = {"alwaysMatch": {"webSocketUrl": True}}
+
+ self.transport = transport.HTTPWireProtocol(host, port, url_prefix)
+ self.requested_capabilities = capabilities
+ self.capabilities = None
+ self.session_id = None
+ self.timeouts = None
+ self.window = None
+ self.find = None
+ self.enable_bidi = enable_bidi
+ self.bidi_session = None
+ self.extension = None
+ self.extension_cls = extension
+
+ self.timeouts = Timeouts(self)
+ self.window = Window(self)
+ self.find = Find(self)
+ self.alert = UserPrompt(self)
+ self.actions = Actions(self)
+
+ def __repr__(self):
+ return "<%s %s>" % (self.__class__.__name__, self.session_id or "(disconnected)")
+
+ def __eq__(self, other):
+ return (self.session_id is not None and isinstance(other, Session) and
+ self.session_id == other.session_id)
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ self.end()
+
+ def __del__(self):
+ self.end()
+
+ def match(self, capabilities):
+ return self.requested_capabilities == capabilities
+
+ def start(self):
+ """Start a new WebDriver session.
+
+ :return: Dictionary with `capabilities` and `sessionId`.
+
+ :raises error.WebDriverException: If the remote end returns
+ an error.
+ """
+ if self.session_id is not None:
+ return
+
+ self.transport.close()
+
+ body = {"capabilities": {}}
+
+ if self.requested_capabilities is not None:
+ body["capabilities"] = self.requested_capabilities
+
+ value = self.send_command("POST", "session", body=body)
+ assert isinstance(value["sessionId"], str)
+ assert isinstance(value["capabilities"], Dict)
+
+ self.session_id = value["sessionId"]
+ self.capabilities = value["capabilities"]
+
+ if "webSocketUrl" in self.capabilities:
+ self.bidi_session = BidiSession.from_http(self.session_id,
+ self.capabilities)
+ elif self.enable_bidi:
+ self.end()
+ raise error.SessionNotCreatedException(
+ "Requested bidi session, but webSocketUrl capability not found")
+
+ if self.extension_cls:
+ self.extension = self.extension_cls(self)
+
+ return value
+
+ def end(self):
+ """Try to close the active session."""
+ if self.session_id is None:
+ return
+
+ try:
+ self.send_command("DELETE", "session/%s" % self.session_id)
+ except (OSError, error.InvalidSessionIdException):
+ pass
+ finally:
+ self.session_id = None
+ self.transport.close()
+
+ def send_command(self, method, url, body=None, timeout=None):
+ """
+ Send a command to the remote end and validate its success.
+
+ :param method: HTTP method to use in request.
+ :param uri: "Command part" of the HTTP request URL,
+ e.g. `window/rect`.
+ :param body: Optional body of the HTTP request.
+
+ :return: `None` if the HTTP response body was empty, otherwise
+ the `value` field returned after parsing the response
+ body as JSON.
+
+ :raises error.WebDriverException: If the remote end returns
+ an error.
+ :raises ValueError: If the response body does not contain a
+ `value` key.
+ """
+
+ response = self.transport.send(
+ method, url, body,
+ encoder=protocol.Encoder, decoder=protocol.Decoder,
+ session=self, timeout=timeout)
+
+ if response.status != 200:
+ err = error.from_response(response)
+
+ if isinstance(err, error.InvalidSessionIdException):
+ # The driver could have already been deleted the session.
+ self.session_id = None
+
+ raise err
+
+ if "value" in response.body:
+ value = response.body["value"]
+ """
+ Edge does not yet return the w3c session ID.
+ We want the tests to run in Edge anyway to help with REC.
+ In order to run the tests in Edge, we need to hack around
+ bug:
+ https://developer.microsoft.com/en-us/microsoft-edge/platform/issues/14641972
+ """
+ if url == "session" and method == "POST" and "sessionId" in response.body and "sessionId" not in value:
+ value["sessionId"] = response.body["sessionId"]
+ else:
+ raise ValueError("Expected 'value' key in response body:\n"
+ "%s" % response)
+
+ return value
+
+ def send_session_command(self, method, uri, body=None, timeout=None):
+ """
+ Send a command to an established session and validate its success.
+
+ :param method: HTTP method to use in request.
+ :param url: "Command part" of the HTTP request URL,
+ e.g. `window/rect`.
+ :param body: Optional body of the HTTP request. Must be JSON
+ serialisable.
+
+ :return: `None` if the HTTP response body was empty, otherwise
+ the result of parsing the body as JSON.
+
+ :raises error.WebDriverException: If the remote end returns
+ an error.
+ """
+ url = urlparse.urljoin("session/%s/" % self.session_id, uri)
+ return self.send_command(method, url, body, timeout)
+
+ @property # type: ignore
+ @command
+ def url(self):
+ return self.send_session_command("GET", "url")
+
+ @url.setter # type: ignore
+ @command
+ def url(self, url):
+ if urlparse.urlsplit(url).netloc is None:
+ return self.url(url)
+ body = {"url": url}
+ return self.send_session_command("POST", "url", body)
+
+ @command
+ def back(self):
+ return self.send_session_command("POST", "back")
+
+ @command
+ def forward(self):
+ return self.send_session_command("POST", "forward")
+
+ @command
+ def refresh(self):
+ return self.send_session_command("POST", "refresh")
+
+ @property # type: ignore
+ @command
+ def title(self):
+ return self.send_session_command("GET", "title")
+
+ @property # type: ignore
+ @command
+ def source(self):
+ return self.send_session_command("GET", "source")
+
+ @command
+ def new_window(self, type_hint="tab"):
+ body = {"type": type_hint}
+ value = self.send_session_command("POST", "window/new", body)
+
+ return value["handle"]
+
+ @property # type: ignore
+ @command
+ def window_handle(self):
+ return self.send_session_command("GET", "window")
+
+ @window_handle.setter # type: ignore
+ @command
+ def window_handle(self, handle):
+ body = {"handle": handle}
+ return self.send_session_command("POST", "window", body=body)
+
+ def switch_frame(self, frame):
+ if frame == "parent":
+ url = "frame/parent"
+ body = None
+ else:
+ url = "frame"
+ body = {"id": frame}
+
+ return self.send_session_command("POST", url, body)
+
+ @property # type: ignore
+ @command
+ def handles(self):
+ return self.send_session_command("GET", "window/handles")
+
+ @property # type: ignore
+ @command
+ def active_element(self):
+ return self.send_session_command("GET", "element/active")
+
+ @command
+ def cookies(self, name=None):
+ if name is None:
+ url = "cookie"
+ else:
+ url = "cookie/%s" % name
+ return self.send_session_command("GET", url, {})
+
+ @command
+ def set_cookie(self, name, value, path=None, domain=None,
+ secure=None, expiry=None, http_only=None):
+ body = {
+ "name": name,
+ "value": value,
+ }
+
+ if domain is not None:
+ body["domain"] = domain
+ if expiry is not None:
+ body["expiry"] = expiry
+ if http_only is not None:
+ body["httpOnly"] = http_only
+ if path is not None:
+ body["path"] = path
+ if secure is not None:
+ body["secure"] = secure
+ self.send_session_command("POST", "cookie", {"cookie": body})
+
+ def delete_cookie(self, name=None):
+ if name is None:
+ url = "cookie"
+ else:
+ url = "cookie/%s" % name
+ self.send_session_command("DELETE", url, {})
+
+ #[...]
+
+ @command
+ def execute_script(self, script, args=None):
+ if args is None:
+ args = []
+
+ body = {
+ "script": script,
+ "args": args
+ }
+ return self.send_session_command("POST", "execute/sync", body)
+
+ @command
+ def execute_async_script(self, script, args=None):
+ if args is None:
+ args = []
+
+ body = {
+ "script": script,
+ "args": args
+ }
+ return self.send_session_command("POST", "execute/async", body)
+
+ #[...]
+
+ @command
+ def screenshot(self):
+ return self.send_session_command("GET", "screenshot")
+
+class Element:
+ """
+ Representation of a web element.
+
+ A web element is an abstraction used to identify an element when
+ it is transported via the protocol, between remote- and local ends.
+ """
+ identifier = "element-6066-11e4-a52e-4f735466cecf"
+
+ def __init__(self, id, session):
+ """
+ Construct a new web element representation.
+
+ :param id: Web element UUID which must be unique across
+ all browsing contexts.
+ :param session: Current ``webdriver.Session``.
+ """
+ self.id = id
+ self.session = session
+
+ def __repr__(self):
+ return "<%s %s>" % (self.__class__.__name__, self.id)
+
+ def __eq__(self, other):
+ return (isinstance(other, Element) and self.id == other.id and
+ self.session == other.session)
+
+ @classmethod
+ def from_json(cls, json, session):
+ uuid = json[Element.identifier]
+ return cls(uuid, session)
+
+ def send_element_command(self, method, uri, body=None):
+ url = "element/%s/%s" % (self.id, uri)
+ return self.session.send_session_command(method, url, body)
+
+ @command
+ def find_element(self, strategy, selector):
+ body = {"using": strategy,
+ "value": selector}
+ return self.send_element_command("POST", "element", body)
+
+ @command
+ def click(self):
+ self.send_element_command("POST", "click", {})
+
+ @command
+ def tap(self):
+ self.send_element_command("POST", "tap", {})
+
+ @command
+ def clear(self):
+ self.send_element_command("POST", "clear", {})
+
+ @command
+ def send_keys(self, text):
+ return self.send_element_command("POST", "value", {"text": text})
+
+ @property # type: ignore
+ @command
+ def text(self):
+ return self.send_element_command("GET", "text")
+
+ @property # type: ignore
+ @command
+ def name(self):
+ return self.send_element_command("GET", "name")
+
+ @command
+ def style(self, property_name):
+ return self.send_element_command("GET", "css/%s" % property_name)
+
+ @property # type: ignore
+ @command
+ def rect(self):
+ return self.send_element_command("GET", "rect")
+
+ @property # type: ignore
+ @command
+ def selected(self):
+ return self.send_element_command("GET", "selected")
+
+ @command
+ def screenshot(self):
+ return self.send_element_command("GET", "screenshot")
+
+ @property # type: ignore
+ @command
+ def shadow_root(self):
+ return self.send_element_command("GET", "shadow")
+
+ @command
+ def attribute(self, name):
+ return self.send_element_command("GET", "attribute/%s" % name)
+
+ # This MUST come last because otherwise @property decorators above
+ # will be overridden by this.
+ @command
+ def property(self, name):
+ return self.send_element_command("GET", "property/%s" % name)
diff --git a/testing/web-platform/tests/tools/webdriver/webdriver/error.py b/testing/web-platform/tests/tools/webdriver/webdriver/error.py
new file mode 100644
index 0000000000..1b67d3325a
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/error.py
@@ -0,0 +1,232 @@
+# mypy: allow-untyped-defs
+
+import collections
+import json
+
+from typing import ClassVar, DefaultDict, Type
+
+
+class WebDriverException(Exception):
+ # The status_code class variable is used to map the JSON Error Code (see
+ # https://w3c.github.io/webdriver/#errors) to a WebDriverException subclass.
+ # However, http_status need not match, and both are set as instance
+ # variables, shadowing the class variables. TODO: Match on both http_status
+ # and status_code and let these be class variables only.
+ http_status = None # type: ClassVar[int]
+ status_code = None # type: ClassVar[str]
+
+ def __init__(self, http_status=None, status_code=None, message=None, stacktrace=None):
+ super()
+
+ if http_status is not None:
+ self.http_status = http_status
+ if status_code is not None:
+ self.status_code = status_code
+ self.message = message
+ self.stacktrace = stacktrace
+
+ def __repr__(self):
+ return f"<{self.__class__.__name__} http_status={self.http_status}>"
+
+ def __str__(self):
+ message = f"{self.status_code} ({self.http_status})"
+
+ if self.message is not None:
+ message += ": %s" % self.message
+ message += "\n"
+
+ if self.stacktrace:
+ message += ("\nRemote-end stacktrace:\n\n%s" % self.stacktrace)
+
+ return message
+
+
+class DetachedShadowRootException(WebDriverException):
+ http_status = 404
+ status_code = "detached shadow root"
+
+
+class ElementClickInterceptedException(WebDriverException):
+ http_status = 400
+ status_code = "element click intercepted"
+
+
+class ElementNotSelectableException(WebDriverException):
+ http_status = 400
+ status_code = "element not selectable"
+
+
+class ElementNotVisibleException(WebDriverException):
+ http_status = 400
+ status_code = "element not visible"
+
+
+class InsecureCertificateException(WebDriverException):
+ http_status = 400
+ status_code = "insecure certificate"
+
+
+class InvalidArgumentException(WebDriverException):
+ http_status = 400
+ status_code = "invalid argument"
+
+
+class InvalidCookieDomainException(WebDriverException):
+ http_status = 400
+ status_code = "invalid cookie domain"
+
+
+class InvalidElementCoordinatesException(WebDriverException):
+ http_status = 400
+ status_code = "invalid element coordinates"
+
+
+class InvalidElementStateException(WebDriverException):
+ http_status = 400
+ status_code = "invalid element state"
+
+
+class InvalidSelectorException(WebDriverException):
+ http_status = 400
+ status_code = "invalid selector"
+
+
+class InvalidSessionIdException(WebDriverException):
+ http_status = 404
+ status_code = "invalid session id"
+
+
+class JavascriptErrorException(WebDriverException):
+ http_status = 500
+ status_code = "javascript error"
+
+
+class MoveTargetOutOfBoundsException(WebDriverException):
+ http_status = 500
+ status_code = "move target out of bounds"
+
+
+class NoSuchAlertException(WebDriverException):
+ http_status = 404
+ status_code = "no such alert"
+
+
+class NoSuchCookieException(WebDriverException):
+ http_status = 404
+ status_code = "no such cookie"
+
+
+class NoSuchElementException(WebDriverException):
+ http_status = 404
+ status_code = "no such element"
+
+
+class NoSuchFrameException(WebDriverException):
+ http_status = 404
+ status_code = "no such frame"
+
+
+class NoSuchShadowRootException(WebDriverException):
+ http_status = 404
+ status_code = "no such shadow root"
+
+
+class NoSuchWindowException(WebDriverException):
+ http_status = 404
+ status_code = "no such window"
+
+
+class ScriptTimeoutException(WebDriverException):
+ http_status = 500
+ status_code = "script timeout"
+
+
+class SessionNotCreatedException(WebDriverException):
+ http_status = 500
+ status_code = "session not created"
+
+
+class StaleElementReferenceException(WebDriverException):
+ http_status = 404
+ status_code = "stale element reference"
+
+
+class TimeoutException(WebDriverException):
+ http_status = 500
+ status_code = "timeout"
+
+
+class UnableToSetCookieException(WebDriverException):
+ http_status = 500
+ status_code = "unable to set cookie"
+
+
+class UnexpectedAlertOpenException(WebDriverException):
+ http_status = 500
+ status_code = "unexpected alert open"
+
+
+class UnknownErrorException(WebDriverException):
+ http_status = 500
+ status_code = "unknown error"
+
+
+class UnknownCommandException(WebDriverException):
+ http_status = 404
+ status_code = "unknown command"
+
+
+class UnknownMethodException(WebDriverException):
+ http_status = 405
+ status_code = "unknown method"
+
+
+class UnsupportedOperationException(WebDriverException):
+ http_status = 500
+ status_code = "unsupported operation"
+
+
+def from_response(response):
+ """
+ Unmarshals an error from a ``Response``'s `body`, failing
+ if not all three required `error`, `message`, and `stacktrace`
+ fields are given. Defaults to ``WebDriverException`` if `error`
+ is unknown.
+ """
+ if response.status == 200:
+ raise UnknownErrorException(
+ response.status,
+ None,
+ "Response is not an error:\n"
+ "%s" % json.dumps(response.body))
+
+ if "value" in response.body:
+ value = response.body["value"]
+ else:
+ raise UnknownErrorException(
+ response.status,
+ None,
+ "Expected 'value' key in response body:\n"
+ "%s" % json.dumps(response.body))
+
+ # all fields must exist, but stacktrace can be an empty string
+ code = value["error"]
+ message = value["message"]
+ stack = value["stacktrace"] or None
+
+ cls = get(code)
+ return cls(response.status, code, message, stacktrace=stack)
+
+
+def get(error_code):
+ """
+ Gets exception from `error_code`, falling back to
+ ``WebDriverException`` if it is not found.
+ """
+ return _errors.get(error_code, WebDriverException)
+
+
+_errors: DefaultDict[str, Type[WebDriverException]] = collections.defaultdict()
+for item in list(locals().values()):
+ if type(item) == type and issubclass(item, WebDriverException):
+ _errors[item.status_code] = item
diff --git a/testing/web-platform/tests/tools/webdriver/webdriver/protocol.py b/testing/web-platform/tests/tools/webdriver/webdriver/protocol.py
new file mode 100644
index 0000000000..1972c3fce2
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/protocol.py
@@ -0,0 +1,49 @@
+# mypy: allow-untyped-defs
+
+import json
+
+import webdriver
+
+
+"""WebDriver wire protocol codecs."""
+
+
+class Encoder(json.JSONEncoder):
+ def __init__(self, *args, **kwargs):
+ kwargs.pop("session")
+ super().__init__(*args, **kwargs)
+
+ def default(self, obj):
+ if isinstance(obj, (list, tuple)):
+ return [self.default(x) for x in obj]
+ elif isinstance(obj, webdriver.Element):
+ return {webdriver.Element.identifier: obj.id}
+ elif isinstance(obj, webdriver.Frame):
+ return {webdriver.Frame.identifier: obj.id}
+ elif isinstance(obj, webdriver.Window):
+ return {webdriver.Frame.identifier: obj.id}
+ elif isinstance(obj, webdriver.ShadowRoot):
+ return {webdriver.ShadowRoot.identifier: obj.id}
+ return super().default(obj)
+
+
+class Decoder(json.JSONDecoder):
+ def __init__(self, *args, **kwargs):
+ self.session = kwargs.pop("session")
+ super().__init__(
+ object_hook=self.object_hook, *args, **kwargs)
+
+ def object_hook(self, payload):
+ if isinstance(payload, (list, tuple)):
+ return [self.object_hook(x) for x in payload]
+ elif isinstance(payload, dict) and webdriver.Element.identifier in payload:
+ return webdriver.Element.from_json(payload, self.session)
+ elif isinstance(payload, dict) and webdriver.Frame.identifier in payload:
+ return webdriver.Frame.from_json(payload, self.session)
+ elif isinstance(payload, dict) and webdriver.Window.identifier in payload:
+ return webdriver.Window.from_json(payload, self.session)
+ elif isinstance(payload, dict) and webdriver.ShadowRoot.identifier in payload:
+ return webdriver.ShadowRoot.from_json(payload, self.session)
+ elif isinstance(payload, dict):
+ return {k: self.object_hook(v) for k, v in payload.items()}
+ return payload
diff --git a/testing/web-platform/tests/tools/webdriver/webdriver/transport.py b/testing/web-platform/tests/tools/webdriver/webdriver/transport.py
new file mode 100644
index 0000000000..47d0659196
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/transport.py
@@ -0,0 +1,267 @@
+# mypy: allow-untyped-defs
+
+import json
+import select
+
+from http.client import HTTPConnection
+from typing import Dict, List, Mapping, Sequence, Tuple
+from urllib import parse as urlparse
+
+from . import error
+
+"""Implements HTTP transport for the WebDriver wire protocol."""
+
+
+missing = object()
+
+
+class ResponseHeaders(Mapping[str, str]):
+ """Read-only dictionary-like API for accessing response headers.
+
+ This class:
+ * Normalizes the header keys it is built with to lowercase (such that
+ iterating the items will return lowercase header keys).
+ * Has case-insensitive header lookup.
+ * Always returns all header values that have the same name, separated by
+ commas.
+ """
+ def __init__(self, items: Sequence[Tuple[str, str]]):
+ self.headers_dict: Dict[str, List[str]] = {}
+ for key, value in items:
+ key = key.lower()
+ if key not in self.headers_dict:
+ self.headers_dict[key] = []
+ self.headers_dict[key].append(value)
+
+ def __getitem__(self, key):
+ """Get all headers of a certain (case-insensitive) name. If there is
+ more than one, the values are returned comma separated"""
+ values = self.headers_dict[key.lower()]
+ if len(values) == 1:
+ return values[0]
+ else:
+ return ", ".join(values)
+
+ def get_list(self, key, default=missing):
+ """Get all the header values for a particular field name as a list"""
+ try:
+ return self.headers_dict[key.lower()]
+ except KeyError:
+ if default is not missing:
+ return default
+ else:
+ raise
+
+ def __iter__(self):
+ yield from self.headers_dict
+
+ def __len__(self):
+ return len(self.headers_dict)
+
+
+class Response:
+ """
+ Describes an HTTP response received from a remote end whose
+ body has been read and parsed as appropriate.
+ """
+
+ def __init__(self, status, body, headers):
+ self.status = status
+ self.body = body
+ self.headers = headers
+
+ def __repr__(self):
+ cls_name = self.__class__.__name__
+ if self.error:
+ return f"<{cls_name} status={self.status} error={repr(self.error)}>"
+ return f"<{cls_name: }tatus={self.status} body={json.dumps(self.body)}>"
+
+ def __str__(self):
+ return json.dumps(self.body, indent=2)
+
+ @property
+ def error(self):
+ if self.status != 200:
+ return error.from_response(self)
+ return None
+
+ @classmethod
+ def from_http(cls, http_response, decoder=json.JSONDecoder, **kwargs):
+ try:
+ body = json.load(http_response, cls=decoder, **kwargs)
+ headers = ResponseHeaders(http_response.getheaders())
+ except ValueError:
+ raise ValueError("Failed to decode response body as JSON:\n" +
+ http_response.read())
+
+ return cls(http_response.status, body, headers)
+
+
+class HTTPWireProtocol:
+ """
+ Transports messages (commands and responses) over the WebDriver
+ wire protocol.
+
+ Complex objects, such as ``webdriver.Element``, ``webdriver.Frame``,
+ and ``webdriver.Window`` are by default not marshaled to enable
+ use of `session.transport.send` in WPT tests::
+
+ session = webdriver.Session("127.0.0.1", 4444)
+ response = transport.send("GET", "element/active", None)
+ print response.body["value"]
+ # => {u'element-6066-11e4-a52e-4f735466cecf': u'<uuid>'}
+
+ Automatic marshaling is provided by ``webdriver.protocol.Encoder``
+ and ``webdriver.protocol.Decoder``, which can be passed in to
+ ``HTTPWireProtocol.send`` along with a reference to the current
+ ``webdriver.Session``::
+
+ session = webdriver.Session("127.0.0.1", 4444)
+ response = transport.send("GET", "element/active", None,
+ encoder=protocol.Encoder, decoder=protocol.Decoder,
+ session=session)
+ print response.body["value"]
+ # => webdriver.Element
+ """
+
+ def __init__(self, host, port, url_prefix="/"):
+ """
+ Construct interface for communicating with the remote server.
+
+ :param url: URL of remote WebDriver server.
+ :param wait: Duration to wait for remote to appear.
+ """
+ self.host = host
+ self.port = port
+ self.url_prefix = url_prefix
+ self._conn = None
+ self._last_request_is_blocked = False
+
+ def __del__(self):
+ self.close()
+
+ def close(self):
+ """Closes the current HTTP connection, if there is one."""
+ if self._conn:
+ try:
+ self._conn.close()
+ except OSError:
+ # The remote closed the connection
+ pass
+ self._conn = None
+
+ @property
+ def connection(self):
+ """Gets the current HTTP connection, or lazily creates one."""
+ if not self._conn:
+ conn_kwargs = {}
+ # We are not setting an HTTP timeout other than the default when the
+ # connection its created. The send method has a timeout value if needed.
+ self._conn = HTTPConnection(self.host, self.port, **conn_kwargs)
+
+ return self._conn
+
+ def url(self, suffix):
+ """
+ From the relative path to a command end-point,
+ craft a full URL suitable to be used in a request to the HTTPD.
+ """
+ return urlparse.urljoin(self.url_prefix, suffix)
+
+ def send(self,
+ method,
+ uri,
+ body=None,
+ headers=None,
+ encoder=json.JSONEncoder,
+ decoder=json.JSONDecoder,
+ timeout=None,
+ **codec_kwargs):
+ """
+ Send a command to the remote.
+
+ The request `body` must be JSON serialisable unless a
+ custom `encoder` has been provided. This means complex
+ objects such as ``webdriver.Element``, ``webdriver.Frame``,
+ and `webdriver.Window`` are not automatically made
+ into JSON. This behaviour is, however, provided by
+ ``webdriver.protocol.Encoder``, should you want it.
+
+ Similarly, the response body is returned au natural
+ as plain JSON unless a `decoder` that converts web
+ element references to ``webdriver.Element`` is provided.
+ Use ``webdriver.protocol.Decoder`` to achieve this behaviour.
+
+ The client will attempt to use persistent HTTP connections.
+
+ :param method: `GET`, `POST`, or `DELETE`.
+ :param uri: Relative endpoint of the requests URL path.
+ :param body: Body of the request. Defaults to an empty
+ dictionary if ``method`` is `POST`.
+ :param headers: Additional dictionary of headers to include
+ in the request.
+ :param encoder: JSON encoder class, which defaults to
+ ``json.JSONEncoder`` unless specified.
+ :param decoder: JSON decoder class, which defaults to
+ ``json.JSONDecoder`` unless specified.
+ :param codec_kwargs: Surplus arguments passed on to `encoder`
+ and `decoder` on construction.
+
+ :return: Instance of ``webdriver.transport.Response``
+ describing the HTTP response received from the remote end.
+
+ :raises ValueError: If `body` or the response body are not
+ JSON serialisable.
+ """
+ if body is None and method == "POST":
+ body = {}
+
+ payload = None
+ if body is not None:
+ try:
+ payload = json.dumps(body, cls=encoder, **codec_kwargs)
+ except ValueError:
+ raise ValueError("Failed to encode request body as JSON:\n"
+ "%s" % json.dumps(body, indent=2))
+
+ # When the timeout triggers, the TestRunnerManager thread will reuse
+ # this connection to check if the WebDriver its alive and we may end
+ # raising an httplib.CannotSendRequest exception if the WebDriver is
+ # not responding and this httplib.request() call is blocked on the
+ # runner thread. We use the boolean below to check for that and restart
+ # the connection in that case.
+ self._last_request_is_blocked = True
+ response = self._request(method, uri, payload, headers, timeout=None)
+ self._last_request_is_blocked = False
+ return Response.from_http(response, decoder=decoder, **codec_kwargs)
+
+ def _request(self, method, uri, payload, headers=None, timeout=None):
+ if isinstance(payload, str):
+ payload = payload.encode("utf-8")
+
+ if headers is None:
+ headers = {}
+ headers.update({"Connection": "keep-alive"})
+
+ url = self.url(uri)
+
+ if self._last_request_is_blocked or self._has_unread_data():
+ self.close()
+
+ self.connection.request(method, url, payload, headers)
+
+ # timeout for request has to be set just before calling httplib.getresponse()
+ # and the previous value restored just after that, even on exception raised
+ try:
+ if timeout:
+ previous_timeout = self._conn.gettimeout()
+ self._conn.settimeout(timeout)
+ response = self.connection.getresponse()
+ finally:
+ if timeout:
+ self._conn.settimeout(previous_timeout)
+
+ return response
+
+ def _has_unread_data(self):
+ return self._conn and self._conn.sock and select.select([self._conn.sock], [], [], 0)[0]