summaryrefslogtreecommitdiffstats
path: root/testing/marionette/client/marionette_driver/transport.py
diff options
context:
space:
mode:
Diffstat (limited to 'testing/marionette/client/marionette_driver/transport.py')
-rw-r--r--testing/marionette/client/marionette_driver/transport.py409
1 files changed, 409 insertions, 0 deletions
diff --git a/testing/marionette/client/marionette_driver/transport.py b/testing/marionette/client/marionette_driver/transport.py
new file mode 100644
index 0000000000..aa63a69023
--- /dev/null
+++ b/testing/marionette/client/marionette_driver/transport.py
@@ -0,0 +1,409 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+import json
+import socket
+import sys
+import time
+from threading import RLock
+
+import six
+
+
+class SocketTimeout(object):
+ def __init__(self, socket_ctx, timeout):
+ self.socket_ctx = socket_ctx
+ self.timeout = timeout
+ self.old_timeout = None
+
+ def __enter__(self):
+ self.old_timeout = self.socket_ctx.socket_timeout
+ self.socket_ctx.socket_timeout = self.timeout
+
+ def __exit__(self, *args, **kwargs):
+ self.socket_ctx.socket_timeout = self.old_timeout
+
+
+class Message(object):
+ def __init__(self, msgid):
+ self.id = msgid
+
+ def __eq__(self, other):
+ return self.id == other.id
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __hash__(self):
+ # pylint --py3k: W1641
+ return hash(self.id)
+
+
+class Command(Message):
+ TYPE = 0
+
+ def __init__(self, msgid, name, params):
+ Message.__init__(self, msgid)
+ self.name = name
+ self.params = params
+
+ def __str__(self):
+ return "<Command id={0}, name={1}, params={2}>".format(
+ self.id, self.name, self.params
+ )
+
+ def to_msg(self):
+ msg = [Command.TYPE, self.id, self.name, self.params]
+ return json.dumps(msg)
+
+ @staticmethod
+ def from_msg(data):
+ assert data[0] == Command.TYPE
+ cmd = Command(data[1], data[2], data[3])
+ return cmd
+
+
+class Response(Message):
+ TYPE = 1
+
+ def __init__(self, msgid, error, result):
+ Message.__init__(self, msgid)
+ self.error = error
+ self.result = result
+
+ def __str__(self):
+ return "<Response id={0}, error={1}, result={2}>".format(
+ self.id, self.error, self.result
+ )
+
+ def to_msg(self):
+ msg = [Response.TYPE, self.id, self.error, self.result]
+ return json.dumps(msg)
+
+ @staticmethod
+ def from_msg(data):
+ assert data[0] == Response.TYPE
+ return Response(data[1], data[2], data[3])
+
+
+class SocketContext(object):
+ """Object that guards access to a socket via a lock.
+
+ The socket must be accessed using this object as a context manager;
+ access to the socket outside of a context will bypass the lock."""
+
+ def __init__(self, host, port, timeout):
+ self.lock = RLock()
+
+ self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self._sock.settimeout(timeout)
+ self._sock.connect((host, port))
+
+ @property
+ def socket_timeout(self):
+ return self._sock.gettimeout()
+
+ @socket_timeout.setter
+ def socket_timeout(self, value):
+ self._sock.settimeout(value)
+
+ def __enter__(self):
+ self.lock.acquire()
+ return self._sock
+
+ def __exit__(self, *args, **kwargs):
+ self.lock.release()
+
+
+class TcpTransport(object):
+ """Socket client that communciates with Marionette via TCP.
+
+ It speaks the protocol of the remote debugger in Gecko, in which
+ messages are always preceded by the message length and a colon, e.g.:
+
+ 7:MESSAGE
+
+ On top of this protocol it uses a Marionette message format, that
+ depending on the protocol level offered by the remote server, varies.
+ Supported protocol levels are `min_protocol_level` and above.
+ """
+
+ max_packet_length = 4096
+ min_protocol_level = 3
+
+ def __init__(self, host, port, socket_timeout=60.0):
+ """If `socket_timeout` is `0` or `0.0`, non-blocking socket mode
+ will be used. Setting it to `1` or `None` disables timeouts on
+ socket operations altogether.
+ """
+ self._socket_context = None
+
+ self.host = host
+ self.port = port
+ self._socket_timeout = socket_timeout
+
+ self.protocol = self.min_protocol_level
+ self.application_type = None
+ self.last_id = 0
+ self.expected_response = None
+
+ @property
+ def socket_timeout(self):
+ return self._socket_timeout
+
+ @socket_timeout.setter
+ def socket_timeout(self, value):
+ self._socket_timeout = value
+
+ if self._socket_context is not None:
+ self._socket_context.socket_timeout = value
+
+ def _unmarshal(self, packet):
+ """Convert data from bytes to a Message subtype
+
+ Message format is [type, msg_id, body1, body2], where body1 and body2 depend
+ on the message type.
+
+ :param packet: Bytes received over the wire representing a complete message.
+ """
+ msg = None
+
+ data = json.loads(packet)
+ msg_type = data[0]
+
+ if msg_type == Command.TYPE:
+ msg = Command.from_msg(data)
+ elif msg_type == Response.TYPE:
+ msg = Response.from_msg(data)
+ else:
+ raise ValueError("Invalid message body {!r}".format(packet))
+
+ return msg
+
+ def receive(self, unmarshal=True):
+ """Wait for the next complete response from the remote.
+
+ Packet format is length-prefixed JSON:
+
+ packet = digit+ ":" body
+ digit = "0"-"9"
+ body = JSON text
+
+ :param unmarshal: Default is to deserialise the packet and
+ return a ``Message`` type. Setting this to false will return
+ the raw packet.
+ """
+ # Initally we read 4 bytes. We don't support reading beyond the end of a message, and
+ # so assuming the JSON body has to be an array or object, the minimum possible message
+ # is 4 bytes: "2:{}". In practice the marionette format has some required fields so the
+ # message is longer, but 4 bytes allows reading messages with bodies up to 999 bytes in
+ # length in two reads, which is the common case.
+ with self._socket_context as sock:
+ recv_bytes = 4
+
+ length_prefix = b""
+
+ body_length = -1
+ body_received = 0
+ body_parts = []
+
+ now = time.time()
+ timeout_time = (
+ now + self.socket_timeout if self.socket_timeout is not None else None
+ )
+
+ while recv_bytes > 0:
+ if timeout_time is not None and time.time() > timeout_time:
+ raise socket.timeout(
+ "Connection timed out after {}s".format(self.socket_timeout)
+ )
+
+ try:
+ chunk = sock.recv(recv_bytes)
+ except socket.timeout:
+ # Lets handle it with our own timeout check
+ continue
+
+ if not chunk:
+ raise socket.error("No data received over socket")
+
+ body_part = None
+ if body_length > 0:
+ body_part = chunk
+ else:
+ parts = chunk.split(b":", 1)
+ length_prefix += parts[0]
+
+ # With > 10 decimal digits we aren't going to have a 32 bit number
+ if len(length_prefix) > 10:
+ raise ValueError(
+ "Invalid message length: {!r}".format(length_prefix)
+ )
+
+ if len(parts) == 2:
+ # We found a : so we know the full length
+ err = None
+ try:
+ body_length = int(length_prefix)
+ except ValueError:
+ err = "expected an integer"
+ else:
+ if body_length <= 0:
+ err = "expected a positive integer"
+ elif body_length > 2 ** 32 - 1:
+ err = "expected a 32 bit integer"
+ if err is not None:
+ raise ValueError(
+ "Invalid message length: {} got {!r}".format(
+ err, length_prefix
+ )
+ )
+ body_part = parts[1]
+
+ # If we didn't find a : yet we keep reading 4 bytes at a time until we do.
+ # We could increase this here to 7 bytes (since we can't have more than 10
+ # length bytes and a seperator byte), or just increase it to
+ # int(length_prefix) + 1 since that's the minimum total number of remaining
+ # bytes (if the : is in the next byte), but it's probably not worth optimising
+ # for large messages.
+
+ if body_part is not None:
+ body_received += len(body_part)
+ body_parts.append(body_part)
+ recv_bytes = body_length - body_received
+
+ body = b"".join(body_parts)
+ if unmarshal:
+ msg = self._unmarshal(body)
+ self.last_id = msg.id
+
+ # keep reading incoming responses until
+ # we receive the user's expected response
+ if isinstance(msg, Response) and msg != self.expected_response:
+ return self.receive(unmarshal)
+
+ return msg
+ return body
+
+ def connect(self):
+ """Connect to the server and process the hello message we expect
+ to receive in response.
+
+ Returns a tuple of the protocol level and the application type.
+ """
+ try:
+ self._socket_context = SocketContext(
+ self.host, self.port, self._socket_timeout
+ )
+ except Exception:
+ # Unset so that the next attempt to send will cause
+ # another connection attempt.
+ self._socket_context = None
+ raise
+
+ try:
+ with SocketTimeout(self._socket_context, 60.0):
+ # first packet is always a JSON Object
+ # which we can use to tell which protocol level we are at
+ raw = self.receive(unmarshal=False)
+ except socket.timeout:
+ exc_cls, exc, tb = sys.exc_info()
+ msg = "Connection attempt failed because no data has been received over the socket: {}"
+ six.reraise(exc_cls, exc_cls(msg.format(exc)), tb)
+
+ hello = json.loads(raw)
+ application_type = hello.get("applicationType")
+ protocol = hello.get("marionetteProtocol")
+
+ if application_type != "gecko":
+ raise ValueError(
+ "Application type '{}' is not supported".format(application_type)
+ )
+
+ if not isinstance(protocol, int) or protocol < self.min_protocol_level:
+ msg = "Earliest supported protocol level is '{}' but got '{}'"
+ raise ValueError(msg.format(self.min_protocol_level, protocol))
+
+ self.application_type = application_type
+ self.protocol = protocol
+
+ return (self.protocol, self.application_type)
+
+ def send(self, obj):
+ """Send message to the remote server. Allowed input is a
+ ``Message`` instance or a JSON serialisable object.
+ """
+ if not self._socket_context:
+ self.connect()
+
+ if isinstance(obj, Message):
+ data = obj.to_msg()
+ if isinstance(obj, Command):
+ self.expected_response = obj
+ else:
+ data = json.dumps(obj)
+ data = six.ensure_binary(data)
+ payload = six.ensure_binary(str(len(data))) + b":" + data
+
+ with self._socket_context as sock:
+ totalsent = 0
+ while totalsent < len(payload):
+ sent = sock.send(payload[totalsent:])
+ if sent == 0:
+ raise IOError(
+ "Socket error after sending {0} of {1} bytes".format(
+ totalsent, len(payload)
+ )
+ )
+ else:
+ totalsent += sent
+
+ def respond(self, obj):
+ """Send a response to a command. This can be an arbitrary JSON
+ serialisable object or an ``Exception``.
+ """
+ res, err = None, None
+ if isinstance(obj, Exception):
+ err = obj
+ else:
+ res = obj
+ msg = Response(self.last_id, err, res)
+ self.send(msg)
+ return self.receive()
+
+ def request(self, name, params):
+ """Sends a message to the remote server and waits for a response
+ to come back.
+ """
+ self.last_id = self.last_id + 1
+ cmd = Command(self.last_id, name, params)
+ self.send(cmd)
+ return self.receive()
+
+ def close(self):
+ """Close the socket.
+
+ First forces the socket to not send data anymore, and then explicitly
+ close it to free up its resources.
+
+ See: https://docs.python.org/2/howto/sockets.html#disconnecting
+ """
+ if self._socket_context:
+ with self._socket_context as sock:
+ try:
+ sock.shutdown(socket.SHUT_RDWR)
+ except IOError as exc:
+ # If the socket is already closed, don't care about:
+ # Errno 57: Socket not connected
+ # Errno 107: Transport endpoint is not connected
+ if exc.errno not in (57, 107):
+ raise
+
+ if sock:
+ # Guard against unclean shutdown.
+ sock.close()
+ self._socket_context = None
+
+ def __del__(self):
+ self.close()