diff options
Diffstat (limited to 'testing/marionette/harness/marionette_harness/runner')
6 files changed, 1978 insertions, 0 deletions
diff --git a/testing/marionette/harness/marionette_harness/runner/__init__.py b/testing/marionette/harness/marionette_harness/runner/__init__.py new file mode 100644 index 0000000000..2fdac637d3 --- /dev/null +++ b/testing/marionette/harness/marionette_harness/runner/__init__.py @@ -0,0 +1,16 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +from .base import ( + BaseMarionetteArguments, + BaseMarionetteTestRunner, + Marionette, + MarionetteTest, + MarionetteTestResult, + MarionetteTextTestRunner, + TestManifest, + TestResult, + TestResultCollection, +) +from .mixins import WindowManagerMixin diff --git a/testing/marionette/harness/marionette_harness/runner/base.py b/testing/marionette/harness/marionette_harness/runner/base.py new file mode 100644 index 0000000000..b5ddc2d788 --- /dev/null +++ b/testing/marionette/harness/marionette_harness/runner/base.py @@ -0,0 +1,1265 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import json +import os +import random +import re +import socket +import sys +import time +import traceback +import unittest +from argparse import ArgumentParser +from collections import defaultdict +from copy import deepcopy + +import mozinfo +import moznetwork +import mozprofile +import mozversion +import six +from manifestparser import TestManifest +from manifestparser.filters import tags +from marionette_driver.marionette import Marionette +from moztest.adapters.unit import StructuredTestResult, StructuredTestRunner +from moztest.results import TestResult, TestResultCollection, relevant_line +from six import MAXSIZE, reraise + +from . import serve + +here = os.path.abspath(os.path.dirname(__file__)) + + +def update_mozinfo(path=None): + """Walk up directories to find mozinfo.json and update the info.""" + path = path or here + dirs = set() + while path != os.path.expanduser("~"): + if path in dirs: + break + dirs.add(path) + path = os.path.split(path)[0] + + return mozinfo.find_and_update_from_json(*dirs) + + +class MarionetteTest(TestResult): + @property + def test_name(self): + if self.test_class is not None: + return "{0}.py {1}.{2}".format( + self.test_class.split(".")[0], self.test_class, self.name + ) + else: + return self.name + + +class MarionetteTestResult(StructuredTestResult, TestResultCollection): + resultClass = MarionetteTest + + def __init__(self, *args, **kwargs): + self.marionette = kwargs.pop("marionette") + TestResultCollection.__init__(self, "MarionetteTest") + self.passed = 0 + self.testsRun = 0 + self.result_modifiers = [] # used by mixins to modify the result + StructuredTestResult.__init__(self, *args, **kwargs) + + @property + def skipped(self): + return [t for t in self if t.result == "SKIPPED"] + + @skipped.setter + def skipped(self, value): + pass + + @property + def expectedFailures(self): + return [t for t in self if t.result == "KNOWN-FAIL"] + + @expectedFailures.setter + def expectedFailures(self, value): + pass + + @property + def unexpectedSuccesses(self): + return [t for t in self if t.result == "UNEXPECTED-PASS"] + + @unexpectedSuccesses.setter + def unexpectedSuccesses(self, value): + pass + + @property + def tests_passed(self): + return [t for t in self if t.result == "PASS"] + + @property + def errors(self): + return [t for t in self if t.result == "ERROR"] + + @errors.setter + def errors(self, value): + pass + + @property + def failures(self): + return [t for t in self if t.result == "UNEXPECTED-FAIL"] + + @failures.setter + def failures(self, value): + pass + + @property + def duration(self): + if self.stop_time: + return self.stop_time - self.start_time + else: + return 0 + + def add_test_result( + self, + test, + result_expected="PASS", + result_actual="PASS", + output="", + context=None, + **kwargs + ): + def get_class(test): + return test.__class__.__module__ + "." + test.__class__.__name__ + + name = str(test).split()[0] + test_class = get_class(test) + if hasattr(test, "jsFile"): + name = os.path.basename(test.jsFile) + test_class = None + + t = self.resultClass( + name=name, + test_class=test_class, + time_start=test.start_time, + result_expected=result_expected, + context=context, + **kwargs + ) + # call any registered result modifiers + for modifier in self.result_modifiers: + result_expected, result_actual, output, context = modifier( + t, result_expected, result_actual, output, context + ) + t.finish( + result_actual, + time_end=time.time() if test.start_time else 0, + reason=relevant_line(output), + output=output, + ) + self.append(t) + + def addError(self, test, err): + self.add_test_result( + test, output=self._exc_info_to_string(err, test), result_actual="ERROR" + ) + super(MarionetteTestResult, self).addError(test, err) + + def addFailure(self, test, err): + self.add_test_result( + test, + output=self._exc_info_to_string(err, test), + result_actual="UNEXPECTED-FAIL", + ) + super(MarionetteTestResult, self).addFailure(test, err) + + def addSuccess(self, test): + self.passed += 1 + self.add_test_result(test, result_actual="PASS") + super(MarionetteTestResult, self).addSuccess(test) + + def addExpectedFailure(self, test, err): + """Called when an expected failure/error occured.""" + self.add_test_result( + test, output=self._exc_info_to_string(err, test), result_actual="KNOWN-FAIL" + ) + super(MarionetteTestResult, self).addExpectedFailure(test, err) + + def addUnexpectedSuccess(self, test): + """Called when a test was expected to fail, but succeed.""" + self.add_test_result(test, result_actual="UNEXPECTED-PASS") + super(MarionetteTestResult, self).addUnexpectedSuccess(test) + + def addSkip(self, test, reason): + self.add_test_result(test, output=reason, result_actual="SKIPPED") + super(MarionetteTestResult, self).addSkip(test, reason) + + def getInfo(self, test): + return test.test_name + + def getDescription(self, test): + doc_first_line = test.shortDescription() + if self.descriptions and doc_first_line: + return "\n".join((str(test), doc_first_line)) + else: + desc = str(test) + return desc + + def printLogs(self, test): + for testcase in test._tests: + if hasattr(testcase, "loglines") and testcase.loglines: + # Don't dump loglines to the console if they only contain + # TEST-START and TEST-END. + skip_log = True + for line in testcase.loglines: + str_line = " ".join(line) + if "TEST-END" not in str_line and "TEST-START" not in str_line: + skip_log = False + break + if skip_log: + return + self.logger.info("START LOG:") + for line in testcase.loglines: + self.logger.info(" ".join(line).encode("ascii", "replace")) + self.logger.info("END LOG:") + + def stopTest(self, *args, **kwargs): + unittest._TextTestResult.stopTest(self, *args, **kwargs) + if self.marionette.check_for_crash(): + # this tells unittest.TestSuite not to continue running tests + self.shouldStop = True + test = next((a for a in args if isinstance(a, unittest.TestCase)), None) + if test: + self.addError(test, sys.exc_info()) + + +class MarionetteTextTestRunner(StructuredTestRunner): + resultclass = MarionetteTestResult + + def __init__(self, **kwargs): + self.marionette = kwargs.pop("marionette") + self.capabilities = kwargs.pop("capabilities") + + StructuredTestRunner.__init__(self, **kwargs) + + def _makeResult(self): + return self.resultclass( + self.stream, + self.descriptions, + self.verbosity, + marionette=self.marionette, + logger=self.logger, + result_callbacks=self.result_callbacks, + ) + + def run(self, test): + result = super(MarionetteTextTestRunner, self).run(test) + result.printLogs(test) + return result + + +class BaseMarionetteArguments(ArgumentParser): + def __init__(self, **kwargs): + ArgumentParser.__init__(self, **kwargs) + + def dir_path(path): + path = os.path.abspath(os.path.expanduser(path)) + if not os.access(path, os.F_OK): + os.makedirs(path) + return path + + self.argument_containers = [] + self.add_argument( + "tests", + nargs="*", + default=[], + help="Tests to run. " + "One or more paths to test files (Python or JS), " + "manifest files (.toml) or directories. " + "When a directory is specified, " + "all test files in the directory will be run.", + ) + self.add_argument( + "--binary", + help="path to gecko executable to launch before running the test", + ) + self.add_argument( + "--address", help="host:port of running Gecko instance to connect to" + ) + self.add_argument( + "--emulator", + action="store_true", + help="If no --address is given, then the harness will launch an " + "emulator. (See Remote options group.) " + "If --address is given, then the harness assumes you are " + "running an emulator already, and will launch gecko app " + "on that emulator.", + ) + self.add_argument( + "--app", help="application to use. see marionette_driver.geckoinstance" + ) + self.add_argument( + "--app-arg", + dest="app_args", + action="append", + default=[], + help="specify a command line argument to be passed onto the application", + ) + self.add_argument( + "--profile", + help="profile to use when launching the gecko process. If not passed, " + "then a profile will be constructed and used", + type=dir_path, + ) + self.add_argument( + "--setpref", + action="append", + metavar="PREF=VALUE", + dest="prefs_args", + help="set a browser preference; repeat for multiple preferences.", + ) + self.add_argument( + "--preferences", + action="append", + dest="prefs_files", + help="read preferences from a JSON or TOML file. For TOML, use " + "'file.toml:section' to specify a particular section.", + ) + self.add_argument( + "--addon", + action="append", + dest="addons", + help="addon to install; repeat for multiple addons.", + ) + self.add_argument( + "--repeat", type=int, help="number of times to repeat the test(s)" + ) + self.add_argument( + "--run-until-failure", + action="store_true", + help="Run tests repeatedly and stop on the first time a test fails. " + "Default cap is 30 runs, which can be overwritten " + "with the --repeat parameter.", + ) + self.add_argument( + "--testvars", + action="append", + help="path to a json file with any test data required", + ) + self.add_argument( + "--symbols-path", + help="absolute path to directory containing breakpad symbols, or the " + "url of a zip file containing symbols", + ) + self.add_argument( + "--socket-timeout", + type=float, + default=Marionette.DEFAULT_SOCKET_TIMEOUT, + help="Set the global timeout for marionette socket operations." + " Default: %(default)ss.", + ) + self.add_argument( + "--startup-timeout", + type=int, + default=Marionette.DEFAULT_STARTUP_TIMEOUT, + help="the max number of seconds to wait for a Marionette connection " + "after launching a binary. Default: %(default)ss.", + ) + self.add_argument( + "--shuffle", + action="store_true", + default=False, + help="run tests in a random order", + ) + self.add_argument( + "--shuffle-seed", + type=int, + default=random.randint(0, MAXSIZE), + help="Use given seed to shuffle tests", + ) + self.add_argument( + "--total-chunks", + type=int, + help="how many chunks to split the tests up into", + ) + self.add_argument("--this-chunk", type=int, help="which chunk to run") + self.add_argument( + "--server-root", + help="url to a webserver or path to a document root from which content " + "resources are served (default: {}).".format( + os.path.join(os.path.dirname(here), "www") + ), + ) + self.add_argument( + "--gecko-log", + help="Define the path to store log file. If the path is" + " a directory, the real log file will be created" + " given the format gecko-(timestamp).log. If it is" + " a file, if will be used directly. '-' may be passed" + " to write to stdout. Default: './gecko.log'", + ) + self.add_argument( + "--logger-name", + default="Marionette-based Tests", + help="Define the name to associate with the logger used", + ) + self.add_argument( + "--jsdebugger", + action="store_true", + default=False, + help="Enable the jsdebugger for marionette javascript.", + ) + self.add_argument( + "--pydebugger", + help="Enable python post-mortem debugger when a test fails." + " Pass in the debugger you want to use, eg pdb or ipdb.", + ) + self.add_argument( + "--disable-fission", + action="store_true", + dest="disable_fission", + default=False, + help="Disable Fission (site isolation) in Gecko.", + ) + self.add_argument( + "-z", + "--headless", + action="store_true", + dest="headless", + default=os.environ.get("MOZ_HEADLESS", False), + help="Run tests in headless mode.", + ) + self.add_argument( + "--tag", + action="append", + dest="test_tags", + default=None, + help="Filter out tests that don't have the given tag. Can be " + "used multiple times in which case the test must contain " + "at least one of the given tags.", + ) + self.add_argument( + "--workspace", + action="store", + default=None, + help="Path to directory for Marionette output. " + "(Default: .) (Default profile dest: TMP)", + type=dir_path, + ) + self.add_argument( + "-v", + "--verbose", + action="count", + help="Increase verbosity to include debug messages with -v, " + "and trace messages with -vv.", + ) + self.register_argument_container(RemoteMarionetteArguments()) + + def register_argument_container(self, container): + group = self.add_argument_group(container.name) + + for cli, kwargs in container.args: + group.add_argument(*cli, **kwargs) + + self.argument_containers.append(container) + + def parse_known_args(self, args=None, namespace=None): + args, remainder = ArgumentParser.parse_known_args(self, args, namespace) + for container in self.argument_containers: + if hasattr(container, "parse_args_handler"): + container.parse_args_handler(args) + return (args, remainder) + + def _get_preferences(self, prefs_files, prefs_args): + """Return user defined profile preferences as a dict.""" + # object that will hold the preferences + prefs = mozprofile.prefs.Preferences() + + # add preferences files + if prefs_files: + for prefs_file in prefs_files: + prefs.add_file(prefs_file) + + separator = "=" + cli_prefs = [] + if prefs_args: + misformatted = [] + for pref in prefs_args: + if separator not in pref: + misformatted.append(pref) + else: + cli_prefs.append(pref.split(separator, 1)) + if misformatted: + self._print_message( + "Warning: Ignoring preferences not in key{}value format: {}\n".format( + separator, ", ".join(misformatted) + ) + ) + # string preferences + prefs.add(cli_prefs, cast=True) + + return dict(prefs()) + + def verify_usage(self, args): + if not args.tests: + self.error( + "You must specify one or more test files, manifests, or directories." + ) + + missing_tests = [path for path in args.tests if not os.path.exists(path)] + if missing_tests: + self.error( + "Test file(s) not found: " + " ".join([path for path in missing_tests]) + ) + + if not args.address and not args.binary and not args.emulator: + self.error("You must specify --binary, or --address, or --emulator") + + if args.repeat is not None and args.repeat < 0: + self.error("The value of --repeat has to be equal or greater than 0.") + + if args.total_chunks is not None and args.this_chunk is None: + self.error("You must specify which chunk to run.") + + if args.this_chunk is not None and args.total_chunks is None: + self.error("You must specify how many chunks to split the tests into.") + + if args.total_chunks is not None: + if not 1 < args.total_chunks: + self.error("Total chunks must be greater than 1.") + if not 1 <= args.this_chunk <= args.total_chunks: + self.error( + "Chunk to run must be between 1 and {}.".format(args.total_chunks) + ) + + if args.jsdebugger: + args.app_args.append("-jsdebugger") + args.socket_timeout = None + + args.prefs = self._get_preferences(args.prefs_files, args.prefs_args) + + for container in self.argument_containers: + if hasattr(container, "verify_usage_handler"): + container.verify_usage_handler(args) + + return args + + +class RemoteMarionetteArguments(object): + name = "Remote (Emulator/Device)" + args = [ + [ + ["--emulator-binary"], + { + "help": "Path to emulator binary. By default mozrunner uses `which emulator`", + "dest": "emulator_bin", + }, + ], + [ + ["--adb"], + { + "help": "Path to the adb. By default mozrunner uses `which adb`", + "dest": "adb_path", + }, + ], + [ + ["--avd"], + { + "help": ( + "Name of an AVD available in your environment." + "See mozrunner.FennecEmulatorRunner" + ), + }, + ], + [ + ["--avd-home"], + { + "help": "Path to avd parent directory", + }, + ], + [ + ["--device"], + { + "help": ( + "Serial ID to connect to as seen in `adb devices`," + "e.g emulator-5444" + ), + "dest": "device_serial", + }, + ], + [ + ["--package"], + { + "help": "Name of Android package, e.g. org.mozilla.fennec", + "dest": "package_name", + }, + ], + ] + + +class Fixtures(object): + def where_is(self, uri, on="http"): + return serve.where_is(uri, on) + + +class BaseMarionetteTestRunner(object): + textrunnerclass = MarionetteTextTestRunner + driverclass = Marionette + + def __init__( + self, + address=None, + app=None, + app_args=None, + binary=None, + profile=None, + logger=None, + logdir=None, + repeat=None, + run_until_failure=None, + testvars=None, + symbols_path=None, + shuffle=False, + shuffle_seed=random.randint(0, MAXSIZE), + this_chunk=1, + total_chunks=1, + server_root=None, + gecko_log=None, + result_callbacks=None, + prefs=None, + test_tags=None, + socket_timeout=None, + startup_timeout=None, + addons=None, + workspace=None, + verbose=0, + emulator=False, + headless=False, + disable_fission=False, + **kwargs + ): + self._appName = None + self._capabilities = None + self._filename_pattern = None + self._version_info = {} + + self.fixture_servers = {} + self.fixtures = Fixtures() + self.extra_kwargs = kwargs + self.test_kwargs = deepcopy(kwargs) + self.address = address + self.app = app + self.app_args = app_args or [] + self.bin = binary + self.emulator = emulator + self.profile = profile + self.addons = addons + self.logger = logger + self.marionette = None + self.logdir = logdir + self.repeat = repeat or 0 + self.run_until_failure = run_until_failure or False + self.symbols_path = symbols_path + self.socket_timeout = socket_timeout + self.startup_timeout = startup_timeout + self.shuffle = shuffle + self.shuffle_seed = shuffle_seed + self.server_root = server_root + self.this_chunk = this_chunk + self.total_chunks = total_chunks + self.mixin_run_tests = [] + self.manifest_skipped_tests = [] + self.tests = [] + self.result_callbacks = result_callbacks or [] + self.prefs = prefs or {} + self.test_tags = test_tags + self.workspace = workspace + # If no workspace is set, default location for gecko.log is . + # and default location for profile is TMP + self.workspace_path = workspace or os.getcwd() + self.verbose = verbose + self.headless = headless + + self.prefs.update({"fission.autostart": not disable_fission}) + + # If no repeat has been set, default to 30 extra runs + if self.run_until_failure and repeat is None: + self.repeat = 30 + + def gather_debug(test, status): + # No screenshots and page source for skipped tests + if status == "SKIP": + return + + rv = {} + marionette = test._marionette_weakref() + + # In the event we're gathering debug without starting a session, + # skip marionette commands + if marionette.session is not None: + try: + with marionette.using_context(marionette.CONTEXT_CHROME): + rv["screenshot"] = marionette.screenshot() + with marionette.using_context(marionette.CONTEXT_CONTENT): + rv["source"] = marionette.page_source + except Exception as exc: + self.logger.warning( + "Failed to gather test failure debug: {}".format(exc) + ) + return rv + + self.result_callbacks.append(gather_debug) + + # testvars are set up in self.testvars property + self._testvars = None + self.testvars_paths = testvars + + self.test_handlers = [] + + self.reset_test_stats() + + self.logger.info( + "Using workspace for temporary data: " '"{}"'.format(self.workspace_path) + ) + + if not gecko_log: + self.gecko_log = os.path.join(self.workspace_path or "", "gecko.log") + else: + self.gecko_log = gecko_log + + self.results = [] + + @property + def filename_pattern(self): + if self._filename_pattern is None: + self._filename_pattern = re.compile("^test(((_.+?)+?\.((py))))$") + + return self._filename_pattern + + @property + def testvars(self): + if self._testvars is not None: + return self._testvars + + self._testvars = {} + + def update(d, u): + """Update a dictionary that may contain nested dictionaries.""" + for k, v in six.iteritems(u): + o = d.get(k, {}) + if isinstance(v, dict) and isinstance(o, dict): + d[k] = update(d.get(k, {}), v) + else: + d[k] = u[k] + return d + + json_testvars = self._load_testvars() + for j in json_testvars: + self._testvars = update(self._testvars, j) + return self._testvars + + def _load_testvars(self): + data = [] + if self.testvars_paths is not None: + for path in list(self.testvars_paths): + path = os.path.abspath(os.path.expanduser(path)) + if not os.path.exists(path): + raise IOError("--testvars file {} does not exist".format(path)) + try: + with open(path) as f: + data.append(json.loads(f.read())) + except ValueError as e: + msg = "JSON file ({0}) is not properly formatted: {1}" + reraise( + ValueError, + ValueError(msg.format(os.path.abspath(path), e)), + sys.exc_info()[2], + ) + return data + + @property + def capabilities(self): + if self._capabilities: + return self._capabilities + + self.marionette.start_session() + self._capabilities = self.marionette.session_capabilities + self.marionette.delete_session() + return self._capabilities + + @property + def appName(self): + if self._appName: + return self._appName + + self._appName = self.capabilities.get("browserName") + return self._appName + + @property + def bin(self): + return self._bin + + @bin.setter + def bin(self, path): + """Set binary and reset parts of runner accordingly. + Intended use: to change binary between calls to run_tests + """ + self._bin = path + self.tests = [] + self.cleanup() + + @property + def version_info(self): + if not self._version_info: + try: + # TODO: Get version_info in Fennec case + self._version_info = mozversion.get_version(binary=self.bin) + except Exception: + self.logger.warning( + "Failed to retrieve version information for {}".format(self.bin) + ) + return self._version_info + + def reset_test_stats(self): + self.passed = 0 + self.failed = 0 + self.crashed = 0 + self.unexpected_successes = 0 + self.todo = 0 + self.skipped = 0 + self.failures = [] + + def _build_kwargs(self): + if self.logdir and not os.access(self.logdir, os.F_OK): + os.mkdir(self.logdir) + + kwargs = { + "socket_timeout": self.socket_timeout, + "prefs": self.prefs, + "startup_timeout": self.startup_timeout, + "verbose": self.verbose, + "symbols_path": self.symbols_path, + } + if self.bin or self.emulator: + kwargs.update( + { + "host": "127.0.0.1", + "port": 2828, + "app": self.app, + "app_args": self.app_args, + "profile": self.profile, + "addons": self.addons, + "gecko_log": self.gecko_log, + # ensure Marionette class takes care of starting gecko instance + "bin": True, + } + ) + + if self.bin: + kwargs.update( + { + "bin": self.bin, + } + ) + + if self.emulator: + kwargs.update( + { + "avd_home": self.extra_kwargs.get("avd_home"), + "adb_path": self.extra_kwargs.get("adb_path"), + "emulator_binary": self.extra_kwargs.get("emulator_bin"), + "avd": self.extra_kwargs.get("avd"), + "package_name": self.extra_kwargs.get("package_name"), + } + ) + + if self.address: + host, port = self.address.split(":") + kwargs.update( + { + "host": host, + "port": int(port), + } + ) + if self.emulator: + kwargs.update( + { + "connect_to_running_emulator": True, + } + ) + if not self.bin and not self.emulator: + try: + # Establish a socket connection so we can vertify the data come back + connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + connection.connect((host, int(port))) + connection.close() + except Exception as e: + exc_cls, _, tb = sys.exc_info() + msg = "Connection attempt to {0}:{1} failed with error: {2}" + reraise(exc_cls, exc_cls(msg.format(host, port, e)), tb) + if self.workspace: + kwargs["workspace"] = self.workspace_path + if self.headless: + kwargs["headless"] = True + + return kwargs + + def record_crash(self): + crash = True + try: + crash = self.marionette.check_for_crash() + self.crashed += int(crash) + except Exception: + traceback.print_exc() + return crash + + def _initialize_test_run(self, tests): + assert len(tests) > 0 + assert len(self.test_handlers) > 0 + self.reset_test_stats() + + def _add_tests(self, tests): + for test in tests: + self.add_test(test) + + invalid_tests = [ + t["filepath"] + for t in self.tests + if not self._is_filename_valid(t["filepath"]) + ] + if invalid_tests: + raise Exception( + "Test file names must be of the form " + "'test_something.py'." + " Invalid test names:\n {}".format("\n ".join(invalid_tests)) + ) + + def _is_filename_valid(self, filename): + filename = os.path.basename(filename) + return self.filename_pattern.match(filename) + + def _fix_test_path(self, path): + """Normalize a logged test path from the test package.""" + test_path_prefixes = [ + "tests{}".format(os.path.sep), + ] + + path = os.path.relpath(path) + for prefix in test_path_prefixes: + if path.startswith(prefix): + path = path[len(prefix) :] + break + path = path.replace("\\", "/") + + return path + + def _log_skipped_tests(self): + for test in self.manifest_skipped_tests: + rel_path = None + if os.path.exists(test["path"]): + rel_path = self._fix_test_path(test["path"]) + + self.logger.test_start(rel_path) + self.logger.test_end(rel_path, "SKIP", message=test["disabled"]) + self.todo += 1 + + def run_tests(self, tests): + start_time = time.time() + self._initialize_test_run(tests) + + if self.marionette is None: + self.marionette = self.driverclass(**self._build_kwargs()) + self.logger.info("Profile path is %s" % self.marionette.profile_path) + + if len(self.fixture_servers) == 0 or any( + not server.is_alive for _, server in self.fixture_servers + ): + self.logger.info("Starting fixture servers") + self.fixture_servers = self.start_fixture_servers() + for url in serve.iter_url(self.fixture_servers): + self.logger.info("Fixture server listening on %s" % url) + + # backwards compatibility + self.marionette.baseurl = serve.where_is("/") + + self._add_tests(tests) + + device_info = None + if self.marionette.instance and self.emulator: + try: + device_info = self.marionette.instance.runner.device.device.get_info() + except Exception: + self.logger.warning("Could not get device info", exc_info=True) + + tests_by_group = defaultdict(list) + for test in self.tests: + group = self._fix_test_path(test["group"]) + filepath = self._fix_test_path(test["filepath"]) + tests_by_group[group].append(filepath) + + self.logger.suite_start( + tests_by_group, + name="marionette-test", + version_info=self.version_info, + device_info=device_info, + ) + + if self.shuffle: + self.logger.info("Using shuffle seed: %d" % self.shuffle_seed) + + self._log_skipped_tests() + + interrupted = None + try: + repeat_index = 0 + while repeat_index <= self.repeat: + if repeat_index > 0: + self.logger.info("\nREPEAT {}\n-------".format(repeat_index)) + self.run_test_sets() + if self.run_until_failure and self.failed > 0: + break + + repeat_index += 1 + + except KeyboardInterrupt: + # in case of KeyboardInterrupt during the test execution + # we want to display current test results. + # so we keep the exception to raise it later. + interrupted = sys.exc_info() + except Exception: + # For any other exception we return immediately and have to + # cleanup running processes + self.cleanup() + raise + + try: + self._print_summary(tests) + self.record_crash() + self.elapsedtime = time.time() - start_time + + for run_tests in self.mixin_run_tests: + run_tests(tests) + + self.logger.suite_end() + except Exception: + # raise only the exception if we were not interrupted + if not interrupted: + raise + finally: + self.cleanup() + + # reraise previous interruption now + if interrupted: + reraise(interrupted[0], interrupted[1], interrupted[2]) + + def _print_summary(self, tests): + self.logger.info("\nSUMMARY\n-------") + self.logger.info("passed: {}".format(self.passed)) + if self.unexpected_successes == 0: + self.logger.info("failed: {}".format(self.failed)) + else: + self.logger.info( + "failed: {0} (unexpected sucesses: {1})".format( + self.failed, self.unexpected_successes + ) + ) + if self.skipped == 0: + self.logger.info("todo: {}".format(self.todo)) + else: + self.logger.info("todo: {0} (skipped: {1})".format(self.todo, self.skipped)) + + if self.failed > 0: + self.logger.info("\nFAILED TESTS\n-------") + for failed_test in self.failures: + self.logger.info("{}".format(failed_test[0])) + + def start_fixture_servers(self): + root = self.server_root or os.path.join(os.path.dirname(here), "www") + if self.appName == "fennec": + return serve.start(root, host=moznetwork.get_ip()) + else: + return serve.start(root) + + def add_test(self, test, expected="pass", group="default"): + filepath = os.path.abspath(test) + + if os.path.isdir(filepath): + for root, dirs, files in os.walk(filepath): + for filename in files: + if filename.endswith(".toml"): + msg_tmpl = ( + "Ignoring manifest '{0}'; running all tests in '{1}'." + " See --help for details." + ) + relpath = os.path.relpath( + os.path.join(root, filename), filepath + ) + self.logger.warning(msg_tmpl.format(relpath, filepath)) + elif self._is_filename_valid(filename): + test_file = os.path.join(root, filename) + self.add_test(test_file) + return + + file_ext = os.path.splitext(os.path.split(filepath)[-1])[1] + + if file_ext == ".toml": + group = filepath + + manifest = TestManifest() + manifest.read(filepath) + + json_path = update_mozinfo(filepath) + mozinfo.update( + { + "appname": self.appName, + "manage_instance": self.marionette.instance is not None, + "headless": self.headless, + } + ) + self.logger.info("mozinfo updated from: {}".format(json_path)) + self.logger.info("mozinfo is: {}".format(mozinfo.info)) + + filters = [] + if self.test_tags: + filters.append(tags(self.test_tags)) + + manifest_tests = manifest.active_tests( + exists=False, disabled=True, filters=filters, **mozinfo.info + ) + if len(manifest_tests) == 0: + self.logger.error( + "No tests to run using specified " + "combination of filters: {}".format(manifest.fmt_filters()) + ) + + target_tests = [] + for test in manifest_tests: + if test.get("disabled"): + self.manifest_skipped_tests.append(test) + else: + target_tests.append(test) + + for i in target_tests: + if not os.path.exists(i["path"]): + raise IOError("test file: {} does not exist".format(i["path"])) + + self.add_test(i["path"], i["expected"], group=group) + return + + self.tests.append({"filepath": filepath, "expected": expected, "group": group}) + + def run_test(self, filepath, expected): + testloader = unittest.TestLoader() + suite = unittest.TestSuite() + self.test_kwargs["expected"] = expected + mod_name = os.path.splitext(os.path.split(filepath)[-1])[0] + for handler in self.test_handlers: + if handler.match(os.path.basename(filepath)): + handler.add_tests_to_suite( + mod_name, + filepath, + suite, + testloader, + self.marionette, + self.fixtures, + self.testvars, + **self.test_kwargs + ) + break + + if suite.countTestCases(): + runner = self.textrunnerclass( + logger=self.logger, + marionette=self.marionette, + capabilities=self.capabilities, + result_callbacks=self.result_callbacks, + ) + + results = runner.run(suite) + self.results.append(results) + + self.failed += len(results.failures) + len(results.errors) + if hasattr(results, "skipped"): + self.skipped += len(results.skipped) + self.todo += len(results.skipped) + self.passed += results.passed + for failure in results.failures + results.errors: + self.failures.append( + (results.getInfo(failure), failure.output, "TEST-UNEXPECTED-FAIL") + ) + if hasattr(results, "unexpectedSuccesses"): + self.failed += len(results.unexpectedSuccesses) + self.unexpected_successes += len(results.unexpectedSuccesses) + for failure in results.unexpectedSuccesses: + self.failures.append( + ( + results.getInfo(failure), + failure.output, + "TEST-UNEXPECTED-PASS", + ) + ) + if hasattr(results, "expectedFailures"): + self.todo += len(results.expectedFailures) + + self.mixin_run_tests = [] + for result in self.results: + result.result_modifiers = [] + + def run_test_set(self, tests): + if self.shuffle: + random.seed(self.shuffle_seed) + random.shuffle(tests) + + for test in tests: + self.run_test(test["filepath"], test["expected"]) + if self.record_crash(): + break + + def run_test_sets(self): + if len(self.tests) < 1: + raise Exception("There are no tests to run.") + elif self.total_chunks is not None and self.total_chunks > len(self.tests): + raise ValueError( + "Total number of chunks must be between 1 and {}.".format( + len(self.tests) + ) + ) + if self.total_chunks is not None and self.total_chunks > 1: + chunks = [[] for i in range(self.total_chunks)] + for i, test in enumerate(self.tests): + target_chunk = i % self.total_chunks + chunks[target_chunk].append(test) + + self.logger.info( + "Running chunk {0} of {1} ({2} tests selected from a " + "total of {3})".format( + self.this_chunk, + self.total_chunks, + len(chunks[self.this_chunk - 1]), + len(self.tests), + ) + ) + self.tests = chunks[self.this_chunk - 1] + + self.run_test_set(self.tests) + + def cleanup(self): + for proc in serve.iter_proc(self.fixture_servers): + proc.stop() + proc.kill() + self.fixture_servers = {} + + if hasattr(self, "marionette") and self.marionette: + if self.marionette.instance is not None: + if self.marionette.instance.runner.is_running(): + # Force a clean shutdown of the application process first if + # it is still running. If that fails, kill the process. + # Therefore a new session needs to be started. + self.marionette.start_session() + self.marionette.quit() + + self.marionette.instance.close(clean=True) + self.marionette.instance = None + + self.marionette.cleanup() + self.marionette = None + + __del__ = cleanup diff --git a/testing/marionette/harness/marionette_harness/runner/httpd.py b/testing/marionette/harness/marionette_harness/runner/httpd.py new file mode 100755 index 0000000000..8ffc85aeb0 --- /dev/null +++ b/testing/marionette/harness/marionette_harness/runner/httpd.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python + +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +"""Specialisation of wptserver.server.WebTestHttpd for testing +Marionette. + +""" + +import argparse +import os +import select +import sys +import time + +from six.moves.urllib import parse as urlparse +from wptserve import handlers, request, server +from wptserve import routes as default_routes + +root = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) +default_doc_root = os.path.join(root, "www") +default_ssl_cert = os.path.join(root, "certificates", "test.cert") +default_ssl_key = os.path.join(root, "certificates", "test.key") + + +@handlers.handler +def http_auth_handler(req, response): + # Allow the test to specify the username and password + params = dict(urlparse.parse_qsl(req.url_parts.query)) + username = params.get("username", "guest") + password = params.get("password", "guest") + + auth = request.Authentication(req.headers) + content = """<!doctype html> +<title>HTTP Authentication</title> +<p id="status">{}</p>""" + + if auth.username == username and auth.password == password: + response.status = 200 + response.content = content.format("success") + + else: + response.status = 401 + response.headers.set("WWW-Authenticate", 'Basic realm="secret"') + response.content = content.format("restricted") + + +@handlers.handler +def upload_handler(request, response): + return 200, [], [request.headers.get("Content-Type")] or [] + + +@handlers.handler +def slow_loading_handler(request, response): + # Allow the test specify the delay for delivering the content + params = dict(urlparse.parse_qsl(request.url_parts.query)) + delay = int(params.get("delay", 5)) + time.sleep(delay) + + # Do not allow the page to be cached to circumvent the bfcache of the browser + response.headers.set("Cache-Control", "no-cache, no-store") + response.content = """<!doctype html> +<meta charset="UTF-8"> +<title>Slow page loading</title> + +<p>Delay: <span id="delay">{}</span></p> +""".format( + delay + ) + + +@handlers.handler +def slow_coop_handler(request, response): + # Allow the test specify the delay for delivering the content + params = dict(urlparse.parse_qsl(request.url_parts.query)) + delay = int(params.get("delay", 5)) + time.sleep(delay) + + # Isolate the browsing context exclusively to same-origin documents + response.headers.set("Cross-Origin-Opener-Policy", "same-origin") + response.headers.set("Cache-Control", "no-cache, no-store") + response.content = """<!doctype html> +<meta charset="UTF-8"> +<title>Slow cross-origin page loading</title> + +<p>Delay: <span id="delay">{}</span></p> +""".format( + delay + ) + + +@handlers.handler +def update_xml_handler(request, response): + response.headers.set("Content-Type", "text/xml") + mar_digest = ( + "75cd68e6c98c84c435cd27e353f5b4f6a3f2c50f6802aa9bf62b47e47138757306769fd9befa08793635ee649" + "2319253480860b4aa8ed9ee1caaa4c83ebc90b9" + ) + response.content = """ + <updates> + <update type="minor" displayVersion="9999.0" appVersion="9999.0" platformVersion="9999.0" + buildID="20220627075547"> + <patch type="complete" URL="{}://{}/update/complete.mar" size="86612" + hashFunction="sha512" hashValue="{}"/> + </update> + </updates> + """.format( + request.url_parts.scheme, request.url_parts.netloc, mar_digest + ) + + +class NotAliveError(Exception): + """Occurs when attempting to run a function that requires the HTTPD + to have been started, and it has not. + + """ + + pass + + +class FixtureServer(object): + def __init__( + self, + doc_root, + url="http://127.0.0.1:0", + use_ssl=False, + ssl_cert=None, + ssl_key=None, + ): + if not os.path.isdir(doc_root): + raise ValueError("Server root is not a directory: %s" % doc_root) + + url = urlparse.urlparse(url) + if url.scheme is None: + raise ValueError("Server scheme not provided") + + scheme, host, port = url.scheme, url.hostname, url.port + if host is None: + host = "127.0.0.1" + if port is None: + port = 0 + + routes = [ + ("POST", "/file_upload", upload_handler), + ("GET", "/http_auth", http_auth_handler), + ("GET", "/slow", slow_loading_handler), + ("GET", "/slow-coop", slow_coop_handler), + ("GET", "/update.xml", update_xml_handler), + ] + routes.extend(default_routes.routes) + + self._httpd = server.WebTestHttpd( + host=host, + port=port, + bind_address=True, + doc_root=doc_root, + routes=routes, + use_ssl=True if scheme == "https" else False, + certificate=ssl_cert, + key_file=ssl_key, + ) + + def start(self): + if self.is_alive: + return + self._httpd.start() + + def wait(self): + if not self.is_alive: + return + try: + select.select([], [], []) + except KeyboardInterrupt: + self.stop() + + def stop(self): + if not self.is_alive: + return + self._httpd.stop() + + def get_url(self, path): + if not self.is_alive: + raise NotAliveError() + return self._httpd.get_url(path) + + @property + def doc_root(self): + return self._httpd.router.doc_root + + @property + def router(self): + return self._httpd.router + + @property + def routes(self): + return self._httpd.router.routes + + @property + def is_alive(self): + return self._httpd.started + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Specialised HTTP server for testing Marionette." + ) + parser.add_argument( + "url", + help=""" +service address including scheme, hostname, port, and prefix for document root, +e.g. \"https://0.0.0.0:0/base/\"""", + ) + parser.add_argument( + "-r", + dest="doc_root", + default=default_doc_root, + help="path to document root (default %(default)s)", + ) + parser.add_argument( + "-c", + dest="ssl_cert", + default=default_ssl_cert, + help="path to SSL certificate (default %(default)s)", + ) + parser.add_argument( + "-k", + dest="ssl_key", + default=default_ssl_key, + help="path to SSL certificate key (default %(default)s)", + ) + args = parser.parse_args() + + httpd = FixtureServer( + args.doc_root, args.url, ssl_cert=args.ssl_cert, ssl_key=args.ssl_key + ) + httpd.start() + print( + "{0}: started fixture server on {1}".format(sys.argv[0], httpd.get_url("/")), + file=sys.stderr, + ) + httpd.wait() diff --git a/testing/marionette/harness/marionette_harness/runner/mixins/__init__.py b/testing/marionette/harness/marionette_harness/runner/mixins/__init__.py new file mode 100644 index 0000000000..71b13461d5 --- /dev/null +++ b/testing/marionette/harness/marionette_harness/runner/mixins/__init__.py @@ -0,0 +1,5 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +from .window_manager import WindowManagerMixin diff --git a/testing/marionette/harness/marionette_harness/runner/mixins/window_manager.py b/testing/marionette/harness/marionette_harness/runner/mixins/window_manager.py new file mode 100644 index 0000000000..85729cc585 --- /dev/null +++ b/testing/marionette/harness/marionette_harness/runner/mixins/window_manager.py @@ -0,0 +1,210 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import sys + +from marionette_driver import Wait +from six import reraise + + +class WindowManagerMixin(object): + def setUp(self): + super(WindowManagerMixin, self).setUp() + + self.start_window = self.marionette.current_chrome_window_handle + self.start_windows = self.marionette.chrome_window_handles + + self.start_tab = self.marionette.current_window_handle + self.start_tabs = self.marionette.window_handles + + def tearDown(self): + if len(self.marionette.chrome_window_handles) > len(self.start_windows): + raise Exception("Not all windows as opened by the test have been closed") + + if len(self.marionette.window_handles) > len(self.start_tabs): + raise Exception("Not all tabs as opened by the test have been closed") + + super(WindowManagerMixin, self).tearDown() + + def close_all_tabs(self): + current_window_handles = self.marionette.window_handles + + # If the start tab is not present anymore, use the next one of the list + if self.start_tab not in current_window_handles: + self.start_tab = current_window_handles[0] + + current_window_handles.remove(self.start_tab) + for handle in current_window_handles: + self.marionette.switch_to_window(handle) + self.marionette.close() + + self.marionette.switch_to_window(self.start_tab) + + def close_all_windows(self): + current_chrome_window_handles = self.marionette.chrome_window_handles + + # If the start window is not present anymore, use the next one of the list + if self.start_window not in current_chrome_window_handles: + self.start_window = current_chrome_window_handles[0] + current_chrome_window_handles.remove(self.start_window) + + for handle in current_chrome_window_handles: + self.marionette.switch_to_window(handle) + self.marionette.close_chrome_window() + + self.marionette.switch_to_window(self.start_window) + + def open_tab(self, callback=None, focus=False): + current_tabs = self.marionette.window_handles + + try: + if callable(callback): + callback() + else: + result = self.marionette.open(type="tab", focus=focus) + if result["type"] != "tab": + raise Exception( + "Newly opened browsing context is of type {} and not tab.".format( + result["type"] + ) + ) + except Exception: + exc_cls, exc, tb = sys.exc_info() + reraise( + exc_cls, + exc_cls("Failed to trigger opening a new tab: {}".format(exc)), + tb, + ) + else: + Wait(self.marionette).until( + lambda mn: len(mn.window_handles) == len(current_tabs) + 1, + message="No new tab has been opened", + ) + + [new_tab] = list(set(self.marionette.window_handles) - set(current_tabs)) + + return new_tab + + def open_window(self, callback=None, focus=False, private=False): + current_windows = self.marionette.chrome_window_handles + current_tabs = self.marionette.window_handles + + def loaded(handle): + with self.marionette.using_context("chrome"): + return self.marionette.execute_script( + """ + const { windowManager } = ChromeUtils.importESModule( + "chrome://remote/content/shared/WindowManager.sys.mjs" + ); + const win = windowManager.findWindowByHandle(arguments[0]).win; + return win.document.readyState == "complete"; + """, + script_args=[handle], + ) + + try: + if callable(callback): + callback(focus) + else: + result = self.marionette.open( + type="window", focus=focus, private=private + ) + if result["type"] != "window": + raise Exception( + "Newly opened browsing context is of type {} and not window.".format( + result["type"] + ) + ) + except Exception: + exc_cls, exc, tb = sys.exc_info() + reraise( + exc_cls, + exc_cls("Failed to trigger opening a new window: {}".format(exc)), + tb, + ) + else: + Wait(self.marionette).until( + lambda mn: len(mn.chrome_window_handles) == len(current_windows) + 1, + message="No new window has been opened", + ) + + [new_window] = list( + set(self.marionette.chrome_window_handles) - set(current_windows) + ) + + # Before continuing ensure the window has been completed loading + Wait(self.marionette).until( + lambda _: loaded(new_window), + message="Window with handle '{}'' did not finish loading".format( + new_window + ), + ) + + # Bug 1507771 - Return the correct handle based on the currently selected context + # as long as "WebDriver:NewWindow" is not handled separtely in chrome context + context = self.marionette._send_message( + "Marionette:GetContext", key="value" + ) + if context == "chrome": + return new_window + elif context == "content": + [new_tab] = list( + set(self.marionette.window_handles) - set(current_tabs) + ) + return new_tab + + def open_chrome_window(self, url, focus=False): + """Open a new chrome window with the specified chrome URL. + + Can be replaced with "WebDriver:NewWindow" once the command + supports opening generic chrome windows beside browsers (bug 1507771). + """ + + def open_with_js(focus): + with self.marionette.using_context("chrome"): + self.marionette.execute_async_script( + """ + let [url, focus, resolve] = arguments; + + function waitForEvent(target, type, args) { + return new Promise(resolve => { + let params = Object.assign({once: true}, args); + target.addEventListener(type, event => { + dump(`** Received DOM event ${event.type} for ${event.target}\n`); + resolve(); + }, params); + }); + } + + function waitForFocus(win) { + return Promise.all([ + waitForEvent(win, "activate"), + waitForEvent(win, "focus", {capture: true}), + ]); + } + + (async function() { + // Open a window, wait for it to receive focus + let win = window.openDialog(url, null, "chrome,centerscreen"); + let focused = waitForFocus(win); + + win.focus(); + await focused; + + // The new window shouldn't get focused. As such set the + // focus back to the opening window. + if (!focus && Services.focus.activeWindow != window) { + let focused = waitForFocus(window); + window.focus(); + await focused; + } + + resolve(win.docShell.browsingContext.id); + })(); + """, + script_args=(url, focus), + ) + + with self.marionette.using_context("chrome"): + return self.open_window(callback=open_with_js, focus=focus) diff --git a/testing/marionette/harness/marionette_harness/runner/serve.py b/testing/marionette/harness/marionette_harness/runner/serve.py new file mode 100755 index 0000000000..3833bbe876 --- /dev/null +++ b/testing/marionette/harness/marionette_harness/runner/serve.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python + +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +"""Spawns necessary HTTP servers for testing Marionette in child +processes. + +""" + +import argparse +import multiprocessing +import os +import sys +from collections import defaultdict + +from six import iteritems + +from . import httpd + +__all__ = [ + "default_doc_root", + "iter_proc", + "iter_url", + "registered_servers", + "servers", + "start", + "where_is", +] +here = os.path.abspath(os.path.dirname(__file__)) + + +class BlockingChannel(object): + def __init__(self, channel): + self.chan = channel + self.lock = multiprocessing.Lock() + + def call(self, func, args=()): + self.send((func, args)) + return self.recv() + + def send(self, *args): + try: + self.lock.acquire() + self.chan.send(args) + finally: + self.lock.release() + + def recv(self): + try: + self.lock.acquire() + payload = self.chan.recv() + if isinstance(payload, tuple) and len(payload) == 1: + return payload[0] + return payload + except KeyboardInterrupt: + return ("stop", ()) + finally: + self.lock.release() + + +class ServerProxy(multiprocessing.Process, BlockingChannel): + def __init__(self, channel, init_func, *init_args, **init_kwargs): + multiprocessing.Process.__init__(self) + BlockingChannel.__init__(self, channel) + self.init_func = init_func + self.init_args = init_args + self.init_kwargs = init_kwargs + + def run(self): + try: + server = self.init_func(*self.init_args, **self.init_kwargs) + server.start() + self.send(("ok", ())) + + while True: + # ["func", ("arg", ...)] + # ["prop", ()] + sattr, fargs = self.recv() + attr = getattr(server, sattr) + + # apply fargs to attr if it is a function + if callable(attr): + rv = attr(*fargs) + + # otherwise attr is a property + else: + rv = attr + + self.send(rv) + + if sattr == "stop": + return + + except Exception as e: + self.send(("stop", e)) + + except KeyboardInterrupt: + server.stop() + + +class ServerProc(BlockingChannel): + def __init__(self, init_func): + self._init_func = init_func + self.proc = None + + parent_chan, self.child_chan = multiprocessing.Pipe() + BlockingChannel.__init__(self, parent_chan) + + def start(self, doc_root, ssl_config, **kwargs): + self.proc = ServerProxy( + self.child_chan, self._init_func, doc_root, ssl_config, **kwargs + ) + self.proc.daemon = True + self.proc.start() + + res, exc = self.recv() + if res == "stop": + raise exc + + def get_url(self, url): + return self.call("get_url", (url,)) + + @property + def doc_root(self): + return self.call("doc_root", ()) + + def stop(self): + self.call("stop") + if not self.is_alive: + return + self.proc.join() + + def kill(self): + if not self.is_alive: + return + self.proc.terminate() + self.proc.join(0) + + @property + def is_alive(self): + if self.proc is not None: + return self.proc.is_alive() + return False + + +def http_server(doc_root, ssl_config, host="127.0.0.1", **kwargs): + return httpd.FixtureServer(doc_root, url="http://{}:0/".format(host), **kwargs) + + +def https_server(doc_root, ssl_config, host="127.0.0.1", **kwargs): + return httpd.FixtureServer( + doc_root, + url="https://{}:0/".format(host), + ssl_key=ssl_config["key_path"], + ssl_cert=ssl_config["cert_path"], + **kwargs + ) + + +def start_servers(doc_root, ssl_config, **kwargs): + servers = defaultdict() + for schema, builder_fn in registered_servers: + proc = ServerProc(builder_fn) + proc.start(doc_root, ssl_config, **kwargs) + servers[schema] = (proc.get_url("/"), proc) + return servers + + +def start(doc_root=None, **kwargs): + """Start all relevant test servers. + + If no `doc_root` is given the default + testing/marionette/harness/marionette_harness/www directory will be used. + + Additional keyword arguments can be given which will be passed on + to the individual ``FixtureServer``'s in httpd.py. + + """ + doc_root = doc_root or default_doc_root + ssl_config = { + "cert_path": httpd.default_ssl_cert, + "key_path": httpd.default_ssl_key, + } + + global servers + servers = start_servers(doc_root, ssl_config, **kwargs) + return servers + + +def where_is(uri, on="http"): + """Returns the full URL, including scheme, hostname, and port, for + a fixture resource from the server associated with the ``on`` key. + It will by default look for the resource in the "http" server. + + """ + return servers.get(on)[1].get_url(uri) + + +def iter_proc(servers): + for _, (_, proc) in iteritems(servers): + yield proc + + +def iter_url(servers): + for _, (url, _) in iteritems(servers): + yield url + + +default_doc_root = os.path.join(os.path.dirname(here), "www") +registered_servers = [("http", http_server), ("https", https_server)] +servers = defaultdict() + + +def main(args): + global servers + + parser = argparse.ArgumentParser() + parser.add_argument( + "-r", dest="doc_root", help="Path to document root. Overrides default." + ) + args = parser.parse_args() + + servers = start(args.doc_root) + for url in iter_url(servers): + print("{}: listening on {}".format(sys.argv[0], url), file=sys.stderr) + + try: + while any(proc.is_alive for proc in iter_proc(servers)): + for proc in iter_proc(servers): + proc.proc.join(1) + except KeyboardInterrupt: + for proc in iter_proc(servers): + proc.kill() + + +if __name__ == "__main__": + main(sys.argv[1:]) |