summaryrefslogtreecommitdiffstats
path: root/testing/web-platform/tests/tools/wptrunner/wptrunner/testloader.py
diff options
context:
space:
mode:
Diffstat (limited to 'testing/web-platform/tests/tools/wptrunner/wptrunner/testloader.py')
-rw-r--r--testing/web-platform/tests/tools/wptrunner/wptrunner/testloader.py575
1 files changed, 575 insertions, 0 deletions
diff --git a/testing/web-platform/tests/tools/wptrunner/wptrunner/testloader.py b/testing/web-platform/tests/tools/wptrunner/wptrunner/testloader.py
new file mode 100644
index 0000000000..950273c389
--- /dev/null
+++ b/testing/web-platform/tests/tools/wptrunner/wptrunner/testloader.py
@@ -0,0 +1,575 @@
+# mypy: allow-untyped-defs
+
+import abc
+import hashlib
+import itertools
+import json
+import os
+from urllib.parse import urlsplit
+from abc import ABCMeta, abstractmethod
+from queue import Empty
+from collections import defaultdict, deque, namedtuple
+from typing import Any, cast
+
+from . import manifestinclude
+from . import manifestexpected
+from . import manifestupdate
+from . import mpcontext
+from . import wpttest
+from mozlog import structured
+
+manifest = None
+manifest_update = None
+download_from_github = None
+
+
+def do_delayed_imports():
+ # This relies on an already loaded module having set the sys.path correctly :(
+ global manifest, manifest_update, download_from_github
+ from manifest import manifest # type: ignore
+ from manifest import update as manifest_update
+ from manifest.download import download_from_github # type: ignore
+
+
+class TestGroupsFile:
+ """
+ Mapping object representing {group name: [test ids]}
+ """
+
+ def __init__(self, logger, path):
+ try:
+ with open(path) as f:
+ self._data = json.load(f)
+ except ValueError:
+ logger.critical("test groups file %s not valid json" % path)
+ raise
+
+ self.group_by_test = {}
+ for group, test_ids in self._data.items():
+ for test_id in test_ids:
+ self.group_by_test[test_id] = group
+
+ def __contains__(self, key):
+ return key in self._data
+
+ def __getitem__(self, key):
+ return self._data[key]
+
+def read_include_from_file(file):
+ new_include = []
+ with open(file) as f:
+ for line in f:
+ line = line.strip()
+ # Allow whole-line comments;
+ # fragments mean we can't have partial line #-based comments
+ if len(line) > 0 and not line.startswith("#"):
+ new_include.append(line)
+ return new_include
+
+def update_include_for_groups(test_groups, include):
+ if include is None:
+ # We're just running everything
+ return
+
+ new_include = []
+ for item in include:
+ if item in test_groups:
+ new_include.extend(test_groups[item])
+ else:
+ new_include.append(item)
+ return new_include
+
+
+class TestChunker(abc.ABC):
+ def __init__(self, total_chunks: int, chunk_number: int, **kwargs: Any):
+ self.total_chunks = total_chunks
+ self.chunk_number = chunk_number
+ assert self.chunk_number <= self.total_chunks
+ self.logger = structured.get_default_logger()
+ assert self.logger
+ self.kwargs = kwargs
+
+ @abstractmethod
+ def __call__(self, manifest):
+ ...
+
+
+class Unchunked(TestChunker):
+ def __init__(self, *args, **kwargs):
+ TestChunker.__init__(self, *args, **kwargs)
+ assert self.total_chunks == 1
+
+ def __call__(self, manifest, **kwargs):
+ yield from manifest
+
+
+class HashChunker(TestChunker):
+ def __call__(self, manifest):
+ for test_type, test_path, tests in manifest:
+ tests_for_chunk = {
+ test for test in tests
+ if self._key_in_chunk(self.chunk_key(test_type, test_path, test))
+ }
+ if tests_for_chunk:
+ yield test_type, test_path, tests_for_chunk
+
+ def _key_in_chunk(self, key: str) -> bool:
+ chunk_index = self.chunk_number - 1
+ digest = hashlib.md5(key.encode()).hexdigest()
+ return int(digest, 16) % self.total_chunks == chunk_index
+
+ @abstractmethod
+ def chunk_key(self, test_type: str, test_path: str,
+ test: wpttest.Test) -> str:
+ ...
+
+
+class PathHashChunker(HashChunker):
+ def chunk_key(self, test_type: str, test_path: str,
+ test: wpttest.Test) -> str:
+ return test_path
+
+
+class IDHashChunker(HashChunker):
+ def chunk_key(self, test_type: str, test_path: str,
+ test: wpttest.Test) -> str:
+ return cast(str, test.id)
+
+
+class DirectoryHashChunker(HashChunker):
+ """Like HashChunker except the directory is hashed.
+
+ This ensures that all tests in the same directory end up in the same
+ chunk.
+ """
+ def chunk_key(self, test_type: str, test_path: str,
+ test: wpttest.Test) -> str:
+ depth = self.kwargs.get("depth")
+ if depth:
+ return os.path.sep.join(os.path.dirname(test_path).split(os.path.sep, depth)[:depth])
+ else:
+ return os.path.dirname(test_path)
+
+
+class TestFilter:
+ """Callable that restricts the set of tests in a given manifest according
+ to initial criteria"""
+ def __init__(self, test_manifests, include=None, exclude=None, manifest_path=None, explicit=False):
+ if manifest_path is None or include or explicit:
+ self.manifest = manifestinclude.IncludeManifest.create()
+ self.manifest.set_defaults()
+ else:
+ self.manifest = manifestinclude.get_manifest(manifest_path)
+
+ if include or explicit:
+ self.manifest.set("skip", "true")
+
+ if include:
+ for item in include:
+ self.manifest.add_include(test_manifests, item)
+
+ if exclude:
+ for item in exclude:
+ self.manifest.add_exclude(test_manifests, item)
+
+ def __call__(self, manifest_iter):
+ for test_type, test_path, tests in manifest_iter:
+ include_tests = set()
+ for test in tests:
+ if self.manifest.include(test):
+ include_tests.add(test)
+
+ if include_tests:
+ yield test_type, test_path, include_tests
+
+
+class TagFilter:
+ def __init__(self, tags):
+ self.tags = set(tags)
+
+ def __call__(self, test):
+ return test.tags & self.tags
+
+
+class ManifestLoader:
+ def __init__(self, test_paths, force_manifest_update=False, manifest_download=False,
+ types=None):
+ do_delayed_imports()
+ self.test_paths = test_paths
+ self.force_manifest_update = force_manifest_update
+ self.manifest_download = manifest_download
+ self.types = types
+ self.logger = structured.get_default_logger()
+ if self.logger is None:
+ self.logger = structured.structuredlog.StructuredLogger("ManifestLoader")
+
+ def load(self):
+ rv = {}
+ for url_base, paths in self.test_paths.items():
+ manifest_file = self.load_manifest(url_base=url_base,
+ **paths)
+ path_data = {"url_base": url_base}
+ path_data.update(paths)
+ rv[manifest_file] = path_data
+ return rv
+
+ def load_manifest(self, tests_path, manifest_path, metadata_path, url_base="/", **kwargs):
+ cache_root = os.path.join(metadata_path, ".cache")
+ if self.manifest_download:
+ download_from_github(manifest_path, tests_path)
+ return manifest.load_and_update(tests_path, manifest_path, url_base,
+ cache_root=cache_root, update=self.force_manifest_update,
+ types=self.types)
+
+
+def iterfilter(filters, iter):
+ for f in filters:
+ iter = f(iter)
+ yield from iter
+
+
+class TestLoader:
+ """Loads tests according to a WPT manifest and any associated expectation files"""
+ def __init__(self,
+ test_manifests,
+ test_types,
+ run_info,
+ manifest_filters=None,
+ test_filters=None,
+ chunk_type="none",
+ total_chunks=1,
+ chunk_number=1,
+ include_https=True,
+ include_h2=True,
+ include_webtransport_h3=False,
+ skip_timeout=False,
+ skip_implementation_status=None,
+ chunker_kwargs=None):
+
+ self.test_types = test_types
+ self.run_info = run_info
+
+ self.manifest_filters = manifest_filters if manifest_filters is not None else []
+ self.test_filters = test_filters if test_filters is not None else []
+
+ self.manifests = test_manifests
+ self.tests = None
+ self.disabled_tests = None
+ self.include_https = include_https
+ self.include_h2 = include_h2
+ self.include_webtransport_h3 = include_webtransport_h3
+ self.skip_timeout = skip_timeout
+ self.skip_implementation_status = skip_implementation_status
+
+ self.chunk_type = chunk_type
+ self.total_chunks = total_chunks
+ self.chunk_number = chunk_number
+
+ if chunker_kwargs is None:
+ chunker_kwargs = {}
+ self.chunker = {"none": Unchunked,
+ "hash": PathHashChunker,
+ "id_hash": IDHashChunker,
+ "dir_hash": DirectoryHashChunker}[chunk_type](total_chunks,
+ chunk_number,
+ **chunker_kwargs)
+
+ self._test_ids = None
+
+ self.directory_manifests = {}
+
+ self._load_tests()
+
+ @property
+ def test_ids(self):
+ if self._test_ids is None:
+ self._test_ids = []
+ for test_dict in [self.disabled_tests, self.tests]:
+ for test_type in self.test_types:
+ self._test_ids += [item.id for item in test_dict[test_type]]
+ return self._test_ids
+
+ def get_test(self, manifest_file, manifest_test, inherit_metadata, test_metadata):
+ if test_metadata is not None:
+ inherit_metadata.append(test_metadata)
+ test_metadata = test_metadata.get_test(manifestupdate.get_test_name(manifest_test.id))
+
+ return wpttest.from_manifest(manifest_file, manifest_test, inherit_metadata, test_metadata)
+
+ def load_dir_metadata(self, test_manifest, metadata_path, test_path):
+ rv = []
+ path_parts = os.path.dirname(test_path).split(os.path.sep)
+ for i in range(len(path_parts) + 1):
+ path = os.path.join(metadata_path, os.path.sep.join(path_parts[:i]), "__dir__.ini")
+ if path not in self.directory_manifests:
+ self.directory_manifests[path] = manifestexpected.get_dir_manifest(path,
+ self.run_info)
+ manifest = self.directory_manifests[path]
+ if manifest is not None:
+ rv.append(manifest)
+ return rv
+
+ def load_metadata(self, test_manifest, metadata_path, test_path):
+ inherit_metadata = self.load_dir_metadata(test_manifest, metadata_path, test_path)
+ test_metadata = manifestexpected.get_manifest(
+ metadata_path, test_path, self.run_info)
+ return inherit_metadata, test_metadata
+
+ def iter_tests(self):
+ manifest_items = []
+ manifests_by_url_base = {}
+
+ for manifest in sorted(self.manifests.keys(), key=lambda x:x.url_base):
+ manifest_iter = iterfilter(self.manifest_filters,
+ manifest.itertypes(*self.test_types))
+ manifest_items.extend(manifest_iter)
+ manifests_by_url_base[manifest.url_base] = manifest
+
+ if self.chunker is not None:
+ manifest_items = self.chunker(manifest_items)
+
+ for test_type, test_path, tests in manifest_items:
+ manifest_file = manifests_by_url_base[next(iter(tests)).url_base]
+ metadata_path = self.manifests[manifest_file]["metadata_path"]
+
+ inherit_metadata, test_metadata = self.load_metadata(manifest_file, metadata_path, test_path)
+ for test in tests:
+ wpt_test = self.get_test(manifest_file, test, inherit_metadata, test_metadata)
+ if all(f(wpt_test) for f in self.test_filters):
+ yield test_path, test_type, wpt_test
+
+ def _load_tests(self):
+ """Read in the tests from the manifest file and add them to a queue"""
+ tests = {"enabled":defaultdict(list),
+ "disabled":defaultdict(list)}
+
+ for test_path, test_type, test in self.iter_tests():
+ enabled = not test.disabled()
+ if not self.include_https and test.environment["protocol"] == "https":
+ enabled = False
+ if not self.include_h2 and test.environment["protocol"] == "h2":
+ enabled = False
+ if self.skip_timeout and test.expected() == "TIMEOUT":
+ enabled = False
+ if self.skip_implementation_status and test.implementation_status() in self.skip_implementation_status:
+ enabled = False
+ key = "enabled" if enabled else "disabled"
+ tests[key][test_type].append(test)
+
+ self.tests = tests["enabled"]
+ self.disabled_tests = tests["disabled"]
+
+ def groups(self, test_types, chunk_type="none", total_chunks=1, chunk_number=1):
+ groups = set()
+
+ for test_type in test_types:
+ for test in self.tests[test_type]:
+ group = test.url.split("/")[1]
+ groups.add(group)
+
+ return groups
+
+
+def get_test_src(**kwargs):
+ test_source_kwargs = {"processes": kwargs["processes"],
+ "logger": kwargs["logger"]}
+ chunker_kwargs = {}
+ if kwargs["run_by_dir"] is not False:
+ # A value of None indicates infinite depth
+ test_source_cls = PathGroupedSource
+ test_source_kwargs["depth"] = kwargs["run_by_dir"]
+ chunker_kwargs["depth"] = kwargs["run_by_dir"]
+ elif kwargs["test_groups"]:
+ test_source_cls = GroupFileTestSource
+ test_source_kwargs["test_groups"] = kwargs["test_groups"]
+ else:
+ test_source_cls = SingleTestSource
+ return test_source_cls, test_source_kwargs, chunker_kwargs
+
+
+TestGroup = namedtuple("TestGroup", ["group", "test_type", "metadata"])
+
+
+class TestSource:
+ __metaclass__ = ABCMeta
+
+ def __init__(self, test_queue):
+ self.test_queue = test_queue
+ self.current_group = TestGroup(None, None, None)
+ self.logger = structured.get_default_logger()
+ if self.logger is None:
+ self.logger = structured.structuredlog.StructuredLogger("TestSource")
+
+ @abstractmethod
+ #@classmethod (doesn't compose with @abstractmethod in < 3.3)
+ def make_queue(cls, tests_by_type, **kwargs): # noqa: N805
+ pass
+
+ @abstractmethod
+ def tests_by_group(cls, tests_by_type, **kwargs): # noqa: N805
+ pass
+
+ @classmethod
+ def group_metadata(cls, state):
+ return {"scope": "/"}
+
+ def group(self):
+ if not self.current_group.group or len(self.current_group.group) == 0:
+ try:
+ self.current_group = self.test_queue.get(block=True, timeout=5)
+ except Empty:
+ self.logger.warning("Timed out getting test group from queue")
+ return TestGroup(None, None, None)
+ return self.current_group
+
+ @classmethod
+ def add_sentinal(cls, test_queue, num_of_workers):
+ # add one sentinal for each worker
+ for _ in range(num_of_workers):
+ test_queue.put(TestGroup(None, None, None))
+
+ @classmethod
+ def process_count(cls, requested_processes, num_test_groups):
+ """Get the number of processes to use.
+
+ This must always be at least one, but otherwise not more than the number of test groups"""
+ return max(1, min(requested_processes, num_test_groups))
+
+
+class GroupedSource(TestSource):
+ @classmethod
+ def new_group(cls, state, test_type, test, **kwargs):
+ raise NotImplementedError
+
+ @classmethod
+ def make_queue(cls, tests_by_type, **kwargs):
+ mp = mpcontext.get_context()
+ test_queue = mp.Queue()
+ groups = []
+
+ state = {}
+
+ for test_type, tests in tests_by_type.items():
+ for test in tests:
+ if cls.new_group(state, test_type, test, **kwargs):
+ group_metadata = cls.group_metadata(state)
+ groups.append(TestGroup(deque(), test_type, group_metadata))
+
+ group, _, metadata = groups[-1]
+ group.append(test)
+ test.update_metadata(metadata)
+
+ for item in groups:
+ test_queue.put(item)
+
+ processes = cls.process_count(kwargs["processes"], len(groups))
+ cls.add_sentinal(test_queue, processes)
+ return test_queue, processes
+
+ @classmethod
+ def tests_by_group(cls, tests_by_type, **kwargs):
+ groups = defaultdict(list)
+ state = {}
+ current = None
+ for test_type, tests in tests_by_type.items():
+ for test in tests:
+ if cls.new_group(state, test_type, test, **kwargs):
+ current = cls.group_metadata(state)['scope']
+ groups[current].append(test.id)
+ return groups
+
+
+class SingleTestSource(TestSource):
+ @classmethod
+ def make_queue(cls, tests_by_type, **kwargs):
+ mp = mpcontext.get_context()
+ test_queue = mp.Queue()
+ num_test_groups = 0
+ for test_type, tests in tests_by_type.items():
+ processes = kwargs["processes"]
+ queues = [deque([]) for _ in range(processes)]
+ metadatas = [cls.group_metadata(None) for _ in range(processes)]
+ for test in tests:
+ idx = hash(test.id) % processes
+ group = queues[idx]
+ metadata = metadatas[idx]
+ group.append(test)
+ test.update_metadata(metadata)
+
+ for item in zip(queues, itertools.repeat(test_type), metadatas):
+ if len(item[0]) > 0:
+ test_queue.put(TestGroup(*item))
+ num_test_groups += 1
+
+ processes = cls.process_count(kwargs["processes"], num_test_groups)
+ cls.add_sentinal(test_queue, processes)
+ return test_queue, processes
+
+ @classmethod
+ def tests_by_group(cls, tests_by_type, **kwargs):
+ return {cls.group_metadata(None)['scope']:
+ [t.id for t in itertools.chain.from_iterable(tests_by_type.values())]}
+
+
+class PathGroupedSource(GroupedSource):
+ @classmethod
+ def new_group(cls, state, test_type, test, **kwargs):
+ depth = kwargs.get("depth")
+ if depth is True or depth == 0:
+ depth = None
+ path = urlsplit(test.url).path.split("/")[1:-1][:depth]
+ rv = (test_type != state.get("prev_test_type") or
+ path != state.get("prev_path"))
+ state["prev_test_type"] = test_type
+ state["prev_path"] = path
+ return rv
+
+ @classmethod
+ def group_metadata(cls, state):
+ return {"scope": "/%s" % "/".join(state["prev_path"])}
+
+
+class GroupFileTestSource(TestSource):
+ @classmethod
+ def make_queue(cls, tests_by_type, **kwargs):
+ mp = mpcontext.get_context()
+ test_queue = mp.Queue()
+ num_test_groups = 0
+
+ for test_type, tests in tests_by_type.items():
+ tests_by_group = cls.tests_by_group({test_type: tests},
+ **kwargs)
+
+ ids_to_tests = {test.id: test for test in tests}
+
+ for group_name, test_ids in tests_by_group.items():
+ group_metadata = {"scope": group_name}
+ group = deque()
+
+ for test_id in test_ids:
+ test = ids_to_tests[test_id]
+ group.append(test)
+ test.update_metadata(group_metadata)
+
+ test_queue.put(TestGroup(group, test_type, group_metadata))
+ num_test_groups += 1
+
+ processes = cls.process_count(kwargs["processes"], num_test_groups)
+ cls.add_sentinal(test_queue, processes)
+ return test_queue, processes
+
+ @classmethod
+ def tests_by_group(cls, tests_by_type, **kwargs):
+ logger = kwargs["logger"]
+ test_groups = kwargs["test_groups"]
+
+ tests_by_group = defaultdict(list)
+ for test in itertools.chain.from_iterable(tests_by_type.values()):
+ try:
+ group = test_groups.group_by_test[test.id]
+ except KeyError:
+ logger.error("%s is missing from test groups file" % test.id)
+ raise
+ tests_by_group[group].append(test.id)
+
+ return tests_by_group