diff options
Diffstat (limited to 'testing/marionette/client/marionette_driver/transport.py')
-rw-r--r-- | testing/marionette/client/marionette_driver/transport.py | 409 |
1 files changed, 409 insertions, 0 deletions
diff --git a/testing/marionette/client/marionette_driver/transport.py b/testing/marionette/client/marionette_driver/transport.py new file mode 100644 index 0000000000..aa63a69023 --- /dev/null +++ b/testing/marionette/client/marionette_driver/transport.py @@ -0,0 +1,409 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import json +import socket +import sys +import time +from threading import RLock + +import six + + +class SocketTimeout(object): + def __init__(self, socket_ctx, timeout): + self.socket_ctx = socket_ctx + self.timeout = timeout + self.old_timeout = None + + def __enter__(self): + self.old_timeout = self.socket_ctx.socket_timeout + self.socket_ctx.socket_timeout = self.timeout + + def __exit__(self, *args, **kwargs): + self.socket_ctx.socket_timeout = self.old_timeout + + +class Message(object): + def __init__(self, msgid): + self.id = msgid + + def __eq__(self, other): + return self.id == other.id + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + # pylint --py3k: W1641 + return hash(self.id) + + +class Command(Message): + TYPE = 0 + + def __init__(self, msgid, name, params): + Message.__init__(self, msgid) + self.name = name + self.params = params + + def __str__(self): + return "<Command id={0}, name={1}, params={2}>".format( + self.id, self.name, self.params + ) + + def to_msg(self): + msg = [Command.TYPE, self.id, self.name, self.params] + return json.dumps(msg) + + @staticmethod + def from_msg(data): + assert data[0] == Command.TYPE + cmd = Command(data[1], data[2], data[3]) + return cmd + + +class Response(Message): + TYPE = 1 + + def __init__(self, msgid, error, result): + Message.__init__(self, msgid) + self.error = error + self.result = result + + def __str__(self): + return "<Response id={0}, error={1}, result={2}>".format( + self.id, self.error, self.result + ) + + def to_msg(self): + msg = [Response.TYPE, self.id, self.error, self.result] + return json.dumps(msg) + + @staticmethod + def from_msg(data): + assert data[0] == Response.TYPE + return Response(data[1], data[2], data[3]) + + +class SocketContext(object): + """Object that guards access to a socket via a lock. + + The socket must be accessed using this object as a context manager; + access to the socket outside of a context will bypass the lock.""" + + def __init__(self, host, port, timeout): + self.lock = RLock() + + self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._sock.settimeout(timeout) + self._sock.connect((host, port)) + + @property + def socket_timeout(self): + return self._sock.gettimeout() + + @socket_timeout.setter + def socket_timeout(self, value): + self._sock.settimeout(value) + + def __enter__(self): + self.lock.acquire() + return self._sock + + def __exit__(self, *args, **kwargs): + self.lock.release() + + +class TcpTransport(object): + """Socket client that communciates with Marionette via TCP. + + It speaks the protocol of the remote debugger in Gecko, in which + messages are always preceded by the message length and a colon, e.g.: + + 7:MESSAGE + + On top of this protocol it uses a Marionette message format, that + depending on the protocol level offered by the remote server, varies. + Supported protocol levels are `min_protocol_level` and above. + """ + + max_packet_length = 4096 + min_protocol_level = 3 + + def __init__(self, host, port, socket_timeout=60.0): + """If `socket_timeout` is `0` or `0.0`, non-blocking socket mode + will be used. Setting it to `1` or `None` disables timeouts on + socket operations altogether. + """ + self._socket_context = None + + self.host = host + self.port = port + self._socket_timeout = socket_timeout + + self.protocol = self.min_protocol_level + self.application_type = None + self.last_id = 0 + self.expected_response = None + + @property + def socket_timeout(self): + return self._socket_timeout + + @socket_timeout.setter + def socket_timeout(self, value): + self._socket_timeout = value + + if self._socket_context is not None: + self._socket_context.socket_timeout = value + + def _unmarshal(self, packet): + """Convert data from bytes to a Message subtype + + Message format is [type, msg_id, body1, body2], where body1 and body2 depend + on the message type. + + :param packet: Bytes received over the wire representing a complete message. + """ + msg = None + + data = json.loads(packet) + msg_type = data[0] + + if msg_type == Command.TYPE: + msg = Command.from_msg(data) + elif msg_type == Response.TYPE: + msg = Response.from_msg(data) + else: + raise ValueError("Invalid message body {!r}".format(packet)) + + return msg + + def receive(self, unmarshal=True): + """Wait for the next complete response from the remote. + + Packet format is length-prefixed JSON: + + packet = digit+ ":" body + digit = "0"-"9" + body = JSON text + + :param unmarshal: Default is to deserialise the packet and + return a ``Message`` type. Setting this to false will return + the raw packet. + """ + # Initally we read 4 bytes. We don't support reading beyond the end of a message, and + # so assuming the JSON body has to be an array or object, the minimum possible message + # is 4 bytes: "2:{}". In practice the marionette format has some required fields so the + # message is longer, but 4 bytes allows reading messages with bodies up to 999 bytes in + # length in two reads, which is the common case. + with self._socket_context as sock: + recv_bytes = 4 + + length_prefix = b"" + + body_length = -1 + body_received = 0 + body_parts = [] + + now = time.time() + timeout_time = ( + now + self.socket_timeout if self.socket_timeout is not None else None + ) + + while recv_bytes > 0: + if timeout_time is not None and time.time() > timeout_time: + raise socket.timeout( + "Connection timed out after {}s".format(self.socket_timeout) + ) + + try: + chunk = sock.recv(recv_bytes) + except socket.timeout: + # Lets handle it with our own timeout check + continue + + if not chunk: + raise socket.error("No data received over socket") + + body_part = None + if body_length > 0: + body_part = chunk + else: + parts = chunk.split(b":", 1) + length_prefix += parts[0] + + # With > 10 decimal digits we aren't going to have a 32 bit number + if len(length_prefix) > 10: + raise ValueError( + "Invalid message length: {!r}".format(length_prefix) + ) + + if len(parts) == 2: + # We found a : so we know the full length + err = None + try: + body_length = int(length_prefix) + except ValueError: + err = "expected an integer" + else: + if body_length <= 0: + err = "expected a positive integer" + elif body_length > 2 ** 32 - 1: + err = "expected a 32 bit integer" + if err is not None: + raise ValueError( + "Invalid message length: {} got {!r}".format( + err, length_prefix + ) + ) + body_part = parts[1] + + # If we didn't find a : yet we keep reading 4 bytes at a time until we do. + # We could increase this here to 7 bytes (since we can't have more than 10 + # length bytes and a seperator byte), or just increase it to + # int(length_prefix) + 1 since that's the minimum total number of remaining + # bytes (if the : is in the next byte), but it's probably not worth optimising + # for large messages. + + if body_part is not None: + body_received += len(body_part) + body_parts.append(body_part) + recv_bytes = body_length - body_received + + body = b"".join(body_parts) + if unmarshal: + msg = self._unmarshal(body) + self.last_id = msg.id + + # keep reading incoming responses until + # we receive the user's expected response + if isinstance(msg, Response) and msg != self.expected_response: + return self.receive(unmarshal) + + return msg + return body + + def connect(self): + """Connect to the server and process the hello message we expect + to receive in response. + + Returns a tuple of the protocol level and the application type. + """ + try: + self._socket_context = SocketContext( + self.host, self.port, self._socket_timeout + ) + except Exception: + # Unset so that the next attempt to send will cause + # another connection attempt. + self._socket_context = None + raise + + try: + with SocketTimeout(self._socket_context, 60.0): + # first packet is always a JSON Object + # which we can use to tell which protocol level we are at + raw = self.receive(unmarshal=False) + except socket.timeout: + exc_cls, exc, tb = sys.exc_info() + msg = "Connection attempt failed because no data has been received over the socket: {}" + six.reraise(exc_cls, exc_cls(msg.format(exc)), tb) + + hello = json.loads(raw) + application_type = hello.get("applicationType") + protocol = hello.get("marionetteProtocol") + + if application_type != "gecko": + raise ValueError( + "Application type '{}' is not supported".format(application_type) + ) + + if not isinstance(protocol, int) or protocol < self.min_protocol_level: + msg = "Earliest supported protocol level is '{}' but got '{}'" + raise ValueError(msg.format(self.min_protocol_level, protocol)) + + self.application_type = application_type + self.protocol = protocol + + return (self.protocol, self.application_type) + + def send(self, obj): + """Send message to the remote server. Allowed input is a + ``Message`` instance or a JSON serialisable object. + """ + if not self._socket_context: + self.connect() + + if isinstance(obj, Message): + data = obj.to_msg() + if isinstance(obj, Command): + self.expected_response = obj + else: + data = json.dumps(obj) + data = six.ensure_binary(data) + payload = six.ensure_binary(str(len(data))) + b":" + data + + with self._socket_context as sock: + totalsent = 0 + while totalsent < len(payload): + sent = sock.send(payload[totalsent:]) + if sent == 0: + raise IOError( + "Socket error after sending {0} of {1} bytes".format( + totalsent, len(payload) + ) + ) + else: + totalsent += sent + + def respond(self, obj): + """Send a response to a command. This can be an arbitrary JSON + serialisable object or an ``Exception``. + """ + res, err = None, None + if isinstance(obj, Exception): + err = obj + else: + res = obj + msg = Response(self.last_id, err, res) + self.send(msg) + return self.receive() + + def request(self, name, params): + """Sends a message to the remote server and waits for a response + to come back. + """ + self.last_id = self.last_id + 1 + cmd = Command(self.last_id, name, params) + self.send(cmd) + return self.receive() + + def close(self): + """Close the socket. + + First forces the socket to not send data anymore, and then explicitly + close it to free up its resources. + + See: https://docs.python.org/2/howto/sockets.html#disconnecting + """ + if self._socket_context: + with self._socket_context as sock: + try: + sock.shutdown(socket.SHUT_RDWR) + except IOError as exc: + # If the socket is already closed, don't care about: + # Errno 57: Socket not connected + # Errno 107: Transport endpoint is not connected + if exc.errno not in (57, 107): + raise + + if sock: + # Guard against unclean shutdown. + sock.close() + self._socket_context = None + + def __del__(self): + self.close() |