summaryrefslogtreecommitdiffstats
path: root/testing/marionette/harness/marionette_harness/runner/serve.py
blob: 3833bbe876f48fae24f9d8edf3c4d06cc0f06c86 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
#!/usr/bin/env python

# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

"""Spawns necessary HTTP servers for testing Marionette in child
processes.

"""

import argparse
import multiprocessing
import os
import sys
from collections import defaultdict

from six import iteritems

from . import httpd

__all__ = [
    "default_doc_root",
    "iter_proc",
    "iter_url",
    "registered_servers",
    "servers",
    "start",
    "where_is",
]
here = os.path.abspath(os.path.dirname(__file__))


class BlockingChannel(object):
    def __init__(self, channel):
        self.chan = channel
        self.lock = multiprocessing.Lock()

    def call(self, func, args=()):
        self.send((func, args))
        return self.recv()

    def send(self, *args):
        try:
            self.lock.acquire()
            self.chan.send(args)
        finally:
            self.lock.release()

    def recv(self):
        try:
            self.lock.acquire()
            payload = self.chan.recv()
            if isinstance(payload, tuple) and len(payload) == 1:
                return payload[0]
            return payload
        except KeyboardInterrupt:
            return ("stop", ())
        finally:
            self.lock.release()


class ServerProxy(multiprocessing.Process, BlockingChannel):
    def __init__(self, channel, init_func, *init_args, **init_kwargs):
        multiprocessing.Process.__init__(self)
        BlockingChannel.__init__(self, channel)
        self.init_func = init_func
        self.init_args = init_args
        self.init_kwargs = init_kwargs

    def run(self):
        try:
            server = self.init_func(*self.init_args, **self.init_kwargs)
            server.start()
            self.send(("ok", ()))

            while True:
                # ["func", ("arg", ...)]
                # ["prop", ()]
                sattr, fargs = self.recv()
                attr = getattr(server, sattr)

                # apply fargs to attr if it is a function
                if callable(attr):
                    rv = attr(*fargs)

                # otherwise attr is a property
                else:
                    rv = attr

                self.send(rv)

                if sattr == "stop":
                    return

        except Exception as e:
            self.send(("stop", e))

        except KeyboardInterrupt:
            server.stop()


class ServerProc(BlockingChannel):
    def __init__(self, init_func):
        self._init_func = init_func
        self.proc = None

        parent_chan, self.child_chan = multiprocessing.Pipe()
        BlockingChannel.__init__(self, parent_chan)

    def start(self, doc_root, ssl_config, **kwargs):
        self.proc = ServerProxy(
            self.child_chan, self._init_func, doc_root, ssl_config, **kwargs
        )
        self.proc.daemon = True
        self.proc.start()

        res, exc = self.recv()
        if res == "stop":
            raise exc

    def get_url(self, url):
        return self.call("get_url", (url,))

    @property
    def doc_root(self):
        return self.call("doc_root", ())

    def stop(self):
        self.call("stop")
        if not self.is_alive:
            return
        self.proc.join()

    def kill(self):
        if not self.is_alive:
            return
        self.proc.terminate()
        self.proc.join(0)

    @property
    def is_alive(self):
        if self.proc is not None:
            return self.proc.is_alive()
        return False


def http_server(doc_root, ssl_config, host="127.0.0.1", **kwargs):
    return httpd.FixtureServer(doc_root, url="http://{}:0/".format(host), **kwargs)


def https_server(doc_root, ssl_config, host="127.0.0.1", **kwargs):
    return httpd.FixtureServer(
        doc_root,
        url="https://{}:0/".format(host),
        ssl_key=ssl_config["key_path"],
        ssl_cert=ssl_config["cert_path"],
        **kwargs
    )


def start_servers(doc_root, ssl_config, **kwargs):
    servers = defaultdict()
    for schema, builder_fn in registered_servers:
        proc = ServerProc(builder_fn)
        proc.start(doc_root, ssl_config, **kwargs)
        servers[schema] = (proc.get_url("/"), proc)
    return servers


def start(doc_root=None, **kwargs):
    """Start all relevant test servers.

    If no `doc_root` is given the default
    testing/marionette/harness/marionette_harness/www directory will be used.

    Additional keyword arguments can be given which will be passed on
    to the individual ``FixtureServer``'s in httpd.py.

    """
    doc_root = doc_root or default_doc_root
    ssl_config = {
        "cert_path": httpd.default_ssl_cert,
        "key_path": httpd.default_ssl_key,
    }

    global servers
    servers = start_servers(doc_root, ssl_config, **kwargs)
    return servers


def where_is(uri, on="http"):
    """Returns the full URL, including scheme, hostname, and port, for
    a fixture resource from the server associated with the ``on`` key.
    It will by default look for the resource in the "http" server.

    """
    return servers.get(on)[1].get_url(uri)


def iter_proc(servers):
    for _, (_, proc) in iteritems(servers):
        yield proc


def iter_url(servers):
    for _, (url, _) in iteritems(servers):
        yield url


default_doc_root = os.path.join(os.path.dirname(here), "www")
registered_servers = [("http", http_server), ("https", https_server)]
servers = defaultdict()


def main(args):
    global servers

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-r", dest="doc_root", help="Path to document root.  Overrides default."
    )
    args = parser.parse_args()

    servers = start(args.doc_root)
    for url in iter_url(servers):
        print("{}: listening on {}".format(sys.argv[0], url), file=sys.stderr)

    try:
        while any(proc.is_alive for proc in iter_proc(servers)):
            for proc in iter_proc(servers):
                proc.proc.join(1)
    except KeyboardInterrupt:
        for proc in iter_proc(servers):
            proc.kill()


if __name__ == "__main__":
    main(sys.argv[1:])