summaryrefslogtreecommitdiffstats
path: root/tests/utils.py
blob: 5636a138effb646a5e433483e400ee84888d19a8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from __future__ import annotations

__all__ = ('http_server',)

import socket
from contextlib import contextmanager
from http.server import ThreadingHTTPServer
from pathlib import Path
from ssl import PROTOCOL_TLS_SERVER, SSLContext
from threading import Thread
from typing import TYPE_CHECKING
from urllib.parse import urlparse

if TYPE_CHECKING:
    from collections.abc import Iterator
    from http.server import HTTPServer
    from socketserver import BaseRequestHandler
    from typing import Final

    from sphinx.application import Sphinx

# Generated with:
# $ openssl req -new -x509 -days 3650 -nodes -out cert.pem \
#     -keyout cert.pem -addext "subjectAltName = DNS:localhost"
TESTS_ROOT: Final[Path] = Path(__file__).parent
CERT_FILE: Final[str] = str(TESTS_ROOT / 'certs' / 'cert.pem')


class HttpServerThread(Thread):
    def __init__(self, handler: type[BaseRequestHandler], *, port: int = 0) -> None:
        """
        Constructs a threaded HTTP server.  The default port number of ``0``
        delegates selection of a port number to bind to to Python.

        Ref: https://docs.python.org/3.11/library/socketserver.html#asynchronous-mixins
        """
        super().__init__(daemon=True)
        self.server = ThreadingHTTPServer(('localhost', port), handler)

    def run(self) -> None:
        self.server.serve_forever(poll_interval=0.001)

    def terminate(self) -> None:
        self.server.shutdown()
        self.server.server_close()
        self.join()


class HttpsServerThread(HttpServerThread):
    def __init__(self, handler: type[BaseRequestHandler], *, port: int = 0) -> None:
        super().__init__(handler, port=port)
        sslcontext = SSLContext(PROTOCOL_TLS_SERVER)
        sslcontext.load_cert_chain(CERT_FILE)
        self.server.socket = sslcontext.wrap_socket(self.server.socket, server_side=True)


@contextmanager
def http_server(
    handler: type[BaseRequestHandler],
    *,
    tls_enabled: bool = False,
    port: int = 0,
) -> Iterator[HTTPServer]:
    server_cls = HttpsServerThread if tls_enabled else HttpServerThread
    server_thread = server_cls(handler, port=port)
    server_thread.start()
    server_port = server_thread.server.server_port
    assert port == 0 or server_port == port
    try:
        socket.create_connection(('localhost', server_port), timeout=0.5).close()
        yield server_thread.server  # Connection has been confirmed possible; proceed.
    finally:
        server_thread.terminate()


@contextmanager
def rewrite_hyperlinks(app: Sphinx, server: HTTPServer) -> Iterator[None]:
    """
    Rewrite hyperlinks that refer to network location 'localhost:7777',
    allowing that location to vary dynamically with the arbitrary test HTTP
    server port assigned during unit testing.

    :param app: The Sphinx application where link replacement is to occur.
    :param server: Destination server to redirect the hyperlinks to.
    """
    match_netloc, replacement_netloc = (
        'localhost:7777',
        f'localhost:{server.server_port}',
    )

    def rewrite_hyperlink(_app: Sphinx, uri: str) -> str | None:
        parsed_uri = urlparse(uri)
        if parsed_uri.netloc != match_netloc:
            return uri
        return parsed_uri._replace(netloc=replacement_netloc).geturl()

    listener_id = app.connect('linkcheck-process-uri', rewrite_hyperlink)
    yield
    app.disconnect(listener_id)


@contextmanager
def serve_application(
    app: Sphinx,
    handler: type[BaseRequestHandler],
    *,
    tls_enabled: bool = False,
    port: int = 0,
) -> Iterator[str]:
    """
    Prepare a temporary server to handle HTTP requests related to the links
    found in a Sphinx application project.

    :param app: The Sphinx application.
    :param handler: Determines how each request will be handled.
    :param tls_enabled: Whether TLS (SSL) should be enabled for the server.
    :param port: Optional server port (default: auto).

    :return: The address of the temporary HTTP server.
    """
    with (
        http_server(handler, tls_enabled=tls_enabled, port=port) as server,
        rewrite_hyperlinks(app, server),
    ):
        yield f'localhost:{server.server_port}'